Browse Source

Began moving column type derivation to be more compile-time.

Kestrel 2 years ago
parent
commit
76753a91ac
7 changed files with 101 additions and 179 deletions
  1. 7 0
      Cargo.lock
  2. 28 2
      microrm-macros/src/lib.rs
  3. 1 0
      microrm/Cargo.toml
  4. 1 0
      microrm/src/lib.rs
  5. 6 0
      microrm/src/model.rs
  6. 28 175
      microrm/src/model/create.rs
  7. 30 2
      microrm/src/model/modelable.rs

+ 7 - 0
Cargo.lock

@@ -80,6 +80,12 @@ version = "1.0.1"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "1aab8fc367588b89dcee83ab0fd66b72b50b72fa1904d7095045ace2b0c81c35"
 
+[[package]]
+name = "lazy_static"
+version = "1.4.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646"
+
 [[package]]
 name = "libc"
 version = "0.2.125"
@@ -91,6 +97,7 @@ name = "microrm"
 version = "0.1.2"
 dependencies = [
  "base64",
+ "lazy_static",
  "microrm-macros",
  "serde",
  "serde_bytes",

+ 28 - 2
microrm-macros/src/lib.rs

@@ -72,6 +72,7 @@ pub fn derive_entity(tokens: TokenStream) -> TokenStream {
     let mut variants = Vec::new();
     let mut field_names = Vec::new();
     let mut field_numbers = Vec::new();
+    let mut field_types = Vec::new();
     let mut value_references = Vec::new();
 
     let mut foreign_keys = Vec::new();
@@ -90,6 +91,9 @@ pub fn derive_entity(tokens: TokenStream) -> TokenStream {
         let nn = field_numbers.len() + 1;
         field_numbers.push(quote! { #nn => Self::#converted_case, });
 
+        let ty = &name.ty;
+        field_types.push(quote!{ <#ty as #microrm_ref::model::Modelable>::column_type() });
+
         if parse_fk(&name.attrs) {
             let fk_struct_name = format_ident!("{}{}ForeignKey", struct_name, converted_case);
             let ty = &name.ty;
@@ -115,6 +119,8 @@ pub fn derive_entity(tokens: TokenStream) -> TokenStream {
         value_references.push(quote! { &self. #field_name });
     }
 
+    let column_types_name = format_ident!("{}_COLUMN_TYPES", struct_name.to_string().to_case(Case::ScreamingSnake));
+
     let field_count = fields.named.iter().count();
 
     quote!{
@@ -162,6 +168,17 @@ pub fn derive_entity(tokens: TokenStream) -> TokenStream {
             fn build_from(stmt: &#microrm_ref::re_export::sqlite::Statement, col_offset: usize) -> #microrm_ref::re_export::sqlite::Result<(Self, usize)> where Self: Sized {
                 stmt.read::<i64>(col_offset).map(|x| (#id_name(x), 1))
             }
+            fn column_type() -> &'static str where Self: Sized {
+                "integer"
+            }
+        }
+        #microrm_ref::re_export::lazy_static::lazy_static!{
+            static ref #column_types_name: [&'static str;#field_count + 1] = {
+                [
+                    "id",
+                    #(#field_types),*
+                ]
+            };
         }
 
         // Implementations for #struct_name
@@ -186,7 +203,9 @@ pub fn derive_entity(tokens: TokenStream) -> TokenStream {
             fn values(&self) -> Vec<&dyn #microrm_ref::model::Modelable> {
                 vec![ #(#value_references),* ]
             }
-
+            fn column_types() -> &'static [&'static str] {
+                #column_types_name.as_ref()
+            }
             fn foreign_keys() -> &'static [&'static dyn #microrm_ref::model::EntityForeignKey<Self::Column>] {
                 &[#(#foreign_keys),*]
             }
@@ -197,15 +216,19 @@ pub fn derive_entity(tokens: TokenStream) -> TokenStream {
     }.into()
 }
 
-/// Marks a struct as able to be directly used in an Entity to correspond to a single database column.
+/// Marks a struct or enum as able to be directly used in an Entity to correspond to a single database column.
 #[proc_macro_derive(Modelable, attributes(microrm_internal))]
 pub fn derive_modelable(tokens: TokenStream) -> TokenStream {
     let input = parse_macro_input!(tokens as DeriveInput);
 
+    // TODO: implement unit-only-variant-enum optimization
+    // TODO: if is a struct/newtype AND only has single element, store as that type
+
     let microrm_ref = parse_microrm_ref(&input.attrs);
 
     let ident = input.ident;
 
+
     quote!{
         impl #microrm_ref::model::Modelable for #ident {
             fn bind_to(&self, stmt: &mut #microrm_ref::re_export::sqlite::Statement, col: usize) -> #microrm_ref::re_export::sqlite::Result<()> {
@@ -220,6 +243,9 @@ pub fn derive_modelable(tokens: TokenStream) -> TokenStream {
                 let data = serde_json::from_str(str_data.as_str()).map_err(|e| sqlite::Error { code: None, message: Some(e.to_string()) })?;
                 Ok((data,1))
             }
+            fn column_type() -> &'static str where Self: Sized {
+                "blob"
+            }
         }
     }.into()
 }

+ 1 - 0
microrm/Cargo.toml

@@ -15,5 +15,6 @@ sqlite = "0.26"
 serde = { version = "1.0", features = ["derive"] }
 serde_bytes = { version = "0.11.6" }
 serde_json = { version = "1.0" }
+lazy_static = { version = "1.4.0" }
 
 microrm-macros = { path = "../microrm-macros", version = "0.1.2" }

+ 1 - 0
microrm/src/lib.rs

@@ -23,6 +23,7 @@ pub mod re_export {
     pub use serde;
     pub use serde_json;
     pub use sqlite;
+    pub use lazy_static;
 }
 
 #[derive(Debug)]

+ 6 - 0
microrm/src/model.rs

@@ -55,6 +55,9 @@ pub trait Modelable {
     fn build_from(stmt: &sqlite::Statement, col_offset: usize) -> sqlite::Result<(Self, usize)>
     where
         Self: Sized;
+    fn column_type() -> &'static str
+    where
+        Self: Sized;
 }
 
 /// A database entity, aka a struct representing a row in a table
@@ -63,6 +66,9 @@ pub trait Entity: 'static + for<'de> serde::Deserialize<'de> + serde::Serialize
     type ID: EntityID;
     fn table_name() -> &'static str;
     fn column_count() -> usize
+    where
+        Self: Sized;
+    fn column_types() -> &'static [&'static str]
     where
         Self: Sized;
     fn index(c: Self::Column) -> usize

+ 28 - 175
microrm/src/model/create.rs

@@ -1,173 +1,5 @@
-use serde::de::Visitor;
-
-use std::cell::Cell;
-use std::rc::Rc;
-
-struct EmptySequence {}
-
-impl<'de> serde::de::SeqAccess<'de> for EmptySequence {
-    type Error = super::ModelError;
-
-    fn next_element_seed<T: serde::de::DeserializeSeed<'de>>(
-        &mut self,
-        _seed: T,
-    ) -> Result<Option<T::Value>, Self::Error> {
-        Ok(None)
-    }
-}
-
-#[derive(Debug)]
-pub struct CreateDeserializer<'de> {
-    struct_visited: bool,
-    column_types: Vec<&'static str>,
-    expected_length: Rc<Cell<usize>>,
-    _de: std::marker::PhantomData<&'de u8>,
-}
-
-impl<'de> CreateDeserializer<'de> {
-    fn integral_type(&mut self) {
-        self.column_types.push("integer");
-    }
-}
-
-impl<'de, 'a> serde::de::Deserializer<'de> for &'a mut CreateDeserializer<'de> {
-    type Error = super::ModelError;
-
-    // we (ab)use the forward_to_deserialize_any! macro to stub out the types we don't care about
-    serde::forward_to_deserialize_any! {
-        bool i128 u128 f32 f64 char str
-        option unit unit_struct tuple
-        tuple_struct map enum identifier ignored_any
-    }
-
-    fn deserialize_any<V: Visitor<'de>>(self, _v: V) -> Result<V::Value, Self::Error> {
-        todo!()
-    }
-
-    fn deserialize_u8<V: Visitor<'de>>(self, v: V) -> Result<V::Value, Self::Error> {
-        self.integral_type();
-        v.visit_u8(0)
-    }
-
-    fn deserialize_u16<V: Visitor<'de>>(self, v: V) -> Result<V::Value, Self::Error> {
-        self.integral_type();
-        v.visit_u16(0)
-    }
-
-    fn deserialize_u32<V: Visitor<'de>>(self, v: V) -> Result<V::Value, Self::Error> {
-        self.integral_type();
-        v.visit_u32(0)
-    }
-
-    fn deserialize_u64<V: Visitor<'de>>(self, v: V) -> Result<V::Value, Self::Error> {
-        self.integral_type();
-        v.visit_u64(0)
-    }
-
-    fn deserialize_i8<V: Visitor<'de>>(self, v: V) -> Result<V::Value, Self::Error> {
-        self.integral_type();
-        v.visit_i8(0)
-    }
-
-    fn deserialize_i16<V: Visitor<'de>>(self, v: V) -> Result<V::Value, Self::Error> {
-        self.integral_type();
-        v.visit_i16(0)
-    }
-
-    fn deserialize_i32<V: Visitor<'de>>(self, v: V) -> Result<V::Value, Self::Error> {
-        self.integral_type();
-        v.visit_i32(0)
-    }
-
-    fn deserialize_i64<V: Visitor<'de>>(self, v: V) -> Result<V::Value, Self::Error> {
-        self.integral_type();
-        v.visit_i64(0)
-    }
-
-    fn deserialize_string<V: Visitor<'de>>(self, v: V) -> Result<V::Value, Self::Error> {
-        self.column_types.push("text");
-        v.visit_string("".to_owned())
-    }
-
-    fn deserialize_bytes<V: Visitor<'de>>(self, v: V) -> Result<V::Value, Self::Error> {
-        self.column_types.push("blob");
-        v.visit_bytes(&[])
-    }
-
-    fn deserialize_byte_buf<V: Visitor<'de>>(self, v: V) -> Result<V::Value, Self::Error> {
-        self.column_types.push("blob");
-        v.visit_bytes(&[])
-    }
-
-    fn deserialize_seq<V: Visitor<'de>>(self, v: V) -> Result<V::Value, Self::Error> {
-        // we store sequences as JSON-encoded strings, so...
-        self.column_types.push("text");
-        let es = EmptySequence {};
-        v.visit_seq(es)
-    }
-
-    fn deserialize_struct<V: Visitor<'de>>(
-        self,
-        _name: &'static str,
-        fields: &'static [&'static str],
-        v: V,
-    ) -> Result<V::Value, Self::Error> {
-        if self.struct_visited {
-            panic!("Nested structs not allowed!");
-        } else {
-            self.struct_visited = true;
-            self.expected_length.set(fields.len());
-            v.visit_seq(self)
-        }
-    }
-
-    fn deserialize_newtype_struct<V: Visitor<'de>>(
-        self,
-        _name: &'static str,
-        v: V,
-    ) -> Result<V::Value, Self::Error> {
-        let elength = self.expected_length.clone();
-        let old_elength = elength.get();
-        elength.set(1);
-        let ret = v.visit_seq(self);
-
-        elength.set(old_elength);
-
-        ret
-    }
-}
-
-impl<'de> serde::de::SeqAccess<'de> for CreateDeserializer<'de> {
-    type Error = super::ModelError;
-
-    fn next_element_seed<T: serde::de::DeserializeSeed<'de>>(
-        &mut self,
-        seed: T,
-    ) -> Result<Option<T::Value>, Self::Error> {
-        if self.expected_length.get() == 0 {
-            return Err(Self::Error::CreateError);
-        }
-
-        self.expected_length.set(self.expected_length.get() - 1);
-
-        seed.deserialize(self).map(Some)
-    }
-}
-
 pub fn sql_for_table<T: crate::model::Entity>() -> (String, String) {
-    let elength = Rc::new(Cell::new(0));
-
-    let mut cd = CreateDeserializer {
-        struct_visited: false,
-        column_types: Vec::new(),
-        expected_length: elength,
-        _de: std::marker::PhantomData {},
-    };
-
-    T::deserialize(&mut cd).expect("SQL creation failed!");
-
-    // +1 to account for id column that is included in column_count
-    assert_eq!(T::column_count(), cd.column_types.len() + 1);
+    let types = T::column_types();
 
     let mut columns = vec!["id integer primary key".to_owned()];
 
@@ -190,7 +22,7 @@ pub fn sql_for_table<T: crate::model::Entity>() -> (String, String) {
         columns.push(format!(
             "\"{}\" {}{}",
             T::name(col),
-            cd.column_types[i - 1],
+            types[i],
             fk.last().unwrap_or_else(|| "".to_string())
         ));
     }
@@ -364,11 +196,32 @@ mod test {
     #[test]
     fn test_vec() {
         assert_eq!(
-            super::sql_for_table::<VecTest>(),
-            (
-                r#"DROP TABLE IF EXISTS "vec_test""#.to_owned(),
-                r#"CREATE TABLE IF NOT EXISTS "vec_test" (id integer primary key,"e" integer,"test" text)"#.to_owned()
-            )
+            super::sql_for_table::<VecTest>().1,
+            r#"CREATE TABLE IF NOT EXISTS "vec_test" (id integer primary key,"e" integer,"test" blob)"#.to_owned()
         );
     }
+
+    #[derive(crate::Modelable, serde::Deserialize, serde::Serialize)]
+    #[microrm_internal]
+    pub enum TestEnum {
+        A,
+        B,
+        C
+    }
+
+    #[derive(serde::Serialize, serde::Deserialize, crate::Entity)]
+    #[microrm_internal]
+    pub struct EnumContainer {
+        before: usize,
+        e: TestEnum,
+        after: usize
+    }
+
+    #[test]
+    fn test_enum() {
+        assert_eq!(
+            super::sql_for_table::<EnumContainer>().1,
+            r#"CREATE TABLE IF NOT EXISTS "enum_container" (id integer primary key,"before" integer,"e" blob,"after" integer)"#.to_owned()
+        )
+    }
 }

+ 30 - 2
microrm/src/model/modelable.rs

@@ -16,6 +16,10 @@ macro_rules! integral {
             {
                 stmt.read::<i64>(col_offset).map(|x| (x as Self, 1))
             }
+
+            fn column_type() -> &'static str where Self: Sized {
+                "integer"
+            }
         }
     };
 }
