Sfoglia il codice sorgente

Improved ergonomics of query interface.

Kestrel 2 anni fa
parent
commit
071ccb2c47
5 ha cambiato i file con 81 aggiunte e 83 eliminazioni
  1. 1 0
      microrm/Cargo.toml
  2. 34 28
      microrm/src/lib.rs
  3. 2 11
      microrm/src/model.rs
  4. 35 8
      microrm/src/model/create.rs
  5. 9 36
      microrm/src/query.rs

+ 1 - 0
microrm/Cargo.toml

@@ -10,5 +10,6 @@ base64 = "0.13"
 sha2 = "0.10"
 rusqlite = "0.27"
 serde = { version = "1.0", features = ["derive"] }
+serde_bytes = { version = "0.11.6" }
 
 microrm-macros = { path = "../microrm-macros" }

+ 34 - 28
microrm/src/lib.rs

@@ -15,7 +15,7 @@ pub struct DB {
 }
 
 impl DB {
-    pub fn new(schema: model::SchemaModel, path: &str, allow_recreate: bool) -> Self {
+    pub fn new(schema: model::SchemaModel, path: &str, allow_recreate: bool) -> Result<Self, &'static str> {
         Self::from_connection(
             rusqlite::Connection::open(path).expect("Opening database connection failed"),
             schema,
@@ -24,7 +24,7 @@ impl DB {
     }
 
     /// For use in tests
-    pub fn new_in_memory(schema: model::SchemaModel) -> Self {
+    pub fn new_in_memory(schema: model::SchemaModel) -> Result<Self, &'static str> {
         Self::from_connection(
             rusqlite::Connection::open_in_memory().expect("Opening database connection failed"),
             schema,
@@ -36,15 +36,15 @@ impl DB {
         conn: rusqlite::Connection,
         schema: model::SchemaModel,
         allow_recreate: bool,
-    ) -> Self {
+    ) -> Result<Self, &'static str> {
         let sig = Self::calculate_schema_hash(&schema);
         let ret = Self {
             conn,
             schema_hash: sig,
             schema: schema.add::<meta::Metaschema>(),
         };
-        ret.check_schema(allow_recreate);
-        ret
+        ret.check_schema(allow_recreate)?;
+        Ok(ret)
     }
 
     fn calculate_schema_hash(schema: &model::SchemaModel) -> String {
@@ -65,38 +65,44 @@ impl DB {
         base64::encode(hasher.finalize())
     }
 
-    fn check_schema(&self, allow_recreate: bool) {
-        let hash = query::get_one_by::<meta::Metaschema, _>(self, meta::MetaschemaColumns::Key, "schema_hash");
+    fn check_schema(&self, allow_recreate: bool) -> Result<(), &'static str> {
+        let hash = query::get_one_by(self, meta::MetaschemaColumns::Key, "schema_hash");
 
         if hash.is_none() || hash.unwrap().value != self.schema_hash {
             if !allow_recreate {
-                panic!("No schema version in database, and not allowed to create!");
+                return Err("No schema version in database, and not allowed to create!")
             }
             println!("Failed to retrieve schema; probably is empty database");
+            self.create_schema();
 
-            for ds in self.schema.drop() {
-                let prepared = self.conn.prepare(ds);
-                prepared.unwrap().execute([]).expect("Dropping sql failed");
-            }
+        }
 
-            for cs in self.schema.create() {
-                let prepared = self.conn.prepare(cs);
-                prepared.unwrap().execute([]).expect("Creation sql failed");
-            }
+        Ok(())
+    }
 
-            query::add(
-                self,
-                &meta::Metaschema {
-                    key: "schema_hash".to_string(),
-                    value: self.schema_hash.clone(),
-                },
-            );
-
-            println!(
-                "re-search results: {:?}",
-                query::get_one_by::<meta::Metaschema, _>(self, meta::MetaschemaColumns::Key, "schema_hash")
-            );
+    fn create_schema(&self) -> Result<(), &'static str> {
+        for ds in self.schema.drop() {
+            let prepared = self.conn.prepare(ds);
+            prepared.unwrap().execute([]).expect("Dropping sql failed");
         }
+
+        for cs in self.schema.create() {
+            let prepared = self.conn.prepare(cs);
+            prepared.unwrap().execute([]).expect("Creation sql failed");
+        }
+
+        query::add(
+            self,
+            &meta::Metaschema {
+                key: "schema_hash".to_string(),
+                value: self.schema_hash.clone(),
+            },
+        );
+
+        let sanity_check = query::get_one_by(self, meta::MetaschemaColumns::Key, "schema_hash");
+        assert_eq!(sanity_check.is_some(), true);
+
+        Ok(())
     }
 }
 

+ 2 - 11
microrm/src/model.rs

@@ -24,15 +24,6 @@ impl From<ModelError> for rusqlite::Error {
     }
 }
 
-/*impl Into<rusqlite::Error> for ModelError {
-    fn into(self) -> rusqlite::Error {
-        match self {
-            Self::DBError(e) => e,
-            _ => panic!()
-        }
-    }
-}*/
-
 impl std::fmt::Display for ModelError {
     fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
         fmt.write_fmt(format_args!("{:?}", self))
@@ -55,7 +46,7 @@ impl std::error::Error for ModelError {}
 
 /// A database entity, aka a struct representing a row in a table
 pub trait Entity: for<'de> serde::Deserialize<'de> + serde::Serialize {
-    type Column;
+    type Column : EntityColumns;
     fn table_name() -> &'static str;
     fn column_count() -> usize
     where
@@ -70,7 +61,7 @@ pub trait Entity: for<'de> serde::Deserialize<'de> + serde::Serialize {
 }
 
 pub trait EntityColumns {
-    type Entity;
+    type Entity : Entity;
 }
 
 /// How we describe an entire schema

+ 35 - 8
microrm/src/model/create.rs

@@ -2,9 +2,9 @@ use serde::de::Visitor;
 
 #[derive(Debug)]
 pub struct CreateDeserializer<'de> {
-    table_name: Option<&'static str>,
-    column_names: Option<&'static [&'static str]>,
+    column_names: Vec<String>,
     column_types: Vec<String>,
+    column_name_stack: Vec<String>,
     _de: std::marker::PhantomData<&'de u8>,
 }
 
@@ -14,7 +14,7 @@ impl<'de, 'a> serde::de::Deserializer<'de> for &'a mut CreateDeserializer<'de> {
     // we (ab)use the forward_to_deserialize_any! macro to stub out the types we don't care about
     serde::forward_to_deserialize_any! {
         bool i8 i16 i128 u8 u16 u32 u64 u128 f32 f64 char str
-        bytes byte_buf option unit unit_struct newtype_struct seq tuple
+        option unit unit_struct tuple
         tuple_struct map enum identifier ignored_any
     }
 
@@ -24,29 +24,57 @@ impl<'de, 'a> serde::de::Deserializer<'de> for &'a mut CreateDeserializer<'de> {
 
     fn deserialize_i32<V: Visitor<'de>>(self, v: V) -> Result<V::Value, Self::Error> {
         self.column_types.push("integer".to_owned());
+        self.column_names.push(self.column_name_stack.pop().unwrap());
         v.visit_i32(0)
     }
 
     fn deserialize_i64<V: Visitor<'de>>(self, v: V) -> Result<V::Value, Self::Error> {
         self.column_types.push("integer".to_owned());
+        self.column_names.push(self.column_name_stack.pop().unwrap());
         v.visit_i64(0)
     }
 
     fn deserialize_string<V: Visitor<'de>>(self, v: V) -> Result<V::Value, Self::Error> {
         self.column_types.push("varchar".to_owned());
