Browse Source

Initial ConnectionPool.

Kestrel 1 month ago
parent
commit
fbe24c2324
3 changed files with 73 additions and 20 deletions
  1. 1 0
      microrm/Cargo.toml
  2. 71 19
      microrm/src/db.rs
  3. 1 1
      microrm/src/schema/build.rs

+ 1 - 0
microrm/Cargo.toml

@@ -23,6 +23,7 @@ serde = { version = "1.0" }
 serde_json = { version = "1.0" }
 time = "0.3"
 itertools = "0.12"
+thread_local = "1.1"
 
 microrm-macros = { path = "../microrm-macros", version = "0.4.1" }
 log = "0.4.17"

+ 71 - 19
microrm/src/db.rs

@@ -1,11 +1,11 @@
 use crate::{DBResult, Error};
 use libsqlite3_sys as sq;
 use std::{
-    cell::Cell,
+    cell::{Cell, RefCell},
     collections::HashMap,
     ffi::{CStr, CString},
     pin::Pin,
-    sync::{Arc, Mutex},
+    sync::Arc,
 };
 
 fn check_rcode<'a>(sql: impl FnOnce() -> Option<&'a str>, rcode: i32) -> Result<(), Error> {
@@ -24,20 +24,18 @@ fn check_rcode<'a>(sql: impl FnOnce() -> Option<&'a str>, rcode: i32) -> Result<
 
 struct ConnectionData {
     sqlite: *mut sq::sqlite3,
-    stmts: HashMap<u64, Statement>,
+    stmts: RefCell<HashMap<u64, Statement>>,
 }
 
 impl Drop for ConnectionData {
     fn drop(&mut self) {
-        self.stmts.clear();
+        self.stmts.borrow_mut().clear();
         unsafe {
             sq::sqlite3_close(self.sqlite);
         }
     }
 }
 
-unsafe impl Send for ConnectionData {}
-
 pub(crate) trait PreparedKey {
     fn into_u64(self) -> u64;
 }
@@ -58,14 +56,13 @@ impl PreparedKey for std::any::TypeId {
     }
 }
 
-/// Represents a single sqlite connection, in SQLITE_SERIALIZED mode.
-///
-/// This translates to a struct that is Send, but not Sync.
+/// Represents a single sqlite connection, in SQLITE_MULTITHREAD mode.
 #[derive(Clone)]
-pub struct Connection(Arc<Mutex<ConnectionData>>);
+pub struct Connection(Arc<ConnectionData>);
 
 impl Connection {
-    /// Establish a new connection to a sqlite database object. Note that this type carries no schema information, unlike [`Database`](../schema/traits.Database.html).
+    /// Establish a new connection to a sqlite database object. Note that this type carries no
+    /// schema information, unlike types implementing [`Database`](../schema/traits.Database.html).
     pub fn new(url: &str) -> Result<Self, Error> {
         let db_ptr = unsafe {
             let url = CString::new(url)?;
@@ -75,7 +72,7 @@ impl Connection {
                 sq::sqlite3_open_v2(
                     url.as_ptr(),
                     &mut db_ptr,
-                    sq::SQLITE_OPEN_READWRITE | /* sq::SQLITE_OPEN_NOMUTEX |*/ sq::SQLITE_OPEN_CREATE,
+                    sq::SQLITE_OPEN_READWRITE | sq::SQLITE_OPEN_NOMUTEX | sq::SQLITE_OPEN_CREATE,
                     std::ptr::null_mut(),
                 ),
             )?;
@@ -92,15 +89,15 @@ impl Connection {
             sq::sqlite3_busy_timeout(db_ptr, 1000);
         }
 
-        Ok(Self(Arc::new(Mutex::new(ConnectionData {
+        Ok(Self(Arc::new(ConnectionData {
             sqlite: db_ptr,
             stmts: Default::default(),
-        }))))
+        })))
     }
 
     /// Execute a raw SQL statement on the database this connection represents. Use with care.
     pub fn execute_raw_sql(&self, sql: impl AsRef<str>) -> DBResult<()> {
-        let data = self.0.lock()?;
+        let data = self.0.as_ref();
 
         log::trace!("executing raw sql: {sql}", sql = sql.as_ref());
 
@@ -144,11 +141,12 @@ impl Connection {
         build_query: impl Fn() -> String,
         run_query: impl Fn(StatementContext) -> DBResult<R>,
     ) -> DBResult<R> {
-        let mut data = self.0.lock()?;
+        let data = self.0.as_ref();
         let conn = data.sqlite;
 
         use std::collections::hash_map::Entry;
-        match data.stmts.entry(hash_key.into_u64()) {
+        let mut stmts = data.stmts.borrow_mut();
+        match stmts.entry(hash_key.into_u64()) {
             Entry::Vacant(e) => {
                 let sql = build_query();
 
@@ -184,6 +182,60 @@ impl Connection {
     }
 }
 
+struct SendWrapper<T: Clone> {
+    value: T,
+}
+
+impl<T: Clone> SendWrapper<T> {
+    fn new(value: T) -> Self {
+        Self { value }
+    }
+
+    fn get(&self) -> T {
+        self.value.clone()
+    }
+}
+
+unsafe impl<T: Clone> Send for SendWrapper<T> {}
+
+/// Multithreading-safe database connection pool.
+pub struct ConnectionPool {
+    uri: String,
+    connections: std::sync::RwLock<HashMap<std::thread::ThreadId, SendWrapper<Connection>>>,
+}
+
+impl ConnectionPool {
+    /// Construct a new pool from a URI
+    pub fn new(uri: &str) -> Self {
+        Self {
+            uri: uri.into(),
+            connections: Default::default(),
+        }
+    }
+
+    /// Retrieve the [`Connection`] for the current thread.
+    pub fn get(&self) -> DBResult<Connection> {
+        let thread_id = std::thread::current().id();
+        // short path: thread already has a connection
+        {
+            let cmap = self.connections.read().expect("poisoned lock");
+            if let Some(conn) = cmap.get(&thread_id) {
+                return Ok(conn.get());
+            }
+        }
+        // long path: need to construct a new connection
+        let nconn = Connection::new(self.uri.as_str())?;
+
+        let mut cmap = self.connections.write().expect("poisoned lock");
+        cmap.insert(thread_id, SendWrapper::new(nconn.clone()));
+
+        Ok(nconn)
+    }
+}
+
+unsafe impl Send for ConnectionPool {}
+unsafe impl Sync for ConnectionPool {}
+
 pub(crate) struct Transaction<'l> {
     db: &'l Connection,
     committed: bool,
@@ -619,11 +671,11 @@ mod sendsync_check {
 
     #[test]
     fn check_send() {
-        let _ = CheckSend::<super::Connection>(Default::default());
+        let _ = CheckSend::<super::ConnectionPool>(Default::default());
     }
 
     #[test]
     fn check_sync() {
-        let _ = CheckSync::<super::Connection>(Default::default());
+        let _ = CheckSync::<super::ConnectionPool>(Default::default());
     }
 }

+ 1 - 1
microrm/src/schema/build.rs

@@ -2,7 +2,7 @@ use crate::{
     prelude::*,
     schema::{
         collect::{EntityStateContainer, PartType},
-        entity::{Entity, EntityPart, EntityPartList, EntityPartVisitor, EntityVisitor},
+        entity::{Entity, EntityPart, EntityPartList, EntityPartVisitor},
         meta, Connection, DatabaseItemVisitor,
     },
     DBResult,