Forráskód Böngészése

Begin support for newtype structs.

Kestrel 2 éve
szülő
commit
5b2f038efa

+ 24 - 0
Cargo.lock

@@ -126,6 +126,12 @@ dependencies = [
  "hashbrown",
 ]
 
+[[package]]
+name = "itoa"
+version = "1.0.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "1aab8fc367588b89dcee83ab0fd66b72b50b72fa1904d7095045ace2b0c81c35"
+
 [[package]]
 name = "libc"
 version = "0.2.125"
@@ -157,6 +163,7 @@ dependencies = [
  "rusqlite",
  "serde",
  "serde_bytes",
+ "serde_json",
  "sha2",
 ]
 
@@ -215,6 +222,12 @@ dependencies = [
  "smallvec",
 ]
 
+[[package]]
+name = "ryu"
+version = "1.0.9"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "73b4b750c782965c211b42f022f59af1fbceabdd026623714f104152f1ec149f"
+
 [[package]]
 name = "serde"
 version = "1.0.137"
@@ -244,6 +257,17 @@ dependencies = [
  "syn",
 ]
 
+[[package]]
+name = "serde_json"
+version = "1.0.81"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "9b7ce2b32a1aed03c558dc61a5cd328f15aff2dbc17daad8fb8af04d2100e15c"
+dependencies = [
+ "itoa",
+ "ryu",
+ "serde",
+]
+
 [[package]]
 name = "sha2"
 version = "0.10.2"

+ 41 - 19
microrm-macros/src/lib.rs

@@ -4,6 +4,20 @@ use quote::{quote,format_ident};
 
 use convert_case::{Case, Casing};
 
+fn parse_microrm_ref(attrs: &Vec<syn::Attribute>) -> proc_macro2::TokenStream {
+    for attr in attrs {
+        if attr.path.segments.len() == 0 { continue }
+
+        if attr.tokens.is_empty() {
+            if attr.path.segments.last().unwrap().ident == "microrm_internal" {
+                return quote!{ crate }.into()
+            }
+        }
+    }
+    
+    quote!{ ::microrm }
+}
+
 /// Turns a serializable/deserializable struct into a microrm entity model.
 ///
 /// There are two important visible effects:
@@ -15,30 +29,16 @@ use convert_case::{Case, Casing};
 /// and a struct field named `field_name` is given a variant name of `FieldName`.
 ///
 /// The `#[microrm...]` attributes can be used to control the derivation somewhat.
-/// The following are understood:
+/// The following are understood for the Entity struct:
 /// - `#[microrm_internal]`: this is internal to the microrm crate (of extremely limited usefulness
 /// outside the microrm library)
