Sfoglia il codice sorgente

Added UNIQUE constraint on association mappings.

Kestrel 7 mesi fa
parent
commit
bb138d7270

+ 23 - 20
microrm-macros/src/entity.rs

@@ -25,22 +25,24 @@ fn is_elided(attrs: &Vec<syn::Attribute>) -> bool {
     attrs.iter().filter(|a| a.path.is_ident("elide")).count() > 0
 }
 
+fn is_unique(attrs: &Vec<syn::Attribute>) -> bool {
+    attrs.iter().filter(|a| a.path.is_ident("unique")).count() > 0
+}
+
 pub fn derive(tokens: proc_macro::TokenStream) -> proc_macro::TokenStream {
     let input: syn::DeriveInput = syn::parse_macro_input!(tokens);
 
-    let parts = if let syn::Data::Struct(syn::DataStruct {
-        struct_token: _,
-        fields: syn::Fields::Named(fields),
-        semi_token: _,
-    }) = input.data
-    {
-        fields
+    let parts = match input.data {
+        syn::Data::Struct(syn::DataStruct {
+            struct_token: _,
+            fields: syn::Fields::Named(fields),
+            semi_token: _,
+        }) => fields
             .named
             .into_iter()
             .map(|f| (f.ident.unwrap(), f.ty, f.attrs))
-            .collect::<Vec<_>>()
-    } else {
-        panic!("Can only derive Entity on data structs with named fields!");
+            .collect::<Vec<_>>(),
+        _ => panic!("Can only derive Entity on data structs with named fields!"),
     };
 
     let entity_ident = input.ident;
@@ -67,29 +69,22 @@ pub fn derive(tokens: proc_macro::TokenStream) -> proc_macro::TokenStream {
 
     let vis = input.vis;
 
-    let unique_ident = format_ident!("unique");
-
     // collect list of unique parts
     let unique_parts = parts
         .iter()
-        .filter(|part| {
-            part.2.iter().any(|attr| {
-                attr.parse_meta()
-                    .map(|a| a.path().is_ident(&unique_ident))
-                    .unwrap_or(false)
-            })
-        })
+        .filter(|part| is_unique(&part.2))
         .cloned()
         .collect::<Vec<_>>();
 
     let part_defs = parts.iter().map(|part| {
         let part_combined_name = make_combined_name(part);
+        let part_base_ident = &part.0;
         let part_base_name = &part.0.to_string();
         let part_type = &part.1;
 
         let placeholder = format!("${}_{}", entity_ident, part_base_name);
 
-        let unique = unique_parts.iter().any(|p| p.0 == part.0);
+        let unique = is_unique(&part.2);
 
         let doc = extract_doc_comment(&part.2);
 
@@ -111,6 +106,10 @@ pub fn derive(tokens: proc_macro::TokenStream) -> proc_macro::TokenStream {
                 fn desc() -> Option<&'static str> {
                     #doc
                 }
+
+                fn get_datum(from: &Self::Entity) -> &Self::Datum {
+                    &from.#part_base_ident
+                }
             }
         }
     });
@@ -214,6 +213,10 @@ pub fn derive(tokens: proc_macro::TokenStream) -> proc_macro::TokenStream {
             fn placeholder() -> &'static str { "TODO" }
 
             fn desc() -> Option<&'static str> { None }
+
+            fn get_datum(from: &Self::Entity) -> &Self::Datum {
+                unreachable!()
+            }
         }
 
         impl ::microrm::schema::datum::Datum for #id_ident {

+ 12 - 3
microrm/src/query.rs

@@ -322,7 +322,7 @@ fn do_connect<Remote: Entity>(
         hash_of(("connect", an.local_name, an.remote_name, an.part_name)),
         || {
             format!(
-                "insert into `{assoc_name}` (`{local_field}`, `{remote_field}`) values (?, ?)",
+                "insert into `{assoc_name}` (`{local_field}`, `{remote_field}`) values (?, ?) returning (`id`)",
                 assoc_name = an.assoc_name(),
                 local_field = an.local_field,
                 remote_field = an.remote_field
@@ -332,7 +332,11 @@ fn do_connect<Remote: Entity>(
             ctx.bind(1, adata.local_id)?;
             ctx.bind(2, remote_id.into_raw())?;
 
-            ctx.run().map(|_| ())
+            println!("bound values ({:?}, {:?})", adata.local_id, remote_id);
+
+            ctx.run()?
+                .ok_or(Error::LogicError("Already connected"))
+                .map(|v| println!("v: {:?}", v.read::<i64>(0)))
         },
     )
 }
@@ -343,6 +347,11 @@ pub trait AssocInterface: 'static {
     fn get_distinguishing_name(&self) -> DBResult<&'static str>;
     const SIDE: LocalSide;
 
+    fn query_all(&self) -> impl Queryable<EntityOutput = Self::RemoteEntity> {
+        components::TableComponent::<Self::RemoteEntity>::new(self.get_data().unwrap().conn.clone())
+        // IDMap::<Self::RemoteEntity>::build(self.get_data().unwrap().conn.clone())
+    }
+
     fn connect_to(&self, remote_id: <Self::RemoteEntity as Entity>::ID) -> DBResult<()>
     where
         Self: Sized,
@@ -638,7 +647,7 @@ pub trait Queryable: Clone {
     }
 }
 
-// Generic implementation for all IDMaps
+// Generic implementations for all IDMaps
 impl<'a, T: Entity> Queryable for &'a IDMap<T> {
     type EntityOutput = T;
     type OutputContainer = Vec<Stored<T>>;

+ 41 - 0
microrm/src/query/components.rs

@@ -13,6 +13,47 @@ use crate::{
 
 use super::Query;
 
+/// Allow manipulation of an entire table.
+pub(crate) struct TableComponent<E: Entity> {
+    conn: Connection,
+    _ghost: std::marker::PhantomData<E>,
+}
+
+impl<E: Entity> Clone for TableComponent<E> {
+    fn clone(&self) -> Self {
+        Self {
+            conn: self.conn.clone(),
+            _ghost: Default::default(),
+        }
+    }
+}
+
+impl<E: Entity> TableComponent<E> {
+    pub fn new(conn: Connection) -> Self {
+        Self {
+            conn,
+            _ghost: Default::default(),
+        }
+    }
+}
+
+impl<E: Entity> Queryable for TableComponent<E> {
+    type EntityOutput = E;
+    type OutputContainer = Vec<Stored<E>>;
+    type StaticVersion = Self;
+
+    fn build(&self) -> Query {
+        Query::new()
+            .attach(QueryPart::Root, "SELECT DISTINCT".into())
+            .attach(QueryPart::Columns, "*".into())
+            .attach(QueryPart::From, format!("`{}`", E::entity_name()))
+    }
+    fn bind(&self, _stmt: &mut StatementContext, _index: &mut i32) {}
+    fn conn(&self) -> &Connection {
+        &self.conn
+    }
+}
+
 /// Filter on a Datum
 #[derive(Clone)]
 pub(crate) struct WithComponent<WEP: EntityPart, Parent: Queryable, QE: QueryEquivalent<WEP::Datum>>

+ 2 - 1
microrm/src/schema.rs

@@ -313,7 +313,7 @@ impl<R: Relation> Datum for AssocDomain<R> {
     }
 
     fn accept_entity_visitor(v: &mut impl EntityVisitor) {
-        v.visit::<R::Domain>();
+        v.visit::<R::Range>();
     }
 
     fn accept_discriminator(d: &mut impl DatumDiscriminator)
@@ -549,6 +549,7 @@ impl<T: 'static + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Deb
 // ----------------------------------------------------------------------
 
 /// Table with EntityID-based lookup.
+#[derive(Clone)]
 pub struct IDMap<T: Entity> {
     pub(crate) conn: Connection,
     _ghost: std::marker::PhantomData<T>,

+ 14 - 2
microrm/src/schema/build.rs

@@ -19,6 +19,7 @@ struct ColumnInfo {
 struct TableInfo {
     table_name: String,
     columns: Vec<ColumnInfo>,
+    constraints: Vec<String>,
     dependencies: Vec<String>,
 }
 
@@ -27,6 +28,7 @@ impl TableInfo {
         TableInfo {
             table_name: name,
             columns: vec![],
+            constraints: vec![],
             dependencies: vec![],
         }
     }
@@ -49,10 +51,13 @@ impl TableInfo {
         });
 
         format!(
-            "create table `{}` (`id` integer primary key{}{});",
+            "create table `{}` (`id` integer primary key{}{}{});",
             self.table_name,
             columns.collect::<String>(),
-            fkeys.collect::<String>()
+            fkeys.collect::<String>(),
+            self.constraints
+                .iter()
+                .fold(String::new(), |a, b| format!("{}, {}", a, b))
         )
     }
 }
@@ -165,6 +170,10 @@ pub(crate) fn collect_from_database<DB: Database>() -> DatabaseSchema {
                         unique: false,
                     });
 
+                    assoc_table
+                        .constraints
+                        .push(format!("UNIQUE(`range`, `domain`)"));
+
                     tables.insert(assoc_table_name.clone(), assoc_table);
                 }
                 PartType::AssocRange {
@@ -189,6 +198,9 @@ pub(crate) fn collect_from_database<DB: Database>() -> DatabaseSchema {
                         unique: false,
                     });
 
