Browse Source

Add DBPool for multi-threading support.

Kestrel 2 years ago
parent
commit
7e37e322d2
2 changed files with 62 additions and 1 deletions
  1. 56 0
      microrm/src/lib.rs
  2. 6 1
      microrm/src/query.rs

+ 56 - 0
microrm/src/lib.rs

@@ -183,6 +183,62 @@ impl DB {
     }
 }
 
+/// Add support for multi-threading to a `DB`.
+/// 
+/// This is a thread-local cache that carefully maintains the property that no
+/// element of the cache will ever be accessed in any way from another thread. The only
+/// way to maintain this property is to leak all data, so this is best used
+/// in lightly-threaded programs (or at least a context where threads are long-lived).
+/// All cached values are assumed to use interior mutability where needed to maintain state.
+///
+/// This approach ensures that all items can live for the provided lifetime `'l`.
+pub struct DBPool<'a> {
+    // normally DB is not Send because the raw sqlite ptr is not Send
+    // however we assume sqlite is operating in serialized mode, which means
+    // that it is in fact both `Send` and `Sync`
+    db: &'a DB,
+    // we carefully maintain the invariant here that only the thread with the
+    // given `ThreadId` accesses the QueryInterface part of the pair
+    qi: std::sync::RwLock<Vec<(std::thread::ThreadId, &'a QueryInterface<'a>)>>,
+}
+
+impl<'a> DBPool<'a> {
+    pub fn new(db: &'a DB) -> Self {
+        Self { db: db, qi: std::sync::RwLock::new(Vec::new()) }
+    }
+
+    /// Get a query interface from this DB pool for the current thread
+    pub fn query_interface(&self) -> &query::QueryInterface<'a> {
+        let guard = self.qi.read().expect("Couldn't acquire read lock");
+        let current_id = std::thread::current().id();
+        if let Some(res) = guard.iter().find_map(|x| if x.0 == current_id { Some(x.1) } else { None }) {
+            return res
+        }
+
+        drop(guard);
+        let mut guard = self.qi.write().expect("Couldn't acquire write lock");
+        guard.push((current_id, Box::leak(Box::new(self.db.query_interface()))));
+        drop(guard);
+
+        self.query_interface()
+    }
+}
+
+/// We carefully implement `DBPool` so that it is `Send`.
+unsafe impl<'a> Send for DBPool<'a> {}
+/// We carefully implement `DBPool` so that it is `Sync`.
+unsafe impl<'a> Sync for DBPool<'a> {}
+
+#[cfg(test)]
+mod pool_test {
+    trait IsSend: Send { }
+    impl IsSend for super::DB { }
+    impl<'a> IsSend for super::DBPool<'a> { }
+    // we make sure that DBPool is send / sync safe
+    trait IsSendAndSync : Send + Sync { }
+    impl<'a> IsSendAndSync for super::DBPool<'a> { }
+}
+
 #[cfg(test)]
 mod test {
     use super::DB;

+ 6 - 1
microrm/src/query.rs

@@ -52,11 +52,15 @@ type CacheIndex = (&'static str, std::any::TypeId, u64);
 ///
 /// As the query interface provides some level of caching, try to strive for as much sharing as
 /// possible. Passing around `QueryInterface` references instead of `DB` references is a good way
-/// to achieve this.
+/// to achieve this. However, `QueryInterface` is explicitly `!Send`, so you may need to use
+/// something like a `DBPool`.
 pub struct QueryInterface<'l> {
     db: &'l crate::DB,
 
     cache: std::sync::Mutex<std::collections::HashMap<CacheIndex, sqlite::Statement<'l>>>,
+
+    // use a phantom non-Send-able type to implement !Send for QueryInterface
+    prevent_send: std::marker::PhantomData<*mut ()>
 }
 
 const NO_HASH: u64 = 0;
@@ -66,6 +70,7 @@ impl<'l> QueryInterface<'l> {
         Self {
             db,
             cache: std::sync::Mutex::new(std::collections::HashMap::new()),
+            prevent_send: std::marker::PhantomData
         }
     }