Kaynağa Gözat

Add thread-safe query caching layer.

Kestrel 2 yıl önce
ebeveyn
işleme
2eb1748461

+ 1 - 1
microrm-macros/src/lib.rs

@@ -116,7 +116,7 @@ pub fn derive_entity(tokens: TokenStream) -> TokenStream {
 
     quote!{
         // Related types for #struct_name
-        #[derive(Clone,Copy,PartialEq)]
+        #[derive(Clone,Copy,PartialEq,Hash)]
         #[allow(unused)]
         #[repr(usize)]
         pub enum #enum_name {

+ 9 - 3
microrm/src/lib.rs

@@ -44,9 +44,11 @@ pub mod model;
 pub mod query;
 
 use meta::Metaschema;
-pub use microrm_macros::{Entity, Modelable};
 use model::Entity;
 
+pub use microrm_macros::{Entity, Modelable};
+pub use query::{QueryInterface, WithID};
+
 // no need to show the re-exports in the documentation
 #[doc(hidden)]
 pub mod re_export {
@@ -109,6 +111,11 @@ impl DB {
         )
     }
 
+    /// Get a query interface for this DB connection
+    pub fn query_interface(&self) -> query::QueryInterface {
+        query::QueryInterface::new(self)
+    }
+
     pub fn recreate_schema(&self) -> Result<(), DBError> {
         self.create_schema()
     }
@@ -169,7 +176,6 @@ impl DB {
 
         let qi = query::QueryInterface::new(self);
         let hash = qi.get_one_by(meta::MetaschemaColumns::Key, "schema_hash");
-        // let hash = query::get_one_by(self, meta::MetaschemaColumns::Key, "schema_hash");
 
         if hash.is_none() {
             if mode == CreateMode::MustExist {
@@ -247,7 +253,7 @@ mod test {
     fn simple_foreign_key() {
         let db = DB::new_in_memory(super::model::SchemaModel::new().add::<S1>().add::<S2>())
             .expect("Can't connect to in-memory DB");
-        let qi = crate::query::QueryInterface::new(&db);
+        let qi = db.query_interface();
 
         let id = qi.add(&S1 { an_id: -1 }).expect("Can't add S1");
         let child_id = qi.add(&S2 { parent_id: id }).expect("Can't add S2");

+ 2 - 2
microrm/src/model.rs

@@ -58,7 +58,7 @@ pub trait Modelable {
 }
 
 /// A database entity, aka a struct representing a row in a table
-pub trait Entity: for<'de> serde::Deserialize<'de> + serde::Serialize {
+pub trait Entity: 'static + for<'de> serde::Deserialize<'de> + serde::Serialize {
     type Column: EntityColumns + 'static + Copy;
     type ID: EntityID;
     fn table_name() -> &'static str;
@@ -77,7 +77,7 @@ pub trait Entity: for<'de> serde::Deserialize<'de> + serde::Serialize {
 }
 
 /// Trait representing the columns of a database entity
-pub trait EntityColumns: PartialEq + From<usize> {
+pub trait EntityColumns: PartialEq + From<usize> + std::hash::Hash + Clone {
     type Entity: Entity;
 }
 

+ 5 - 2
microrm/src/model/modelable.rs

@@ -7,14 +7,17 @@ macro_rules! integral {
             fn bind_to(&self, stmt: &mut sqlite::Statement, col: usize) -> sqlite::Result<()> {
                 (*self as i64).bind(stmt, col)
             }
-            fn build_from(stmt: &sqlite::Statement, col_offset: usize) -> sqlite::Result<(Self, usize)>
+            fn build_from(
+                stmt: &sqlite::Statement,
+                col_offset: usize,
+            ) -> sqlite::Result<(Self, usize)>
             where
                 Self: Sized,
             {
                 stmt.read::<i64>(col_offset).map(|x| (x as Self, 1))
             }
         }
-    }
+    };
 }
 
 integral!(i8);

+ 170 - 83
microrm/src/query.rs

@@ -46,34 +46,98 @@ impl<T: Entity> std::ops::DerefMut for WithID<T> {
     }
 }
 
+type CacheIndex = (&'static str, std::any::TypeId, u64);
+
+/// The query interface for a database.
+///
+/// As the query interface provides some level of caching, try to strive for as much sharing as
+/// possible. Passing around `QueryInterface` references instead of `DB` references is a good way
+/// to achieve this.
 pub struct QueryInterface<'l> {
     db: &'l crate::DB,
+
+    cache: std::sync::Mutex<std::collections::HashMap<CacheIndex, sqlite::Statement<'l>>>,
 }
 
+const NO_HASH: u64 = 0;
+
 impl<'l> QueryInterface<'l> {
     pub fn new(db: &'l crate::DB) -> Self {
-        Self { db }
+        Self {
+            db,
+            cache: std::sync::Mutex::new(std::collections::HashMap::new()),
+        }
     }
 
-    /// Helper function to process an expected one result
+    /// Helper function to process an expected single result
+    /// Note that this errors out if there is more than a single result
     fn expect_one_result<T>(
         &self,
         stmt: &mut sqlite::Statement,
         with_result: &mut dyn FnMut(&mut sqlite::Statement) -> Option<T>,
     ) -> Option<T> {
-        let state = stmt.next();
-        assert!(state.is_ok());
-        assert_eq!(state.ok(), Some(sqlite::State::Row));
+        let state = stmt.next().ok()?;
+        if state != sqlite::State::Row {
+            return None;
+        }
 
         let res = with_result(stmt);
 
-        let state = stmt.next();
-        assert!(state.is_ok());
-        assert_eq!(state.ok(), Some(sqlite::State::Done));
+        let state = stmt.next().ok()?;
+        if state != sqlite::State::Done {
+            return None;
+        }
 
         res
     }
 
+    fn cached_query<Return>(
+        &self,
+        context: &'static str,
+        ty: std::any::TypeId,
+        create: &dyn Fn() -> sqlite::Statement<'l>,
+        with: &mut dyn FnMut(&mut sqlite::Statement<'l>) -> Return,
+    ) -> Return {
+        let mut cache = self.cache.lock().expect("Couldn't acquire cache?");
+        let key = (context, ty, NO_HASH);
+        if !cache.contains_key(&key) {
+            cache.insert(key, create());
+        }
+        let mut query = cache
+            .get_mut(&key)
+            .expect("Just-inserted item not in cache?");
+
+        query.reset().expect("Couldn't reset query");
+        with(&mut query)
+    }
+
+    fn cached_query_column<Column: crate::model::EntityColumns, Return>(
+        &self,
+        context: &'static str,
+        ty: std::any::TypeId,
+        variant: &Column,
+        create: &dyn Fn() -> sqlite::Statement<'l>,
+        with: &mut dyn FnMut(&mut sqlite::Statement<'l>) -> Return,
+    ) -> Return {
+        use std::hash::Hasher;
+
+        let mut hasher = std::collections::hash_map::DefaultHasher::new();
+        variant.hash(&mut hasher);
+        let hash = hasher.finish();
+
+        let mut cache = self.cache.lock().expect("Couldn't acquire cache?");
+        let key = (context, ty, hash);
+        if !cache.contains_key(&key) {
+            cache.insert(key, create());
+        }
+        let mut query = cache
+            .get_mut(&key)
+            .expect("Just-inserted item not in cache?");
+
+        query.reset().expect("Couldn't reset query");
+        with(&mut query)
+    }
+
     /// Search for an entity by a property
     pub fn get_one_by<
         T: Entity<Column = C>,
@@ -85,44 +149,56 @@ impl<'l> QueryInterface<'l> {
         val: V,
     ) -> Option<WithID<T>> {
         let table_name = <T as Entity>::table_name();
-        let column_name = <T as Entity>::name(c);
-
-        let mut prepared = self
-            .db
-            .conn
-            .prepare(&format!(
-                "SELECT * FROM \"{}\" WHERE \"{}\" = ?",
-                table_name, column_name
-            ))
-            .expect("");
-
-        prepared.reset().ok()?;
-
-        val.bind_to(&mut prepared, 1).ok()?;
-
-        self.expect_one_result(&mut prepared, &mut |stmt| {
-            let id: i64 = stmt.read(0).ok()?;
-            let mut rd = crate::model::load::RowDeserializer::from_row(stmt);
-            Some(WithID::wrap(T::deserialize(&mut rd).ok()?, id))
-        })
+        let column_name = <T as Entity>::name(c.clone());
+
+        self.cached_query_column(
+            "get_one_by",
+            std::any::TypeId::of::<T>(),
+            &c,
+            &|| {
+                self.db
+                    .conn
+                    .prepare(&format!(
+                        "SELECT * FROM \"{}\" WHERE \"{}\" = ?",
+                        table_name, column_name
+                    ))
+                    .expect("")
+            },
+            &mut |stmt| {
+                val.bind_to(stmt, 1).ok()?;
+
+                self.expect_one_result(stmt, &mut |stmt| {
+                    let id: i64 = stmt.read(0).ok()?;
+                    let mut rd = crate::model::load::RowDeserializer::from_row(stmt);
+                    Some(WithID::wrap(T::deserialize(&mut rd).ok()?, id))
+                })
+            },
+        )
     }
 
     /// Search for an entity by ID
     pub fn get_one_by_id<I: EntityID<Entity = T>, T: Entity>(&self, id: I) -> Option<WithID<T>> {
         let table_name = <T as Entity>::table_name();
-        let mut prepared = self
-            .db
-            .conn
-            .prepare(&format!("SELECT * FROM \"{}\" WHERE id = ?", table_name))
-            .ok()?;
-
-        id.bind_to(&mut prepared, 1).ok()?;
 
-        self.expect_one_result(&mut prepared, &mut |stmt| {
-            let id: i64 = stmt.read(0).ok()?;
-            let mut rd = crate::model::load::RowDeserializer::from_row(stmt);
-            return Some(WithID::wrap(T::deserialize(&mut rd).ok()?, id));
-        })
+        self.cached_query(
+            "get_one_by_id",
+            std::any::TypeId::of::<T>(),
+            &|| {
+                self.db
+                    .conn
+                    .prepare(&format!("SELECT * FROM \"{}\" WHERE id = ?", table_name))
+                    .expect("")
+            },
+            &mut |stmt| {
+                id.bind_to(stmt, 1).ok()?;
+
+                self.expect_one_result(stmt, &mut |stmt| {
+                    let id: i64 = stmt.read(0).ok()?;
+                    let mut rd = crate::model::load::RowDeserializer::from_row(stmt);
+                    Some(WithID::wrap(T::deserialize(&mut rd).ok()?, id))
+                })
+            },
+        )
     }
 
     /// Search for all entities matching a property
@@ -136,8 +212,35 @@ impl<'l> QueryInterface<'l> {
         val: V,
     ) -> Option<Vec<WithID<T>>> {
         let table_name = <T as Entity>::table_name();
-        let column_name = <T as Entity>::name(c);
+        let column_name = <T as Entity>::name(c.clone());
+
+        self.cached_query_column(
+            "get_all_by",
+            std::any::TypeId::of::<T>(),
+            &c,
+            &|| {
+                self.db
+                    .conn
+                    .prepare(&format!(
+                        "SELECT * FROM \"{}\" WHERE {} = ?",
+                        table_name, column_name
+                    ))
+                    .expect("")
+            },
+            &mut |stmt| {
+                val.bind_to(stmt, 1).ok()?;
+
+                todo!()
+
+                /*self.expect_one_result(stmt, &mut |stmt| {
+                    let id: i64 = stmt.read(0).ok()?;
+                    let mut rd = crate::model::load::RowDeserializer::from_row(stmt);
+                    Some(WithID::wrap(T::deserialize(&mut rd).ok()?, id))
+                })*/
+            },
+        )
 
+        /*
         let mut prepared = self
             .db
             .conn
@@ -150,6 +253,7 @@ impl<'l> QueryInterface<'l> {
         val.bind_to(&mut prepared, 1).ok()?;
 
         todo!();
+        */
 
         /*let rows = prepared
             .query_map([&val], |row| {
@@ -166,47 +270,30 @@ impl<'l> QueryInterface<'l> {
 
     /// Add an entity to its table
     pub fn add<T: Entity + serde::Serialize>(&self, m: &T) -> Option<<T as Entity>::ID> {
-        let placeholders = (0..(<T as Entity>::column_count() - 1))
-            .map(|_| "?".to_string())
-            .collect::<Vec<_>>()
-            .join(",");
-
-        let mut prepared = self
-            .db
-            .conn
-            .prepare(&format!(
-                "INSERT INTO \"{}\" VALUES (NULL, {}) RETURNING \"id\"",
-                <T as Entity>::table_name(),
-                placeholders
-            ))
-            .ok()?;
-
-        crate::model::store::serialize_into(&mut prepared, m).ok()?;
-
-        let rowid = self.expect_one_result(&mut prepared, &mut |stmt| stmt.read::<i64>(0).ok())?;
-
-        Some(<T as Entity>::ID::from_raw_id(rowid))
-
-        /*
-        let row = crate::model::store::serialize_as_row(m);
-
-        let placeholders = (0..(<T as Entity>::column_count() - 1))
-            .map(|n| format!("?{}", n + 1))
-            .collect::<Vec<_>>()
-            .join(",");
-
-        let res = db.conn.prepare(&format!(
-            "INSERT INTO \"{}\" VALUES (NULL, {})",
-            <T as Entity>::table_name(),
-            placeholders
-        ));
-        let mut prepared = res.ok()?;
-
-        // make sure we bound enough things (not including ID column here)
-        assert_eq!(row.len(), <T as Entity>::column_count() - 1);
-        */
-
-        /*let id = prepared.insert(rusqlite::params_from_iter(row)).ok()?;
-        Some(<T as Entity>::ID::from_raw_id(id))*/
+        self.cached_query(
+            "get_all_by",
+            std::any::TypeId::of::<T>(),
+            &|| {
+                let placeholders = (0..(<T as Entity>::column_count() - 1))
+                    .map(|_| "?".to_string())
+                    .collect::<Vec<_>>()
+                    .join(",");
+
+                self.db
+                    .conn
+                    .prepare(&format!(
+                        "INSERT INTO \"{}\" VALUES (NULL, {}) RETURNING \"id\"",
+                        <T as Entity>::table_name(), placeholders
+                    ))
+                    .expect("")
+            },
+            &mut |stmt| {
+                crate::model::store::serialize_into(stmt, m).ok()?;
+
+                let rowid = self.expect_one_result(stmt, &mut |stmt| stmt.read::<i64>(0).ok())?;
+
+                Some(<T as Entity>::ID::from_raw_id(rowid))
+            },
+        )
     }
 }