update.rs 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310
  1. use super::build::{DerivedQuery, QueryComponent, QueryPart, StaticVersion};
  2. use super::{Filterable, Resolvable};
  3. use crate::entity::EntityColumn;
  4. use crate::model::Modelable;
  5. use crate::{Entity, Error, QueryInterface};
  6. use std::hash::{Hash, Hasher};
  7. use std::marker::PhantomData;
  8. pub struct Update<'r, 'q, T: Entity> {
  9. qi: &'r QueryInterface<'q>,
  10. _ghost: PhantomData<T>,
  11. }
  12. impl<'r, 'q, T: Entity> Update<'r, 'q, T> {
  13. pub fn new(qi: &'r QueryInterface<'q>) -> Self {
  14. Self {
  15. qi,
  16. _ghost: std::marker::PhantomData,
  17. }
  18. }
  19. pub fn to(self, to: &'r T) -> Entire<'r, 'q, T> {
  20. Entire { wrap: self, to }
  21. }
  22. }
  23. impl<'r, 'q, T: Entity> Settable<'r, 'q> for Update<'r, 'q, T>
  24. where
  25. 'q: 'r,
  26. {
  27. type Table = T;
  28. }
  29. impl<'r, 'q, T: Entity> StaticVersion for Update<'r, 'q, T> {
  30. type Is = Update<'static, 'static, T>;
  31. }
  32. impl<'r, 'q, T: Entity> QueryComponent for Update<'r, 'q, T> {
  33. fn derive(&self) -> DerivedQuery {
  34. DerivedQuery::new().add(QueryPart::Root, format!("UPDATE `{}`", T::table_name()))
  35. }
  36. fn contribute<H: Hasher>(&self, hasher: &mut H) {
  37. "update".hash(hasher);
  38. std::any::TypeId::of::<T>().hash(hasher);
  39. }
  40. // next binding point is the first, we do nothing here
  41. fn bind(&self, _stmt: &mut sqlite::Statement<'_>) -> Result<usize, Error> {
  42. Ok(1)
  43. }
  44. }
  45. impl<'r, 'q, T: Entity> Filterable<'r, 'q> for Update<'r, 'q, T> {
  46. type Output = ();
  47. type Table = T;
  48. }
  49. impl<'r, 'q, T: Entity> Resolvable<'r, 'q> for Update<'r, 'q, T> {
  50. type Output = ();
  51. fn qi(&self) -> &'r QueryInterface<'q> {
  52. self.qi
  53. }
  54. }
  55. pub struct Entire<'r, 'q, T: Entity> {
  56. wrap: Update<'r, 'q, T>,
  57. to: &'r T,
  58. }
  59. impl<'r, 'q, T: Entity> StaticVersion for Entire<'r, 'q, T> {
  60. type Is = Entire<'static, 'static, T>;
  61. }
  62. impl<'r, 'q, T: Entity> QueryComponent for Entire<'r, 'q, T> {
  63. fn derive(&self) -> DerivedQuery {
  64. let mut dq = self.wrap.derive();
  65. // skip ID column
  66. let cols = T::columns();
  67. for column in &cols[1..] {
  68. dq = dq.add(QueryPart::Set, format!("`{}` = ?", column.name()));
  69. }
  70. dq
  71. }
  72. fn contribute<H: Hasher>(&self, hasher: &mut H) {
  73. self.wrap.contribute(hasher);
  74. std::any::TypeId::of::<Self::Is>().hash(hasher);
  75. }
  76. fn bind(&self, stmt: &mut sqlite::Statement<'_>) -> Result<usize, Error> {
  77. let mut ind = self.wrap.bind(stmt)?;
  78. self.to.visit_values::<Error, _>(&mut |val| {
  79. val.bind_to(stmt, ind)?;
  80. ind += 1;
  81. Ok(())
  82. })?;
  83. Ok(ind)
  84. }
  85. }
  86. impl<'r, 'q, T: Entity> Resolvable<'r, 'q> for Entire<'r, 'q, T> {
  87. type Output = ();
  88. fn qi(&self) -> &'r QueryInterface<'q> {
  89. self.wrap.qi
  90. }
  91. }
  92. impl<'r, 'q, T: Entity> Filterable<'r, 'q> for Entire<'r, 'q, T> {
  93. type Output = ();
  94. type Table = T;
  95. }
  96. pub trait Settable<'r, 'q>: Resolvable<'r, 'q>
  97. where
  98. 'q: 'r,
  99. {
  100. type Table: Entity;
  101. fn update<C: EntityColumn<Entity = Self::Table>, G: Modelable + ?Sized>(
  102. self,
  103. col: C,
  104. given: &'r G,
  105. ) -> Set<'r, 'q, Self, C, G>
  106. where
  107. Self: Sized,
  108. {
  109. Set {
  110. wrap: self,
  111. col,
  112. given,
  113. _ghost: PhantomData,
  114. }
  115. }
  116. }
  117. /// A concrete SET clause
  118. pub struct Set<'r, 'q, S: Settable<'r, 'q>, C: EntityColumn, G: Modelable + ?Sized>
  119. where
  120. 'q: 'r,
  121. {
  122. wrap: S,
  123. col: C,
  124. given: &'r G,
  125. _ghost: PhantomData<&'q ()>,
  126. }
  127. impl<'r, 'q, S: Settable<'r, 'q>, C: EntityColumn, G: Modelable + ?Sized> StaticVersion
  128. for Set<'r, 'q, S, C, G>
  129. where
  130. <S as StaticVersion>::Is: Settable<'static, 'static>,
  131. {
  132. type Is = Set<'static, 'static, <S as StaticVersion>::Is, C, u64>;
  133. }
  134. impl<'r, 'q, S: Settable<'r, 'q>, C: EntityColumn, G: Modelable + ?Sized> QueryComponent
  135. for Set<'r, 'q, S, C, G>
  136. where
  137. <S as StaticVersion>::Is: Settable<'static, 'static>,
  138. 'q: 'r,
  139. {
  140. fn derive(&self) -> DerivedQuery {
  141. self.wrap
  142. .derive()
  143. .add(QueryPart::Set, format!("`{}` = ?", self.col.name()))
  144. }
  145. fn contribute<H: Hasher>(&self, hasher: &mut H) {
  146. self.wrap.contribute(hasher);
  147. std::any::TypeId::of::<Self::Is>().hash(hasher);
  148. std::any::TypeId::of::<C>().hash(hasher);
  149. }
  150. fn bind(&self, stmt: &mut sqlite::Statement<'_>) -> Result<usize, crate::Error> {
  151. let next_index = self.wrap.bind(stmt)?;
  152. self.given.bind_to(stmt, next_index)?;
  153. Ok(next_index + 1)
  154. }
  155. }
  156. impl<'r, 'q, S: Settable<'r, 'q>, C: EntityColumn, G: Modelable + ?Sized> Resolvable<'r, 'q>
  157. for Set<'r, 'q, S, C, G>
  158. where
  159. <S as StaticVersion>::Is: Settable<'static, 'static>,
  160. 'q: 'r,
  161. {
  162. type Output = ();
  163. fn qi(&self) -> &'r QueryInterface<'q> {
  164. self.wrap.qi()
  165. }
  166. }
  167. impl<'r, 'q, S: Settable<'r, 'q>, C: EntityColumn, G: Modelable + ?Sized> Settable<'r, 'q>
  168. for Set<'r, 'q, S, C, G>
  169. where
  170. <S as StaticVersion>::Is: Settable<'static, 'static>,
  171. 'q: 'r,
  172. {
  173. type Table = C::Entity;
  174. }
  175. impl<'r, 'q, S: Settable<'r, 'q>, C: EntityColumn, G: Modelable + ?Sized> Filterable<'r, 'q>
  176. for Set<'r, 'q, S, C, G>
  177. where
  178. <S as StaticVersion>::Is: Settable<'static, 'static>,
  179. 'q: 'r,
  180. {
  181. type Output = ();
  182. type Table = C::Entity;
  183. }
  184. #[cfg(test)]
  185. mod test {
  186. use crate::prelude::*;
  187. use crate::query::Resolvable;
  188. use crate::test_support::KVStore;
  189. #[test]
  190. fn simple_update() {
  191. let db = crate::DB::new_in_memory(crate::Schema::new().entity::<KVStore>()).unwrap();
  192. let qi = db.query_interface();
  193. qi.add(&KVStore {
  194. key: "key".into(),
  195. value: "value".into(),
  196. })
  197. .unwrap();
  198. qi.add(&KVStore {
  199. key: "key2".into(),
  200. value: "value2".into(),
  201. })
  202. .unwrap();
  203. qi.add(&KVStore {
  204. key: "key2".into(),
  205. value: "value2b".into(),
  206. })
  207. .unwrap();
  208. assert_eq!(
  209. qi.get()
  210. .by(KVStore::Key, "key")
  211. .one()
  212. .unwrap()
  213. .unwrap()
  214. .value,
  215. "value"
  216. );
  217. assert_eq!(qi.get().by(KVStore::Key, "key2").all().unwrap().len(), 2);
  218. qi.update()
  219. .update(KVStore::Value, "newvalue")
  220. .by(KVStore::Key, "key")
  221. .exec()
  222. .unwrap();
  223. assert_eq!(
  224. qi.get()
  225. .by(KVStore::Key, "key")
  226. .one()
  227. .unwrap()
  228. .unwrap()
  229. .value,
  230. "newvalue"
  231. );
  232. }
  233. #[test]
  234. fn swapout() {
  235. let db = crate::DB::new_in_memory(crate::Schema::new().entity::<KVStore>()).unwrap();
  236. let qi = db.query_interface();
  237. let id = qi
  238. .add(&KVStore {
  239. key: "a".into(),
  240. value: "b".into(),
  241. })
  242. .unwrap();
  243. let check = qi.get().by(KVStore::ID, &id).all().unwrap();
  244. assert_eq!(check.len(), 1);
  245. assert_eq!(check[0].key, "a");
  246. assert_eq!(check[0].value, "b");
  247. qi.update()
  248. .to(&KVStore {
  249. key: "c".into(),
  250. value: "d".into(),
  251. })
  252. .by(KVStore::ID, &id)
  253. .exec()
  254. .unwrap();
  255. let check = qi.get().by(KVStore::ID, &id).all().unwrap();
  256. assert_eq!(check.len(), 1);
  257. assert_eq!(check[0].key, "c");
  258. assert_eq!(check[0].value, "d");
  259. let check = qi.get().by_id(&id).all().unwrap();
  260. assert_eq!(check.len(), 1);
  261. assert_eq!(check[0].key, "c");
  262. assert_eq!(check[0].value, "d");
  263. }
  264. }