+                    assoc_table
+                        .constraints
+                        .push(format!("unique(`range`, `domain`)"));
                     tables.insert(assoc_table_name.clone(), assoc_table);
                 }
             }

+ 3 - 1
microrm/src/schema/datum.rs

@@ -80,7 +80,9 @@ pub trait DatumList: Clone {
 
 /// A list of concrete datums.
 pub trait ConcreteDatumList: DatumList + Clone {
-    fn build_equivalent<'l>(from: &'l [&'l str]) -> Option<impl QueryEquivalentList<Self> + 'l>;
+    fn build_equivalent<'l>(
+        from: impl Iterator<Item = &'l str>,
+    ) -> Option<impl QueryEquivalentList<Self> + 'l>;
 }
 
 /// A walker for a DatumList instance.

+ 17 - 14
microrm/src/schema/datum/datum_list.rs

@@ -8,6 +8,13 @@ impl DatumList for () {
 
     const LEN: usize = 0;
 }
+impl ConcreteDatumList for () {
+    fn build_equivalent<'l>(
+        _from: impl Iterator<Item = &'l str>,
+    ) -> Option<impl QueryEquivalentList<Self> + 'l> {
+        Some(())
+    }
+}
 
 impl QueryEquivalentList<()> for () {}
 
@@ -20,12 +27,10 @@ impl<T: Datum> DatumList for T {
 }
 
 impl<T: ConcreteDatum> ConcreteDatumList for T {
-    fn build_equivalent<'l>(from: &'l [&'l str]) -> Option<impl QueryEquivalentList<Self> + 'l> {
-        if from.len() != 1 {
-            None
-        } else {
-            Some(StringQuery(from[0]))
-        }
+    fn build_equivalent<'l>(
+        mut from: impl Iterator<Item = &'l str>,
+    ) -> Option<impl QueryEquivalentList<Self> + 'l> {
+        Some(StringQuery(from.next()?))
     }
 }
 
