Pārlūkot izejas kodu

Make db interface slightly more resilient and fixed a few bugs.

Kestrel 7 mēneši atpakaļ
vecāks
revīzija
bcc23f46b7
4 mainītis faili ar 32 papildinājumiem un 29 dzēšanām
  1. 27 23
      microrm/src/db.rs
  2. 2 5
      microrm/src/query.rs
  3. 1 1
      microrm/src/schema/build.rs
  4. 2 0
      microrm/src/schema/tests.rs

+ 27 - 23
microrm/src/db.rs

@@ -1,10 +1,7 @@
 use crate::{DBResult, Error};
 use libsqlite3_sys as sq;
 use std::{
-    collections::HashMap,
-    ffi::{CStr, CString},
-    pin::Pin,
-    sync::{Arc, Mutex},
+    cell::Cell, collections::HashMap, ffi::{CStr, CString}, pin::Pin, sync::{Arc, Mutex}
 };
 
 fn check_rcode<'a>(sql: impl FnOnce() -> Option<&'a str>, rcode: i32) -> Result<(), Error> {
@@ -175,8 +172,10 @@ impl Connection {
                 let stmt = e.insert(Statement { sqlite: conn, stmt });
 
                 run_query(stmt.make_context()?)
-            }
-            Entry::Occupied(mut e) => run_query(e.get_mut().make_context()?),
+            },
+            Entry::Occupied(mut e) => {
+                run_query(e.get_mut().make_context()?)
+            },
         }
     }
 }
@@ -189,13 +188,6 @@ pub(crate) struct Transaction<'l> {
 impl<'l> Transaction<'l> {
     pub fn new(db: &'l Connection) -> DBResult<Self> {
         db.execute_raw_sql("BEGIN TRANSACTION")?;
-        /*struct BeginQuery;
-        db.with_prepared(
-            std::any::TypeId::of::<BeginQuery>(),
-            || "BEGIN TRANSACTION".to_string(),
-            |ctx| {
-                ctx.run().map(|_| ())
-            })?; */
         Ok(Self {
             db,
             committed: false,
@@ -206,8 +198,8 @@ impl<'l> Transaction<'l> {
         self.committed = true;
 
         self.db.execute_raw_sql("COMMIT")
-        /*
-        struct CommitQuery;
+
+        /*struct CommitQuery;
         self.db.with_prepared(
             std::any::TypeId::of::<CommitQuery>(),
             || "COMMIT".to_string(),
@@ -220,6 +212,7 @@ impl<'l> Transaction<'l> {
 impl<'l> Drop for Transaction<'l> {
     fn drop(&mut self) {
         if !self.committed {
+            let _ = self.db.execute_raw_sql("ROLLBACK");
             /*
             struct AbortQuery;
             let _ = self.db.with_prepared(
@@ -227,8 +220,8 @@ impl<'l> Drop for Transaction<'l> {
                 || "ROLLBACK".to_string(),
                 |ctx| {
                     ctx.run().map(|_| ())
-                });*/
-            let _ = self.db.execute_raw_sql("ROLLBACK");
+                });
+            */
         }
     }
 }
@@ -249,6 +242,7 @@ impl Statement {
         Ok(StatementContext {
             stmt: self,
             owned_strings: Default::default(),
+            done: false.into()
         })
     }
 }
@@ -330,6 +324,7 @@ impl<'a> StatementRow<'a> {
 pub struct StatementContext<'a> {
     stmt: &'a Statement,
     owned_strings: Vec<Pin<Box<String>>>,
+    done: Cell<bool>,
 }
 
 impl<'a> StatementContext<'a> {
@@ -342,17 +337,26 @@ impl<'a> StatementContext<'a> {
     }
 
     fn step(&self) -> DBResult<bool> {
+        if self.done.get() {
+            return Ok(false)
+        }
         match unsafe { sq::sqlite3_step(self.stmt.stmt) } {
             sq::SQLITE_ROW => Ok(true),
-            sq::SQLITE_DONE => Ok(false),
+            sq::SQLITE_DONE => {
+                self.done.set(true);
+                Ok(false)
+            },
             sq::SQLITE_BUSY => {
                 log::trace!("Concurrent database access!");
                 todo!()
-            }
+            },
             sq::SQLITE_CONSTRAINT => {
-                log::trace!("SQLite constraint violation");
-                return Err(Error::LogicError("constraint violation"));
-            }
+                let msg = unsafe {
+                    CStr::from_ptr(sq::sqlite3_errmsg(self.stmt.sqlite))
+                }.to_str().unwrap().to_string();
+                log::trace!("SQLite constraint violation: {msg}");
+                return Err(Error::ConstraintViolation(""))
+            },
             err => {
                 log::trace!("unexpected error during sqlite3_step: {:?}", err);
                 check_rcode(|| None, err)?;
@@ -605,7 +609,7 @@ impl Readable for Vec<u8> {
 #[cfg(test)]
 mod sendsync_check {
     struct CheckSend<T: Send>(std::marker::PhantomData<T>);
-    struct CheckSync<T: Send>(std::marker::PhantomData<T>);
+    struct CheckSync<T: Sync>(std::marker::PhantomData<T>);
 
     #[test]
     fn check_send() {

+ 2 - 5
microrm/src/query.rs

@@ -332,11 +332,9 @@ fn do_connect<Remote: Entity>(
             ctx.bind(1, adata.local_id)?;
             ctx.bind(2, remote_id.into_raw())?;
 
-            println!("bound values ({:?}, {:?})", adata.local_id, remote_id);
-
             ctx.run()?
-                .ok_or(Error::LogicError("Already connected"))
-                .map(|v| println!("v: {:?}", v.read::<i64>(0)))
+                .ok_or(Error::ConstraintViolation("Already connected"))
+                .map(|_| ())
         },
     )
 }
@@ -349,7 +347,6 @@ pub trait AssocInterface: 'static {
 
     fn query_all(&self) -> impl Queryable<EntityOutput = Self::RemoteEntity> {
         components::TableComponent::<Self::RemoteEntity>::new(self.get_data().unwrap().conn.clone())
-        // IDMap::<Self::RemoteEntity>::build(self.get_data().unwrap().conn.clone())
     }
 
     fn connect_to(&self, remote_id: <Self::RemoteEntity as Entity>::ID) -> DBResult<()>

+ 1 - 1
microrm/src/schema/build.rs

@@ -220,7 +220,7 @@ pub(crate) fn collect_from_database<DB: Database>() -> DatabaseSchema {
             table.constraints.push(format!(
                 "/* keying index */ unique({})",
                 key.into_iter()
-                    .map(|s| s.name.to_string())
+                    .map(|s| format!("`{}`", s.name))
                     .reduce(|a, b| format!("{},{}", a, b))
                     .unwrap()
             ));

+ 2 - 0
microrm/src/schema/tests.rs

@@ -596,6 +596,8 @@ mod join_test {
             })
             .expect("couldn't insert target");
 
+        log::trace!("looking up base by ID {:?}", b1id);
+
         let b1 = db
             .bases
             .by_id(b1id)