collect.rs 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197
  1. use std::collections::HashMap;
  2. use crate::schema::datum::Datum;
  3. use crate::schema::entity::{Entity, EntityPart, EntityPartList, EntityPartVisitor, EntityVisitor};
  4. use crate::schema::{DatumDiscriminator, Relation};
  5. #[derive(Debug)]
  6. pub enum PartType {
  7. /// stores sql data type
  8. Datum(&'static str),
  9. /// stores the entity name
  10. IDReference(
  11. &'static str,
  12. ),
  13. AssocDomain {
  14. table_name: String,
  15. range_name: &'static str,
  16. },
  17. AssocRange {
  18. table_name: String,
  19. domain_name: &'static str,
  20. },
  21. }
  22. #[derive(Debug)]
  23. pub struct PartState {
  24. pub name: &'static str,
  25. pub ty: PartType,
  26. pub unique: bool,
  27. pub key: bool,
  28. }
  29. impl PartState {
  30. fn build<EP: EntityPart>() -> Self {
  31. struct Discriminator<EP: EntityPart> {
  32. ty: Option<PartType>,
  33. _ghost: std::marker::PhantomData<EP>,
  34. }
  35. impl<EP: EntityPart> DatumDiscriminator for Discriminator<EP> {
  36. fn visit_entity_id<E: Entity>(&mut self) {
  37. self.ty = Some(PartType::IDReference(E::entity_name()));
  38. }
  39. fn visit_bare_field<T: Datum>(&mut self) {
  40. self.ty = Some(PartType::Datum(T::sql_type()));
  41. }
  42. fn visit_serialized<T: serde::Serialize + serde::de::DeserializeOwned>(&mut self) {
  43. self.ty = Some(PartType::Datum("text"));
  44. }
  45. fn visit_assoc_map<E: Entity>(&mut self) {
  46. self.ty = Some(PartType::AssocDomain {
  47. table_name: format!(
  48. "{}_{}_assoc_{}",
  49. EP::Entity::entity_name(),
  50. E::entity_name(),
  51. EP::part_name()
  52. ),
  53. range_name: E::entity_name(),
  54. });
  55. }
  56. fn visit_assoc_domain<R: Relation>(&mut self) {
  57. self.ty = Some(PartType::AssocDomain {
  58. table_name: format!(
  59. "{}_{}_assoc_{}",
  60. R::Domain::entity_name(),
  61. R::Range::entity_name(),
  62. R::NAME
  63. ),
  64. range_name: R::Range::entity_name(),
  65. });
  66. }
  67. fn visit_assoc_range<R: Relation>(&mut self) {
  68. self.ty = Some(PartType::AssocRange {
  69. table_name: format!(
  70. "{}_{}_assoc_{}",
  71. R::Domain::entity_name(),
  72. R::Range::entity_name(),
  73. R::NAME
  74. ),
  75. domain_name: R::Domain::entity_name(),
  76. });
  77. }
  78. }
  79. let mut discrim = Discriminator::<EP> {
  80. ty: None,
  81. _ghost: Default::default(),
  82. };
  83. <EP::Datum>::accept_discriminator(&mut discrim);
  84. if let Some(ty) = discrim.ty {
  85. PartState {
  86. name: EP::part_name(),
  87. ty,
  88. unique: EP::unique(),
  89. key: false,
  90. }
  91. } else {
  92. unreachable!("no PartType extracted from EntityPart")
  93. }
  94. }
  95. }
  96. #[derive(Debug)]
  97. pub struct EntityState {
  98. pub name: &'static str,
  99. typeid: std::any::TypeId,
  100. pub parts: Vec<PartState>,
  101. }
  102. impl EntityState {
  103. fn build<E: Entity>() -> Self {
  104. #[derive(Default)]
  105. struct PartVisitor(Vec<PartState>);
  106. impl EntityPartVisitor for PartVisitor {
  107. fn visit<EP: EntityPart>(&mut self) {
  108. self.0.push(PartState::build::<EP>());
  109. }
  110. }
  111. let mut pv = PartVisitor::default();
  112. E::accept_part_visitor(&mut pv);
  113. struct KeyVisitor<'l>(&'l mut Vec<PartState>);
  114. impl<'l> EntityPartVisitor for KeyVisitor<'l> {
  115. fn visit<EP: EntityPart>(&mut self) {
  116. for part in self.0.iter_mut() {
  117. if part.name == EP::part_name() {
  118. part.key = true;
  119. }
  120. }
  121. }
  122. }
  123. <E::Keys as EntityPartList>::accept_part_visitor(&mut KeyVisitor(&mut pv.0));
  124. Self {
  125. name: E::entity_name(),
  126. typeid: std::any::TypeId::of::<E>(),
  127. parts: pv.0,
  128. }
  129. }
  130. }
  131. #[derive(Default, Debug)]
  132. pub struct EntityStateContainer {
  133. states: HashMap<&'static str, EntityState>,
  134. }
  135. impl EntityStateContainer {
  136. pub fn iter_states(&self) -> impl Iterator<Item = &EntityState> {
  137. self.states.values()
  138. }
  139. pub fn make_context(&mut self) -> EntityContext {
  140. EntityContext { container: self }
  141. }
  142. }
  143. pub struct EntityContext<'a> {
  144. container: &'a mut EntityStateContainer,
  145. }
  146. impl<'a> EntityVisitor for EntityContext<'a> {
  147. fn visit<E: Entity>(&mut self) {
  148. // three cases:
  149. // 1. we haven't seen this entity
  150. // 2. we've seen this entity before
  151. if self.container.states.contains_key(E::entity_name()) {
  152. return;
  153. }
  154. let entry = self.container.states.entry(E::entity_name());
  155. let entry = entry.or_insert_with(EntityState::build::<E>);
  156. // sanity-check
  157. if entry.typeid != std::any::TypeId::of::<E>() {
  158. panic!("Identical entity name but different typeid!");
  159. }
  160. struct RecursiveVisitor<'a, 'b>(&'a mut EntityContext<'b>);
  161. impl<'a, 'b> EntityPartVisitor for RecursiveVisitor<'a, 'b> {
  162. fn visit<EP: EntityPart>(&mut self) {
  163. EP::Datum::accept_entity_visitor(self.0);
  164. }
  165. }
  166. E::accept_part_visitor(&mut RecursiveVisitor(self));
  167. }
  168. }