Browse Source

Several enhancements.

Kestrel 9 months ago
parent
commit
ef7dcca3ea

+ 20 - 0
microrm-macros/src/entity.rs

@@ -139,6 +139,17 @@ pub fn derive(tokens: proc_macro::TokenStream) -> proc_macro::TokenStream {
         })
         })
         .collect::<Vec<_>>();
         .collect::<Vec<_>>();
 
 
+    let debug_fields = parts
+        .iter()
+        .map(|part| {
+            let ident = &part.0;
+            let field = ident.to_string();
+            quote! {
+                self . #ident . debug_field(#field, &mut ds);
+            }
+        })
+        .collect::<Vec<_>>();
+
     let parts_list = make_part_list(&parts);
     let parts_list = make_part_list(&parts);
     let uniques_list = make_part_list(&unique_parts);
     let uniques_list = make_part_list(&unique_parts);
 
 
@@ -197,6 +208,15 @@ pub fn derive(tokens: proc_macro::TokenStream) -> proc_macro::TokenStream {
             }
             }
         }
         }
 
 
+        impl ::std::fmt::Debug for #entity_ident {
+            fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> Result<(), ::std::fmt::Error> {
+                use ::microrm::schema::datum::Datum;
+                let mut ds = f.debug_struct(#entity_name);
+                #(#debug_fields)*
+                ds.finish()
+            }
+        }
+
         impl ::microrm::schema::entity::Entity for #entity_ident {
         impl ::microrm::schema::entity::Entity for #entity_ident {
             type Parts = #parts_list;
             type Parts = #parts_list;
             type Uniques = #uniques_list;
             type Uniques = #uniques_list;

+ 1 - 0
microrm/Cargo.toml

@@ -24,6 +24,7 @@ log = "0.4.17"
 topological-sort = { version = "0.2" }
 topological-sort = { version = "0.2" }
 
 
 [dev-dependencies]
 [dev-dependencies]
+test-log = "0.2.15"
 # criterion = "0.5"
 # criterion = "0.5"
 # rand = "0.8.5"
 # rand = "0.8.5"
 # stats_alloc = "0.1.10"
 # stats_alloc = "0.1.10"

+ 132 - 36
microrm/src/db.rs

@@ -7,7 +7,7 @@ use std::{
     sync::{Arc, Mutex},
     sync::{Arc, Mutex},
 };
 };
 
 
-fn check_rcode(sql: Option<&str>, rcode: i32) -> Result<(), Error> {
+fn check_rcode<'a>(sql: impl FnOnce() -> Option<&'a str>, rcode: i32) -> Result<(), Error> {
     if rcode == sq::SQLITE_OK {
     if rcode == sq::SQLITE_OK {
         Ok(())
         Ok(())
     } else {
     } else {
@@ -16,7 +16,7 @@ fn check_rcode(sql: Option<&str>, rcode: i32) -> Result<(), Error> {
             msg: unsafe { CStr::from_ptr(sq::sqlite3_errstr(rcode)) }
             msg: unsafe { CStr::from_ptr(sq::sqlite3_errstr(rcode)) }
                 .to_str()?
                 .to_str()?
                 .to_string(),
                 .to_string(),
-            sql: sql.map(|s| s.to_string()),
+            sql: sql().map(String::from),
         })
         })
     }
     }
 }
 }
@@ -69,7 +69,7 @@ impl Connection {
             let url = CString::new(url)?;
             let url = CString::new(url)?;
             let mut db_ptr = std::ptr::null_mut();
             let mut db_ptr = std::ptr::null_mut();
             check_rcode(
             check_rcode(
-                None,
+                || None,
                 sq::sqlite3_open_v2(
                 sq::sqlite3_open_v2(
                     url.as_ptr(),
                     url.as_ptr(),
                     &mut db_ptr,
                     &mut db_ptr,
@@ -86,6 +86,10 @@ impl Connection {
             ));
             ));
         }
         }
 
 