@@ -40,12 +45,10 @@ impl<T0: Datum> DatumList for (T0,) {
 }
 
 impl<T0: ConcreteDatum> ConcreteDatumList for (T0,) {
-    fn build_equivalent<'l>(from: &'l [&'l str]) -> Option<impl QueryEquivalentList<Self> + 'l> {
-        if from.len() != 1 {
-            None
-        } else {
-            Some((StringQuery(from[0]),))
-        }
+    fn build_equivalent<'l>(
+        mut from: impl Iterator<Item = &'l str>,
+    ) -> Option<impl QueryEquivalentList<Self> + 'l> {
+        Some((StringQuery(from.next()?),))
     }
 }
 
@@ -63,9 +66,9 @@ macro_rules! datum_list {
         impl<$( $ty: ConcreteDatum, $e: QueryEquivalent<$ty> ),*> QueryEquivalentList<( $( $ty ),* )> for ( $( $e ),* ) {}
 
         impl<$( $ty: ConcreteDatum ),*> ConcreteDatumList for ($($ty),*) {
-            fn build_equivalent<'l>(from: &'l [&'l str]) -> Option<impl QueryEquivalentList<Self> + 'l> {
+            fn build_equivalent<'l>(mut from: impl Iterator<Item = &'l str>) -> Option<impl QueryEquivalentList<Self> + 'l> {
                 Some((
-                        $( StringQuery( from.get($n)? ) ),*
+                        $( if $n == $n { StringQuery( from.next()? ) } else { panic!() } ),*
                 ))
             }
         }

+ 5 - 3
microrm/src/schema/entity.rs

@@ -2,11 +2,11 @@ use std::{fmt::Debug, hash::Hash};
 
 use crate::{
     db::{Connection, StatementRow},
-    schema::datum::{Datum, DatumList},
+    schema::datum::Datum,
     DBResult,
 };
 
-use super::datum::{ConcreteDatum, QueryEquivalentList};
+use super::datum::{ConcreteDatum, ConcreteDatumList, QueryEquivalentList};
 
 pub(crate) mod helpers;
 
@@ -34,6 +34,8 @@ pub trait EntityPart: Default + Clone + 'static {
     fn placeholder() -> &'static str;
     fn unique() -> bool;
     fn desc() -> Option<&'static str>;
+
+    fn get_datum(from: &Self::Entity) -> &Self::Datum;
 }
 
 /// Visitor for traversing all `EntityPart`s in an `Entity` or `EntityPartList`.
