create.rs 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288
  1. use serde::de::Visitor;
  2. use std::cell::Cell;
  3. use std::rc::Rc;
  4. #[derive(Debug)]
  5. pub struct CreateDeserializer<'de> {
  6. struct_visited: bool,
  7. column_types: Vec<&'static str>,
  8. expected_length: Rc<Cell<usize>>,
  9. _de: std::marker::PhantomData<&'de u8>,
  10. }
  11. impl<'de> CreateDeserializer<'de> {
  12. fn integral_type(&mut self) {
  13. self.column_types.push("integer");
  14. }
  15. }
  16. impl<'de, 'a> serde::de::Deserializer<'de> for &'a mut CreateDeserializer<'de> {
  17. type Error = super::ModelError;
  18. // we (ab)use the forward_to_deserialize_any! macro to stub out the types we don't care about
  19. serde::forward_to_deserialize_any! {
  20. bool i128 u64 u128 f32 f64 char str
  21. option unit unit_struct tuple
  22. tuple_struct map enum identifier ignored_any
  23. }
  24. fn deserialize_any<V: Visitor<'de>>(self, _v: V) -> Result<V::Value, Self::Error> {
  25. todo!()
  26. }
  27. fn deserialize_u8<V: Visitor<'de>>(self, v: V) -> Result<V::Value, Self::Error> {
  28. self.integral_type();
  29. v.visit_u8(0)
  30. }
  31. fn deserialize_u16<V: Visitor<'de>>(self, v: V) -> Result<V::Value, Self::Error> {
  32. self.integral_type();
  33. v.visit_u16(0)
  34. }
  35. fn deserialize_u32<V: Visitor<'de>>(self, v: V) -> Result<V::Value, Self::Error> {
  36. self.integral_type();
  37. v.visit_u32(0)
  38. }
  39. fn deserialize_i8<V: Visitor<'de>>(self, v: V) -> Result<V::Value, Self::Error> {
  40. self.integral_type();
  41. v.visit_i8(0)
  42. }
  43. fn deserialize_i16<V: Visitor<'de>>(self, v: V) -> Result<V::Value, Self::Error> {
  44. self.integral_type();
  45. v.visit_i16(0)
  46. }
  47. fn deserialize_i32<V: Visitor<'de>>(self, v: V) -> Result<V::Value, Self::Error> {
  48. self.integral_type();
  49. v.visit_i32(0)
  50. }
  51. fn deserialize_i64<V: Visitor<'de>>(self, v: V) -> Result<V::Value, Self::Error> {
  52. self.integral_type();
  53. v.visit_i64(0)
  54. }
  55. fn deserialize_string<V: Visitor<'de>>(self, v: V) -> Result<V::Value, Self::Error> {
  56. self.column_types.push("text");
  57. v.visit_string("".to_owned())
  58. }
  59. fn deserialize_bytes<V: Visitor<'de>>(self, v: V) -> Result<V::Value, Self::Error> {
  60. self.column_types.push("blob");
  61. v.visit_bytes(&[])
  62. }
  63. fn deserialize_byte_buf<V: Visitor<'de>>(self, v: V) -> Result<V::Value, Self::Error> {
  64. self.column_types.push("blob");
  65. v.visit_bytes(&[])
  66. }
  67. fn deserialize_seq<V: Visitor<'de>>(self, v: V) -> Result<V::Value, Self::Error> {
  68. v.visit_seq(self)
  69. }
  70. fn deserialize_struct<V: Visitor<'de>>(
  71. self,
  72. _name: &'static str,
  73. fields: &'static [&'static str],
  74. v: V,
  75. ) -> Result<V::Value, Self::Error> {
  76. if self.struct_visited {
  77. panic!("Nested structs not allowed!");
  78. } else {
  79. self.expected_length.set(fields.len());
  80. v.visit_seq(self)
  81. }
  82. }
  83. fn deserialize_newtype_struct<V: Visitor<'de>>(
  84. self,
  85. _name: &'static str,
  86. v: V,
  87. ) -> Result<V::Value, Self::Error> {
  88. let elength = self.expected_length.clone();
  89. let old_elength = elength.get();
  90. elength.set(1);
  91. let ret = v.visit_seq(self);
  92. elength.set(old_elength);
  93. ret
  94. }
  95. }
  96. impl<'de> serde::de::SeqAccess<'de> for CreateDeserializer<'de> {
  97. type Error = super::ModelError;
  98. fn next_element_seed<T: serde::de::DeserializeSeed<'de>>(
  99. &mut self,
  100. seed: T,
  101. ) -> Result<Option<T::Value>, Self::Error> {
  102. if self.expected_length.get() == 0 {
  103. return Err(Self::Error::CreateError);
  104. }
  105. self.expected_length.set(self.expected_length.get() - 1);
  106. seed.deserialize(self).map(Some)
  107. }
  108. }
  109. pub fn sql_for<T: crate::model::Entity>() -> (String, String) {
  110. let elength = Rc::new(Cell::new(0));
  111. let mut cd = CreateDeserializer {
  112. struct_visited: false,
  113. column_types: Vec::new(),
  114. expected_length: elength,
  115. _de: std::marker::PhantomData {},
  116. };
  117. T::deserialize(&mut cd).expect("SQL creation failed!");
  118. // +1 to account for id column that is included in column_count
  119. assert_eq!(T::column_count(), cd.column_types.len() + 1);
  120. let mut columns = Vec::new();
  121. columns.push("id integer primary key".to_owned());
  122. for i in 1..T::column_count() {
  123. let col = <T::Column as std::convert::TryFrom<usize>>::try_from(i).unwrap();
  124. let fk = T::foreign_keys()
  125. .iter()
  126. .filter(|x| x.local_column() == &col)
  127. .take(1);
  128. let fk = fk.map(|x| {
  129. format!(
  130. " references \"{}\"(\"{}\")",
  131. x.foreign_table_name(),
  132. x.foreign_column_name()
  133. )
  134. });
  135. columns.push(format!(
  136. "\"{}\" {}{}",
  137. T::name(col),
  138. cd.column_types[i - 1],
  139. fk.last().unwrap_or("".to_string())
  140. ));
  141. }
  142. (
  143. format!(
  144. "DROP TABLE IF EXISTS \"{}\"",
  145. <T as crate::model::Entity>::table_name()
  146. ),
  147. format!(
  148. "CREATE TABLE IF NOT EXISTS \"{}\" ({})",
  149. <T as crate::model::Entity>::table_name(),
  150. columns.join(",")
  151. ),
  152. )
  153. }
  154. #[cfg(test)]
  155. mod test {
  156. #[derive(serde::Serialize, serde::Deserialize, crate::Entity)]
  157. #[microrm_internal]
  158. pub struct Empty {}
  159. #[derive(serde::Serialize, serde::Deserialize, crate::Entity)]
  160. #[microrm_internal]
  161. pub struct Single {
  162. e: i32,
  163. }
  164. #[derive(serde::Serialize, serde::Deserialize, crate::Entity)]
  165. #[microrm_internal]
  166. pub struct Reference {
  167. e: SingleID,
  168. }
  169. #[test]
  170. fn example_sql_for() {
  171. assert_eq!(
  172. super::sql_for::<Empty>(),
  173. (
  174. r#"DROP TABLE IF EXISTS "empty""#.to_owned(),
  175. r#"CREATE TABLE IF NOT EXISTS "empty" (id integer primary key)"#.to_owned()
  176. )
  177. );
  178. assert_eq!(
  179. super::sql_for::<Single>(),
  180. (
  181. r#"DROP TABLE IF EXISTS "single""#.to_owned(),
  182. r#"CREATE TABLE IF NOT EXISTS "single" (id integer primary key,"e" integer)"#.to_owned()
  183. )
  184. );
  185. assert_eq!(
  186. super::sql_for::<Reference>(),
  187. (
  188. r#"DROP TABLE IF EXISTS "reference""#.to_owned(),
  189. r#"CREATE TABLE IF NOT EXISTS "reference" (id integer primary key,"e" integer)"#.to_owned()
  190. )
  191. );
  192. }
  193. #[derive(serde::Serialize, serde::Deserialize, crate::Modelable)]
  194. #[microrm_internal]
  195. pub struct Unit(u8);
  196. #[derive(serde::Serialize, serde::Deserialize, crate::Entity)]
  197. #[microrm_internal]
  198. pub struct UnitNewtype {
  199. newtype: Unit,
  200. }
  201. #[test]
  202. fn unit_newtype_struct() {
  203. assert_eq!(
  204. super::sql_for::<UnitNewtype>(),
  205. (
  206. r#"DROP TABLE IF EXISTS "unit_newtype""#.to_owned(),
  207. r#"CREATE TABLE IF NOT EXISTS "unit_newtype" (id integer primary key,"newtype" integer)"#
  208. .to_owned()
  209. )
  210. );
  211. }
  212. #[derive(serde::Serialize, serde::Deserialize, crate::Modelable)]
  213. #[microrm_internal]
  214. pub struct NonUnit(u8, u8);
  215. #[derive(serde::Serialize, serde::Deserialize, crate::Entity)]
  216. #[microrm_internal]
  217. pub struct NonUnitNewtype {
  218. newtype: NonUnit,
  219. }
  220. #[test]
  221. #[should_panic]
  222. fn nonunit_newtype_struct() {
  223. super::sql_for::<NonUnitNewtype>();
  224. }
  225. #[derive(serde::Serialize, serde::Deserialize, crate::Entity)]
  226. #[microrm_internal]
  227. pub struct Child {
  228. #[microrm_foreign]
  229. parent_id: SingleID,
  230. }
  231. #[test]
  232. fn test_foreign_key() {
  233. assert_eq!(
  234. super::sql_for::<Child>(),
  235. (
  236. r#"DROP TABLE IF EXISTS "child""#.to_owned(),
  237. r#"CREATE TABLE IF NOT EXISTS "child" (id integer primary key,"parent_id" integer references "single"("id"))"#.to_owned()
  238. )
  239. );
  240. }
  241. }