config.rs 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268
  1. use crate::schema;
  2. use microrm::prelude::*;
  3. use serde::{Deserialize, Serialize};
  4. fn default_auth_token_expiry() -> u64 {
  5. 600
  6. }
  7. #[derive(Serialize, Deserialize)]
  8. pub struct ServerConfig {
  9. pub base_url: String,
  10. #[serde(default = "default_auth_token_expiry")]
  11. pub auth_token_expiry: u64,
  12. }
  13. impl ServerConfig {
  14. pub fn build_from(qi: &microrm::QueryInterface, cfile: Option<&str>) -> Self {
  15. let mut config_map = std::collections::HashMap::<String, String>::new();
  16. // load config keys from query interface
  17. let db_pcs = qi
  18. .get::<schema::PersistentConfig>()
  19. .all()
  20. .expect("couldn't get config keys from database");
  21. config_map.extend(db_pcs.into_iter().map(|pc| {
  22. let pc = pc.wrapped();
  23. (pc.key, pc.value)
  24. }));
  25. if let Some(path) = cfile {
  26. match std::fs::read(&path) {
  27. Ok(data) => {
  28. log::info!("Loading config from {path}...");
  29. let toml_table: toml::Table = toml::from_str(
  30. std::str::from_utf8(data.as_slice())
  31. .expect("couldn't read config file contents as utf-8"),
  32. )
  33. .expect("couldn't parse config toml");
  34. }
  35. Err(e) => {
  36. log::error!("Could not open {path} for reading: {e}");
  37. }
  38. }
  39. }
  40. let mut deser = ConfigDeserializer {
  41. config_map: &config_map,
  42. prefix: "".to_string(),
  43. };
  44. let config = ServerConfig::deserialize(&mut deser).expect("couldn't load configuration");
  45. config
  46. }
  47. }
  48. struct ConfigDeserializer<'de> {
  49. config_map: &'de std::collections::HashMap<String, String>,
  50. prefix: String,
  51. }
  52. #[derive(Debug)]
  53. enum ConfigError {
  54. Missing(String),
  55. InvalidType(String),
  56. CustomError(String),
  57. }
  58. impl std::fmt::Display for ConfigError {
  59. fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  60. match self {
  61. Self::Missing(what) => f.write_fmt(format_args!(
  62. "Missing required config entry: {}",
  63. what.as_str()
  64. )),
  65. Self::InvalidType(what) => f.write_fmt(format_args!(
  66. "Could not parse config entry '{}'",
  67. what.as_str()
  68. )),
  69. Self::CustomError(what) => {
  70. f.write_fmt(format_args!("Custom error '{}'", what.as_str()))
  71. }
  72. }
  73. }
  74. }
  75. impl std::error::Error for ConfigError {}
  76. impl serde::de::Error for ConfigError {
  77. fn custom<T>(msg: T) -> Self
  78. where
  79. T: std::fmt::Display,
  80. {
  81. Self::CustomError(msg.to_string())
  82. }
  83. fn invalid_type(_unexp: serde::de::Unexpected, _exp: &dyn serde::de::Expected) -> Self {
  84. Self::InvalidType("".into())
  85. }
  86. fn missing_field(field: &'static str) -> Self {
  87. Self::Missing(field.into())
  88. }
  89. }
  90. impl<'de> serde::Deserializer<'de> for &'de mut ConfigDeserializer<'de> {
  91. type Error = ConfigError;
  92. fn deserialize_any<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
  93. where
  94. V: serde::de::Visitor<'de>,
  95. {
  96. unreachable!("deserialize_any needs context")
  97. }
  98. fn deserialize_struct<V>(
  99. self,
  100. _name: &'static str,
  101. _fields: &'static [&'static str],
  102. visitor: V,
  103. ) -> Result<V::Value, Self::Error>
  104. where
  105. V: serde::de::Visitor<'de>,
  106. {
  107. self.deserialize_map(visitor)
  108. }
  109. fn deserialize_seq<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
  110. where
  111. V: serde::de::Visitor<'de>,
  112. {
  113. todo!("deserialize_seq")
  114. }
  115. fn deserialize_map<V>(mut self, visitor: V) -> Result<V::Value, Self::Error>
  116. where
  117. V: serde::de::Visitor<'de>,
  118. {
  119. let mut map_access = ConfigDeserializerIterator {
  120. it: self
  121. .config_map
  122. .iter()
  123. .filter(|e| {
  124. e.0.starts_with(&self.prefix) && !e.0[self.prefix.len()..].contains(".")
  125. })
  126. .peekable(),
  127. };
  128. visitor.visit_map(&mut map_access)
  129. }
  130. fn deserialize_enum<V>(
  131. self,
  132. _name: &'static str,
  133. _variants: &'static [&'static str],
  134. _visitor: V,
  135. ) -> Result<V::Value, Self::Error>
  136. where
  137. V: serde::de::Visitor<'de>,
  138. {
  139. todo!("deserialize_enum")
  140. }
  141. fn deserialize_tuple<V>(self, _len: usize, _visitor: V) -> Result<V::Value, Self::Error>
  142. where
  143. V: serde::de::Visitor<'de>,
  144. {
  145. todo!("deserialize_tuple")
  146. }
  147. fn deserialize_tuple_struct<V>(
  148. self,
  149. _name: &'static str,
  150. _len: usize,
  151. _visitor: V,
  152. ) -> Result<V::Value, Self::Error>
  153. where
  154. V: serde::de::Visitor<'de>,
  155. {
  156. todo!("deserialize_tuple_struct")
  157. }
  158. serde::forward_to_deserialize_any!(
  159. i8 u8 i16 u16 i32 u32 i64 u64 i128 u128 str string bytes
  160. bool f32 f64 char byte_buf option unit unit_struct
  161. newtype_struct identifier ignored_any
  162. );
  163. }
  164. struct AtomicForwarder<'de> {
  165. to_fwd: &'de str,
  166. }
  167. impl<'de> serde::Deserializer<'de> for AtomicForwarder<'de> {
  168. type Error = ConfigError;
  169. fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
  170. where
  171. V: serde::de::Visitor<'de>,
  172. {
  173. unreachable!()
  174. }
  175. fn deserialize_str<V>(self, visitor: V) -> Result<V::Value, Self::Error>
  176. where
  177. V: serde::de::Visitor<'de>,
  178. {
  179. visitor.visit_str(self.to_fwd)
  180. }
  181. fn deserialize_string<V>(self, visitor: V) -> Result<V::Value, Self::Error>
  182. where
  183. V: serde::de::Visitor<'de>,
  184. {
  185. visitor.visit_str(self.to_fwd)
  186. }
  187. fn deserialize_identifier<V>(self, visitor: V) -> Result<V::Value, Self::Error>
  188. where
  189. V: serde::de::Visitor<'de>,
  190. {
  191. visitor.visit_str(self.to_fwd)
  192. }
  193. serde::forward_to_deserialize_any!(
  194. i8 u8 i16 u16 i32 u32 i64 u64 i128 u128 bytes
  195. bool f32 f64 char byte_buf unit unit_struct option
  196. newtype_struct ignored_any struct tuple tuple_struct
  197. seq map enum
  198. );
  199. }
  200. struct ConfigDeserializerIterator<'de, I: Iterator<Item = (&'de String, &'de String)>> {
  201. it: std::iter::Peekable<I>,
  202. }
  203. impl<'de, I: Iterator<Item = (&'de String, &'de String)>> serde::de::MapAccess<'de>
  204. for ConfigDeserializerIterator<'de, I>
  205. {
  206. type Error = ConfigError;
  207. fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>, Self::Error>
  208. where
  209. K: serde::de::DeserializeSeed<'de>,
  210. {
  211. if let Some(e) = self.it.peek() {
  212. let de = AtomicForwarder {
  213. to_fwd: e.0.as_str(),
  214. };
  215. Ok(seed.deserialize(de).ok())
  216. } else {
  217. Ok(None)
  218. }
  219. }
  220. fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value, Self::Error>
  221. where
  222. V: serde::de::DeserializeSeed<'de>,
  223. {
  224. let value = self.it.next().unwrap();
  225. let de = AtomicForwarder {
  226. to_fwd: value.1.as_str(),
  227. };
  228. seed.deserialize(de)
  229. .map_err(|e| ConfigError::InvalidType(e.to_string()))
  230. }
  231. }