+        unsafe {
+            sq::sqlite3_busy_timeout(db_ptr, 1000);
+        }
+
         Ok(Self(Arc::new(Mutex::new(ConnectionData {
         Ok(Self(Arc::new(Mutex::new(ConnectionData {
             sqlite: db_ptr,
             sqlite: db_ptr,
             stmts: Default::default(),
             stmts: Default::default(),
@@ -95,6 +99,8 @@ impl Connection {
     pub fn execute_raw_sql(&self, sql: impl AsRef<str>) -> DBResult<()> {
     pub fn execute_raw_sql(&self, sql: impl AsRef<str>) -> DBResult<()> {
         let data = self.0.lock()?;
         let data = self.0.lock()?;
 
 
+        println!("executing: {sql}", sql = sql.as_ref());
+
         unsafe {
         unsafe {
             let c_sql = CString::new(sql.as_ref())?;
             let c_sql = CString::new(sql.as_ref())?;
             let mut err = std::ptr::null_mut();
             let mut err = std::ptr::null_mut();
@@ -143,11 +149,13 @@ impl Connection {
             Entry::Vacant(e) => {
             Entry::Vacant(e) => {
                 let sql = build_query();
                 let sql = build_query();
 
 
+                log::trace!("preparing query: {sql}");
+
                 // prepare the statement
                 // prepare the statement
                 let mut stmt = std::ptr::null_mut();
                 let mut stmt = std::ptr::null_mut();
                 unsafe {
                 unsafe {
                     check_rcode(
                     check_rcode(
-                        Some(sql.as_str()),
+                        || Some(sql.as_str()),
                         sq::sqlite3_prepare_v2(
                         sq::sqlite3_prepare_v2(
                             conn,
                             conn,
                             sql.as_ptr().cast(),
                             sql.as_ptr().cast(),
@@ -173,6 +181,59 @@ impl Connection {
     }
     }
 }
 }
 
 
+pub(crate) struct Transaction<'l> {
+    db: &'l Connection,
+    committed: bool,
+}
+
+impl<'l> Transaction<'l> {
+    pub fn new(db: &'l Connection) -> DBResult<Self> {
+        println!("backtrace: {}", std::backtrace::Backtrace::force_capture());
+        db.execute_raw_sql("BEGIN TRANSACTION")?;
+        /*struct BeginQuery;
+        db.with_prepared(
+            std::any::TypeId::of::<BeginQuery>(),
+            || "BEGIN".to_string(),
+            |ctx| {
+                ctx.run().map(|_| ())
+            })?; */
+        Ok(Self {
+            db,
+            committed: false,
+        })
+    }
+
+    pub fn commit(mut self) -> DBResult<()> {
+        self.committed = true;
+
+        self.db.execute_raw_sql("COMMIT")
+        /*
+        struct CommitQuery;
+        self.db.with_prepared(
+            std::any::TypeId::of::<CommitQuery>(),
+            || "COMMIT".to_string(),
+            |ctx| {
+                ctx.run().map(|_| ())
+            })*/
+    }
+}
+
+impl<'l> Drop for Transaction<'l> {
+    fn drop(&mut self) {
+        if !self.committed {
+            /*
+            struct AbortQuery;
+            let _ = self.db.with_prepared(
+                std::any::TypeId::of::<AbortQuery>(),
+                || "ROLLBACK".to_string(),
+                |ctx| {
+                    ctx.run().map(|_| ())
+                });*/
+            let _ = self.db.execute_raw_sql("ROLLBACK");
+        }
+    }
+}
+
 struct Statement {
 struct Statement {
     #[allow(unused)]
     #[allow(unused)]
     sqlite: *mut sq::sqlite3,
     sqlite: *mut sq::sqlite3,
@@ -183,8 +244,11 @@ impl Statement {
     fn make_context(&mut self) -> DBResult<StatementContext> {
     fn make_context(&mut self) -> DBResult<StatementContext> {
         // begin by resetting the statement
         // begin by resetting the statement
         unsafe {
         unsafe {
-            check_rcode(None, sq::sqlite3_reset(self.stmt))?;
+            check_rcode(|| None, sq::sqlite3_reset(self.stmt))?;
         }
         }
+
+        let v = unsafe { CStr::from_ptr(sq::sqlite3_sql(self.stmt)).to_str().unwrap() };
+        println!("making Statement context for SQL: {}", v);
         Ok(StatementContext {
         Ok(StatementContext {
             stmt: self,
             stmt: self,
             owned_strings: Default::default(),
             owned_strings: Default::default(),
@@ -256,7 +320,7 @@ mod test {
 
 
 pub struct StatementRow<'a> {
 pub struct StatementRow<'a> {
     stmt: &'a Statement,
     stmt: &'a Statement,
-    owned: Option<Vec<Pin<Box<String>>>>,
+    _ctx: Option<StatementContext<'a>>,
 }
 }
 
 
 impl<'a> StatementRow<'a> {
 impl<'a> StatementRow<'a> {
@@ -272,7 +336,7 @@ pub struct StatementContext<'a> {
 
 
 impl<'a> StatementContext<'a> {
 impl<'a> StatementContext<'a> {
     pub fn bind<B: Bindable>(&self, index: i32, bindable: B) -> DBResult<()> {
     pub fn bind<B: Bindable>(&self, index: i32, bindable: B) -> DBResult<()> {
-        bindable.bind_to(self, index)
+        bindable.bind(self, index)
     }
     }
 
 
     pub fn transfer(&mut self, s: Pin<Box<String>>) {
     pub fn transfer(&mut self, s: Pin<Box<String>>) {
@@ -281,23 +345,34 @@ impl<'a> StatementContext<'a> {
 
 
     fn step(&self) -> Option<()> {
     fn step(&self) -> Option<()> {
         match unsafe { sq::sqlite3_step(self.stmt.stmt) } {
         match unsafe { sq::sqlite3_step(self.stmt.stmt) } {
-            sq::SQLITE_ROW => Some(()),
-            sq::SQLITE_DONE => None,
-            _ => {
-                // check_rcode(None, v)?;
+            sq::SQLITE_ROW => {
+                println!("sqlite3_step: row");
+                Some(())
+            }
+            sq::SQLITE_DONE => {
+                println!("sqlite3_step: done");
+                None
+            }
+            sq::SQLITE_BUSY => {
+                println!("Concurrent database access!");
+                None
+            }
+            err => {
+                println!("unexpected error during sqlite3_step: {:?}", err);
+                // let _ = check_rcode(|| None, err);
                 // Ok(false)
                 // Ok(false)
                 None
                 None
             }
             }
         }
         }
     }
     }
 
 
-    pub fn run(mut self) -> DBResult<Option<StatementRow<'a>>> {
+    // this needs to be replaced with a "single" version that keeps the StatementContext alive, or
+    // StatementRow needs an optional StatementContext to keep alive
+    pub fn run(self) -> DBResult<Option<StatementRow<'a>>> {
         if self.step().is_some() {
         if self.step().is_some() {
-            let mut owned = vec![];
-            owned.append(&mut self.owned_strings);
             Ok(Some(StatementRow {
             Ok(Some(StatementRow {
-                owned: Some(owned),
                 stmt: self.stmt,
                 stmt: self.stmt,
+                _ctx: Some(self),
             }))
             }))
         } else {
         } else {
             Ok(None)
             Ok(None)
@@ -312,7 +387,7 @@ impl<'a> StatementContext<'a> {
 
 
             fn next(&mut self) -> Option<Self::Item> {
             fn next(&mut self) -> Option<Self::Item> {
                 self.0.step().map(|_| StatementRow {
                 self.0.step().map(|_| StatementRow {
-                    owned: None,
+                    _ctx: None,
                     stmt: self.0.stmt,
                     stmt: self.0.stmt,
                 })
                 })
             }
             }
@@ -326,13 +401,16 @@ impl<'a> Drop for StatementContext<'a> {
     fn drop(&mut self) {
     fn drop(&mut self) {
         // attempt to bind NULLs into each parameter
         // attempt to bind NULLs into each parameter
         unsafe {
         unsafe {
+            println!("clearing bindings...");
+            // clear out the rest of the rows
+            while self.step().is_some() {}
             sq::sqlite3_clear_bindings(self.stmt.stmt);
             sq::sqlite3_clear_bindings(self.stmt.stmt);
         }
         }
     }
     }
 }
 }
 
 
 pub trait Bindable {
 pub trait Bindable {
-    fn bind_to<'ctx, 'data: 'ctx>(
+    fn bind<'ctx, 'data: 'ctx>(
         &'data self,
         &'data self,
         ctx: &StatementContext<'ctx>,
         ctx: &StatementContext<'ctx>,
         index: i32,
         index: i32,
@@ -340,56 +418,61 @@ pub trait Bindable {
 }
 }
 
 
 impl Bindable for () {
 impl Bindable for () {
-    fn bind_to<'ctx, 'data: 'ctx>(
+    fn bind<'ctx, 'data: 'ctx>(
         &'data self,
         &'data self,
         ctx: &StatementContext<'ctx>,
         ctx: &StatementContext<'ctx>,
         index: i32,
         index: i32,
     ) -> DBResult<()> {
     ) -> DBResult<()> {
-        unsafe { check_rcode(None, sq::sqlite3_bind_null(ctx.stmt.stmt, index)) }
+        unsafe { check_rcode(|| None, sq::sqlite3_bind_null(ctx.stmt.stmt, index)) }
     }
     }
 }
 }
 
 
 impl Bindable for i64 {
 impl Bindable for i64 {
-    fn bind_to<'ctx, 'data: 'ctx>(
+    fn bind<'ctx, 'data: 'ctx>(
         &'data self,
         &'data self,
         ctx: &StatementContext<'ctx>,
         ctx: &StatementContext<'ctx>,
         index: i32,
         index: i32,
     ) -> DBResult<()> {
     ) -> DBResult<()> {
-        unsafe { check_rcode(None, sq::sqlite3_bind_int64(ctx.stmt.stmt, index, *self)) }
+        unsafe { check_rcode(|| None, sq::sqlite3_bind_int64(ctx.stmt.stmt, index, *self)) }
     }
     }
 }
 }
 
 
 impl Bindable for usize {
 impl Bindable for usize {
-    fn bind_to<'ctx, 'data: 'ctx>(
+    fn bind<'ctx, 'data: 'ctx>(
         &'data self,
         &'data self,
         ctx: &StatementContext<'ctx>,
         ctx: &StatementContext<'ctx>,
         index: i32,
         index: i32,
     ) -> DBResult<()> {
     ) -> DBResult<()> {
-        (*self as i64).bind_to(ctx, index)
+        (*self as i64).bind(ctx, index)
     }
     }
 }
 }
 
 
 impl Bindable for f32 {
 impl Bindable for f32 {
-    fn bind_to<'ctx, 'data: 'ctx>(&self, ctx: &StatementContext<'ctx>, index: i32) -> DBResult<()> {
-        (*self as f64).bind_to(ctx, index)
+    fn bind<'ctx, 'data: 'ctx>(&self, ctx: &StatementContext<'ctx>, index: i32) -> DBResult<()> {
+        (*self as f64).bind(ctx, index)
     }
     }
 }
 }
 
 
 impl Bindable for f64 {
 impl Bindable for f64 {
-    fn bind_to<'ctx, 'data: 'ctx>(&self, ctx: &StatementContext<'ctx>, index: i32) -> DBResult<()> {
-        unsafe { check_rcode(None, sq::sqlite3_bind_double(ctx.stmt.stmt, index, *self)) }
+    fn bind<'ctx, 'data: 'ctx>(&self, ctx: &StatementContext<'ctx>, index: i32) -> DBResult<()> {
+        unsafe {
+            check_rcode(
+                || None,
+                sq::sqlite3_bind_double(ctx.stmt.stmt, index, *self),
+            )
+        }
     }
     }
 }
 }
 
 
 impl<'a> Bindable for &'a str {
 impl<'a> Bindable for &'a str {
-    fn bind_to<'ctx, 'data: 'ctx>(
+    fn bind<'ctx, 'data: 'ctx>(
         &'data self,
         &'data self,
         ctx: &StatementContext<'ctx>,
         ctx: &StatementContext<'ctx>,
         index: i32,
         index: i32,
     ) -> DBResult<()> {
     ) -> DBResult<()> {
         unsafe {
         unsafe {
             check_rcode(
             check_rcode(
-                None,
+                || None,
                 sq::sqlite3_bind_text(
                 sq::sqlite3_bind_text(
                     ctx.stmt.stmt,
                     ctx.stmt.stmt,
                     index,
                     index,
@@ -403,34 +486,34 @@ impl<'a> Bindable for &'a str {
 }
 }
 
 
 impl Bindable for str {
 impl Bindable for str {
-    fn bind_to<'ctx, 'data: 'ctx>(
+    fn bind<'ctx, 'data: 'ctx>(
         &'data self,
         &'data self,
         ctx: &StatementContext<'ctx>,
         ctx: &StatementContext<'ctx>,
         index: i32,
         index: i32,
     ) -> DBResult<()> {
     ) -> DBResult<()> {
-        <&'_ str>::bind_to(&self, ctx, index)
+        <&'_ str>::bind(&self, ctx, index)
     }
     }
 }
 }
 
 
 impl Bindable for String {
 impl Bindable for String {
-    fn bind_to<'ctx, 'data: 'ctx>(
+    fn bind<'ctx, 'data: 'ctx>(
         &'data self,
         &'data self,
         ctx: &StatementContext<'ctx>,
         ctx: &StatementContext<'ctx>,
         index: i32,
         index: i32,
     ) -> DBResult<()> {
     ) -> DBResult<()> {
-        self.as_str().bind_to(ctx, index)
+        self.as_str().bind(ctx, index)
     }
     }
 }
 }
 
 
 impl<'a> Bindable for &'a [u8] {
 impl<'a> Bindable for &'a [u8] {
-    fn bind_to<'ctx, 'data: 'ctx>(
+    fn bind<'ctx, 'data: 'ctx>(
         &'data self,
         &'data self,
         ctx: &StatementContext<'ctx>,
         ctx: &StatementContext<'ctx>,
         index: i32,
         index: i32,
     ) -> DBResult<()> {
     ) -> DBResult<()> {
         unsafe {
         unsafe {
             check_rcode(
             check_rcode(
-                None,
+                || None,
                 sq::sqlite3_bind_blob64(
                 sq::sqlite3_bind_blob64(
                     ctx.stmt.stmt,
                     ctx.stmt.stmt,
                     index,
                     index,
@@ -447,6 +530,20 @@ pub trait Readable: Sized {
     fn read_from(sr: &StatementRow<'_>, index: i32) -> DBResult<Self>;
     fn read_from(sr: &StatementRow<'_>, index: i32) -> DBResult<Self>;
 }
 }
 
 
+pub struct IsNull(pub bool);
+
+// NULL-checker
+impl Readable for IsNull {
+    fn read_from(sr: &StatementRow<'_>, index: i32) -> DBResult<Self> {
+        let column_type = unsafe { sq::sqlite3_column_type(sr.stmt.stmt, index) };
+        if column_type == sq::SQLITE_NULL {
+            Ok(IsNull(true))
+        } else {
+            Ok(IsNull(false))
+        }
+    }
+}
+
 impl Readable for i64 {
 impl Readable for i64 {
     fn read_from(sr: &StatementRow<'_>, index: i32) -> DBResult<Self> {
     fn read_from(sr: &StatementRow<'_>, index: i32) -> DBResult<Self> {
         unsafe { Ok(sq::sqlite3_column_int64(sr.stmt.stmt, index)) }
         unsafe { Ok(sq::sqlite3_column_int64(sr.stmt.stmt, index)) }
@@ -480,7 +577,6 @@ impl Readable for String {
                     "NULL pointer result from sqlite3_column_text",
                     "NULL pointer result from sqlite3_column_text",
                 ))
                 ))
             } else {
             } else {
-                let cstr = CStr::from_ptr(text.cast());
                 Ok(CStr::from_ptr(text.cast()).to_str()?.to_string())
                 Ok(CStr::from_ptr(text.cast()).to_str()?.to_string())
             }
             }
         }
         }

+ 80 - 36
microrm/src/query.rs

@@ -1,11 +1,10 @@
-use crate::db::{Connection, StatementContext, StatementRow};
+use crate::db::{Connection, StatementContext, StatementRow, Transaction};
 use crate::prelude::IDMap;
 use crate::prelude::IDMap;
 use crate::schema::datum::{QueryEquivalent, QueryEquivalentList};
 use crate::schema::datum::{QueryEquivalent, QueryEquivalentList};
 use crate::schema::entity::helpers::check_assoc;
 use crate::schema::entity::helpers::check_assoc;
-use crate::schema::entity::EntityVisitor;
-use crate::schema::{AssocData, DatumDiscriminator, LocalSide, Stored};
+use crate::schema::{AssocData, LocalSide, Stored};
 use crate::{
 use crate::{
-    schema::datum::{Datum, DatumList},
+    schema::datum::Datum,
     schema::entity::{Entity, EntityID, EntityPart, EntityPartList, EntityPartVisitor},
     schema::entity::{Entity, EntityID, EntityPart, EntityPartList, EntityPartVisitor},
 };
 };
 use crate::{DBResult, Error};
 use crate::{DBResult, Error};
@@ -252,12 +251,6 @@ impl Query {
             Some(v) => v.into_iter().reduce(|a, b| format!("{} {}", a, b)).unwrap(),
             Some(v) => v.into_iter().reduce(|a, b| format!("{} {}", a, b)).unwrap(),
         };
         };
 
 
-        /*println!(
-            "built SQL query: {} {} {} {} {} {} {}",
-            root, columns_, from_, set_, join_, where_, trailing_
-        );*/
-        // log::trace!("built SQL query: {} {} {}", root, set_, where_);
-
         format!(
         format!(
             "{} {} {} {} {} {} {}",
             "{} {} {} {} {} {} {}",
             root, columns_, from_, set_, join_, where_, trailing_
             root, columns_, from_, set_, join_, where_, trailing_
@@ -320,6 +313,30 @@ fn hash_of<T: Hash>(val: T) -> u64 {
     hasher.finish()
     hasher.finish()
 }
 }
 
 
+fn do_connect<Remote: Entity>(
+    adata: &AssocData,
+    an: AssocNames,
+    remote_id: Remote::ID,
+) -> DBResult<()> {
+    adata.conn.with_prepared(
+        hash_of(("connect", an.local_name, an.remote_name, an.part_name)),
+        || {
+            format!(
+                "insert into `{assoc_name}` (`{local_field}`, `{remote_field}`) values (?, ?)",
+                assoc_name = an.assoc_name(),
+                local_field = an.local_field,
+                remote_field = an.remote_field
+            )
+        },
+        |ctx| {
+            ctx.bind(1, adata.local_id)?;
+            ctx.bind(2, remote_id.into_raw())?;
+
+            ctx.run().map(|_| ())
+        },
+    )
+}
+
 pub trait AssocInterface: 'static {
 pub trait AssocInterface: 'static {
     type RemoteEntity: Entity;
     type RemoteEntity: Entity;
     fn get_data(&self) -> DBResult<&AssocData>;
     fn get_data(&self) -> DBResult<&AssocData>;
@@ -333,24 +350,11 @@ pub trait AssocInterface: 'static {
         let adata = self.get_data()?;
         let adata = self.get_data()?;
         let an = AssocNames::collect::<Self>(self)?;
         let an = AssocNames::collect::<Self>(self)?;
 
 
-        // second, add to the assoc table
-        adata.conn.with_prepared(
-            hash_of(("connect", an.local_name, an.remote_name, an.part_name)),
-            || {
-                format!(
-                    "insert into `{assoc_name}` (`{local_field}`, `{remote_field}`) values (?, ?)",
-                    assoc_name = an.assoc_name(),
-                    local_field = an.local_field,
-                    remote_field = an.remote_field
-                )
-            },
-            |ctx| {
-                ctx.bind(1, adata.local_id)?;
-                ctx.bind(2, remote_id.into_raw())?;
+        let txn = Transaction::new(&adata.conn)?;
 
 
-                ctx.run().map(|_| ())
-            },
-        )
+        do_connect::<Self::RemoteEntity>(adata, an, remote_id)?;
+
+        txn.commit()
     }
     }
 
 
     fn disconnect_from(&self, remote_id: <Self::RemoteEntity as Entity>::ID) -> DBResult<()>
     fn disconnect_from(&self, remote_id: <Self::RemoteEntity as Entity>::ID) -> DBResult<()>
@@ -360,6 +364,8 @@ pub trait AssocInterface: 'static {
         let adata = self.get_data()?;
         let adata = self.get_data()?;
         let an = AssocNames::collect::<Self>(self)?;
         let an = AssocNames::collect::<Self>(self)?;
 
 
+        let txn = Transaction::new(&adata.conn)?;
+
         // second, add to the assoc table
         // second, add to the assoc table
         adata.conn.with_prepared(
         adata.conn.with_prepared(
             hash_of(("disconnect", an.local_name, an.remote_name, an.part_name)),
             hash_of(("disconnect", an.local_name, an.remote_name, an.part_name)),
@@ -377,7 +383,9 @@ pub trait AssocInterface: 'static {
 
 
                 ctx.run().map(|_| ())
                 ctx.run().map(|_| ())
             },
             },
-        )
+        )?;
+
+        txn.commit()
     }
     }
 
 
     fn insert(&self, value: Self::RemoteEntity) -> DBResult<<Self::RemoteEntity as Entity>::ID>
     fn insert(&self, value: Self::RemoteEntity) -> DBResult<<Self::RemoteEntity as Entity>::ID>
@@ -389,14 +397,42 @@ pub trait AssocInterface: 'static {
         // - adding the association row into the assoc table
         // - adding the association row into the assoc table
 
 
         let adata = self.get_data()?;
         let adata = self.get_data()?;
+        let an = AssocNames::collect::<Self>(self)?;
+
+        let txn = Transaction::new(&adata.conn)?;
 
 
-        // so first, the remote table
+        // so first, into the remote table
         let remote_id = insert(&adata.conn, &value)?;
         let remote_id = insert(&adata.conn, &value)?;
         // then the association
         // then the association
-        self.connect_to(remote_id)?;
-        // TODO: handle error case of associate_with() fails but insert() succeeds
+        do_connect::<Self::RemoteEntity>(adata, an, remote_id)?;
+
+        txn.commit()?;
+
         Ok(remote_id)
         Ok(remote_id)
     }
     }
+
+    fn insert_and_return(&self, value: Self::RemoteEntity) -> DBResult<Stored<Self::RemoteEntity>>
+    where
+        Self: Sized,
+    {
+        // we're doing two things:
+        // - inserting the entity into the target table
+        // - adding the association row into the assoc table
+
+        let adata = self.get_data()?;
+        let an = AssocNames::collect::<Self>(self)?;
+
+        let txn = Transaction::new(&adata.conn)?;
+
+        // so first, into the remote table
+        let remote = insert_and_return(&adata.conn, value)?;
+        // then the association
+        do_connect::<Self::RemoteEntity>(adata, an, remote.id())?;
+
+        txn.commit()?;
+
+        Ok(remote)
+    }
 }
 }
 
 
 // ----------------------------------------------------------------------
 // ----------------------------------------------------------------------
@@ -455,8 +491,9 @@ pub trait Queryable {
     where
     where
         Self: Sized,
         Self: Sized,
     {
     {
+        let txn = Transaction::new(self.conn())?;
         struct CountTag;
         struct CountTag;
-        self.conn().with_prepared(
+        let out = self.conn().with_prepared(
             std::any::TypeId::of::<(Self::StaticVersion, CountTag)>(),
             std::any::TypeId::of::<(Self::StaticVersion, CountTag)>(),
             || {
             || {
                 self.build()
                 self.build()
@@ -479,15 +516,18 @@ pub trait Queryable {
                     .ok_or(Error::InternalError("no resulting rows from COUNT query"))?
                     .ok_or(Error::InternalError("no resulting rows from COUNT query"))?
                     .read::<i64>(0)? as usize)
                     .read::<i64>(0)? as usize)
             },
             },
-        )
+        )?;
+        txn.commit()?;
+        Ok(out)
     }
     }
     /// Get all entities in the current context.
     /// Get all entities in the current context.
     fn get(self) -> DBResult<Self::OutputContainer>
     fn get(self) -> DBResult<Self::OutputContainer>
     where
     where
         Self: Sized,
         Self: Sized,
     {
     {
+        let txn = Transaction::new(self.conn())?;
         struct GetTag;
         struct GetTag;
-        self.conn().with_prepared(
+        let out = self.conn().with_prepared(
             std::any::TypeId::of::<(Self::StaticVersion, GetTag)>(),
             std::any::TypeId::of::<(Self::StaticVersion, GetTag)>(),
             || self.build().assemble(),
             || self.build().assemble(),
             |mut ctx| {
             |mut ctx| {
@@ -497,13 +537,16 @@ pub trait Queryable {
 
 
                 <Self::OutputContainer>::assemble_from(self.conn(), ctx)
                 <Self::OutputContainer>::assemble_from(self.conn(), ctx)
             },
             },
-        )
+        )?;
+        txn.commit()?;
+        Ok(out)
     }
     }
     /// Delete all entities in the current context.
     /// Delete all entities in the current context.
     fn delete(self) -> DBResult<()>
     fn delete(self) -> DBResult<()>
     where
     where
         Self: Sized,
         Self: Sized,
     {
     {
+        let txn = Transaction::new(self.conn())?;
         struct DeleteTag;
         struct DeleteTag;
         self.conn().with_prepared(
         self.conn().with_prepared(
             std::any::TypeId::of::<(Self::StaticVersion, DeleteTag)>(),
             std::any::TypeId::of::<(Self::StaticVersion, DeleteTag)>(),
@@ -527,7 +570,8 @@ pub trait Queryable {
                 ctx.run()?;
                 ctx.run()?;
                 Ok(())
                 Ok(())
             },
             },
-        )
+        )?;
+        txn.commit()
     }
     }
 
 
     // ----------------------------------------------------------------------
     // ----------------------------------------------------------------------

+ 42 - 5
microrm/src/schema.rs

@@ -9,7 +9,7 @@
 use query::Queryable;
 use query::Queryable;
 
 
 use crate::{
 use crate::{
-    db::{Connection, StatementContext, StatementRow},
+    db::{Connection, StatementContext, StatementRow, Transaction},
     query::{self, AssocInterface},
     query::{self, AssocInterface},
     schema::datum::Datum,
     schema::datum::Datum,
     schema::entity::{Entity, EntityVisitor},
     schema::entity::{Entity, EntityVisitor},
@@ -54,7 +54,18 @@ impl<T: Entity> Stored<T> {
     }
     }
 
 
     pub fn sync(&mut self) -> DBResult<()> {
     pub fn sync(&mut self) -> DBResult<()> {
-        query::update_entity(&self.db, self)
+        let txn = Transaction::new(&self.db)?;
+        query::update_entity(&self.db, self)?;
+        txn.commit()
+    }
+}
+
+impl<T: Entity + std::fmt::Debug> std::fmt::Debug for Stored<T> {
+    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+        f.write_fmt(format_args!(
+            "Stored {{ id: {:?}, value: {:?} }}",
+            self.id, self.wrap
+        ))
     }
     }
 }
 }
 
 
@@ -200,6 +211,12 @@ impl<T: Entity> Datum for AssocMap<T> {
         unreachable!()
         unreachable!()
     }
     }
 
 
+    fn debug_field(&self, _field: &'static str, _fmt: &mut std::fmt::DebugStruct)
+    where
+        Self: Sized,
+    {
+    }
+
     fn accept_entity_visitor(v: &mut impl EntityVisitor) {
     fn accept_entity_visitor(v: &mut impl EntityVisitor) {
         v.visit::<T>();
         v.visit::<T>();
     }
     }
@@ -289,6 +306,12 @@ impl<R: Relation> Datum for AssocDomain<R> {
         unreachable!()
         unreachable!()
     }
     }
 
 
+    fn debug_field(&self, _field: &'static str, _fmt: &mut std::fmt::DebugStruct)
+    where
+        Self: Sized,
+    {
+    }
+
     fn accept_entity_visitor(v: &mut impl EntityVisitor) {
     fn accept_entity_visitor(v: &mut impl EntityVisitor) {
         v.visit::<R::Domain>();
         v.visit::<R::Domain>();
     }
     }
@@ -378,6 +401,12 @@ impl<R: Relation> Datum for AssocRange<R> {
         unreachable!()
         unreachable!()
     }
     }
 
 
+    fn debug_field(&self, _field: &'static str, _fmt: &mut std::fmt::DebugStruct)
+    where
+        Self: Sized,
+    {
+    }
+
     fn accept_entity_visitor(v: &mut impl EntityVisitor) {
     fn accept_entity_visitor(v: &mut impl EntityVisitor) {
         v.visit::<R::Domain>();
         v.visit::<R::Domain>();
     }
     }
@@ -447,7 +476,9 @@ impl<T: serde::Serialize + serde::de::DeserializeOwned> AsMut<T> for Serialized<
     }
     }
 }
 }
 
 
-impl<T: 'static + serde::Serialize + serde::de::DeserializeOwned> Datum for Serialized<T> {
+impl<T: 'static + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug> Datum
+    for Serialized<T>
+{
     fn sql_type() -> &'static str {
     fn sql_type() -> &'static str {
         "text"
         "text"
     }
     }
@@ -504,11 +535,17 @@ impl<T: Entity> IDMap<T> {
 
 
     /// Insert a new Entity into this map, and return its new ID.
     /// Insert a new Entity into this map, and return its new ID.
     pub fn insert(&self, value: T) -> DBResult<T::ID> {
     pub fn insert(&self, value: T) -> DBResult<T::ID> {
-        query::insert(self.conn(), &value)
+        let txn = Transaction::new(self.conn())?;
+        let out = query::insert(self.conn(), &value)?;
+        txn.commit()?;
+        Ok(out)
     }
     }
 
 
     pub fn insert_and_return(&self, value: T) -> DBResult<Stored<T>> {
     pub fn insert_and_return(&self, value: T) -> DBResult<Stored<T>> {
-        query::insert_and_return(self.conn(), value)
+        let txn = Transaction::new(self.conn())?;
+        let out = query::insert_and_return(self.conn(), value)?;
+        txn.commit()?;
+        Ok(out)
     }
     }
 }
 }
 
 

+ 5 - 1
microrm/src/schema/build.rs

@@ -84,8 +84,9 @@ impl DatabaseSchema {
     }
     }
 
 
     pub fn create(&self, db: Connection) -> DBResult<()> {
     pub fn create(&self, db: Connection) -> DBResult<()> {
+        db.execute_raw_sql("BEGIN TRANSACTION")?;
         for query in self.queries.iter() {
         for query in self.queries.iter() {
-            println!("Running {query}");
+            log::trace!("Running creation query {query}");
             db.execute_raw_sql(query)?;
             db.execute_raw_sql(query)?;
         }
         }
 
 
@@ -96,11 +97,14 @@ impl DatabaseSchema {
             db.execute_raw_sql(query)?;
             db.execute_raw_sql(query)?;
         }
         }
 
 
+        db.execute_raw_sql("COMMIT")?;
+
         // store signature
         // store signature
         metadb.metastore.insert(meta::Meta {
         metadb.metastore.insert(meta::Meta {
             key: Self::SCHEMA_SIGNATURE_KEY.into(),
             key: Self::SCHEMA_SIGNATURE_KEY.into(),
             value: format!("{}", self.signature),
             value: format!("{}", self.signature),
         })?;
         })?;
+
         Ok(())
         Ok(())
     }
     }
 }
 }

+ 8 - 1
microrm/src/schema/datum.rs

@@ -14,9 +14,16 @@ mod datum_list;
 // ----------------------------------------------------------------------
 // ----------------------------------------------------------------------
 
 
 /// Represents a data field in an Entity.
 /// Represents a data field in an Entity.
-pub trait Datum {
+pub trait Datum: std::fmt::Debug {
     fn sql_type() -> &'static str;
     fn sql_type() -> &'static str;
 
 
+    fn debug_field(&self, field: &'static str, fmt: &mut std::fmt::DebugStruct)
+    where
+        Self: Sized,
+    {
+        fmt.field(field, self);
+    }
+
     fn bind_to(&self, _stmt: &mut StatementContext, index: i32);
     fn bind_to(&self, _stmt: &mut StatementContext, index: i32);
     fn build_from(adata: AssocData, stmt: &mut StatementRow, index: &mut i32) -> DBResult<Self>
     fn build_from(adata: AssocData, stmt: &mut StatementRow, index: &mut i32) -> DBResult<Self>
     where
     where

+ 17 - 6
microrm/src/schema/datum/datum_common.rs

@@ -1,5 +1,5 @@
 use crate::{
 use crate::{
-    db::{StatementContext, StatementRow},
+    db::{self, Bindable, StatementContext, StatementRow},
     schema::{AssocData, Datum},
     schema::{AssocData, Datum},
     DBResult, Error,
     DBResult, Error,
 };
 };
@@ -142,16 +142,27 @@ impl<T: Datum> Datum for Option<T> {
         T::sql_type()
         T::sql_type()
     }
     }
 
 
-    fn bind_to(&self, _stmt: &mut StatementContext, _index: i32) {
-        todo!()
+    fn bind_to(&self, stmt: &mut StatementContext, index: i32) {
+        if let Some(v) = self.as_ref() {
+            v.bind_to(stmt, index);
+        } else {
+            // bind a NULL
+            ().bind(stmt, index).expect("couldn't bind NULL for None");
+        }
     }
     }
 
 
-    fn build_from(_: AssocData, _stmt: &mut StatementRow, _index: &mut i32) -> DBResult<Self>
+    fn build_from(adata: AssocData, stmt: &mut StatementRow, index: &mut i32) -> DBResult<Self>
     where
     where
         Self: Sized,
         Self: Sized,
     {
     {
-        // Ok((stmt.read::<i64, _>(index)? as u64, index+1))
-        todo!()
+        let rval = if stmt.read::<db::IsNull>(*index)?.0 {
+            *index += 1;
+            Ok(None)
+        } else {
+            T::build_from(adata, stmt, index).map(Some)
+        };
+
+        rval
     }
     }
 }
 }
 
 

+ 12 - 5
microrm/src/schema/tests.rs

@@ -1,5 +1,7 @@
 #![allow(unused)]
 #![allow(unused)]
 
 
+use test_log::test;
+
 fn open_test_db<DB: super::Database>(identifier: &'static str) -> DB {
 fn open_test_db<DB: super::Database>(identifier: &'static str) -> DB {
     let path = format!("/tmp/microrm-{identifier}.db");
     let path = format!("/tmp/microrm-{identifier}.db");
     let _ = std::fs::remove_file(path.as_str());
     let _ = std::fs::remove_file(path.as_str());
@@ -8,19 +10,19 @@ fn open_test_db<DB: super::Database>(identifier: &'static str) -> DB {
 
 
 mod manual_test_db {
 mod manual_test_db {
     // simple hand-built database example
     // simple hand-built database example
-
     use crate::db::{Connection, StatementContext, StatementRow};
     use crate::db::{Connection, StatementContext, StatementRow};
     use crate::schema::datum::Datum;
     use crate::schema::datum::Datum;
     use crate::schema::entity::{
     use crate::schema::entity::{
         Entity, EntityID, EntityPart, EntityPartList, EntityPartVisitor, EntityVisitor,
         Entity, EntityID, EntityPart, EntityPartList, EntityPartVisitor, EntityVisitor,
     };
     };
     use crate::schema::{Database, DatabaseItem, DatabaseItemVisitor, IDMap};
     use crate::schema::{Database, DatabaseItem, DatabaseItemVisitor, IDMap};
+    use test_log::test;
 
 
     struct SimpleEntity {
     struct SimpleEntity {
         name: String,
         name: String,
     }
     }
 
 
-    #[derive(Clone, Copy, PartialEq, PartialOrd, Debug, Hash)]
+    #[derive(Clone, Copy, PartialEq, PartialOrd, Hash, Debug)]
     struct SimpleEntityID(i64);
     struct SimpleEntityID(i64);
 
 
     impl Datum for SimpleEntityID {
     impl Datum for SimpleEntityID {
@@ -176,14 +178,15 @@ mod derive_tests {
     use crate::query::{AssocInterface, Queryable};
     use crate::query::{AssocInterface, Queryable};
     use crate::schema::{AssocMap, Database, IDMap};
     use crate::schema::{AssocMap, Database, IDMap};
     use microrm_macros::{Database, Entity};
     use microrm_macros::{Database, Entity};
+    use test_log::test;
 
 
-    #[derive(Entity, Debug)]
+    #[derive(Entity)]
     struct Role {
     struct Role {
         title: String,
         title: String,
         permissions: String,
         permissions: String,
     }
     }
 
 
-    #[derive(Entity, Debug)]
+    #[derive(Entity)]
     struct Person {
     struct Person {
         #[unique]
         #[unique]
         name: String,
         name: String,
@@ -347,6 +350,7 @@ mod mutual_relationship {
     use crate::query::{AssocInterface, Queryable};
     use crate::query::{AssocInterface, Queryable};
     use crate::schema::{AssocDomain, AssocMap, AssocRange, Database, IDMap};
     use crate::schema::{AssocDomain, AssocMap, AssocRange, Database, IDMap};
     use microrm_macros::{Database, Entity};
     use microrm_macros::{Database, Entity};
+    use test_log::test;
 
 
     struct CR;
     struct CR;
     impl microrm::schema::Relation for CR {
     impl microrm::schema::Relation for CR {
@@ -456,6 +460,7 @@ mod reserved_words {
     use crate::prelude::*;
     use crate::prelude::*;
     use crate::schema::entity::Entity;
     use crate::schema::entity::Entity;
     use crate::schema::{AssocDomain, AssocRange, Database, IDMap};
     use crate::schema::{AssocDomain, AssocRange, Database, IDMap};
+    use test_log::test;
 
 
     #[derive(Entity)]
     #[derive(Entity)]
     struct Select {
     struct Select {
@@ -482,6 +487,7 @@ mod join_test {
     use super::open_test_db;
     use super::open_test_db;
     use crate::prelude::*;
     use crate::prelude::*;
     use crate::schema;
     use crate::schema;
+    use test_log::test;
 
 
     #[derive(Default, Entity)]
     #[derive(Default, Entity)]
     struct Base {
     struct Base {
@@ -629,7 +635,8 @@ mod join_test {
 
 
 mod query_equivalence {
 mod query_equivalence {
     use crate::prelude::*;
     use crate::prelude::*;
-    #[derive(Entity, Debug)]
+    use test_log::test;
+    #[derive(Entity)]
     struct Item {
     struct Item {
         #[unique]
         #[unique]
         s: String,
         s: String,