Quellcode durchsuchen

Added entity deletion, count, and join queries.

Kestrel vor 1 Jahr
Ursprung
Commit
9fece07a2b
4 geänderte Dateien mit 331 neuen und 102 gelöschten Zeilen
  1. 75 21
      microrm/src/query.rs
  2. 83 16
      microrm/src/query/components.rs
  3. 6 65
      microrm/src/schema.rs
  4. 167 0
      microrm/src/schema/tests.rs

+ 75 - 21
microrm/src/query.rs

@@ -85,15 +85,13 @@ pub(crate) enum QueryPart {
 }
 
 #[derive(Debug)]
-pub struct Query<'a> {
-    conn: &'a DBConnection,
+pub struct Query {
     parts: HashMap<QueryPart, Vec<String>>,
 }
 
-impl<'a> Query<'a> {
-    pub(crate) fn new(conn: &'a DBConnection) -> Self {
+impl Query {
+    pub(crate) fn new() -> Self {
         Self {
-            conn,
             parts: Default::default(),
         }
     }
@@ -153,12 +151,11 @@ impl<'a> Query<'a> {
 
         let join_ = match self.parts.remove(&QueryPart::Join) {
             None => String::new(),
-            Some(v) => {
-                format!(
-                    "INNER JOIN {}",
-                    v.into_iter().reduce(|a, b| format!("{} {}", a, b)).unwrap()
-                )
-            }
+            Some(v) => v
+                .into_iter()
+                .map(|j| format!("INNER JOIN {}", j))
+                .reduce(|a, b| format!("{} {}", a, b))
+                .unwrap(),
         };
 
         let where_ = match self.parts.remove(&QueryPart::Where) {
@@ -382,8 +379,9 @@ pub trait Queryable {
     type OutputContainer: OutputContainer;
     type StaticVersion: Queryable + 'static;
 
-    fn build<'s, 'q: 's>(&'s self) -> Query<'s>;
+    fn build(&self) -> Query;
     fn bind(&self, stmt: &mut sqlite::Statement, index: &mut usize);
+    fn conn(&self) -> &DBConnection;
 
     // ----------------------------------------------------------------------
     // Verbs
@@ -395,16 +393,41 @@ pub trait Queryable {
     where
         Self: Sized,
     {
-        todo!()
+        struct CountTag;
+        self.conn().with_prepared(
+            std::any::TypeId::of::<(Self::StaticVersion, CountTag)>(),
+            || {
+                self.build()
+                    .replace(
+                        QueryPart::Columns,
+                        format!(
+                            "COUNT(DISTINCT `{}`.`id`)",
+                            Self::EntityOutput::entity_name()
+                        ),
+                    )
+                    .assemble()
+            },
+            |stmt| {
+                stmt.reset()?;
+
+                // starting index is 1
+                let mut index = 1;
+                self.bind(stmt, &mut index);
+
+                stmt.next()?;
+
+                Ok(stmt.read::<i64, _>(0)? as usize)
+            },
+        )
     }
     /// Get all entities in the current context.
     fn get(self) -> DBResult<Self::OutputContainer>
     where
         Self: Sized,
     {
-        let q = self.build();
-        q.conn.with_prepared(
-            std::any::TypeId::of::<Self::StaticVersion>(),
+        struct GetTag;
+        self.conn().with_prepared(
+            std::any::TypeId::of::<(Self::StaticVersion, GetTag)>(),
             || self.build().assemble(),
             |stmt| {
                 stmt.reset()?;
@@ -413,7 +436,7 @@ pub trait Queryable {
                 let mut index = 1;
                 self.bind(stmt, &mut index);
 
-                <Self::OutputContainer>::assemble_from(q.conn, stmt)
+                <Self::OutputContainer>::assemble_from(self.conn(), stmt)
             },
         )
     }
@@ -424,7 +447,33 @@ pub trait Queryable {
     where
         Self: Sized,
     {
-        todo!()
+        struct DeleteTag;
+        self.conn().with_prepared(
+            std::any::TypeId::of::<(Self::StaticVersion, DeleteTag)>(),
+            || {
+                format!(
+                    "DELETE FROM `{}` WHERE `id` = ({})",
+                    Self::EntityOutput::entity_name(),
+                    self.build()
+                        .replace(
+                            QueryPart::Columns,
+                            format!("`{}`.`id`", Self::EntityOutput::entity_name())
+                        )
+                        .assemble()
+                )
+            },
+            |stmt| {
+                stmt.reset()?;
+
+                // starting index is 1
+                let mut index = 1;
+                self.bind(stmt, &mut index);
+
+                stmt.next()?;
+
+                Ok(stmt.read::<i64, _>(0)? as usize)
+            },
+        )
     }
 
     // ----------------------------------------------------------------------
@@ -475,7 +524,7 @@ pub trait Queryable {
     fn join<AD: AssocInterface, EP: EntityPart<Entity = Self::EntityOutput, Datum = AD>>(
         self,
         part: EP,
-    ) -> impl Queryable<EntityOutput = AD::RemoteEntity>
+    ) -> impl Queryable<EntityOutput = AD::RemoteEntity, OutputContainer = Vec<IDWrap<AD::RemoteEntity>>>
     where
         Self: Sized,
     {
@@ -489,7 +538,7 @@ impl<'a, AI: AssocInterface> Queryable for &'a AI {
     type OutputContainer = Vec<IDWrap<AI::RemoteEntity>>;
     type StaticVersion = &'static AI;
 
-    fn build<'s, 'q: 's>(&'s self) -> Query<'s> {
+    fn build(&self) -> Query {
         unreachable!()
     }
 
@@ -497,6 +546,10 @@ impl<'a, AI: AssocInterface> Queryable for &'a AI {
         unreachable!()
     }
 
+    fn conn(&self) -> &DBConnection {
+        &self.get_data().unwrap().conn
+    }
+
     fn count(self) -> DBResult<usize> {
         components::AssocQueryable::new(self).count()
     }
@@ -531,7 +584,8 @@ impl<'a, AI: AssocInterface> Queryable for &'a AI {
     fn join<AD: AssocInterface, EP: EntityPart<Entity = Self::EntityOutput, Datum = AD>>(
         self,
         part: EP,
-    ) -> impl Queryable<EntityOutput = AD::RemoteEntity> {
+    ) -> impl Queryable<EntityOutput = AD::RemoteEntity, OutputContainer = Vec<IDWrap<AD::RemoteEntity>>>
+    {
         components::AssocQueryable::new(self).join(part)
     }
 }

+ 83 - 16
microrm/src/query/components.rs

@@ -6,7 +6,7 @@ use crate::{
     schema::{
         datum::{Datum, DatumList, DatumListRef, DatumVisitor},
         entity::{Entity, EntityPart, EntityPartList, EntityPartVisitor},
-        IDMap, IDWrap,
+        DatumDiscriminator, IDMap, IDWrap,
     },
     DBResult,
 };
@@ -29,13 +29,17 @@ impl<'a, E: Entity> Queryable for MapQueryable<'a, E> {
     type OutputContainer = Vec<IDWrap<E>>;
     type StaticVersion = MapQueryable<'static, E>;
 
-    fn build<'s, 'q: 's>(&'s self) -> Query<'s> {
-        Query::new(&self.map.conn())
-            .attach(QueryPart::Root, "SELECT".into())
+    fn build(&self) -> Query {
+        Query::new()
+            .attach(QueryPart::Root, "SELECT DISTINCT".into())
             .attach(QueryPart::Columns, "*".into())
             .attach(QueryPart::From, format!("`{}`", E::entity_name()))
     }
     fn bind(&self, _stmt: &mut sqlite::Statement, _index: &mut usize) {}
+
+    fn conn(&self) -> &crate::db::DBConnection {
+        self.map.conn()
+    }
 }
 
 /// Concrete implementation of Queryable for an IDMap
@@ -54,15 +58,15 @@ impl<'a, AI: AssocInterface> Queryable for AssocQueryable<'a, AI> {
     type OutputContainer = Vec<IDWrap<AI::RemoteEntity>>;
     type StaticVersion = AssocQueryable<'static, AI>;
 
-    fn build<'s, 'q: 's>(&'s self) -> Query<'s> {
+    fn build(&self) -> Query {
         let adata = self
             .assoc
             .get_data()
             .expect("building query for assoc with no data");
         let anames = super::AssocNames::collect(self.assoc).unwrap();
         let assoc_name = anames.assoc_name();
-        Query::new(&adata.conn)
-            .attach(QueryPart::Root, "SELECT".into())
+        Query::new()
+            .attach(QueryPart::Root, "SELECT DISTINCT".into())
             .attach(QueryPart::Columns, format!("`{}`.*", anames.remote_name))
             .attach(QueryPart::From, format!("`{}`", assoc_name))
             .attach(
@@ -87,6 +91,10 @@ impl<'a, AI: AssocInterface> Queryable for AssocQueryable<'a, AI> {
             .expect("couldn't bind assoc id");
         *index += 1;
     }
+
+    fn conn(&self) -> &crate::db::DBConnection {
+        &self.assoc.get_data().unwrap().conn
+    }
 }
 
 /// Filter on a Datum
@@ -111,7 +119,7 @@ impl<'a, WEP: EntityPart, Parent: Queryable> Queryable for WithComponent<'a, WEP
     type OutputContainer = Parent::OutputContainer;
     type StaticVersion = WithComponent<'static, WEP, Parent::StaticVersion>;
 
-    fn build<'s, 'q: 's>(&'s self) -> Query<'s> {
+    fn build(&self) -> Query {
         self.parent.build().attach(
             QueryPart::Where,
             format!(
@@ -126,6 +134,10 @@ impl<'a, WEP: EntityPart, Parent: Queryable> Queryable for WithComponent<'a, WEP
         self.datum.bind_to(stmt, *index);
         *index += 1;
     }
+
+    fn conn(&self) -> &crate::db::DBConnection {
+        self.parent.conn()
+    }
 }
 
 /// Filter on the unique index
@@ -148,11 +160,11 @@ impl<'a, E: Entity, Parent: Queryable> Queryable for UniqueComponent<'a, E, Pare
     type OutputContainer = Option<IDWrap<E>>;
     type StaticVersion = UniqueComponent<'static, E, Parent::StaticVersion>;
 
-    fn build<'s, 'q: 's>(&'s self) -> Query<'s> {
+    fn build(&self) -> Query {
         let mut query = self.parent.build();
 
-        struct PartVisitor<'a, 'b>(&'a mut Query<'b>);
-        impl<'a, 'b> EntityPartVisitor for PartVisitor<'a, 'b> {
+        struct PartVisitor<'a>(&'a mut Query);
+        impl<'a> EntityPartVisitor for PartVisitor<'a> {
             fn visit<EP: EntityPart>(&mut self) {
                 self.0.attach_mut(
                     QueryPart::Where,
@@ -183,6 +195,10 @@ impl<'a, E: Entity, Parent: Queryable> Queryable for UniqueComponent<'a, E, Pare
 
         self.datum.accept(&mut Visitor(stmt, index));
     }
+
+    fn conn(&self) -> &crate::db::DBConnection {
+        self.parent.conn()
+    }
 }
 
 pub(crate) struct SingleComponent<Parent: Queryable> {
@@ -200,7 +216,7 @@ impl<Parent: Queryable> Queryable for SingleComponent<Parent> {
     type OutputContainer = Option<IDWrap<Self::EntityOutput>>;
     type StaticVersion = SingleComponent<Parent::StaticVersion>;
 
-    fn build<'s, 'q: 's>(&'s self) -> Query<'s> {
+    fn build(&self) -> Query {
         self.parent
             .build()
             .attach(QueryPart::Trailing, "LIMIT 1".into())
@@ -209,6 +225,10 @@ impl<Parent: Queryable> Queryable for SingleComponent<Parent> {
     fn bind(&self, stmt: &mut sqlite::Statement, index: &mut usize) {
         self.parent.bind(stmt, index)
     }
+
+    fn conn(&self) -> &crate::db::DBConnection {
+        self.parent.conn()
+    }
 }
 
 /// Join with another entity via an association
@@ -236,13 +256,60 @@ impl<R: Entity, L: Entity, EP: EntityPart<Entity = L>, Parent: Queryable> Querya
     type OutputContainer = Vec<IDWrap<R>>;
     type StaticVersion = JoinComponent<R, L, EP, Parent::StaticVersion>;
 
-    fn build<'s, 'q: 's>(&'s self) -> Query<'s> {
-        // self.parent.build()
-        todo!()
+    fn build(&self) -> Query {
+        let remote_name = R::entity_name();
+        let local_name = L::entity_name();
+        let part_name = EP::part_name();
+        let assoc_name = format!("{local_name}_{remote_name}_assoc_{part_name}");
+
+        struct Discriminator(Option<(&'static str, &'static str)>);
+        impl DatumDiscriminator for Discriminator {
+            fn visit_entity_id<E: Entity>(&mut self) {
+                unreachable!()
+            }
+            fn visit_serialized<T: serde::Serialize + serde::de::DeserializeOwned>(&mut self) {
+                unreachable!()
+            }
+            fn visit_bare_field<T: Datum>(&mut self) {
+                unreachable!()
+            }
+
+            fn visit_assoc_map<E: Entity>(&mut self) {
+                self.0 = Some(("domain", "range"));
+            }
+            fn visit_assoc_domain<R: crate::schema::Relation>(&mut self) {
+                self.0 = Some(("domain", "range"));
+            }
+            fn visit_assoc_range<R: crate::schema::Relation>(&mut self) {
+                self.0 = Some(("range", "domain"));
+            }
+        }
+
+        let mut d = Discriminator(None);
+        <EP::Datum>::accept_discriminator(&mut d);
+
+        let (local_field, remote_field) = d.0.unwrap();
+
+        self.parent
+            .build()
+            .attach(
+                QueryPart::Join,
+                format!("`{assoc_name}` ON `{local_name}`.`id` = `{assoc_name}`.`{local_field}`"),
+            )
+            .attach(
+                QueryPart::Join,
+                format!(
+                    "`{remote_name}` ON `{assoc_name}`.`{remote_field}` = `{remote_name}`.`id`"
+                ),
+            )
+            .replace(QueryPart::Columns, format!("`{remote_name}`.*"))
     }
     fn bind(&self, stmt: &mut sqlite::Statement, index: &mut usize) {
         self.parent.bind(stmt, index);
-        todo!()
+    }
+
+    fn conn(&self) -> &crate::db::DBConnection {
+        self.parent.conn()
     }
 }
 

+ 6 - 65
microrm/src/schema.rs

@@ -174,14 +174,6 @@ impl<T: Entity> AssocMap<T> {
     }
 }
 
-/*impl<T: Entity> EntityMap for AssocMap<T> {
-    type ContainedEntity = T;
-
-    fn conn(&self) -> &DBConnection {
-        &self.data.as_ref().unwrap().conn
-    }
-}*/
-
 impl<T: Entity> Datum for AssocMap<T> {
     fn sql_type() -> &'static str {
         unreachable!()
@@ -348,14 +340,6 @@ impl<R: Relation> AssocInterface for AssocRange<R> {
     }
 }
 
-/*impl<R: Relation> EntityMap for AssocRange<R> {
-    type ContainedEntity = R::Domain;
-
-    fn conn(&self) -> &DBConnection {
-        &self.data.as_ref().unwrap().conn
-    }
-}*/
-
 impl<R: Relation> Datum for AssocRange<R> {
     fn sql_type() -> &'static str {
         unreachable!()
@@ -460,28 +444,12 @@ impl<T: 'static + serde::Serialize + serde::de::DeserializeOwned> Datum for Seri
 // Database specification types
 // ----------------------------------------------------------------------
 
-/*
-/// Trait for a type that represents a sqlite table that contains entities.
-pub(crate) trait EntityMap {
-    type ContainedEntity: Entity;
-
-    fn conn(&self) -> &DBConnection;
-}
-*/
-
 /// Table with EntityID-based lookup.
 pub struct IDMap<T: Entity> {
     conn: DBConnection,
     _ghost: std::marker::PhantomData<T>,
 }
 
-/*impl<T: Entity> EntityMap for IDMap<T> {
-    type ContainedEntity = T;
-    fn conn(&self) -> &DBConnection {
-        &self.conn
-    }
-}*/
-
 impl<T: Entity> IDMap<T> {
     pub fn build(db: DBConnection) -> Self {
         Self {
@@ -499,37 +467,6 @@ impl<T: Entity> IDMap<T> {
         self.with(id, &id).first().get()
     }
 
-    /*
-    /// Retrieve all entities in this map.
-    pub fn get_all<E: Entity>(&self) -> DBResult<Vec<IDWrap<E>>> {
-        query::get_all(&self.conn)
-    }
-
-
-    /// Look up an Entity in this map by the unique-tagged fields.
-    ///
-    /// Fields are passed to this function in order of specification in the original `struct` definition.
-    pub fn lookup_unique(
-        &self,
-        uniques: &<<T as Entity>::Uniques as EntityPartList>::DatumList,
-    ) -> DBResult<Option<IDWrap<T>>> {
-        query::select_by::<T, T::Uniques>(self, uniques).map(|mut v| {
-            if v.len() > 0 {
-                Some(v.remove(0))
-            } else {
-                None
-            }
-        })
-    }
-
-    pub fn delete_unique(
-        &self,
-        uniques: &<<T as Entity>::Uniques as EntityPartList>::DatumList,
-    ) -> DBResult<()> {
-        query::delete_by::<T, T::Uniques>(&self.conn, uniques)
-    }
-    */
-
     /// Insert a new Entity into this map, and return its new ID.
     pub fn insert(&self, value: T) -> DBResult<T::ID> {
         query::insert(self.conn(), value)
@@ -541,13 +478,17 @@ impl<'a, T: Entity> Queryable for &'a IDMap<T> {
     type OutputContainer = Vec<IDWrap<T>>;
     type StaticVersion = &'static IDMap<T>;
 
-    fn build<'s, 'q: 's>(&'s self) -> query::Query<'s> {
+    fn build(&self) -> query::Query {
         unreachable!()
     }
     fn bind(&self, _stmt: &mut sqlite::Statement, _index: &mut usize) {
         unreachable!()
     }
 
+    fn conn(&self) -> &DBConnection {
+        &self.conn
+    }
+
     fn count(self) -> DBResult<usize> {
         query::components::MapQueryable::new(self).count()
     }
@@ -587,7 +528,7 @@ impl<'a, T: Entity> Queryable for &'a IDMap<T> {
     fn join<AD: AssocInterface, EP: entity::EntityPart<Entity = Self::EntityOutput, Datum = AD>>(
         self,
         part: EP,
-    ) -> impl Queryable<EntityOutput = AD::RemoteEntity>
+    ) -> impl Queryable<EntityOutput = AD::RemoteEntity, OutputContainer = Vec<IDWrap<AD::RemoteEntity>>>
     where
         Self: Sized,
     {

+ 167 - 0
microrm/src/schema/tests.rs

@@ -287,6 +287,24 @@ mod derive_tests {
             permissions: "permissions A".to_string(),
         });
     }
+
+    #[test]
+    fn delete_test() {
+        let db = PeopleDB::open_path(":memory:").expect("couldn't open test db");
+
+        let id = db
+            .people
+            .insert(Person {
+                name: "person_name".to_string(),
+                roles: Default::default(),
+            })
+            .expect("couldn't insert test person");
+        assert!(db.people.by_id(id).expect("couldn't query db").is_some());
+
+        db.people.with(id, &id).delete();
+
+        assert!(db.people.by_id(id).expect("couldn't query db").is_none());
+    }
 }
 
 mod mutual_relationship {
@@ -424,3 +442,152 @@ mod reserved_words {
         ReservedWordDB::open_path(":memory:");
     }
 }
+
+mod join_test {
+    use super::open_test_db;
+    use crate::prelude::*;
+    use crate::schema;
+
+    #[derive(Default, Entity)]
+    struct Base {
+        name: String,
+        targets: AssocMap<Target>,
+    }
+
+    #[derive(Default, Entity)]
+    struct Target {
+        name: String,
+        indirect_targets: AssocMap<IndirectTarget>,
+    }
+
+    #[derive(Default, Entity)]
+    struct IndirectTarget {
+        name: String,
+    }
+
+    #[derive(Database)]
+    struct JoinDB {
+        bases: IDMap<Base>,
+        targets: IDMap<Target>,
+        indirect: IDMap<IndirectTarget>,
+    }
+
+    #[test]
+    fn simple_join() {
+        let db = open_test_db::<JoinDB>("simple_join_test");
+
+        let b1id = db
+            .bases
+            .insert(Base {
+                name: "base1".to_string(),
+                targets: Default::default(),
+            })
+            .expect("couldn't insert base");
+        let b2id = db
+            .bases
+            .insert(Base {
+                name: "base2".to_string(),
+                targets: Default::default(),
+            })
+            .expect("couldn't insert base");
+        let b3id = db
+            .bases
+            .insert(Base {
+                name: "base3".to_string(),
+                targets: Default::default(),
+            })
+            .expect("couldn't insert base");
+
+        let t1id = db
+            .targets
+            .insert(Target {
+                name: "target1".to_string(),
+                ..Default::default()
+            })
+            .expect("couldn't insert target");
+        let t2id = db
+            .targets
+            .insert(Target {
+                name: "target2".to_string(),
+                ..Default::default()
+            })
+            .expect("couldn't insert target");
+        let t3id = db
+            .targets
+            .insert(Target {
+                name: "target3".to_string(),
+                ..Default::default()
+            })
+            .expect("couldn't insert target");
+
+        let it1id = db
+            .indirect
+            .insert(IndirectTarget {
+                name: "itarget1".to_string(),
+                ..Default::default()
+            })
+            .expect("couldn't insert target");
+        let it2id = db
+            .indirect
+            .insert(IndirectTarget {
+                name: "itarget2".to_string(),
+                ..Default::default()
+            })
+            .expect("couldn't insert target");
+        let it3id = db
+            .indirect
+            .insert(IndirectTarget {
+                name: "itarget3".to_string(),
+                ..Default::default()
+            })
+            .expect("couldn't insert target");
+
+        let b1 = db
+            .bases
+            .by_id(b1id)
+            .expect("couldn't get base")
+            .expect("couldn't get base");
+        b1.targets.connect_to(t1id);
+        b1.targets.connect_to(t2id);
+
+        let b2 = db
+            .bases
+            .by_id(b2id)
+            .expect("couldn't get base")
+            .expect("couldn't get base");
+        b2.targets.connect_to(t2id);
+        b2.targets.connect_to(t3id);
+
+        let t1 = db
+            .targets
+            .by_id(t2id)
+            .expect("couldn't get target")
+            .expect("couldn't get target");
+        t1.indirect_targets.connect_to(it1id);
+
+        assert_eq!(
+            db.bases
+                .join(Base::Targets)
+                .get()
+                .expect("couldn't get joined results")
+                .len(),
+            3
+        );
+
+        let double_join = db
+            .bases
+            .join(Base::Targets)
+            .join(Target::IndirectTargets)
+            .get()
+            .expect("couldn't get double-joined results");
+        assert_eq!(double_join.len(), 1);
+
+        let double_join_count = db
+            .bases
+            .join(Base::Targets)
+            .join(Target::IndirectTargets)
+            .count()
+            .expect("couldn't count double-joined results");
+        assert_eq!(double_join_count, 1);
+    }
+}