Browse Source

rustfmt pass and get_one_by_multi support.

Kestrel 2 years ago
parent
commit
d06683b9f6
7 changed files with 133 additions and 42 deletions
  1. 1 1
      .vimrc
  2. 15 10
      microrm-macros/src/lib.rs
  3. 28 16
      microrm/src/lib.rs
  4. 2 4
      microrm/src/model.rs
  5. 2 4
      microrm/src/model/create.rs
  6. 32 1
      microrm/src/model/modelable.rs
  7. 53 6
      microrm/src/query.rs

+ 1 - 1
.vimrc

@@ -1 +1 @@
-set wildignore+=target
+set wildignore+=target,archive

+ 15 - 10
microrm-macros/src/lib.rs

@@ -21,8 +21,10 @@ fn parse_microrm_ref(attrs: &[syn::Attribute]) -> proc_macro2::TokenStream {
 
 fn parse_fk(attrs: &[syn::Attribute]) -> bool {
     for attr in attrs {
-        if attr.path.segments.len() == 1 && attr.path.segments.last().unwrap().ident == "microrm_foreign" {
-            return true
+        if attr.path.segments.len() == 1
+            && attr.path.segments.last().unwrap().ident == "microrm_foreign"
+        {
+            return true;
         }
     }
 
@@ -91,7 +93,7 @@ pub fn derive_entity(tokens: TokenStream) -> TokenStream {
         if parse_fk(&name.attrs) {
             let fk_struct_name = format_ident!("{}{}ForeignKey", struct_name, converted_case);
             let ty = &name.ty;
-            foreign_keys.push(quote!{
+            foreign_keys.push(quote! {
                 &#fk_struct_name { col: #enum_name::#converted_case }
             });
             foreign_key_impls.push(quote!{
@@ -222,13 +224,13 @@ pub fn derive_modelable(tokens: TokenStream) -> TokenStream {
     }.into()
 }
 
-type ColumnList = syn::punctuated::Punctuated::<syn::TypePath, syn::Token![,]>;
+type ColumnList = syn::punctuated::Punctuated<syn::TypePath, syn::Token![,]>;
 struct MakeIndexParams {
     unique: Option<syn::Token![!]>,
     name: syn::Ident,
     #[allow(dead_code)]
     comma: syn::Token![,],
-    columns: ColumnList
+    columns: ColumnList,
 }
 
 impl syn::parse::Parse for MakeIndexParams {
@@ -237,7 +239,7 @@ impl syn::parse::Parse for MakeIndexParams {
             unique: input.parse()?,
             name: input.parse()?,
             comma: input.parse()?,
-            columns: ColumnList::parse_separated_nonempty(input)?
+            columns: ColumnList::parse_separated_nonempty(input)?,
         })
     }
 }
@@ -252,7 +254,10 @@ fn do_make_index(tokens: TokenStream, microrm_ref: proc_macro2::TokenStream) ->
 
     // remove variant name
     column_type_path.segments.pop();
-    let last = column_type_path.segments.pop().expect("Full path to EntityColumn variant");
+    let last = column_type_path
+        .segments
+        .pop()
+        .expect("Full path to EntityColumn variant");
     column_type_path.segments.push(last.value().clone());
 
     let index_entity_type_name = format_ident!("{}Entity", index_struct_name);
@@ -304,13 +309,13 @@ fn do_make_index(tokens: TokenStream, microrm_ref: proc_macro2::TokenStream) ->
 /// ```
 #[proc_macro]
 pub fn make_index(tokens: TokenStream) -> TokenStream {
-    do_make_index(tokens, quote!{ ::microrm })
+    do_make_index(tokens, quote! { ::microrm })
 }
 
 /// For internal use inside the microrm library. See `make_index`.
 #[proc_macro]
 pub fn make_index_internal(tokens: TokenStream) -> TokenStream {
-    do_make_index(tokens, quote!{ crate })
+    do_make_index(tokens, quote! { crate })
 }
 
- // , attributes(microrm_internal))]
+// , attributes(microrm_internal))]

+ 28 - 16
microrm/src/lib.rs

@@ -7,9 +7,16 @@ pub mod query;
 use meta::Metaschema;
 use model::Entity;
 
-pub use microrm_macros::{Entity, Modelable, make_index};
+pub use microrm_macros::{make_index, Entity, Modelable};
 pub use query::{QueryInterface, WithID};
 
+#[macro_export]
+macro_rules! value_list {
+    ( $( $element:expr ),* ) => {
+        [ $( ($element) as &dyn $crate::model::Modelable ),* ]
+    }
+}
+
 // no need to show the re-exports in the documentation
 #[doc(hidden)]
 pub mod re_export {
@@ -131,9 +138,8 @@ impl DB {
 
         if !has_metaschema && mode != CreateMode::MustExist {
             return self.create_schema();
-        }
-        else if !has_metaschema && mode == CreateMode::MustExist {
-            return Err(DBError::NoSchema)
+        } else if !has_metaschema && mode == CreateMode::MustExist {
+            return Err(DBError::NoSchema);
         }
 
         let qi = query::QueryInterface::new(self);
@@ -185,7 +191,7 @@ impl DB {
 }
 
 /// Add support for multi-threading to a `DB`.
-/// 
+///
 /// This is a thread-local cache that carefully maintains the property that no
 /// element of the cache will ever be accessed in any way from another thread. The only
 /// way to maintain this property is to leak all data, so this is best used
@@ -205,15 +211,21 @@ pub struct DBPool<'a> {
 
 impl<'a> DBPool<'a> {
     pub fn new(db: &'a DB) -> Self {
-        Self { db: db, qi: std::sync::RwLock::new(Vec::new()) }
+        Self {
+            db: db,
+            qi: std::sync::RwLock::new(Vec::new()),
+        }
     }
 
     /// Get a query interface from this DB pool for the current thread
     pub fn query_interface(&self) -> &query::QueryInterface<'a> {
         let guard = self.qi.read().expect("Couldn't acquire read lock");
         let current_id = std::thread::current().id();
-        if let Some(res) = guard.iter().find_map(|x| if x.0 == current_id { Some(x.1) } else { None }) {
-            return res
+        if let Some(res) = guard
+            .iter()
+            .find_map(|x| if x.0 == current_id { Some(x.1) } else { None })
+        {
+            return res;
         }
 
         drop(guard);
@@ -232,12 +244,12 @@ unsafe impl<'a> Sync for DBPool<'a> {}
 
 #[cfg(test)]
 mod pool_test {
-    trait IsSend: Send { }
-    impl IsSend for super::DB { }
-    impl<'a> IsSend for super::DBPool<'a> { }
+    trait IsSend: Send {}
+    impl IsSend for super::DB {}
+    impl<'a> IsSend for super::DBPool<'a> {}
     // we make sure that DBPool is send / sync safe
-    trait IsSendAndSync : Send + Sync { }
-    impl<'a> IsSendAndSync for super::DBPool<'a> { }
+    trait IsSendAndSync: Send + Sync {}
+    impl<'a> IsSendAndSync for super::DBPool<'a> {}
 }
 
 #[cfg(test)]
@@ -284,11 +296,11 @@ mod test {
 
 #[cfg(test)]
 mod test2 {
-    #[derive(Debug,crate::Entity,serde::Serialize,serde::Deserialize)]
+    #[derive(Debug, crate::Entity, serde::Serialize, serde::Deserialize)]
     #[microrm_internal]
     pub struct KVStore {
         pub key: String,
-        pub value: String
+        pub value: String,
     }
 
     // the !KVStoreIndex here means a type representing a unique index named KVStoreIndex
@@ -310,7 +322,7 @@ mod test2 {
 
         qi.add(&KVStore {
             key: "a_key".to_string(),
-            value: "a_value".to_string()
+            value: "a_value".to_string(),
         });
 
         // because KVStoreIndex indexes key, this is a logarithmic lookup

+ 2 - 4
microrm/src/model.rs

@@ -97,7 +97,7 @@ pub trait EntityForeignKey<T: EntityColumns> {
 
 /// Trait for an index over a column
 pub trait Index {
-    type IndexedEntity : Entity;
+    type IndexedEntity: Entity;
 
     fn index_name() -> &'static str
     where
@@ -132,9 +132,7 @@ impl SchemaModel {
         self
     }
 
-    pub fn index<I: Index>(
-        mut self,
-    ) -> Self {
+    pub fn index<I: Index>(mut self) -> Self {
         let (drop, create) = create::sql_for_index::<I>();
         self.drop.push(drop);
         self.create.push(create);

+ 2 - 4
microrm/src/model/create.rs

@@ -186,9 +186,7 @@ pub fn sql_for_table<T: crate::model::Entity>() -> (String, String) {
     )
 }
 
-pub fn sql_for_index<
-    I: super::Index
->() -> (String, String) {
+pub fn sql_for_index<I: super::Index>() -> (String, String) {
     use super::Entity;
     (
         format!("DROP INDEX IF EXISTS \"{}\"", I::index_name()),
@@ -301,7 +299,7 @@ mod test {
             super::sql_for_table::<Child>(),
             (
                 r#"DROP TABLE IF EXISTS "child""#.to_owned(),
-                r#"CREATE TABLE IF NOT EXISTS "child" (id integer primary key,"parent_id" integer references "single"("id"))"#.to_owned()
+                r#"CREATE TABLE IF NOT EXISTS "child" (id integer primary key,"parent_id" integer references "single"("id") ON DELETE CASCADE)"#.to_owned()
             )
         );
     }

+ 32 - 1
microrm/src/model/modelable.rs

@@ -90,7 +90,7 @@ impl<'a> Modelable for &'a [u8] {
     }
 }
 
-impl Modelable for Vec<u8> {
+/*impl Modelable for Vec<u8> {
     fn bind_to(&self, stmt: &mut sqlite::Statement, col: usize) -> sqlite::Result<()> {
         self.bind(stmt, col)
     }
@@ -100,4 +100,35 @@ impl Modelable for Vec<u8> {
     {
         stmt.read(col_offset).map(|x| (x, 1))
     }
+}*/
+
+impl<'a, T: Modelable + ?Sized> Modelable for &'a T {
+    fn bind_to(&self, stmt: &mut sqlite::Statement, col: usize) -> sqlite::Result<()> {
+        <T as Modelable>::bind_to(self, stmt, col)
+    }
+    fn build_from(stmt: &sqlite::Statement, col_offset: usize) -> sqlite::Result<(Self, usize)>
+    where
+        Self: Sized,
+    {
+        unreachable!();
+    }
+}
+
+impl<T: Modelable + serde::Serialize + serde::de::DeserializeOwned> Modelable for Vec<T> {
+    fn bind_to(&self, stmt: &mut sqlite::Statement, col: usize) -> sqlite::Result<()> {
+        serde_json::to_string(self).unwrap().bind_to(stmt, col)
+    }
+    fn build_from(stmt: &sqlite::Statement, col_offset: usize) -> sqlite::Result<(Self, usize)>
+    where
+        Self: Sized,
+    {
+        let s = String::build_from(stmt, col_offset)?;
+        Ok((
+            serde_json::from_str::<Vec<T>>(s.0.as_str()).map_err(|e| sqlite::Error {
+                code: None,
+                message: Some(e.to_string()),
+            })?,
+            1,
+        ))
+    }
 }

+ 53 - 6
microrm/src/query.rs

@@ -60,7 +60,7 @@ pub struct QueryInterface<'l> {
     cache: std::sync::Mutex<std::collections::HashMap<CacheIndex, sqlite::Statement<'l>>>,
 
     // use a phantom non-Send-able type to implement !Send for QueryInterface
-    prevent_send: std::marker::PhantomData<*mut ()>
+    prevent_send: std::marker::PhantomData<*mut ()>,
 }
 
 const NO_HASH: u64 = 0;
@@ -70,7 +70,7 @@ impl<'l> QueryInterface<'l> {
         Self {
             db,
             cache: std::sync::Mutex::new(std::collections::HashMap::new()),
-            prevent_send: std::marker::PhantomData
+            prevent_send: std::marker::PhantomData,
         }
     }
 
@@ -115,14 +115,16 @@ impl<'l> QueryInterface<'l> {
         &self,
         context: &'static str,
         ty: std::any::TypeId,
-        variant: &Column,
+        variant: &[Column],
         create: &dyn Fn() -> sqlite::Statement<'l>,
         with: &mut dyn FnMut(&mut sqlite::Statement<'l>) -> Return,
     ) -> Return {
         use std::hash::Hasher;
 
         let mut hasher = std::collections::hash_map::DefaultHasher::new();
-        variant.hash(&mut hasher);
+        for v in variant {
+            v.hash(&mut hasher);
+        }
         let hash = hasher.finish();
 
         let mut cache = self.cache.lock().expect("Couldn't acquire cache?");
@@ -149,7 +151,7 @@ impl<'l> QueryInterface<'l> {
         self.cached_query_column(
             "get_one_by",
             std::any::TypeId::of::<T>(),
-            &c,
+            &[c],
             &|| {
                 self.db
                     .conn
@@ -171,6 +173,51 @@ impl<'l> QueryInterface<'l> {
         )
     }
 
+    /// Search for an entity by multiple properties
+    pub fn get_one_by_multi<
+        T: Entity<Column = C>,
+        C: EntityColumns<Entity = T>,
+        V: crate::model::Modelable,
+    >(
+        &self,
+        c: &[C],
+        val: &[V],
+    ) -> Option<WithID<T>> {
+        let table_name = <T as Entity>::table_name();
+
+        assert_eq!(c.len(), val.len());
+
+        self.cached_query_column(
+            "get_one_by",
+            std::any::TypeId::of::<T>(),
+            &c,
+            &|| {
+                self.db
+                    .conn
+                    .prepare(&format!(
+                        "SELECT * FROM \"{}\" WHERE {}",
+                        table_name,
+                        c.iter()
+                            .map(|col| format!("\"{}\" = ?", <T as Entity>::name(col.clone())))
+                            .collect::<Vec<_>>()
+                            .join(",")
+                    ))
+                    .expect("")
+            },
+            &mut |stmt| {
+                for index in 0..val.len() {
+                    val[index].bind_to(stmt, index + 1).ok()?;
+                }
+
+                self.expect_one_result(stmt, &mut |stmt| {
+                    let id: i64 = stmt.read(0).ok()?;
+                    let mut rd = crate::model::load::RowDeserializer::from_row(stmt);
+                    Some(WithID::wrap(T::deserialize(&mut rd).ok()?, id))
+                })
+            },
+        )
+    }
+
     /// Search for an entity by ID
     pub fn get_one_by_id<I: EntityID<Entity = T>, T: Entity>(&self, id: I) -> Option<WithID<T>> {
         let table_name = <T as Entity>::table_name();
@@ -212,7 +259,7 @@ impl<'l> QueryInterface<'l> {
         self.cached_query_column(
             "get_all_by",
             std::any::TypeId::of::<T>(),
-            &c,
+            &[c],
             &|| {
                 self.db
                     .conn