@@ -40,6 +44,9 @@ impl Modelable for f64 {
     {
         stmt.read(col_offset).map(|x| (x, 1))
     }
+    fn column_type() -> &'static str where Self: Sized {
+        "numeric"
+    }
 }
 
 impl Modelable for bool {
@@ -53,6 +60,9 @@ impl Modelable for bool {
     {
         unreachable!("sqlite only gives Strings back, not &strs!");
     }
+    fn column_type() -> &'static str where Self: Sized {
+        "integer"
+    }
 }
 
 impl<'a> Modelable for &'a str {
@@ -65,6 +75,10 @@ impl<'a> Modelable for &'a str {
     {
         unreachable!("sqlite only gives Strings back, not &strs!");
     }
+
+    fn column_type() -> &'static str where Self: Sized {
+        "text"
+    }
 }
 
 impl Modelable for std::string::String {
@@ -77,6 +91,9 @@ impl Modelable for std::string::String {
     {
         stmt.read(col_offset).map(|x| (x, 1))
     }
+    fn column_type() -> &'static str where Self: Sized {
+        "text"
+    }
 }
 
 impl<'a> Modelable for &'a [u8] {
@@ -89,6 +106,9 @@ impl<'a> Modelable for &'a [u8] {
     {
         unreachable!("sqlite only gives Vec<u8> back, not &[u8]!");
     }
+    fn column_type() -> &'static str where Self: Sized {
+        "blob"
+    }
 }
 
 /*impl Modelable for Vec<u8> {
@@ -103,16 +123,19 @@ impl<'a> Modelable for &'a [u8] {
     }
 }*/
 