-#[proc_macro_derive(Entity, attributes(microrm_internal))]
+/// The following are understood on individual fields
+/// - `#[microrm_foreign]`: this is a foreign key (and must be an type implementing `EntityID`)
+#[proc_macro_derive(Entity, attributes(microrm_internal,microrm_foreign))]
 pub fn derive_entity(tokens: TokenStream) -> TokenStream {
     let input = parse_macro_input!(tokens as DeriveInput);
 
-    let mut microrm_ref = quote!{ ::microrm };
-
-    // parse attributes
-    for attr in input.attrs {
-        if attr.path.segments.len() == 0 { continue }
-
-        if attr.tokens.is_empty() {
-            if attr.path.segments.last().unwrap().ident == "microrm_internal" {
-                microrm_ref = quote!{ crate };
-            }
-        }
-        else {
-            let body : Result<syn::Expr,_> = syn::parse2(attr.tokens);
-            if body.is_err() { continue }
-        }
-    }
-
+    let mut microrm_ref = parse_microrm_ref(&input.attrs);
 
     let struct_name = &input.ident;
     let enum_name = format_ident!("{}Columns", &input.ident);
@@ -55,6 +55,10 @@ pub fn derive_entity(tokens: TokenStream) -> TokenStream {
         _ => panic!("Can only use derive(Entity) on non-unit structs with named fields!")
     };
 
+    for name in fields.named.iter() {
+        // println!("ty: {:?}", name.ty);
+    }
+
     let mut variants = syn::punctuated::Punctuated::<syn::Ident, syn::token::Comma>::new();
     let mut field_names = syn::punctuated::Punctuated::<proc_macro2::TokenStream, syn::token::Comma>::new();
     let mut value_references = syn::punctuated::Punctuated::<proc_macro2::TokenStream, syn::token::Comma>::new();
@@ -126,3 +130,21 @@ pub fn derive_entity(tokens: TokenStream) -> TokenStream {
 
     ret
 }
+
+#[proc_macro_derive(Modelable, attributes(microrm_internal))]
+pub fn derive_modelable(tokens: TokenStream) -> TokenStream {
+    let input = parse_macro_input!(tokens as DeriveInput);
+
+    let mut microrm_ref = parse_microrm_ref(&input.attrs);
+
+    let ident = input.ident;
+
+    quote!{
+        impl #microrm_ref::re_export::rusqlite::ToSql for #ident {
+            fn to_sql(&self) -> #microrm_ref::re_export::rusqlite::Result<#microrm_ref::re_export::rusqlite::types::ToSqlOutput<'_>> {
+                use #microrm_ref::re_export::rusqlite::types::{ToSqlOutput,Value};
+                Ok(ToSqlOutput::Owned(Value::Text(#microrm_ref::re_export::serde_json::to_string(self).expect("can be serialized"))))
+            }
+        }
+    }.into()
+}

+ 1 - 0
microrm/Cargo.toml

@@ -11,5 +11,6 @@ sha2 = "0.10"
 rusqlite = "0.27"
 serde = { version = "1.0", features = ["derive"] }
 serde_bytes = { version = "0.11.6" }
+serde_json = { version = "1.0" }
 
 microrm-macros = { path = "../microrm-macros" }

+ 2 - 1
microrm/src/lib.rs

@@ -44,13 +44,14 @@ pub mod model;
 pub mod query;
 mod meta;
 
-pub use microrm_macros::Entity;
+pub use microrm_macros::{Entity,Modelable};
 
 // no need to show the re-exports in the documentation
 #[doc(hidden)]
 pub mod re_export {
     pub use rusqlite;
     pub use serde;
+    pub use serde_json;
 }
 
 #[derive(Debug)]

+ 1 - 0
microrm/src/model.rs

@@ -7,6 +7,7 @@ pub enum ModelError {
     DBError(rusqlite::Error),
     LoadError(String),
     EmptyStoreError,
+    CreateError
 }
 
 impl From<rusqlite::Error> for ModelError {

+ 129 - 18
microrm/src/model/create.rs

@@ -1,19 +1,31 @@
 use serde::de::Visitor;
 
+use std::rc::Rc;
+use std::cell::Cell;
+
 #[derive(Debug)]
 pub struct CreateDeserializer<'de> {
-    column_names: Vec<String>,
-    column_types: Vec<String>,
-    column_name_stack: Vec<String>,
+    struct_visited: bool,
+    column_names: Vec<&'static str>,
+    column_types: Vec<&'static str>,
+    column_name_stack: Vec<&'static str>,
+    expected_length: Rc<Cell<usize>>,
     _de: std::marker::PhantomData<&'de u8>,
 }
 
+impl<'de> CreateDeserializer<'de> {
+    fn integral_type(&mut self) {
+        self.column_types.push("integer");
+        self.column_names.push(self.column_name_stack.pop().unwrap());
+    }
+}
+
 impl<'de, 'a> serde::de::Deserializer<'de> for &'a mut CreateDeserializer<'de> {
     type Error = super::ModelError;
 
     // we (ab)use the forward_to_deserialize_any! macro to stub out the types we don't care about
     serde::forward_to_deserialize_any! {
-        bool i8 i16 i128 u8 u16 u32 u64 u128 f32 f64 char str
+        bool i128 u64 u128 f32 f64 char str
         option unit unit_struct tuple
         tuple_struct map enum identifier ignored_any
     }
@@ -22,32 +34,55 @@ impl<'de, 'a> serde::de::Deserializer<'de> for &'a mut CreateDeserializer<'de> {
         todo!()
     }
 
+    fn deserialize_u8<V: Visitor<'de>>(self, v: V) -> Result<V::Value, Self::Error> {
+        self.integral_type();
+        v.visit_u8(0)
+    }
+
+    fn deserialize_u16<V: Visitor<'de>>(self, v: V) -> Result<V::Value, Self::Error> {
+        self.integral_type();
+        v.visit_u16(0)
+    }
+
+    fn deserialize_u32<V: Visitor<'de>>(self, v: V) -> Result<V::Value, Self::Error> {
+        self.integral_type();
+        v.visit_u32(0)
+    }
+
+    fn deserialize_i8<V: Visitor<'de>>(self, v: V) -> Result<V::Value, Self::Error> {
+        self.integral_type();
+        v.visit_i8(0)
+    }
+
+    fn deserialize_i16<V: Visitor<'de>>(self, v: V) -> Result<V::Value, Self::Error> {
+        self.integral_type();
+        v.visit_i16(0)
+    }
+
     fn deserialize_i32<V: Visitor<'de>>(self, v: V) -> Result<V::Value, Self::Error> {