+        self.column_names.push(self.column_name_stack.pop().unwrap());
         v.visit_string("".to_owned())
     }
 
+    fn deserialize_bytes<V: Visitor<'de>>(self, v: V) -> Result<V::Value, Self::Error> {
+        self.column_types.push("blob".to_owned());
+        self.column_names.push(self.column_name_stack.pop().unwrap());
+        v.visit_bytes(&[])
+    }
+
+    fn deserialize_byte_buf<V: Visitor<'de>>(self, v: V) -> Result<V::Value, Self::Error> {
+        self.column_types.push("blob".to_owned());
+        self.column_names.push(self.column_name_stack.pop().unwrap());
+        v.visit_bytes(&[])
+    }
+
+    fn deserialize_seq<V: Visitor<'de>>(
+        self, v: V) -> Result<V::Value, Self::Error> {
+
+        v.visit_seq(self)
+    }
+
     fn deserialize_struct<V: Visitor<'de>>(
         self,
         name: &'static str,
         fields: &'static [&'static str],
         v: V,
     ) -> Result<V::Value, Self::Error> {
-        self.table_name = Some(name);
-        self.column_names = Some(fields);
+        self.column_name_stack.extend(fields.iter().map(|x| x.to_string()));
         v.visit_seq(self)
     }
+
+    fn deserialize_newtype_struct<V: Visitor<'de>>(
+        self,
+        name: &'static str,
+        v: V
+    ) -> Result<V::Value, Self::Error> {
+        unreachable!("microrm cannot store newtype structs")
+    }
 }
 
 impl<'de> serde::de::SeqAccess<'de> for CreateDeserializer<'de> {