-impl<'a, T: Modelable + ?Sized> Modelable for &'a T {
+impl<'a, T: Modelable> Modelable for &'a T {
     fn bind_to(&self, stmt: &mut sqlite::Statement, col: usize) -> sqlite::Result<()> {
         <T as Modelable>::bind_to(self, stmt, col)
     }
-    fn build_from(stmt: &sqlite::Statement, col_offset: usize) -> sqlite::Result<(Self, usize)>
+    fn build_from(_stmt: &sqlite::Statement, _col_offset: usize) -> sqlite::Result<(Self, usize)>
     where
         Self: Sized,
     {
         unreachable!();
     }
+    fn column_type() -> &'static str where Self: Sized {
+        unreachable!();
+    }
 }
 
 impl<T: Modelable + serde::Serialize + serde::de::DeserializeOwned> Modelable for Vec<T> {
@@ -123,6 +146,8 @@ impl<T: Modelable + serde::Serialize + serde::de::DeserializeOwned> Modelable fo
     where
         Self: Sized,
     {
+        // TODO: add special exception for one-byte datatypes
+
         let s = String::build_from(stmt, col_offset)?;
         Ok((
             serde_json::from_str::<Vec<T>>(s.0.as_str()).map_err(|e| sqlite::Error {
@@ -132,4 +157,7 @@ impl<T: Modelable + serde::Serialize + serde::de::DeserializeOwned> Modelable fo
             1,
         ))
     }
+    fn column_type() -> &'static str where Self: Sized {
+        "blob"
+    }
 }