|
@@ -1,15 +1,28 @@
|
|
|
-use crate::{DBError, DBResult};
|
|
|
-
|
|
|
+use crate::{DBResult, Error};
|
|
|
+use libsqlite3_sys as sq;
|
|
|
use std::{
|
|
|
collections::HashMap,
|
|
|
+ ffi::{CStr, CString},
|
|
|
sync::{Arc, Mutex},
|
|
|
};
|
|
|
|
|
|
-pub type DBConnection = std::sync::Arc<Connection>;
|
|
|
+fn check_rcode(sql: Option<&str>, rcode: i32) -> Result<(), Error> {
|
|
|
+ if rcode == sq::SQLITE_OK {
|
|
|
+ Ok(())
|
|
|
+ } else {
|
|
|
+ Err(Error::Sqlite {
|
|
|
+ code: rcode,
|
|
|
+ msg: unsafe { CStr::from_ptr(sq::sqlite3_errstr(rcode)) }
|
|
|
+ .to_str()?
|
|
|
+ .to_string(),
|
|
|
+ sql: sql.map(|s| s.to_string()),
|
|
|
+ })
|
|
|
+ }
|
|
|
+}
|
|
|
|
|
|
-pub(crate) struct CachedStatement {
|
|
|
- stmt: sqlite::Statement<'static>,
|
|
|
- sql: String,
|
|
|
+struct ConnectionData {
|
|
|
+ sqlite: *mut sq::sqlite3,
|
|
|
+ stmts: HashMap<u64, Statement>,
|
|
|
}
|
|
|
|
|
|
pub(crate) trait PreparedKey {
|
|
@@ -32,6 +45,377 @@ impl PreparedKey for std::any::TypeId {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+/// Represents a single sqlite connection, in SQLITE_MULTITHREADED mode.
|
|
|
+///
|
|
|
+/// This translates to a struct that is Send, but not Sync.
|
|
|
+#[derive(Clone)]
|
|
|
+pub struct Connection(Arc<Mutex<ConnectionData>>);
|
|
|
+
|
|
|
+impl Connection {
|
|
|
+ pub fn new(url: &str) -> Result<Self, Error> {
|
|
|
+ let db_ptr = unsafe {
|
|
|
+ let url = CString::new(url)?;
|
|
|
+ let mut db_ptr = std::ptr::null_mut();
|
|
|
+ check_rcode(
|
|
|
+ None,
|
|
|
+ sq::sqlite3_open_v2(
|
|
|
+ url.as_ptr(),
|
|
|
+ &mut db_ptr,
|
|
|
+ sq::SQLITE_OPEN_READWRITE | sq::SQLITE_OPEN_NOMUTEX | sq::SQLITE_OPEN_CREATE,
|
|
|
+ std::ptr::null_mut(),
|
|
|
+ ),
|
|
|
+ )?;
|
|
|
+ db_ptr
|
|
|
+ };
|
|
|
+
|
|
|
+ if db_ptr.is_null() {
|
|
|
+ return Err(Error::InternalError(
|
|
|
+ "sqlite3_open_v2 returned a NULL connection",
|
|
|
+ ));
|
|
|
+ }
|
|
|
+
|
|
|
+ Ok(Self(Arc::new(Mutex::new(ConnectionData {
|
|
|
+ sqlite: db_ptr,
|
|
|
+ stmts: Default::default(),
|
|
|
+ }))))
|
|
|
+ }
|
|
|
+
|
|
|
+ pub fn execute_raw_sql(&self, sql: impl AsRef<str>) -> DBResult<()> {
|
|
|
+ let data = self.0.lock()?;
|
|
|
+
|
|
|
+ unsafe {
|
|
|
+ let c_sql = CString::new(sql.as_ref())?;
|
|
|
+ let mut err = std::ptr::null_mut();
|
|
|
+ let rcode = sq::sqlite3_exec(
|
|
|
+ data.sqlite,
|
|
|
+ c_sql.as_ptr(),
|
|
|
+ None,
|
|
|
+ std::ptr::null_mut(),
|
|
|
+ &mut err,
|
|
|
+ );
|
|
|
+
|
|
|
+ // special error handling because of the err string
|
|
|
+ if rcode != sq::SQLITE_OK {
|
|
|
+ let e = Error::Sqlite {
|
|
|
+ code: rcode,
|
|
|
+ msg: if err == std::ptr::null_mut() {
|
|
|
+ CStr::from_ptr(sq::sqlite3_errstr(rcode))
|
|
|
+ } else {
|
|
|
+ CStr::from_ptr(err)
|
|
|
+ }
|
|
|
+ .to_str()?
|
|
|
+ .to_string(),
|
|
|
+ sql: Some(sql.as_ref().into()),
|
|
|
+ };
|
|
|
+ if err != std::ptr::null_mut() {
|
|
|
+ sq::sqlite3_free(err.cast());
|
|
|
+ }
|
|
|
+ return Err(e);
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ Ok(())
|
|
|
+ }
|
|
|
+
|
|
|
+ pub(crate) fn with_prepared<R>(
|
|
|
+ &self,
|
|
|
+ hash_key: impl PreparedKey,
|
|
|
+ build_query: impl Fn() -> String,
|
|
|
+ run_query: impl Fn(StatementContext) -> DBResult<R>,
|
|
|
+ ) -> DBResult<R> {
|
|
|
+ let mut data = self.0.lock()?;
|
|
|
+ let conn = data.sqlite;
|
|
|
+
|
|
|
+ use std::collections::hash_map::Entry;
|
|
|
+ match data.stmts.entry(hash_key.into_u64()) {
|
|
|
+ Entry::Vacant(e) => {
|
|
|
+ let sql = build_query();
|
|
|
+
|
|
|
+ // prepare the statement
|
|
|
+ let mut stmt = std::ptr::null_mut();
|
|
|
+ unsafe {
|
|
|
+ check_rcode(
|
|
|
+ Some(sql.as_str()),
|
|
|
+ sq::sqlite3_prepare_v2(
|
|
|
+ conn,
|
|
|
+ sql.as_ptr().cast(),
|
|
|
+ sql.len() as i32,
|
|
|
+ &mut stmt,
|
|
|
+ std::ptr::null_mut(),
|
|
|
+ ),
|
|
|
+ )?;
|
|
|
+ };
|
|
|
+
|
|
|
+ if stmt == std::ptr::null_mut() {
|
|
|
+ return Err(Error::InternalError(
|
|
|
+ "sqlite3_prepare_v2 returned a NULL stmt",
|
|
|
+ ));
|
|
|
+ }
|
|
|
+
|
|
|
+ let stmt = e.insert(Statement { sqlite: conn, stmt });
|
|
|
+
|
|
|
+ run_query(stmt.make_context()?)
|
|
|
+ }
|
|
|
+ Entry::Occupied(mut e) => run_query(e.get_mut().make_context()?),
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+unsafe impl Send for Connection {}
|
|
|
+
|
|
|
+struct Statement {
|
|
|
+ sqlite: *mut sq::sqlite3,
|
|
|
+ stmt: *mut sq::sqlite3_stmt,
|
|
|
+}
|
|
|
+
|
|
|
+impl Statement {
|
|
|
+ fn make_context(&mut self) -> DBResult<StatementContext> {
|
|
|
+ // begin by resetting the statement
|
|
|
+ unsafe {
|
|
|
+ check_rcode(None, sq::sqlite3_reset(self.stmt))?;
|
|
|
+ }
|
|
|
+ Ok(StatementContext { stmt: self })
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+#[cfg(test)]
|
|
|
+mod test {
|
|
|
+ use super::Connection;
|
|
|
+
|
|
|
+ #[test]
|
|
|
+ fn simple_sql() {
|
|
|
+ let c = Connection::new(":memory:").expect("couldn't open test db");
|
|
|
+ c.execute_raw_sql("CREATE TABLE test_table (id integer primary key, value string)")
|
|
|
+ .expect("couldn't execute sql");
|
|
|
+ }
|
|
|
+
|
|
|
+ #[test]
|
|
|
+ fn prepare_stmt() {
|
|
|
+ let c = Connection::new(":memory:").expect("couldn't open test db");
|
|
|
+ c.execute_raw_sql("CREATE TABLE test_table (id integer primary key, value string)")
|
|
|
+ .expect("couldn't execute sql");
|
|
|
+
|
|
|
+ c.with_prepared(
|
|
|
+ 1,
|
|
|
+ || format!("INSERT INTO test_table VALUES (?, ?)"),
|
|
|
+ |ctx| {
|
|
|
+ ctx.bind(1, 1usize)?;
|
|
|
+ ctx.bind(2, "value")?;
|
|
|
+
|
|
|
+ ctx.iter().last();
|
|
|
+
|
|
|
+ Ok(())
|
|
|
+ },
|
|
|
+ )
|
|
|
+ .expect("couldn't run prepared INSERT statement");
|
|
|
+
|
|
|
+ c.with_prepared(
|
|
|
+ 2,
|
|
|
+ || format!("SELECT * FROM test_table"),
|
|
|
+ |ctx| {
|
|
|
+ let count = ctx
|
|
|
+ .iter()
|
|
|
+ .map(|row| {
|
|
|
+ assert_eq!(row.read::<i64>(0).expect("couldn't read row ID"), 1);
|
|
|
+ assert_eq!(
|
|
|
+ row.read::<String>(1).expect("couldn't read row value"),
|
|
|
+ "value"
|
|
|
+ );
|
|
|
+ })
|
|
|
+ .count();
|
|
|
+ assert!(count > 0);
|
|
|
+
|
|
|
+ Ok(())
|
|
|
+ },
|
|
|
+ )
|
|
|
+ .expect("couldn't run prepared SELECT statement");
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+pub struct StatementRow<'a> {
|
|
|
+ stmt: &'a Statement,
|
|
|
+}
|
|
|
+
|
|
|
+impl<'a> StatementRow<'a> {
|
|
|
+ pub fn read<T: Readable>(&self, index: i32) -> DBResult<T> {
|
|
|
+ T::read_from(self, index)
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+pub struct StatementContext<'a> {
|
|
|
+ stmt: &'a Statement,
|
|
|
+}
|
|
|
+
|
|
|
+impl<'a> StatementContext<'a> {
|
|
|
+ pub fn bind<B: Bindable>(&self, index: i32, bindable: B) -> DBResult<()> {
|
|
|
+ bindable.bind_to(self, index)
|
|
|
+ }
|
|
|
+
|
|
|
+ fn step(&self) -> Option<()> {
|
|
|
+ match unsafe { sq::sqlite3_step(self.stmt.stmt) } {
|
|
|
+ sq::SQLITE_ROW => Some(()),
|
|
|
+ sq::SQLITE_DONE => None,
|
|
|
+ _ => {
|
|
|
+ // check_rcode(None, v)?;
|
|
|
+ // Ok(false)
|
|
|
+ None
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ pub fn run(self) -> DBResult<Option<StatementRow<'a>>> {
|
|
|
+ if self.step().is_some() {
|
|
|
+ Ok(Some(StatementRow { stmt: self.stmt }))
|
|
|
+ } else {
|
|
|
+ Ok(None)
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ pub fn iter(self) -> impl Iterator<Item = StatementRow<'a>> {
|
|
|
+ struct I<'a>(StatementContext<'a>);
|
|
|
+
|
|
|
+ impl<'a> Iterator for I<'a> {
|
|
|
+ type Item = StatementRow<'a>;
|
|
|
+
|
|
|
+ fn next(&mut self) -> Option<Self::Item> {
|
|
|
+ self.0.step().map(|_| StatementRow { stmt: self.0.stmt })
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ I(self)
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+impl<'a> Drop for StatementContext<'a> {
|
|
|
+ fn drop(&mut self) {
|
|
|
+ // attempt to bind NULLs into each parameter
|
|
|
+ unsafe {
|
|
|
+ sq::sqlite3_clear_bindings(self.stmt.stmt);
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+pub trait Bindable {
|
|
|
+ fn bind_to(&self, ctx: &StatementContext, index: i32) -> DBResult<()>;
|
|
|
+}
|
|
|
+
|
|
|
+impl<'a> Bindable for () {
|
|
|
+ fn bind_to(&self, ctx: &StatementContext, index: i32) -> DBResult<()> {
|
|
|
+ unsafe { check_rcode(None, sq::sqlite3_bind_null(ctx.stmt.stmt, index)) }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+impl<'a> Bindable for i64 {
|
|
|
+ fn bind_to(&self, ctx: &StatementContext, index: i32) -> DBResult<()> {
|
|
|
+ unsafe { check_rcode(None, sq::sqlite3_bind_int64(ctx.stmt.stmt, index, *self)) }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+impl<'a> Bindable for usize {
|
|
|
+ fn bind_to(&self, ctx: &StatementContext, index: i32) -> DBResult<()> {
|
|
|
+ (*self as i64).bind_to(ctx, index)
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+impl<'a> Bindable for &'a str {
|
|
|
+ fn bind_to(&self, ctx: &StatementContext, index: i32) -> DBResult<()> {
|
|
|
+ unsafe {
|
|
|
+ check_rcode(
|
|
|
+ None,
|
|
|
+ sq::sqlite3_bind_text(
|
|
|
+ ctx.stmt.stmt,
|
|
|
+ index,
|
|
|
+ self.as_ptr().cast(),
|
|
|
+ self.len() as i32,
|
|
|
+ sq::SQLITE_STATIC(),
|
|
|
+ ),
|
|
|
+ )
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+impl Bindable for str {
|
|
|
+ fn bind_to(&self, ctx: &StatementContext, index: i32) -> DBResult<()> {
|
|
|
+ <&'_ str>::bind_to(&self, ctx, index)
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+impl Bindable for String {
|
|
|
+ fn bind_to(&self, ctx: &StatementContext, index: i32) -> DBResult<()> {
|
|
|
+ self.as_str().bind_to(ctx, index)
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+impl<'a> Bindable for &'a [u8] {
|
|
|
+ fn bind_to(&self, ctx: &StatementContext, index: i32) -> DBResult<()> {
|
|
|
+ unsafe {
|
|
|
+ check_rcode(
|
|
|
+ None,
|
|
|
+ sq::sqlite3_bind_blob64(
|
|
|
+ ctx.stmt.stmt,
|
|
|
+ index,
|
|
|
+ self.as_ptr().cast(),
|
|
|
+ self.len() as u64,
|
|
|
+ sq::SQLITE_STATIC(),
|
|
|
+ ),
|
|
|
+ )
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+pub trait Readable: Sized {
|
|
|
+ fn read_from(sr: &StatementRow<'_>, index: i32) -> DBResult<Self>;
|
|
|
+}
|
|
|
+
|
|
|
+impl Readable for i64 {
|
|
|
+ fn read_from(sr: &StatementRow<'_>, index: i32) -> DBResult<Self> {
|
|
|
+ unsafe { Ok(sq::sqlite3_column_int64(sr.stmt.stmt, index)) }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+impl Readable for String {
|
|
|
+ fn read_from(sr: &StatementRow<'_>, index: i32) -> DBResult<Self> {
|
|
|
+ unsafe {
|
|
|
+ let text = sq::sqlite3_column_text(sr.stmt.stmt, index);
|
|
|
+ if text.is_null() {
|
|
|
+ Err(Error::InternalError(
|
|
|
+ "NULL pointer result from sqlite3_column_text",
|
|
|
+ ))
|
|
|
+ } else {
|
|
|
+ Ok(CStr::from_ptr(text.cast()).to_str()?.to_string())
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+impl Readable for Vec<u8> {
|
|
|
+ fn read_from(sr: &StatementRow<'_>, index: i32) -> DBResult<Self> {
|
|
|
+ unsafe {
|
|
|
+ let ptr = sq::sqlite3_column_blob(sr.stmt.stmt, index);
|
|
|
+ let len = sq::sqlite3_column_bytes(sr.stmt.stmt, index);
|
|
|
+
|
|
|
+ if len == 0 {
|
|
|
+ Ok(vec![])
|
|
|
+ } else if len > 0 {
|
|
|
+ Ok(std::slice::from_raw_parts(ptr.cast(), len as usize).to_vec())
|
|
|
+ } else {
|
|
|
+ Err(Error::InternalError(
|
|
|
+ "negative length returned from sqlite3_column_bytes",
|
|
|
+ ))
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+/*
|
|
|
+
|
|
|
+pub type DBConnection = std::sync::Arc<Connection>;
|
|
|
+
|
|
|
+pub(crate) struct CachedStatement {
|
|
|
+ stmt: sqlite::Statement<'static>,
|
|
|
+ sql: String,
|
|
|
+}
|
|
|
+
|
|
|
pub struct Connection {
|
|
|
// we leak the ConnectionThreadSafe and make sure that the only references to it are stored in
|
|
|
// statement_cache, so as long as we drop the statement_cache first there are no correctness
|
|
@@ -54,7 +438,7 @@ mod sendsync_check {
|
|
|
*/
|
|
|
|
|
|
impl Connection {
|
|
|
- pub fn open<U: AsRef<str>>(uri: U) -> Result<DBConnection, DBError> {
|
|
|
+ pub fn open<U: AsRef<str>>(uri: U) -> Result<DBConnection, Error> {
|
|
|
match sqlite::Connection::open_thread_safe_with_flags(
|
|
|
uri.as_ref(),
|
|
|
sqlite::OpenFlags::new()
|
|
@@ -69,11 +453,11 @@ impl Connection {
|
|
|
statement_cache: Default::default(),
|
|
|
}))
|
|
|
}
|
|
|
- Err(e) => Err(DBError::Sqlite(e)),
|
|
|
+ Err(e) => Err(Error::Sqlite(e)),
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- pub(crate) fn execute_raw(&self, sql: &str) -> Result<(), DBError> {
|
|
|
+ pub(crate) fn execute_raw(&self, sql: &str) -> Result<(), Error> {
|
|
|
Ok(self.conn.execute(sql)?)
|
|
|
}
|
|
|
|
|
@@ -90,7 +474,7 @@ impl Connection {
|
|
|
let q: sqlite::Statement<'static> = self
|
|
|
.conn
|
|
|
.prepare(sql.as_str())
|
|
|
- .map_err(|e| DBError::from(e).sqlite_to_query(sql.as_str()))?;
|
|
|
+ .map_err(|e| Error::from(e).sqlite_to_query(sql.as_str()))?;
|
|
|
|
|
|
log::trace!("prepared new SQL query: {sql}");
|
|
|
|
|
@@ -125,3 +509,5 @@ impl std::fmt::Debug for Connection {
|
|
|
f.write_str("microrm::Connection")
|
|
|
}
|
|
|
}
|
|
|
+
|
|
|
+*/
|