Kestrel 1 рік тому
батько
коміт
d678c99300
7 змінених файлів з 292 додано та 138 видалено
  1. 50 31
      src/cli.rs
  2. 2 3
      src/config.rs
  3. 111 50
      src/config/helper.rs
  4. 29 13
      src/jwt.rs
  5. 1 1
      src/server.rs
  6. 51 23
      src/server/oidc.rs
  7. 48 17
      src/token.rs

+ 50 - 31
src/cli.rs

@@ -1,4 +1,8 @@
-use crate::{cert, client_management, schema::{self, schema}, server, user_management, config, token};
+use crate::{
+    cert, client_management, config,
+    schema::{self, schema},
+    server, token, user_management,
+};
 use clap::{Parser, Subcommand};
 use microrm::prelude::*;
 
@@ -145,13 +149,8 @@ impl ClientArgs {
 #[derive(Debug, Subcommand)]
 enum ConfigCommand {
     Dump,
-    Load {
-        toml_path: String
-    },
-    Set {
-        key: String,
-        value: String,
-    }
+    Load { toml_path: String },
+    Set { key: String, value: String },
 }
 
 #[derive(Debug, Parser)]
@@ -167,10 +166,10 @@ impl ConfigArgs {
                 let qi = db.query_interface();
                 let config = config::Config::build_from(&qi, None);
                 println!("config: {:?}", config);
-            },
+            }
             ConfigCommand::Set { key, value } => {
                 todo!()
-            },
+            }
             ConfigCommand::Load { toml_path } => {
                 let config = {
                     let qi = db.query_interface();
@@ -181,7 +180,7 @@ impl ConfigArgs {
                     config.save(&qi);
                     drop(config);
                 }
-            },
+            }
         }
     }
 }
