Bläddra i källkod

Initial version of microrm, successfully can store and load very simple entities.

Kestrel 2 år sedan
incheckning
efa9c37f32

+ 1 - 0
.gitignore

@@ -0,0 +1 @@
+.*.sw?

+ 2 - 0
microrm-macros/.gitignore

@@ -0,0 +1,2 @@
+/target
+Cargo.lock

+ 16 - 0
microrm-macros/Cargo.toml

@@ -0,0 +1,16 @@
+[package]
+name = "microrm-macros"
+version = "0.1.0"
+edition = "2021"
+
+# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
+
+[lib]
+proc-macro = true
+
+[dependencies]
+# proc_macro = "*"
+syn = { version = "1.0", features = ["derive"] }
+quote = "1.0"
+convert_case = "0.5"
+proc-macro2 = "1.0"

+ 77 - 0
microrm-macros/src/lib.rs

@@ -0,0 +1,77 @@
+use proc_macro::TokenStream;
+use syn::{parse_macro_input, DeriveInput};
+use quote::{quote,format_ident};
+
+use convert_case::{Case, Casing};
+
+#[proc_macro_derive(Model, attributes(microrm))]
+pub fn derive_model(tokens: TokenStream) -> TokenStream {
+    let input = parse_macro_input!(tokens as DeriveInput);
+
+
+    let struct_name = &input.ident;
+    let enum_name = format_ident!("{}Columns", &input.ident);
+
+    let table_name = format!("{}", struct_name).to_case(Case::Snake);
+
+    let entity_impl = quote!{
+        impl crate::model::Entity for #struct_name {
+            fn table_name() -> &'static str { #table_name }
+        }
+    };
+
+    if let syn::Data::Struct(st) = input.data {
+        if let syn::Fields::Named(fields) = st.fields {
+
+            let mut variants = syn::punctuated::Punctuated::<syn::Ident, syn::token::Comma>::new();
+            let mut field_names = syn::punctuated::Punctuated::<proc_macro2::TokenStream, syn::token::Comma>::new();
+            let mut value_references = syn::punctuated::Punctuated::<proc_macro2::TokenStream, syn::token::Comma>::new();
+            for name in fields.named.iter() {
+                let converted_case = format!("{}", name.ident.as_ref().unwrap().clone()).to_case(Case::UpperCamel);
+                let converted_case = format_ident!("{}", converted_case);
+                variants.push(converted_case.clone());
+
+                let field_name = name.ident.as_ref().unwrap().clone();
+                let field_name_str = format!("{}", field_name);
+                field_names.push(quote!{ Self::Column::#converted_case => #field_name_str }.into());
+
+                value_references.push(quote!{ &self. #field_name });
+            }
+
+            println!("field_names: {}", quote!{ #field_names });
+
+            let field_count = fields.named.len();
+
+            let ret = quote!{
+                #entity_impl
+
+                #[derive(Clone,Copy,strum::IntoStaticStr,strum::EnumCount)]
+                enum #enum_name {
+                    #variants
+                }
+
+                impl crate::model::EntityColumn for #struct_name {
+                    type Column = #enum_name;
+                    fn count() -> usize {
+                        <Self::Column as strum::EnumCount>::COUNT
+                    }
+                    fn index(c: Self::Column) -> usize {
+                        c as usize
+                    }
+                    fn name(c: Self::Column) -> &'static str {
+                        match c {
+                            #field_names
+                        }
+                    }
+                    fn values(&self) -> Vec<&dyn rusqlite::ToSql> {
+                        vec![ #value_references ]
+                    }
+                }
+            }.into();
+
+            ret
+        }
+        else { panic!("Can only use derive(Model) on non-unit structs with named fields!") }
+    }
+    else { panic!("Can only use derive(Model) on structs!") }
+}

+ 2 - 0
microrm/.gitignore

@@ -0,0 +1,2 @@
+/target
+Cargo.lock

+ 16 - 0
microrm/Cargo.toml

@@ -0,0 +1,16 @@
+[package]
+name = "microrm"
+version = "0.1.0"
+edition = "2021"
+
+# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
+
+[dependencies]
+base64 = "0.13"
+sha2 = "0.10"
+rusqlite = "0.27"
+serde = { version = "1.0", features = ["derive"] }
+
+strum = { version = "0.24", features = ["derive"] }
+
+microrm-macros = { path = "../microrm-macros" }

+ 90 - 0
microrm/src/lib.rs

@@ -0,0 +1,90 @@
+pub mod model;
+pub mod query;
+
+pub use microrm_macros::Model;
+
+#[derive(Debug,serde::Serialize,serde::Deserialize,Model)]
+struct Metaschema {
+    key: String,
+    value: String
+}
+
+pub struct DB {
+    conn: rusqlite::Connection,
+    schema_hash: String,
+    schema: model::SchemaModel
+}
+
+impl DB {
+    pub fn new(schema: model::SchemaModel, path: &str, allow_recreate: bool) -> Self {
+        Self::from_connection(rusqlite::Connection::open(path).expect("Opening database connection failed"), schema, allow_recreate)
+    }
+
+    /// For use in tests
+    pub fn new_in_memory(schema: model::SchemaModel) -> Self {
+        Self::from_connection(rusqlite::Connection::open_in_memory().expect("Opening database connection failed"), schema, true)
+    }
+
+    fn from_connection(conn: rusqlite::Connection, schema: model::SchemaModel, allow_recreate: bool) -> Self {
+        let sig = Self::calculate_schema_hash(&schema);
+        let ret = Self { conn, schema_hash: sig, schema: schema.add::<Metaschema>() };
+        ret.check_schema(allow_recreate);
+        ret
+    }
+
+    fn calculate_schema_hash(schema: &model::SchemaModel) -> String {
+        use sha2::Digest;
+
+        let mut hasher = sha2::Sha256::new();
+        schema.create().iter().map(|sql| hasher.update(sql.as_bytes()));
+
+        base64::encode(hasher.finalize())
+    }
+
+    fn check_schema(&self, allow_recreate: bool) {
+        let hash = query::get_one_by::<Metaschema, _>(self, MetaschemaColumns::Key, "schema_hash");
+
+        if hash.is_none() || hash.unwrap().value != self.schema_hash {
+            if !allow_recreate {
+                panic!("No schema version in database, and not allowed to create!");
+            }
+            println!("Failed to retrieve schema; probably is empty database");
+
+            for ds in self.schema.drop() {
+                let prepared = self.conn.prepare(ds);
+                let result = prepared.unwrap().execute([]).expect("Creation sql failed");
+            }
+
+            for cs in self.schema.create() {
+                let prepared = self.conn.prepare(cs);
+                let result = prepared.unwrap().execute([]).expect("Creation sql failed");
+            }
+
+            query::add(self, &Metaschema { key: "schema_hash".to_string(), value: self.schema_hash.clone() });
+
+            println!("re-search results: {:?}", query::get_one_by::<Metaschema, _>(self, MetaschemaColumns::Key, "schema_hash"));
+        }
+
+        // println!("schema: {:?}", schema);
+    }
+}
+
+#[cfg(test)]
+mod test {
+    use super::DB;
+
+    #[derive(serde::Deserialize,crate::Model)]
+    struct S1 {
+        id: i32,
+    }
+
+    fn simple_schema() -> super::model::SchemaModel {
+        super::model::SchemaModel::new()
+            .add::<S1>()
+    }
+
+    #[test]
+    fn in_memory_schema() {
+        let db = DB::new_in_memory(simple_schema());
+    }
+}

+ 0 - 0
microrm/src/lookup.rs


+ 72 - 0
microrm/src/model.rs

@@ -0,0 +1,72 @@
+pub(crate) mod load;
+mod create;
+pub(crate) mod store;
+
+#[derive(Debug)]
+pub enum ModelError {
+    DBError(rusqlite::Error),
+    LoadError(String),
+    EmptyStoreError
+}
+
+impl From<rusqlite::Error> for ModelError {
+    fn from(e: rusqlite::Error) -> Self {
+        Self::DBError(e)
+    }
+}
+
+impl std::fmt::Display for ModelError {
+    fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
+        fmt.write_fmt(format_args!("{:?}", self))
+    }
+}
+
+impl serde::ser::Error for ModelError {
+    fn custom<T: std::fmt::Display>(msg: T) -> Self {
+        Self::LoadError(format!("{}", msg))
+    }
+}
+
+impl serde::de::Error for ModelError {
+    fn custom<T: std::fmt::Display>(msg: T) -> Self {
+        Self::LoadError(format!("{}", msg))
+    }
+}
+
+impl std::error::Error for ModelError { }
+
+/// A database entity, aka a struct representing a row in a table
+pub trait Entity {
+    fn table_name() -> &'static str;
+}
+
+/// How we describe an entire schema
+#[derive(Debug)]
+pub struct SchemaModel {
+    drop: Vec<String>,
+    create: Vec<String>,
+}
+
+impl SchemaModel {
+    pub fn new() -> Self {
+        Self { drop: Vec::new(), create: Vec::new() }
+    }
+
+    pub fn add<'de, E: Entity + EntityColumn + serde::Deserialize<'de>>(mut self) -> Self {
+        let (drop, create) = create::sql_for::<E>();
+        self.drop.push(drop);
+        self.create.push(create);
+        self
+    }
+
+    pub fn drop(&self) -> &Vec<String> { &self.drop }
+    pub fn create(&self) -> &Vec<String> { &self.create }
+}
+
+pub trait EntityColumn {
+    type Column;
+    fn count() -> usize where Self: Sized;
+    fn index(c: Self::Column) -> usize where Self: Sized;
+    fn name(c: Self::Column) -> &'static str where Self: Sized;
+    fn values(&self) -> Vec<&dyn rusqlite::ToSql>;
+}

+ 94 - 0
microrm/src/model/create.rs

@@ -0,0 +1,94 @@
+use serde::de::Visitor;
+
+#[derive(Debug)]
+pub struct CreateDeserializer<'de> {
+    table_name: Option<&'static str>,
+    column_names: Option<&'static [&'static str]>,
+    column_types: Vec<String>,
+    _de: std::marker::PhantomData<&'de u8>
+}
+
+impl<'de, 'a> serde::de::Deserializer<'de> for &'a mut CreateDeserializer<'de> {
+    type Error = super::ModelError;
+
+    // we (ab)use the forward_to_deserialize_any! macro to stub out the types we don't care about
+    serde::forward_to_deserialize_any! {
+        bool i8 i16 i128 u8 u16 u32 u64 u128 f32 f64 char str
+        bytes byte_buf option unit unit_struct newtype_struct seq tuple
+        tuple_struct map enum identifier ignored_any
+    }
+
+    fn deserialize_any<V: Visitor<'de>>(self, _v: V) -> Result<V::Value, Self::Error> {
+        todo!()
+    }
+
+    fn deserialize_i32<V: Visitor<'de>>(mut self, v: V) -> Result<V::Value, Self::Error> {
+        self.column_types.push("integer".to_owned());
+        v.visit_i32(0)
+    }
+
+    fn deserialize_i64<V: Visitor<'de>>(mut self, v: V) -> Result<V::Value, Self::Error> {
+        self.column_types.push("integer".to_owned());
+        v.visit_i64(0)
+    }
+
+    fn deserialize_string<V: Visitor<'de>>(mut self, v: V) -> Result<V::Value, Self::Error> {
+        self.column_types.push("varchar".to_owned());
+        v.visit_string("".to_owned())
+    }
+    
+
+    fn deserialize_struct<V: Visitor<'de>>(self, name: &'static str, fields: &'static [&'static str], v: V) -> Result<V::Value, Self::Error> {
+        self.table_name = Some(name);
+        self.column_names = Some(fields);
+        v.visit_seq(self)
+    }
+}
+
+impl<'de> serde::de::SeqAccess<'de> for CreateDeserializer<'de> {
+    type Error = super::ModelError;
+    
+    fn next_element_seed<T: serde::de::DeserializeSeed<'de>>(&mut self, seed: T) -> Result<Option<T::Value>, Self::Error> {
+        seed.deserialize(self).map(Some)
+    }
+}
+
+// trait SQLFor: crate::model::Entity
+
+pub fn sql_for<'de, T: crate::model::EntityColumn + crate::model::Entity + serde::de::Deserialize<'de>>() -> (String,String) {
+    let mut cd = CreateDeserializer { table_name: None, column_names: None, column_types: Vec::new(), _de: std::marker::PhantomData{} };
+
+    T::deserialize(&mut cd).expect("SQL creation failed!");
+
+    (
+        format!("DROP TABLE IF EXISTS {}", <T as crate::model::Entity>::table_name()),
+        format!("CREATE TABLE {} ({})",
+            <T as crate::model::Entity>::table_name(),
+            cd.column_names.unwrap().iter().zip(cd.column_types.iter())
+                .map(|(n,t)| n.to_string() + " " + t).collect::<Vec<_>>().join(",")
+            )
+    )
+}
+
+#[cfg(test)]
+mod test {
+    #[derive(serde::Deserialize,crate::Model)]
+    struct Empty {}
+
+    #[derive(serde::Deserialize,crate::Model)]
+    struct Single {
+        e: i32
+    }
+
+    #[test]
+    fn example_sql_for() {
+        assert_eq!(
+            super::sql_for::<Empty>(),
+            ("DROP TABLE IF EXISTS empty".to_owned(), "CREATE TABLE empty ()".to_owned())
+        );
+        assert_eq!(
+            super::sql_for::<Single>(),
+            ("DROP TABLE IF EXISTS single".to_owned(), "CREATE TABLE single (e integer)".to_owned())
+        );
+    }
+}