@@ -45,7 +47,7 @@ pub trait EntityPartVisitor {
 
 /// List of EntityParts.
 pub trait EntityPartList: 'static {
-    type DatumList: DatumList + QueryEquivalentList<Self::DatumList>;
+    type DatumList: ConcreteDatumList + QueryEquivalentList<Self::DatumList>;
 
     fn build_datum_list(conn: &Connection, stmt: &mut StatementRow) -> DBResult<Self::DatumList>;
 

+ 1 - 1
microrm/src/schema/meta.rs

@@ -1,6 +1,6 @@
 use crate::schema::IDMap;
 
-#[derive(microrm_macros::Entity)]
+#[derive(Clone, microrm_macros::Entity)]
 pub struct Meta {
     /// Metadata key-value key
     #[unique]

+ 35 - 9
microrm/src/schema/tests.rs

@@ -1,7 +1,5 @@
 #![allow(unused)]
 
-use test_log::test;
-
 fn open_test_db<DB: super::Database>(identifier: &'static str) -> DB {
     let path = format!("/tmp/microrm-{identifier}.db");
     let _ = std::fs::remove_file(path.as_str());
@@ -81,6 +79,10 @@ mod manual_test_db {
         fn desc() -> Option<&'static str> {
             None
         }
+
+        fn get_datum(from: &Self::Entity) -> &Self::Datum {
+            unreachable!()
+        }
     }
 
     #[derive(Clone, Default)]
@@ -100,6 +102,10 @@ mod manual_test_db {
         fn desc() -> Option<&'static str> {
             None
         }
+
+        fn get_datum(from: &Self::Entity) -> &Self::Datum {
+            &from.name
+        }
     }
 
     impl Entity for SimpleEntity {
@@ -426,12 +432,22 @@ mod mutual_relationship {
             .expect("couldn't retrieve customer record")
             .expect("no customer record");
 
+        println!(
+            "connecting customer a (ID {:?}) to receipt a (ID {:?})",
+            ca, ra
+        );
         e_ca.receipts
             .connect_to(ra)
-            .expect("couldn't associate customer with receipt");
+            .expect("couldn't associate customer with receipt a");
+        println!(
+            "connecting customer a (ID {:?}) to receipt b (ID {:?})",
+            ca, rb
+        );
         e_ca.receipts
             .connect_to(rb)
-            .expect("couldn't associate customer with receipt");
+            .expect("couldn't associate customer with receipt b");
+
+        println!("connected!");
 
         // technically this can fail if sqlite gives ra and rb back in the opposite order, which is
         // valid behaviour
@@ -597,23 +613,33 @@ mod join_test {
             .by_id(b1id)
             .expect("couldn't get base")
             .expect("couldn't get base");
-        b1.targets.connect_to(t1id);
-        b1.targets.connect_to(t2id);
+        b1.targets
+            .connect_to(t1id)
+            .expect("couldn't connect b1 to t1id");
+        b1.targets
+            .connect_to(t2id)
+            .expect("couldn't connect b1 to t2id");
 
         let b2 = db
             .bases
             .by_id(b2id)
             .expect("couldn't get base")
             .expect("couldn't get base");
-        b2.targets.connect_to(t2id);
-        b2.targets.connect_to(t3id);
+        b2.targets
+            .connect_to(t2id)
+            .expect("couldn't connect b2 to t2id");
+        b2.targets
+            .connect_to(t3id)
+            .expect("couldn't connect b2 to t3id");
 
         let t1 = db
             .targets
             .by_id(t2id)
             .expect("couldn't get target")
             .expect("couldn't get target");
-        t1.indirect_targets.connect_to(it1id);
+        t1.indirect_targets
+            .connect_to(it1id)
+            .expect("couldn't connect t1 to it1id");
 
         assert_eq!(
             db.bases