-        self.column_types.push("integer".to_owned());
-        self.column_names.push(self.column_name_stack.pop().unwrap());
+        self.integral_type();
         v.visit_i32(0)
     }
 
     fn deserialize_i64<V: Visitor<'de>>(self, v: V) -> Result<V::Value, Self::Error> {
-        self.column_types.push("integer".to_owned());
-        self.column_names.push(self.column_name_stack.pop().unwrap());
+        self.integral_type();
         v.visit_i64(0)
     }
 
     fn deserialize_string<V: Visitor<'de>>(self, v: V) -> Result<V::Value, Self::Error> {
-        self.column_types.push("varchar".to_owned());
+        self.column_types.push("text");
         self.column_names.push(self.column_name_stack.pop().unwrap());
         v.visit_string("".to_owned())
     }
 
     fn deserialize_bytes<V: Visitor<'de>>(self, v: V) -> Result<V::Value, Self::Error> {
-        self.column_types.push("blob".to_owned());
+        self.column_types.push("blob");
         self.column_names.push(self.column_name_stack.pop().unwrap());
         v.visit_bytes(&[])
     }
 
     fn deserialize_byte_buf<V: Visitor<'de>>(self, v: V) -> Result<V::Value, Self::Error> {
-        self.column_types.push("blob".to_owned());
+        self.column_types.push("blob");
         self.column_names.push(self.column_name_stack.pop().unwrap());
         v.visit_bytes(&[])
     }
@@ -60,15 +95,22 @@ impl<'de, 'a> serde::de::Deserializer<'de> for &'a mut CreateDeserializer<'de> {
 
     fn deserialize_struct<V: Visitor<'de>>(
         self,
-        name: &'static str,
+        _name: &'static str,
         fields: &'static [&'static str],
         v: V,
     ) -> Result<V::Value, Self::Error> {
-        // we may not have a prefix if this is the root struct
-        // but if we do, it means we're serializing a sub-structure
-        let prefix = self.column_name_stack.pop().map(|x| x + "_").unwrap_or("".to_string());
-        self.column_name_stack.extend(fields.iter().map(|x| prefix.clone() + x).rev());
-        v.visit_seq(self)
+        if self.struct_visited {
+            let elength = self.expected_length.clone();
+            let old_elength = elength.get();
+            println!("nested deserialize_struct invoked!");
+            todo!();
+        }
+        else {
+            self.column_name_stack.extend(fields.iter().rev());
+            self.expected_length.set(fields.len());
+            let ret = v.visit_seq(self);
+            ret
+        }
     }
 
     fn deserialize_newtype_struct<V: Visitor<'de>>(