+ 52 - 0
microrm/src/model/load.rs

@@ -0,0 +1,52 @@
+use serde::de::Visitor;
+
+use rusqlite::Row;
+
+pub struct RowDeserializer<'de> {
+    row: &'de Row<'de>,
+    col_index: usize
+}
+
+impl<'de> RowDeserializer<'de> {
+    pub fn from_row(row: &'de Row) -> Self {
+        // we skip the rowid by starting at index 1
+        Self { row, col_index: 1 }
+    }
+
+    fn next_col_value<T>(&mut self) -> T {
+        self.col_index += 1;
+        todo!()
+    }
+}
+
+impl<'de, 'a> serde::de::Deserializer<'de> for &'a mut RowDeserializer<'de> {
+    type Error = super::ModelError;
+
+    fn deserialize_any<V: Visitor<'de>>(self, _v: V) -> Result<V::Value, Self::Error> {
+        todo!()
+    }
+
+    fn deserialize_string<V: Visitor<'de>>(self, v: V) -> Result<V::Value, Self::Error> {
+        let res = v.visit_string(self.row.get(self.col_index)?);
+        self.col_index += 1;
+        res
+    }
+
+    fn deserialize_struct<V: Visitor<'de>>(self, _name: &'static str, _fields: &'static [&'static str], v: V) -> Result<V::Value, Self::Error> {
+        v.visit_seq(self)
+    }
+
+    serde::forward_to_deserialize_any! {
+        bool i8 i16 i32 i64 i128 u8 u16 u32 u64 u128 f32 f64 char str
+        bytes byte_buf option unit unit_struct newtype_struct seq tuple
+        tuple_struct map enum identifier ignored_any
+    }
+}
+
+impl<'de> serde::de::SeqAccess<'de> for RowDeserializer<'de> {
+    type Error = super::ModelError;
+    
+    fn next_element_seed<T: serde::de::DeserializeSeed<'de>>(&mut self, seed: T) -> Result<Option<T::Value>, Self::Error> {
+        seed.deserialize(self).map(Some)
+    }
+}