@@ -62,9 +90,9 @@ impl<'de> serde::de::SeqAccess<'de> for CreateDeserializer<'de> {
 
 pub fn sql_for<T: crate::model::Entity>() -> (String, String) {
     let mut cd = CreateDeserializer {
-        table_name: None,
-        column_names: None,
+        column_names: Vec::new(),
         column_types: Vec::new(),
+        column_name_stack: Vec::new(),
         _de: std::marker::PhantomData {},
     };
 
@@ -79,7 +107,6 @@ pub fn sql_for<T: crate::model::Entity>() -> (String, String) {
             "CREATE TABLE {} ({})",
             <T as crate::model::Entity>::table_name(),
             cd.column_names
-                .unwrap()
                 .iter()
                 .zip(cd.column_types.iter())
                 .map(|(n, t)| n.to_string() + " " + t)

+ 9 - 36
microrm/src/query.rs

@@ -3,11 +3,11 @@ use crate::DB;
 pub mod condition;
 
 #[derive(Clone, Copy, Debug, serde::Serialize, serde::Deserialize)]
-pub struct ID(i64);
+pub struct ID { id: i64 }
 
 impl rusqlite::ToSql for ID {
     fn to_sql(&self) -> Result<rusqlite::types::ToSqlOutput<'_>, rusqlite::Error> {
-        self.0.to_sql()
+        self.id.to_sql()
     }
 }
 
@@ -21,7 +21,7 @@ impl<T: crate::model::Entity> WithID<T> {
     fn wrap(what: T, raw_id: i64) -> Self {
         Self {
             wrap: what,
-            id: ID(raw_id)
+            id: ID { id: raw_id }
         }
     }
 }
@@ -52,9 +52,9 @@ impl<T: crate::model::Entity> std::ops::DerefMut for WithID<T> {
 }
 
 /// Search for an entity by a property
-pub fn get_one_by<T: crate::model::Entity, V: rusqlite::ToSql>(
+pub fn get_one_by<T: crate::model::Entity<Column = C>, C: crate::model::EntityColumns<Entity = T>, V: rusqlite::ToSql>(
     db: &DB,
-    c: <T as crate::model::Entity>::Column,
+    c: C,
     val: V,
 ) -> Option<WithID<T>> {
     let table_name = <T as crate::model::Entity>::table_name();
@@ -79,9 +79,9 @@ pub fn get_one_by<T: crate::model::Entity, V: rusqlite::ToSql>(
 }
 
 /// Search for all entities matching a property
-pub fn get_all_by<T: crate::model::Entity, V: rusqlite::ToSql>(
+pub fn get_all_by<T: crate::model::Entity<Column = C>, C: crate::model::EntityColumns<Entity = T>, V: rusqlite::ToSql>(
     db: &DB,
-    c: <T as crate::model::Entity>::Column,
+    c: C,
     val: V) -> Option<Vec<WithID<T>>> {
 
     let table_name = <T as crate::model::Entity>::table_name();
@@ -119,7 +119,7 @@ pub fn get_one_by_id<T: crate::model::Entity>(
         ))
         .ok()?;
 
-    let result = prepared.query_row([&id.0], |row| {
+    let result = prepared.query_row([&id], |row| {
         let mut deser = crate::model::load::RowDeserializer::from_row(row);
         Ok(WithID::wrap(
             T::deserialize(&mut deser).expect("deserialization works"),
@@ -150,32 +150,5 @@ pub fn add<T: crate::model::Entity + serde::Serialize>(db: &DB, m: &T) -> Option
     assert_eq!(row.len(), <T as crate::model::Entity>::column_count());
 
     let id = prepared.insert(rusqlite::params_from_iter(row)).ok()?;
-    Some(ID(id))
-}
-
-pub struct Context<'a> {
-    db: &'a DB,
-}
-
-impl<'a> Context<'a> {
-    pub fn get_one_by<
-        T: crate::model::Entity + for<'de> serde::Deserialize<'de>,
-        V: rusqlite::ToSql,
-    >(
-        &self,
-        c: <T as crate::model::Entity>::Column,
-        val: V,
-    ) -> Option<WithID<T>> {
-        get_one_by(self.db, c, val)
-    }
-
-    pub fn get_one_by_id<
-        T: crate::model::Entity + for<'de> serde::Deserialize<'de>,
-        V: rusqlite::ToSql,
-    >(
-        &self,
-        id: ID,
-    ) -> Option<WithID<T>> {
-        get_one_by_id(self.db, id)
-    }
+    Some(ID {id})
 }