|
@@ -1,11 +1,11 @@
|
|
|
use crate::{DBResult, Error};
|
|
|
use libsqlite3_sys as sq;
|
|
|
use std::{
|
|
|
- cell::Cell,
|
|
|
+ cell::{Cell, RefCell},
|
|
|
collections::HashMap,
|
|
|
ffi::{CStr, CString},
|
|
|
pin::Pin,
|
|
|
- sync::{Arc, Mutex},
|
|
|
+ sync::Arc,
|
|
|
};
|
|
|
|
|
|
fn check_rcode<'a>(sql: impl FnOnce() -> Option<&'a str>, rcode: i32) -> Result<(), Error> {
|
|
@@ -24,20 +24,18 @@ fn check_rcode<'a>(sql: impl FnOnce() -> Option<&'a str>, rcode: i32) -> Result<
|
|
|
|
|
|
struct ConnectionData {
|
|
|
sqlite: *mut sq::sqlite3,
|
|
|
- stmts: HashMap<u64, Statement>,
|
|
|
+ stmts: RefCell<HashMap<u64, Statement>>,
|
|
|
}
|
|
|
|
|
|
impl Drop for ConnectionData {
|
|
|
fn drop(&mut self) {
|
|
|
- self.stmts.clear();
|
|
|
+ self.stmts.borrow_mut().clear();
|
|
|
unsafe {
|
|
|
sq::sqlite3_close(self.sqlite);
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-unsafe impl Send for ConnectionData {}
|
|
|
-
|
|
|
pub(crate) trait PreparedKey {
|
|
|
fn into_u64(self) -> u64;
|
|
|
}
|
|
@@ -58,14 +56,13 @@ impl PreparedKey for std::any::TypeId {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-/// Represents a single sqlite connection, in SQLITE_SERIALIZED mode.
|
|
|
-///
|
|
|
-/// This translates to a struct that is Send, but not Sync.
|
|
|
+/// Represents a single sqlite connection, in SQLITE_MULTITHREAD mode.
|
|
|
#[derive(Clone)]
|
|
|
-pub struct Connection(Arc<Mutex<ConnectionData>>);
|
|
|
+pub struct Connection(Arc<ConnectionData>);
|
|
|
|
|
|
impl Connection {
|
|
|
- /// Establish a new connection to a sqlite database object. Note that this type carries no schema information, unlike [`Database`](../schema/traits.Database.html).
|
|
|
+ /// Establish a new connection to a sqlite database object. Note that this type carries no
|
|
|
+ /// schema information, unlike types implementing [`Database`](../schema/traits.Database.html).
|
|
|
pub fn new(url: &str) -> Result<Self, Error> {
|
|
|
let db_ptr = unsafe {
|
|
|
let url = CString::new(url)?;
|
|
@@ -75,7 +72,7 @@ impl Connection {
|
|
|
sq::sqlite3_open_v2(
|
|
|
url.as_ptr(),
|
|
|
&mut db_ptr,
|
|
|
- sq::SQLITE_OPEN_READWRITE | /* sq::SQLITE_OPEN_NOMUTEX |*/ sq::SQLITE_OPEN_CREATE,
|
|
|
+ sq::SQLITE_OPEN_READWRITE | sq::SQLITE_OPEN_NOMUTEX | sq::SQLITE_OPEN_CREATE,
|
|
|
std::ptr::null_mut(),
|
|
|
),
|
|
|
)?;
|
|
@@ -92,15 +89,15 @@ impl Connection {
|
|
|
sq::sqlite3_busy_timeout(db_ptr, 1000);
|
|
|
}
|
|
|
|
|
|
- Ok(Self(Arc::new(Mutex::new(ConnectionData {
|
|
|
+ Ok(Self(Arc::new(ConnectionData {
|
|
|
sqlite: db_ptr,
|
|
|
stmts: Default::default(),
|
|
|
- }))))
|
|
|
+ })))
|
|
|
}
|
|
|
|
|
|
/// Execute a raw SQL statement on the database this connection represents. Use with care.
|
|
|
pub fn execute_raw_sql(&self, sql: impl AsRef<str>) -> DBResult<()> {
|
|
|
- let data = self.0.lock()?;
|
|
|
+ let data = self.0.as_ref();
|
|
|
|
|
|
log::trace!("executing raw sql: {sql}", sql = sql.as_ref());
|
|
|
|
|
@@ -144,11 +141,12 @@ impl Connection {
|
|
|
build_query: impl Fn() -> String,
|
|
|
run_query: impl Fn(StatementContext) -> DBResult<R>,
|
|
|
) -> DBResult<R> {
|
|
|
- let mut data = self.0.lock()?;
|
|
|
+ let data = self.0.as_ref();
|
|
|
let conn = data.sqlite;
|
|
|
|
|
|
use std::collections::hash_map::Entry;
|
|
|
- match data.stmts.entry(hash_key.into_u64()) {
|
|
|
+ let mut stmts = data.stmts.borrow_mut();
|
|
|
+ match stmts.entry(hash_key.into_u64()) {
|
|
|
Entry::Vacant(e) => {
|
|
|
let sql = build_query();
|
|
|
|
|
@@ -184,6 +182,60 @@ impl Connection {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+struct SendWrapper<T: Clone> {
|
|
|
+ value: T,
|
|
|
+}
|
|
|
+
|
|
|
+impl<T: Clone> SendWrapper<T> {
|
|
|
+ fn new(value: T) -> Self {
|
|
|
+ Self { value }
|
|
|
+ }
|
|
|
+
|
|
|
+ fn get(&self) -> T {
|
|
|
+ self.value.clone()
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+unsafe impl<T: Clone> Send for SendWrapper<T> {}
|
|
|
+
|
|
|
+/// Multithreading-safe database connection pool.
|
|
|
+pub struct ConnectionPool {
|
|
|
+ uri: String,
|
|
|
+ connections: std::sync::RwLock<HashMap<std::thread::ThreadId, SendWrapper<Connection>>>,
|
|
|
+}
|
|
|
+
|
|
|
+impl ConnectionPool {
|
|
|
+ /// Construct a new pool from a URI
|
|
|
+ pub fn new(uri: &str) -> Self {
|
|
|
+ Self {
|
|
|
+ uri: uri.into(),
|
|
|
+ connections: Default::default(),
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ /// Retrieve the [`Connection`] for the current thread.
|
|
|
+ pub fn get(&self) -> DBResult<Connection> {
|
|
|
+ let thread_id = std::thread::current().id();
|
|
|
+ // short path: thread already has a connection
|
|
|
+ {
|
|
|
+ let cmap = self.connections.read().expect("poisoned lock");
|
|
|
+ if let Some(conn) = cmap.get(&thread_id) {
|
|
|
+ return Ok(conn.get());
|
|
|
+ }
|
|
|
+ }
|
|
|
+ // long path: need to construct a new connection
|
|
|
+ let nconn = Connection::new(self.uri.as_str())?;
|
|
|
+
|
|
|
+ let mut cmap = self.connections.write().expect("poisoned lock");
|
|
|
+ cmap.insert(thread_id, SendWrapper::new(nconn.clone()));
|
|
|
+
|
|
|
+ Ok(nconn)
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+unsafe impl Send for ConnectionPool {}
|
|
|
+unsafe impl Sync for ConnectionPool {}
|
|
|
+
|
|
|
pub(crate) struct Transaction<'l> {
|
|
|
db: &'l Connection,
|
|
|
committed: bool,
|
|
@@ -619,11 +671,11 @@ mod sendsync_check {
|
|
|
|
|
|
#[test]
|
|
|
fn check_send() {
|
|
|
- let _ = CheckSend::<super::Connection>(Default::default());
|
|
|
+ let _ = CheckSend::<super::ConnectionPool>(Default::default());
|
|
|
}
|
|
|
|
|
|
#[test]
|
|
|
fn check_sync() {
|
|
|
- let _ = CheckSync::<super::Connection>(Default::default());
|
|
|
+ let _ = CheckSync::<super::ConnectionPool>(Default::default());
|
|
|
}
|
|
|
}
|