+ 3 - 0
microrm/src/model/store.rs

@@ -0,0 +1,3 @@
+pub fn serialize_as_row<'data, T: serde::Serialize + crate::model::EntityColumn>(value: &'data T) -> Vec<&'data dyn rusqlite::ToSql> {
+    (value as &dyn crate::model::EntityColumn<Column = _>).values()
+}

+ 78 - 0
microrm/src/query.rs

@@ -0,0 +1,78 @@
+pub use crate::DB;
+
+#[derive(Clone,Copy,Debug)]
+pub struct ID (i64);
+
+#[derive(Debug)]
+pub struct WithID<T: crate::model::Entity + crate::model::EntityColumn> {
+    wrap: T,
+    id: ID
+}
+
+impl<T: crate::model::Entity + crate::model::EntityColumn> WithID<T> {
+    fn wrap(what: T, raw_id: i64) -> Self {
+        Self { wrap: what, id: ID { 0: raw_id } }
+    }
+}
+
+impl<T: crate::model::Entity + crate::model::EntityColumn> WithID<T> {
+    pub fn id(&self) -> ID { self.id }
+}
+
+impl<T: crate::model::Entity + crate::model::EntityColumn> AsRef<T> for WithID<T> {
+    fn as_ref(&self) -> &T { return &self.wrap }
+}
+
+impl<T: crate::model::Entity + crate::model::EntityColumn> std::ops::Deref for WithID<T> {
+    type Target = T;
+    fn deref(&self) -> &Self::Target { &self.wrap }
+}
+
+impl<T: crate::model::Entity + crate::model::EntityColumn> std::ops::DerefMut for WithID<T> {
+    fn deref_mut(&mut self) -> &mut Self::Target { &mut self.wrap }
+}
+
+
+/// Search for an entity by a property
+pub fn get_one_by<T: crate::model::Entity + crate::model::EntityColumn + for<'de> serde::Deserialize<'de>, V: rusqlite::ToSql>(
+    db: &DB, c: <T as crate::model::EntityColumn>::Column, val: V) -> Option<WithID<T>> {
+
+    let table_name = <T as crate::model::Entity>::table_name();
+    let column_name = <T as crate::model::EntityColumn>::name(c);
+    let mut prepared = db.conn.prepare(&format!("SELECT rowid, tbl.* FROM {} tbl WHERE {} = ?1", table_name, column_name)).ok()?;
+
+    let result = prepared.query_row([&val], |row| {
+        let mut deser = crate::model::load::RowDeserializer::from_row(row);
+        Ok(WithID::wrap(T::deserialize(&mut deser).expect("deserialization works"), row.get(0).expect("can get rowid")))
+    });
+
+    result.ok()
+}
+
+/// Add an entity to its table
+pub fn add<T: crate::model::Entity + crate::model::EntityColumn + serde::Serialize>(db: &DB, m: &T) -> Option<ID> {
+    let row = crate::model::store::serialize_as_row(m);
+
+    let placeholders = (0..<T as crate::model::EntityColumn>::count()).map(|n| format!("?{}", n+1)).collect::<Vec<_>>().join(",");
+
+    let res = db.conn.prepare(&format!("INSERT INTO {} VALUES ({})", <T as crate::model::Entity>::table_name(), placeholders));
+    let mut prepared = res.ok()?;
+
+    // make sure we bound enough things
+    assert_eq!(row.len(), <T as crate::model::EntityColumn>::count());
+
+    let id = prepared.insert(rusqlite::params_from_iter(row)).ok()?;
+    Some(ID { 0: id })
+}
+
+pub struct Context<'a> {
+    db: &'a DB
+}
+
+impl<'a> Context<'a> {
+    pub fn get_one_by<T: crate::model::Entity + crate::model::EntityColumn + for<'de> serde::Deserialize<'de>, V: rusqlite::ToSql>(
+        &self, c: <T as crate::model::EntityColumn>::Column, val: V) -> Option<WithID<T>> {
+
+        get_one_by(self.db, c, val)
+    }
+}

+ 18 - 0
microrm/src/schema.rs

@@ -0,0 +1,18 @@
+use serde::{Serialize,Deserialize};
+
+#[derive(Serialize,Deserialize)]
+pub struct Realm {
+    realm_id: i32,
+}
+
+#[derive(Serialize,Deserialize)]
+pub struct Key {
+    key_id: i32,
+    realm_id: i32
+}
+
+pub fn schema() -> super::model::SchemaModel {
+    super::model::SchemaModel::new()
+        .add::<Realm>()
+        .add::<Key>()
+}