@@ -76,7 +118,14 @@ impl<'de, 'a> serde::de::Deserializer<'de> for &'a mut CreateDeserializer<'de> {
         name: &'static str,
         v: V
     ) -> Result<V::Value, Self::Error> {
-        unreachable!("microrm cannot store newtype structs")
+        let elength = self.expected_length.clone();
+        let old_elength = elength.get();
+        elength.set(1);
+        let ret = v.visit_seq(self);
+
+        elength.set(old_elength);
+
+        ret
     }
 }
 
@@ -87,15 +136,26 @@ impl<'de> serde::de::SeqAccess<'de> for CreateDeserializer<'de> {
         &mut self,
         seed: T,
     ) -> Result<Option<T::Value>, Self::Error> {
+        if self.expected_length.get() == 0 {
+            return Err(Self::Error::CreateError)
+        }
+
+        self.expected_length.set(self.expected_length.get() - 1);
+
         seed.deserialize(self).map(Some)
     }
 }
 
 pub fn sql_for<T: crate::model::Entity>() -> (String, String) {
+
+    let mut elength = Rc::new(Cell::new(0));
+
     let mut cd = CreateDeserializer {
+        struct_visited: false,
         column_names: Vec::new(),
         column_types: Vec::new(),
         column_name_stack: Vec::new(),
+        expected_length: elength,
         _de: std::marker::PhantomData {},
     };
 
@@ -121,6 +181,8 @@ pub fn sql_for<T: crate::model::Entity>() -> (String, String) {
 
 #[cfg(test)]
 mod test {
+    use serde::Deserialize;
+
     #[derive(serde::Serialize, serde::Deserialize, crate::Entity)]
     #[microrm_internal]
     pub struct Empty {}
@@ -131,6 +193,12 @@ mod test {
         e: i32,
     }
 
+    #[derive(serde::Serialize, serde::Deserialize, crate::Entity)]
+    #[microrm_internal]
+    pub struct Reference {
+        e: SingleID,
+    }
+
     #[test]
     fn example_sql_for() {
         assert_eq!(
@@ -147,5 +215,48 @@ mod test {
                 r#"CREATE TABLE "single" ("e" integer)"#.to_owned()
             )
         );
+
+        assert_eq!(
+            super::sql_for::<Reference>(),
+            (
+                r#"DROP TABLE IF EXISTS "reference""#.to_owned(),
+                r#"CREATE TABLE "reference" ("e" integer)"#.to_owned()
+            )
+        );
+    }
+
+    #[derive(serde::Serialize, serde::Deserialize, crate::Modelable)]
+    #[microrm_internal]
+    pub struct Unit(u8);
+    #[derive(serde::Serialize, serde::Deserialize, crate::Entity)]
+    #[microrm_internal]
+    pub struct UnitNewtype {
+        newtype: Unit,
+    }
+
+    #[test]
+    fn unit_newtype_struct() {
+        assert_eq!(
+            super::sql_for::<UnitNewtype>(),
+            (
+                r#"DROP TABLE IF EXISTS "unit_newtype""#.to_owned(),
+                r#"CREATE TABLE "unit_newtype" ("newtype" integer)"#.to_owned()
+            )
+        );
+    }
+
+    #[derive(serde::Serialize, serde::Deserialize, crate::Modelable)]
+    #[microrm_internal]
+    pub struct NonUnit(u8,u8);
+    #[derive(serde::Serialize, serde::Deserialize, crate::Entity)]
+    #[microrm_internal]
+    pub struct NonUnitNewtype {
+        newtype: NonUnit,
+    }
+
+    #[test]
+    #[should_panic]
+    fn nonunit_newtype_struct() {
+        super::sql_for::<NonUnitNewtype>();
     }
 }

+ 3 - 1
microrm/src/model/store.rs

@@ -1,4 +1,6 @@
-pub fn serialize_as_row<T: crate::model::Entity>(
+use super::{Entity};
+
+pub fn serialize_as_row<T: Entity>(
     value: &T,
 ) -> Vec<&dyn rusqlite::ToSql> {
     value.values()