@@ -207,10 +206,7 @@ struct ServerArgs {
 
 impl ServerArgs {
     async fn run(&self, root: &RootArgs, db: microrm::DB) {
-        let config = config::Config::build_from(
-            &db.query_interface(),
-            None
-        );
+        let config = config::Config::build_from(&db.query_interface(), None);
         server::run_server(db, config, self.port.unwrap_or(2114)).await
     }
 }
@@ -223,7 +219,7 @@ enum TokenCommand {
         #[clap(short, long)]
         username: String,
         #[clap(short, long)]
-        scopes: String
+        scopes: String,
     },
     GenerateRefresh {
         #[clap(short, long)]
@@ -231,9 +227,11 @@ enum TokenCommand {
         #[clap(short, long)]
         username: String,
         #[clap(short, long)]
-        scopes: String
+        scopes: String,
+    },
+    Inspect {
+        token: String,
     },
-    Inspect { token: String },
 }
 
 #[derive(Debug, Parser)]
@@ -246,30 +244,51 @@ impl TokenArgs {
     async fn run(&self, root: &RootArgs, db: microrm::DB) {
         let config = config::Config::build_from(&db.query_interface(), None);
         match &self.command {
-            TokenCommand::GenerateAuth { client, username, scopes } => {
+            TokenCommand::GenerateAuth {
+                client,
+                username,
+                scopes,
+            } => {
                 let qi = db.query_interface();
-                let realm_id = qi.get().by(schema::Realm::Shortname, &root.realm).one().unwrap().expect("no such realm").id();
+                let realm_id = qi
+                    .get()
+                    .by(schema::Realm::Shortname, &root.realm)
+                    .one()
+                    .unwrap()
+                    .expect("no such realm")
+                    .id();
                 let token = token::generate_auth_token(
                     &config,
                     &qi,
                     realm_id,
-                    qi.get().by(schema::Client::Realm, &realm_id).by(schema::Client::Shortname, client.as_str()).one().unwrap().expect("no such client").id(),
-                    qi.get().by(schema::User::Realm, &realm_id).by(schema::User::Username, username.as_str()).one().unwrap().expect("no such user").id(),
+                    qi.get()
+                        .by(schema::Client::Realm, &realm_id)
+                        .by(schema::Client::Shortname, client.as_str())
+                        .one()
+                        .unwrap()
+                        .expect("no such client")
+                        .id(),
+                    qi.get()
+                        .by(schema::User::Realm, &realm_id)
+                        .by(schema::User::Username, username.as_str())
+                        .one()
+                        .unwrap()
+                        .expect("no such user")
+                        .id(),
                     scopes.split_whitespace(),
                 );
                 if let Some(t) = token {
                     println!("token: {t}");
-                }
-                else {
+                } else {
                     println!("Could not generate token");
                 }
-            },
-            TokenCommand::GenerateRefresh { client, username, scopes } => {
-                
-            },
-            TokenCommand::Inspect { token } => {
-                
-            },
+            }
+            TokenCommand::GenerateRefresh {
+                client,
+                username,
+                scopes,
+            } => {}
+            TokenCommand::Inspect { token } => {}
         }
     }
 }

+ 2 - 3
src/config.rs

@@ -43,8 +43,7 @@ impl Config {
                         log::trace!("using config key {} from TOML config...", val.0);
                         if val.1.is_str() {
                             config_map.insert(val.0, val.1.as_str().unwrap().to_string());
-                        }
-                        else {
+                        } else {
                             config_map.insert(val.0, val.1.to_string());
                         }
                     }
@@ -71,7 +70,7 @@ impl Config {
         let ser = helper::ConfigSerializer {
             config: &self,
             qi: &qi,
-            prefix: String::new()
+            prefix: String::new(),
         };
 
         self.serialize(&ser);

+ 111 - 50
src/config/helper.rs

@@ -4,7 +4,7 @@ use crate::schema;
 
 use super::Config;
 
-struct ValueToStringSerializer { }
+struct ValueToStringSerializer {}
 
 impl<'l> serde::Serializer for &'l ValueToStringSerializer {
     type Ok = Option<String>;
@@ -18,7 +18,9 @@ impl<'l> serde::Serializer for &'l ValueToStringSerializer {
     type SerializeTupleVariant = serde::ser::Impossible<Self::Ok, Self::Error>;
     type SerializeStructVariant = serde::ser::Impossible<Self::Ok, Self::Error>;
 
-    fn serialize_bool(self, v: bool) -> Result<Self::Ok, Self::Error> { unreachable!() }
+    fn serialize_bool(self, v: bool) -> Result<Self::Ok, Self::Error> {
+        unreachable!()
+    }
 
     fn serialize_i8(self, v: i8) -> Result<Self::Ok, Self::Error> {
         self.serialize_i64(v as i64)
@@ -32,7 +34,9 @@ impl<'l> serde::Serializer for &'l ValueToStringSerializer {
     fn serialize_i64(self, v: i64) -> Result<Self::Ok, Self::Error> {
         Ok(Some(v.to_string()))
     }
-    fn serialize_i128(self, v: i128) -> Result<Self::Ok, Self::Error> { unreachable!() }
+    fn serialize_i128(self, v: i128) -> Result<Self::Ok, Self::Error> {
+        unreachable!()
+    }
 
     fn serialize_u8(self, v: u8) -> Result<Self::Ok, Self::Error> {
         self.serialize_u64(v as u64)
@@ -46,53 +50,81 @@ impl<'l> serde::Serializer for &'l ValueToStringSerializer {
     fn serialize_u64(self, v: u64) -> Result<Self::Ok, Self::Error> {
         Ok(Some(v.to_string()))
     }
-    fn serialize_u128(self, v: u128) -> Result<Self::Ok, Self::Error> { unreachable!() }
+    fn serialize_u128(self, v: u128) -> Result<Self::Ok, Self::Error> {
+        unreachable!()
+    }
 
-    fn serialize_f32(self, v: f32) -> Result<Self::Ok, Self::Error> { unreachable!() }
-    fn serialize_f64(self, v: f64) -> Result<Self::Ok, Self::Error> { unreachable!() }
+    fn serialize_f32(self, v: f32) -> Result<Self::Ok, Self::Error> {
+        unreachable!()
+    }
+    fn serialize_f64(self, v: f64) -> Result<Self::Ok, Self::Error> {
+        unreachable!()
+    }
 
-    fn serialize_char(self, v: char) -> Result<Self::Ok, Self::Error> { unreachable!() }
+    fn serialize_char(self, v: char) -> Result<Self::Ok, Self::Error> {
+        unreachable!()
+    }
 
-    fn serialize_bytes(self, v: &[u8]) -> Result<Self::Ok, Self::Error> { unreachable!() }
+    fn serialize_bytes(self, v: &[u8]) -> Result<Self::Ok, Self::Error> {
+        unreachable!()
+    }
 
     fn serialize_str(self, v: &str) -> Result<Self::Ok, Self::Error> {
         Ok(Some(v.into()))
     }
 
-    fn serialize_none(self) -> Result<Self::Ok, Self::Error> { Ok(None) }
+    fn serialize_none(self) -> Result<Self::Ok, Self::Error> {
+        Ok(None)
+    }
 
     fn serialize_some<T: ?Sized>(self, value: &T) -> Result<Self::Ok, Self::Error>
     where
-        T: serde::Serialize {
-
+        T: serde::Serialize,
+    {
         value.serialize(self)
     }
 
-    fn serialize_unit(self) -> Result<Self::Ok, Self::Error> { unreachable!() }
+    fn serialize_unit(self) -> Result<Self::Ok, Self::Error> {
+        unreachable!()
+    }
 
-    fn serialize_seq(self, _len: Option<usize>) -> Result<Self::SerializeSeq, Self::Error> { unreachable!() }
-    fn serialize_map(self, _len: Option<usize>) -> Result<Self::SerializeMap, Self::Error> { unreachable!() }
-    fn serialize_tuple(self, _len: usize) -> Result<Self::SerializeTuple, Self::Error> { unreachable!() }
+    fn serialize_seq(self, _len: Option<usize>) -> Result<Self::SerializeSeq, Self::Error> {
+        unreachable!()
+    }
+    fn serialize_map(self, _len: Option<usize>) -> Result<Self::SerializeMap, Self::Error> {
+        unreachable!()
+    }
+    fn serialize_tuple(self, _len: usize) -> Result<Self::SerializeTuple, Self::Error> {
+        unreachable!()
+    }
     fn serialize_struct(
         self,
         _name: &'static str,
         _len: usize,
-    ) -> Result<Self::SerializeStruct, Self::Error> { unreachable!() }
+    ) -> Result<Self::SerializeStruct, Self::Error> {
+        unreachable!()
+    }
 
-    fn serialize_unit_struct(self, name: &'static str) -> Result<Self::Ok, Self::Error> { unreachable!() }
+    fn serialize_unit_struct(self, name: &'static str) -> Result<Self::Ok, Self::Error> {
+        unreachable!()
+    }
 
     fn serialize_unit_variant(
         self,
         name: &'static str,
         variant_index: u32,
         variant: &'static str,
-    ) -> Result<Self::Ok, Self::Error> { unreachable!() }
+    ) -> Result<Self::Ok, Self::Error> {
+        unreachable!()
+    }
 
     fn serialize_tuple_struct(
         self,
         name: &'static str,
         len: usize,
-    ) -> Result<Self::SerializeTupleStruct, Self::Error> { unreachable!() }
+    ) -> Result<Self::SerializeTupleStruct, Self::Error> {
+        unreachable!()
+    }
 
     fn serialize_tuple_variant(
         self,
@@ -110,7 +142,8 @@ impl<'l> serde::Serializer for &'l ValueToStringSerializer {
         value: &T,
     ) -> Result<Self::Ok, Self::Error>
     where
-        T: serde::Serialize {
+        T: serde::Serialize,
+    {
         todo!()
     }
 
@@ -132,8 +165,8 @@ impl<'l> serde::Serializer for &'l ValueToStringSerializer {
         value: &T,
     ) -> Result<Self::Ok, Self::Error>
     where
-        T: serde::Serialize {
-
+        T: serde::Serialize,
+    {
         todo!()
     }
 }
@@ -146,15 +179,21 @@ pub struct ConfigSerializer<'r, 's, 'l> {
 
 impl<'r, 's, 'l> ConfigSerializer<'r, 's, 'l> {
     fn update(&self, key: &str, value: String) {
-        self.qi.delete().by(schema::PersistentConfig::Key, key).exec().expect("couldn't update config");
-        self.qi.add(&schema::PersistentConfig {
-            key: key.into(),
-            value
-        }).expect("couldn't update config");
+        self.qi
+            .delete()
+            .by(schema::PersistentConfig::Key, key)
+            .exec()
+            .expect("couldn't update config");
+        self.qi
+            .add(&schema::PersistentConfig {
+                key: key.into(),
+                value,
+            })
+            .expect("couldn't update config");
     }
 }
 
-impl <'r, 's, 'l> serde::ser::SerializeStruct for ConfigSerializer<'r, 's, 'l> {
+impl<'r, 's, 'l> serde::ser::SerializeStruct for ConfigSerializer<'r, 's, 'l> {
     type Ok = ();
     type Error = ConfigError;
 
@@ -164,11 +203,10 @@ impl <'r, 's, 'l> serde::ser::SerializeStruct for ConfigSerializer<'r, 's, 'l> {
         value: &T,
     ) -> Result<(), Self::Error>
     where
-        T: serde::Serialize {
-
+        T: serde::Serialize,
+    {
         let key = format!("{}{}", self.prefix, key);
 
-
         let value = value.serialize(&ValueToStringSerializer {})?;
         if let Some(value) = value {
             log::trace!("saving config {} = {}", key, value);
@@ -211,7 +249,9 @@ impl<'r, 's, 'l> serde::Serializer for &'r ConfigSerializer<'r, 's, 'l> {
     fn serialize_i64(self, v: i64) -> Result<Self::Ok, Self::Error> {
         todo!()
     }
-    fn serialize_i128(self, v: i128) -> Result<Self::Ok, Self::Error> { unreachable!() }
+    fn serialize_i128(self, v: i128) -> Result<Self::Ok, Self::Error> {
+        unreachable!()
+    }
 
     fn serialize_u8(self, v: u8) -> Result<Self::Ok, Self::Error> {
         self.serialize_u64(v as u64)
@@ -225,12 +265,20 @@ impl<'r, 's, 'l> serde::Serializer for &'r ConfigSerializer<'r, 's, 'l> {
     fn serialize_u64(self, v: u64) -> Result<Self::Ok, Self::Error> {
         todo!()
     }
-    fn serialize_u128(self, v: u128) -> Result<Self::Ok, Self::Error> { unreachable!() }
+    fn serialize_u128(self, v: u128) -> Result<Self::Ok, Self::Error> {
+        unreachable!()
+    }
 
-    fn serialize_f32(self, v: f32) -> Result<Self::Ok, Self::Error> { unreachable!() }
-    fn serialize_f64(self, v: f64) -> Result<Self::Ok, Self::Error> { unreachable!() }
+    fn serialize_f32(self, v: f32) -> Result<Self::Ok, Self::Error> {
+        unreachable!()
+    }
+    fn serialize_f64(self, v: f64) -> Result<Self::Ok, Self::Error> {
+        unreachable!()
+    }
 
-    fn serialize_char(self, v: char) -> Result<Self::Ok, Self::Error> { unreachable!() }
+    fn serialize_char(self, v: char) -> Result<Self::Ok, Self::Error> {
+        unreachable!()
+    }
 
     fn serialize_bytes(self, v: &[u8]) -> Result<Self::Ok, Self::Error> {
         todo!()
@@ -246,16 +294,24 @@ impl<'r, 's, 'l> serde::Serializer for &'r ConfigSerializer<'r, 's, 'l> {
 
     fn serialize_some<T: ?Sized>(self, value: &T) -> Result<Self::Ok, Self::Error>
     where
-        T: serde::Serialize {
-
+        T: serde::Serialize,
+    {
         todo!()
     }
 
-    fn serialize_unit(self) -> Result<Self::Ok, Self::Error> { unreachable!() }
+    fn serialize_unit(self) -> Result<Self::Ok, Self::Error> {
+        unreachable!()
+    }
 
-    fn serialize_seq(self, len: Option<usize>) -> Result<Self::SerializeSeq, Self::Error> { todo!() }
-    fn serialize_map(self, len: Option<usize>) -> Result<Self::SerializeMap, Self::Error> { todo!() }
-    fn serialize_tuple(self, len: usize) -> Result<Self::SerializeTuple, Self::Error> { todo!() }
+    fn serialize_seq(self, len: Option<usize>) -> Result<Self::SerializeSeq, Self::Error> {
+        todo!()
+    }
+    fn serialize_map(self, len: Option<usize>) -> Result<Self::SerializeMap, Self::Error> {
+        todo!()
+    }
+    fn serialize_tuple(self, len: usize) -> Result<Self::SerializeTuple, Self::Error> {
+        todo!()
+    }
     fn serialize_struct(
         self,
         name: &'static str,
@@ -263,7 +319,7 @@ impl<'r, 's, 'l> serde::Serializer for &'r ConfigSerializer<'r, 's, 'l> {
     ) -> Result<Self::SerializeStruct, Self::Error> {
         log::trace!("name: {name}");
 
-        let new_prefix = 
+        let new_prefix =
             // are we at the root?
             if name == "Config" {
                 String::new()
@@ -277,7 +333,7 @@ impl<'r, 's, 'l> serde::Serializer for &'r ConfigSerializer<'r, 's, 'l> {
             qi: self.qi,
             prefix: new_prefix,
         };
-        
+
         Ok(subser)
     }
 
@@ -318,7 +374,8 @@ impl<'r, 's, 'l> serde::Serializer for &'r ConfigSerializer<'r, 's, 'l> {
         value: &T,
     ) -> Result<Self::Ok, Self::Error>
     where
-        T: serde::Serialize {
+        T: serde::Serialize,
+    {
         todo!()
     }
 
@@ -340,8 +397,8 @@ impl<'r, 's, 'l> serde::Serializer for &'r ConfigSerializer<'r, 's, 'l> {
         value: &T,
     ) -> Result<Self::Ok, Self::Error>
     where
-        T: serde::Serialize {
-
+        T: serde::Serialize,
+    {
         todo!()
     }
 }
@@ -503,9 +560,13 @@ impl<'de> serde::Deserializer<'de> for AtomicForwarder<'de> {
 
     fn deserialize_u64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
     where
-        V: serde::de::Visitor<'de> {
-
-        visitor.visit_u64(self.to_fwd.parse().map_err(|_| ConfigError::InvalidType(String::new()))?)
+        V: serde::de::Visitor<'de>,
+    {
+        visitor.visit_u64(
+            self.to_fwd
+                .parse()
+                .map_err(|_| ConfigError::InvalidType(String::new()))?,
+        )
     }
 
     fn deserialize_str<V>(self, visitor: V) -> Result<V::Value, Self::Error>

+ 29 - 13
src/jwt.rs

@@ -25,41 +25,52 @@ pub struct JWT {
 }
 
 impl JWT {
-    pub fn verify<B: std::convert::AsRef<[u8]>>(with: &ring::signature::UnparsedPublicKey<B>, from: &str) -> Option<Self> {
+    pub fn verify<B: std::convert::AsRef<[u8]>>(
+        with: &ring::signature::UnparsedPublicKey<B>,
+        from: &str,
+    ) -> Option<Self> {
         let header_split = from.find(".")?;
         let header = &from[0..header_split];
-        let data_split = header_split + 1 + from[header_split+1..].find(".")?;
-        let data = &from[header_split + 1 .. data_split];
-        let signature = &from[data_split + 1 ..];
+        let data_split = header_split + 1 + from[header_split + 1..].find(".")?;
+        let data = &from[header_split + 1..data_split];
+        let signature = &from[data_split + 1..];
 
         let mut to_verify = vec![];
         to_verify.extend(header.as_bytes());
         to_verify.extend(".".as_bytes());
         to_verify.extend(data.as_bytes());
 
-        let decoded_signature = base64::decode_config(signature.as_bytes(), base64::URL_SAFE_NO_PAD).ok()?;
-        with.verify(to_verify.as_ref(), decoded_signature.as_ref()).ok()?;
+        let decoded_signature =
+            base64::decode_config(signature.as_bytes(), base64::URL_SAFE_NO_PAD).ok()?;
+        with.verify(to_verify.as_ref(), decoded_signature.as_ref())
+            .ok()?;
 
         // if we got this far, the verification passed
         Some(Self {
-            header: header.into(), data: data.into(), signature: signature.into()
+            header: header.into(),
+            data: data.into(),
+            signature: signature.into(),
         })
     }
 
     pub fn sign(with: &ring::signature::Ed25519KeyPair, data: JWTData) -> Self {
         let header = base64::encode_config(DEFAULT_HEADER, base64::URL_SAFE_NO_PAD);
-        let data = base64::encode_config(<JWTData as Into::<String>>::into(data), base64::URL_SAFE_NO_PAD);
+        let data = base64::encode_config(
+            <JWTData as Into<String>>::into(data),
+            base64::URL_SAFE_NO_PAD,
+        );
 
         let mut to_sign = vec![];
         to_sign.extend(header.as_bytes());
         to_sign.extend(".".as_bytes());
         to_sign.extend(data.as_bytes());
-        let signature = base64::encode_config(with.sign(&to_sign).as_ref(), base64::URL_SAFE_NO_PAD);
+        let signature =
+            base64::encode_config(with.sign(&to_sign).as_ref(), base64::URL_SAFE_NO_PAD);
 
         Self {
             header,
             data,
-            signature
+            signature,
         }
     }
 
@@ -75,8 +86,10 @@ mod test {
     #[test]
     fn simple_round_trip() {
         let rng = ring::rand::SystemRandom::new();
-        let kpair_raw = ring::signature::Ed25519KeyPair::generate_pkcs8(&rng).expect("couldn't generate ephemeral keypair");
-        let kpair = ring::signature::Ed25519KeyPair::from_pkcs8(kpair_raw.as_ref()).expect("couldn't load keypair");
+        let kpair_raw = ring::signature::Ed25519KeyPair::generate_pkcs8(&rng)
+            .expect("couldn't generate ephemeral keypair");
+        let kpair = ring::signature::Ed25519KeyPair::from_pkcs8(kpair_raw.as_ref())
+            .expect("couldn't load keypair");
 
         let jdata = super::JWTData {
             sub: "sub",
@@ -90,7 +103,10 @@ mod test {
         let generated = super::JWT::sign(&kpair, jdata).into_string();
         println!("generated: {:?}", generated);
 
-        let pubkey = ring::signature::UnparsedPublicKey::new(&ring::signature::ED25519, kpair.public_key().as_ref());
+        let pubkey = ring::signature::UnparsedPublicKey::new(
+            &ring::signature::ED25519,
+            kpair.public_key().as_ref(),
+        );
 
         let vresult = super::JWT::verify(&pubkey, generated.as_str());
 

+ 1 - 1
src/server.rs

@@ -1,4 +1,4 @@
-use crate::{config,schema};
+use crate::{config, schema};
 use microrm::prelude::*;
 
 mod oidc;

+ 51 - 23
src/server/oidc.rs

@@ -82,9 +82,21 @@ fn do_token_authorize(
     }
     let sauth = sauth.unwrap();
 
-    let scopes = qp.scope.as_ref().map(|slist| slist.as_str()).unwrap_or("").split_whitespace();
-
-    let token = token::generate_auth_token(&request.state().core.config, &qi, shelper.get_realm().unwrap(), client.id(), sauth.user, scopes);
+    let scopes = qp
+        .scope
+        .as_ref()
+        .map(|slist| slist.as_str())
+        .unwrap_or("")
+        .split_whitespace();
+
+    let token = token::generate_auth_token(
+        &request.state().core.config,
+        &qi,
+        shelper.get_realm().unwrap(),
+        client.id(),
+        sauth.user,
+        scopes,
+    );
 
     if let Some(token) = token {
         let response_body = serde_json::json!({
@@ -94,9 +106,12 @@ fn do_token_authorize(
             .content_type(tide::http::mime::JSON)
             .body(response_body)
             .build())
-    }
-    else {
-        Err(OIDCError(OIDCErrorType::ServerError, "internal error while generating token".to_string(), state))
+    } else {
+        Err(OIDCError(
+            OIDCErrorType::ServerError,
+            "internal error while generating token".to_string(),
+            state,
+        ))
     }
 }
 
@@ -181,9 +196,9 @@ async fn token(mut request: Request) -> tide::Result<tide::Response> {
     todo!()
 }
 
-const AUTHORIZE_PATH : &'static str = "oidc/authorize";
-const TOKEN_PATH : &'static str = "oidc/token";
-const JWKS_PATH : &'static str = "oidc/jwks";
+const AUTHORIZE_PATH: &'static str = "oidc/authorize";
+const TOKEN_PATH: &'static str = "oidc/token";
+const JWKS_PATH: &'static str = "oidc/jwks";
 
 async fn jwks(request: Request) -> tide::Result<tide::Response> {
     let qi = request.state().core.pool.query_interface();
@@ -191,18 +206,25 @@ async fn jwks(request: Request) -> tide::Result<tide::Response> {
     let shelper = super::session::SessionHelper::new(&request);
     let realm = shelper.get_realm()?;
 
-    let keyinfo = qi.get().by(schema::Key::Realm, &realm).all().expect("couldn't query db").into_iter().map(|key| {
-        let kpair = ring::signature::Ed25519KeyPair::from_pkcs8(&key.keydata).expect("couldn't parse keypair in db");
-        let pubkey_bytes = kpair.public_key().as_ref();
-        assert_eq!(pubkey_bytes.len(), 32);
-
-        serde_json::json!({
-            "kty": "OKP",
-            "crv": "Ed25519",
-            "x": base64::encode(pubkey_bytes),
-            "kid": key.key_id,
-        })
-    });
+    let keyinfo = qi
+        .get()
+        .by(schema::Key::Realm, &realm)
+        .all()
+        .expect("couldn't query db")
+        .into_iter()
+        .map(|key| {
+            let kpair = ring::signature::Ed25519KeyPair::from_pkcs8(&key.keydata)
+                .expect("couldn't parse keypair in db");
+            let pubkey_bytes = kpair.public_key().as_ref();
+            assert_eq!(pubkey_bytes.len(), 32);
+
+            serde_json::json!({
+                "kty": "OKP",
+                "crv": "Ed25519",
+                "x": base64::encode(pubkey_bytes),
+                "kid": key.key_id,
+            })
+        });
 
     let jwks_response = serde_json::json!({
         "keys": keyinfo.collect::<Vec<_>>(),
@@ -213,7 +235,11 @@ async fn jwks(request: Request) -> tide::Result<tide::Response> {
 
 async fn discovery_config(request: Request) -> tide::Result<tide::Response> {
     let server_config = &request.state().core.config;
-    let base_url = format!("{}/{}", server_config.base_url, request.param("realm").unwrap());
+    let base_url = format!(
+        "{}/{}",
+        server_config.base_url,
+        request.param("realm").unwrap()
+    );
 
     let config_response = serde_json::json!({
         "issuer": base_url,
@@ -233,5 +259,7 @@ pub(super) fn oidc_server(mut route: tide::Route<super::ServerStateWrapper>) {
     route.at(AUTHORIZE_PATH).get(authorize).post(authorize);
     route.at(TOKEN_PATH).post(token);
     route.at(JWKS_PATH).get(jwks);
-    route.at(".well-known/openid-configuration").get(discovery_config);
+    route
+        .at(".well-known/openid-configuration")
+        .get(discovery_config);
 }

+ 48 - 17
src/token.rs

@@ -1,24 +1,37 @@
-use crate::{schema, config, jwt};
-use microrm::{prelude::*, entity::EntityID};
+use crate::{config, jwt, schema};
+use microrm::{entity::EntityID, prelude::*};
 
-pub fn generate_auth_token<'a>(config: &config::Config, qi: &microrm::QueryInterface, realm: schema::RealmID, client: schema::ClientID, user: schema::UserID, scopes: impl Iterator<Item = &'a str>) -> Option<String> {
+pub fn generate_auth_token<'a>(
+    config: &config::Config,
+    qi: &microrm::QueryInterface,
+    realm: schema::RealmID,
+    client: schema::ClientID,
+    user: schema::UserID,
+    scopes: impl Iterator<Item = &'a str>,
+) -> Option<String> {
     let realm = qi.get().by_id(&realm).one().ok()??;
     let client = qi.get().by_id(&client).one().ok()??;
     let user = qi.get().by_id(&user).one().ok()??;
 
-    let issuer = format!(
-        "{}/{}",
-        config.base_url,
-        realm.shortname,
-    );
+    let issuer = format!("{}/{}", config.base_url, realm.shortname,);
 
     let iat = std::time::SystemTime::now();
     let exp = iat + std::time::Duration::from_secs(config.auth_token_expiry);
 
     // find all roles the user can possibly have access to
     let mut user_roles = vec![];
-    for group in qi.get().by(schema::GroupMembership::User, &user.id()).all().ok()? {
-        for group_role in qi.get().by(schema::GroupRole::Group, &group.id()).all().ok()? {
+    for group in qi
+        .get()
+        .by(schema::GroupMembership::User, &user.id())
+        .all()
+        .ok()?
+    {
+        for group_role in qi
+            .get()
+            .by(schema::GroupRole::Group, &group.id())
+            .all()
+            .ok()?
+        {
             user_roles.push(group_role.role);
         }
     }
@@ -27,7 +40,12 @@ pub fn generate_auth_token<'a>(config: &config::Config, qi: &microrm::QueryInter
     let mut requested_roles = vec![];
     for scope in scopes {
         if let Some(scope) = qi.get().by(schema::Scope::Shortname, scope).one().ok()? {
-            for scope_role in qi.get().by(schema::ScopeRole::Scope, &scope.id()).all().ok()? {
+            for scope_role in qi
+                .get()
+                .by(schema::ScopeRole::Scope, &scope.id())
+                .all()
+                .ok()?
+            {
                 requested_roles.push(scope_role.role);
             }
         }
@@ -39,9 +57,20 @@ pub fn generate_auth_token<'a>(config: &config::Config, qi: &microrm::QueryInter
     requested_roles.dedup();
 
     // find the intersection between requested roles and the ones the user actually has
-    let resulting_roles = requested_roles.iter().filter(|req| user_roles.contains(req)).map(|role_id| {
-        serde_json::Value::String(qi.get().by_id(role_id).one().unwrap().unwrap().shortname.clone())
-    });
+    let resulting_roles = requested_roles
+        .iter()
+        .filter(|req| user_roles.contains(req))
+        .map(|role_id| {
+            serde_json::Value::String(
+                qi.get()
+                    .by_id(role_id)
+                    .one()
+                    .unwrap()
+                    .unwrap()
+                    .shortname
+                    .clone(),
+            )
+        });
 
     let token = jwt::JWTData {
         sub: user.username.as_str(),
@@ -49,9 +78,11 @@ pub fn generate_auth_token<'a>(config: &config::Config, qi: &microrm::QueryInter
         aud: client.shortname.as_str(),
         iat: iat.duration_since(std::time::UNIX_EPOCH).unwrap().as_secs(),
         exp: exp.duration_since(std::time::UNIX_EPOCH).unwrap().as_secs(),
-        extras: [
-            ("roles", serde_json::Value::Array(resulting_roles.collect::<Vec<_>>())),
-        ].into(),
+        extras: [(
+            "roles",
+            serde_json::Value::Array(resulting_roles.collect::<Vec<_>>()),
+        )]
+        .into(),
     };
 
     let key = qi.get().by(schema::Key::Realm, &realm.id()).one().ok()??;