Переглянути джерело

Added manual token generation and config saving.

Kestrel 1 рік тому
батько
коміт
263f261d8f
11 змінених файлів з 887 додано та 340 видалено
  1. 5 2
      src/cert.rs
  2. 96 5
      src/cli.rs
  3. 71 14
      src/config.rs
  4. 576 0
      src/config/helper.rs
  5. 1 1
      src/jwt.rs
  6. 1 0
      src/main.rs
  7. 1 1
      src/schema.rs
  8. 4 7
      src/server.rs
  9. 0 268
      src/server/config.rs
  10. 71 42
      src/server/oidc.rs
  11. 61 0
      src/token.rs

+ 5 - 2
src/cert.rs

@@ -1,6 +1,6 @@
 use crate::schema;
 use microrm::prelude::*;
-use ring::signature::Ed25519KeyPair;
+use ring::signature::{Ed25519KeyPair, KeyPair};
 use sha2::Digest;
 use std::collections::HashMap;
 
@@ -41,8 +41,11 @@ impl<'a> CertStore<'a> {
 
         let keydata = sign_generated.as_ref().to_owned();
 
+        let loaded_key = Ed25519KeyPair::from_pkcs8(keydata.as_slice()).unwrap();
+        let pubkey = loaded_key.public_key();
+
         let mut key_hasher = sha2::Sha256::new();
-        key_hasher.update(&keydata);
+        key_hasher.update(&pubkey.as_ref());
         let mut key_id = base64::encode(key_hasher.finalize());
         key_id.truncate(16);
 

+ 96 - 5
src/cli.rs

@@ -1,5 +1,6 @@
-use crate::{cert, client_management, schema, server, user_management};
+use crate::{cert, client_management, schema::{self, schema}, server, user_management, config, token};
 use clap::{Parser, Subcommand};
+use microrm::prelude::*;
 
 #[derive(Debug, Parser)]
 #[clap(author, version, about, long_about = None)]
@@ -21,6 +22,7 @@ enum Command {
     Init,
     Cert(CertArgs),
     Client(ClientArgs),
+    Config(ConfigArgs),
     Group(GroupArgs),
     Server(ServerArgs),
     Token(TokenArgs),
@@ -44,6 +46,7 @@ impl RootArgs {
         match &self.command {
             Command::Init => unreachable!(),
             Command::Cert(v) => v.run(&self, storage).await,
+            Command::Config(v) => v.run(&self, storage).await,
             Command::Client(v) => v.run(&self, storage).await,
             Command::Group(v) => v.run(&self, storage).await,
             Command::Server(v) => v.run(&self, storage).await,
@@ -131,6 +134,50 @@ impl ClientArgs {
     }
 }
 
+#[derive(Debug, Subcommand)]
+enum ConfigCommand {
+    Dump,
+    Load {
+        toml_path: String
+    },
+    Set {
+        key: String,
+        value: String,
+    }
+}
+
+#[derive(Debug, Parser)]
+struct ConfigArgs {
+    #[clap(subcommand)]
+    command: ConfigCommand,
+}
+
+impl ConfigArgs {
+    async fn run(&self, root: &RootArgs, db: microrm::DB) {
+        match &self.command {
+            ConfigCommand::Dump => {
+                let qi = db.query_interface();
+                let config = config::Config::build_from(&qi, None);
+                println!("config: {:?}", config);
+            },
+            ConfigCommand::Set { key, value } => {
+                todo!()
+            },
+            ConfigCommand::Load { toml_path } => {
+                let config = {
+                    let qi = db.query_interface();
+                    config::Config::build_from(&qi, Some(toml_path))
+                };
+                {
+                    let qi = db.query_interface();
+                    config.save(&qi);
+                    drop(config);
+                }
+            },
+        }
+    }
+}
+
 #[derive(Debug, Subcommand)]
 enum GroupCommand {}
 
@@ -146,13 +193,15 @@ impl GroupArgs {
 
 #[derive(Debug, Parser)]
 struct ServerArgs {
+    #[clap(short, long)]
     port: Option<u16>,
+    #[clap(short, long)]
     config_path: Option<String>,
 }
 
 impl ServerArgs {
     async fn run(&self, root: &RootArgs, db: microrm::DB) {
-        let config = server::ServerConfig::build_from(
+        let config = config::Config::build_from(
             &db.query_interface(),
             self.config_path.as_ref().map(|x| x.as_str()),
         );
@@ -162,8 +211,22 @@ impl ServerArgs {
 
 #[derive(Debug, Subcommand)]
 enum TokenCommand {
-    GenerateAuth { username: String, scopes: String },
-    GenerateRefresh { username: String, scopes: String },
+    GenerateAuth {
+        #[clap(short, long)]
+        client: String,
+        #[clap(short, long)]
+        username: String,
+        #[clap(short, long)]
+        scopes: String
+    },
+    GenerateRefresh {
+        #[clap(short, long)]
+        client: String,
+        #[clap(short, long)]
+        username: String,
+        #[clap(short, long)]
+        scopes: String
+    },
     Inspect { token: String },
 }
 
@@ -174,7 +237,35 @@ struct TokenArgs {
 }
 
 impl TokenArgs {
-    async fn run(&self, root: &RootArgs, _db: microrm::DB) {}
+    async fn run(&self, root: &RootArgs, db: microrm::DB) {
+        let config = config::Config::build_from(&db.query_interface(), None);
+        match &self.command {
+            TokenCommand::GenerateAuth { client, username, scopes } => {
+                let qi = db.query_interface();
+                let realm_id = qi.get().by(schema::Realm::Shortname, &root.realm).one().unwrap().expect("no such realm").id();
+                let token = token::generate_auth_token(
+                    &config,
+                    &qi,
+                    realm_id,
+                    qi.get().by(schema::Client::Realm, &realm_id).by(schema::Client::Shortname, client.as_str()).one().unwrap().expect("no such client").id(),
+                    qi.get().by(schema::User::Realm, &realm_id).by(schema::User::Username, username.as_str()).one().unwrap().expect("no such user").id(),
+                    scopes.split_whitespace(),
+                );
+                if let Some(t) = token {
+                    println!("token: {t}");
+                }
+                else {
+                    println!("Could not generate token");
+                }
+            },
+            TokenCommand::GenerateRefresh { client, username, scopes } => {
+                
+            },
+            TokenCommand::Inspect { token } => {
+                
+            },
+        }
+    }
 }
 
 #[derive(Debug, Parser)]

+ 71 - 14
src/config.rs

@@ -1,22 +1,79 @@
-#![allow(dead_code)]
+use crate::schema;
+use microrm::prelude::*;
+use serde::{Deserialize, Serialize};
 
-use serde::Deserialize;
+mod helper;
 
-#[derive(Deserialize)]
-pub enum Flow {
-    Challenge(crate::schema::AuthChallengeType),
-    OneOf(Vec<Flow>),
-    AllOf(Vec<Flow>),
+fn default_auth_token_expiry() -> u64 {
+    600
 }
 
-#[derive(Deserialize)]
-pub struct AuthConfig {
-    pbkdf2_rounds: usize,
+#[derive(Debug, Serialize, Deserialize)]
+pub struct Config {
+    pub base_url: String,
 
-    login_flow: Flow,
+    #[serde(default = "default_auth_token_expiry")]
+    pub auth_token_expiry: u64,
 }
 
-#[derive(Deserialize)]
-pub struct GlobalConfig {
-    auth: AuthConfig,
+impl Config {
+    pub fn build_from(qi: &microrm::QueryInterface, cfile: Option<&str>) -> Self {
+        let mut config_map = std::collections::HashMap::<String, String>::new();
+        // load config keys from query interface
+        let db_pcs = qi
+            .get::<schema::PersistentConfig>()
+            .all()
+            .expect("couldn't get config keys from database");
+        config_map.extend(db_pcs.into_iter().map(|pc| {
+            let pc = pc.wrapped();
+            (pc.key, pc.value)
+        }));
+
+        if let Some(path) = cfile {
+            match std::fs::read(&path) {
+                Ok(data) => {
+                    log::info!("Loading config from {path}...");
+                    let toml_table: toml::Table = toml::from_str(
+                        std::str::from_utf8(data.as_slice())
+                            .expect("couldn't read config file contents as utf-8"),
+                    )
+                    .expect("couldn't parse config toml");
+
+                    for val in toml_table {
+                        log::trace!("using config key {} from TOML config...", val.0);
+                        if val.1.is_str() {
+                            config_map.insert(val.0, val.1.as_str().unwrap().to_string());
+                        }
+                        else {
+                            config_map.insert(val.0, val.1.to_string());
+                        }
+                    }
+                }
+                Err(e) => {
+                    log::error!("Could not open {path} for reading: {e}");
+                }
+            }
+        }
+
+        let mut deser = helper::ConfigDeserializer {
+            config_map: &config_map,
+            prefix: "".to_string(),
+        };
+
+        let config = Config::deserialize(&mut deser).expect("couldn't load configuration");
+
+        log::trace!("final configuration: {:?}", config);
+
+        config
+    }
+
+    pub fn save<'config, 'qi>(&'config self, qi: &'config microrm::QueryInterface<'qi>) {
+        let ser = helper::ConfigSerializer {
+            config: &self,
+            qi: &qi,
+            prefix: String::new()
+        };
+
+        self.serialize(&ser);
+    }
 }

+ 576 - 0
src/config/helper.rs

@@ -0,0 +1,576 @@
+use microrm::prelude::*;
+
+use crate::schema;
+
+use super::Config;
+
+struct ValueToStringSerializer { }
+
+impl<'l> serde::Serializer for &'l ValueToStringSerializer {
+    type Ok = Option<String>;
+    type Error = ConfigError;
+
+    type SerializeSeq = serde::ser::Impossible<Self::Ok, Self::Error>;
+    type SerializeMap = serde::ser::Impossible<Self::Ok, Self::Error>;
+    type SerializeTuple = serde::ser::Impossible<Self::Ok, Self::Error>;
+    type SerializeStruct = serde::ser::Impossible<Self::Ok, Self::Error>;
+    type SerializeTupleStruct = serde::ser::Impossible<Self::Ok, Self::Error>;
+    type SerializeTupleVariant = serde::ser::Impossible<Self::Ok, Self::Error>;
+    type SerializeStructVariant = serde::ser::Impossible<Self::Ok, Self::Error>;
+
+    fn serialize_bool(self, v: bool) -> Result<Self::Ok, Self::Error> { unreachable!() }
+
+    fn serialize_i8(self, v: i8) -> Result<Self::Ok, Self::Error> {
+        self.serialize_i64(v as i64)
+    }
+    fn serialize_i16(self, v: i16) -> Result<Self::Ok, Self::Error> {
+        self.serialize_i64(v as i64)
+    }
+    fn serialize_i32(self, v: i32) -> Result<Self::Ok, Self::Error> {
+        self.serialize_i64(v as i64)
+    }
+    fn serialize_i64(self, v: i64) -> Result<Self::Ok, Self::Error> {
+        Ok(Some(v.to_string()))
+    }
+    fn serialize_i128(self, v: i128) -> Result<Self::Ok, Self::Error> { unreachable!() }
+
+    fn serialize_u8(self, v: u8) -> Result<Self::Ok, Self::Error> {
+        self.serialize_u64(v as u64)
+    }
+    fn serialize_u16(self, v: u16) -> Result<Self::Ok, Self::Error> {
+        self.serialize_u64(v as u64)
+    }
+    fn serialize_u32(self, v: u32) -> Result<Self::Ok, Self::Error> {
+        self.serialize_u64(v as u64)
+    }
+    fn serialize_u64(self, v: u64) -> Result<Self::Ok, Self::Error> {
+        Ok(Some(v.to_string()))
+    }
+    fn serialize_u128(self, v: u128) -> Result<Self::Ok, Self::Error> { unreachable!() }
+
+    fn serialize_f32(self, v: f32) -> Result<Self::Ok, Self::Error> { unreachable!() }
+    fn serialize_f64(self, v: f64) -> Result<Self::Ok, Self::Error> { unreachable!() }
+
+    fn serialize_char(self, v: char) -> Result<Self::Ok, Self::Error> { unreachable!() }
+
+    fn serialize_bytes(self, v: &[u8]) -> Result<Self::Ok, Self::Error> { unreachable!() }
+
+    fn serialize_str(self, v: &str) -> Result<Self::Ok, Self::Error> {
+        Ok(Some(v.into()))
+    }
+
+    fn serialize_none(self) -> Result<Self::Ok, Self::Error> { Ok(None) }
+
+    fn serialize_some<T: ?Sized>(self, value: &T) -> Result<Self::Ok, Self::Error>
+    where
+        T: serde::Serialize {
+
+        value.serialize(self)
+    }
+
+    fn serialize_unit(self) -> Result<Self::Ok, Self::Error> { unreachable!() }
+
+    fn serialize_seq(self, _len: Option<usize>) -> Result<Self::SerializeSeq, Self::Error> { unreachable!() }
+    fn serialize_map(self, _len: Option<usize>) -> Result<Self::SerializeMap, Self::Error> { unreachable!() }
+    fn serialize_tuple(self, _len: usize) -> Result<Self::SerializeTuple, Self::Error> { unreachable!() }
+    fn serialize_struct(
+        self,
+        _name: &'static str,
+        _len: usize,
+    ) -> Result<Self::SerializeStruct, Self::Error> { unreachable!() }
+
+    fn serialize_unit_struct(self, name: &'static str) -> Result<Self::Ok, Self::Error> { unreachable!() }
+
+    fn serialize_unit_variant(
+        self,
+        name: &'static str,
+        variant_index: u32,
+        variant: &'static str,
+    ) -> Result<Self::Ok, Self::Error> { unreachable!() }
+
+    fn serialize_tuple_struct(
+        self,
+        name: &'static str,
+        len: usize,
+    ) -> Result<Self::SerializeTupleStruct, Self::Error> { unreachable!() }
+
+    fn serialize_tuple_variant(
+        self,
+        name: &'static str,
+        variant_index: u32,
+        variant: &'static str,
+        len: usize,
+    ) -> Result<Self::SerializeTupleVariant, Self::Error> {
+        todo!()
+    }
+
+    fn serialize_newtype_struct<T: ?Sized>(
+        self,
+        name: &'static str,
+        value: &T,
+    ) -> Result<Self::Ok, Self::Error>
+    where
+        T: serde::Serialize {
+        todo!()
+    }
+
+    fn serialize_struct_variant(
+        self,
+        name: &'static str,
+        variant_index: u32,
+        variant: &'static str,
+        len: usize,
+    ) -> Result<Self::SerializeStructVariant, Self::Error> {
+        todo!()
+    }
+
+    fn serialize_newtype_variant<T: ?Sized>(
+        self,
+        name: &'static str,
+        variant_index: u32,
+        variant: &'static str,
+        value: &T,
+    ) -> Result<Self::Ok, Self::Error>
+    where
+        T: serde::Serialize {
+
+        todo!()
+    }
+}
+
+pub struct ConfigSerializer<'r, 's, 'l> {
+    pub config: &'r Config,
+    pub qi: &'s microrm::QueryInterface<'l>,
+    pub prefix: String,
+}
+
+impl<'r, 's, 'l> ConfigSerializer<'r, 's, 'l> {
+    fn update(&self, key: &str, value: String) {
+        self.qi.delete().by(schema::PersistentConfig::Key, key).exec().expect("couldn't update config");
+        self.qi.add(&schema::PersistentConfig {
+            key: key.into(),
+            value
+        }).expect("couldn't update config");
+    }
+}
+
+impl <'r, 's, 'l> serde::ser::SerializeStruct for ConfigSerializer<'r, 's, 'l> {
+    type Ok = ();
+    type Error = ConfigError;
+
+    fn serialize_field<T: ?Sized>(
+        &mut self,
+        key: &'static str,
+        value: &T,
+    ) -> Result<(), Self::Error>
+    where
+        T: serde::Serialize {
+
+        let key = format!("{}{}", self.prefix, key);
+
+
+        let value = value.serialize(&ValueToStringSerializer {})?;
+        if let Some(value) = value {
+            log::trace!("saving config {} = {}", key, value);
+            self.update(key.as_str(), value);
+        }
+
+        Ok(())
+    }
+
+    fn end(self) -> Result<Self::Ok, Self::Error> {
+        Ok(())
+    }
+}
+
+impl<'r, 's, 'l> serde::Serializer for &'r ConfigSerializer<'r, 's, 'l> {
+    type Ok = ();
+    type Error = ConfigError;
+
+    type SerializeSeq = serde::ser::Impossible<Self::Ok, Self::Error>;
+    type SerializeMap = serde::ser::Impossible<Self::Ok, Self::Error>;
+    type SerializeTuple = serde::ser::Impossible<Self::Ok, Self::Error>;
+    type SerializeStruct = ConfigSerializer<'r, 's, 'l>;
+    type SerializeTupleStruct = serde::ser::Impossible<Self::Ok, Self::Error>;
+    type SerializeTupleVariant = serde::ser::Impossible<Self::Ok, Self::Error>;
+    type SerializeStructVariant = serde::ser::Impossible<Self::Ok, Self::Error>;
+
+    fn serialize_bool(self, v: bool) -> Result<Self::Ok, Self::Error> {
+        unreachable!()
+    }
+
+    fn serialize_i8(self, v: i8) -> Result<Self::Ok, Self::Error> {
+        self.serialize_i64(v as i64)
+    }
+    fn serialize_i16(self, v: i16) -> Result<Self::Ok, Self::Error> {
+        self.serialize_i64(v as i64)
+    }
+    fn serialize_i32(self, v: i32) -> Result<Self::Ok, Self::Error> {
+        self.serialize_i64(v as i64)
+    }
+    fn serialize_i64(self, v: i64) -> Result<Self::Ok, Self::Error> {
+        todo!()
+    }
+    fn serialize_i128(self, v: i128) -> Result<Self::Ok, Self::Error> { unreachable!() }
+
+    fn serialize_u8(self, v: u8) -> Result<Self::Ok, Self::Error> {
+        self.serialize_u64(v as u64)
+    }
+    fn serialize_u16(self, v: u16) -> Result<Self::Ok, Self::Error> {
+        self.serialize_u64(v as u64)
+    }
+    fn serialize_u32(self, v: u32) -> Result<Self::Ok, Self::Error> {
+        self.serialize_u64(v as u64)
+    }
+    fn serialize_u64(self, v: u64) -> Result<Self::Ok, Self::Error> {
+        todo!()
+    }
+    fn serialize_u128(self, v: u128) -> Result<Self::Ok, Self::Error> { unreachable!() }
+
+    fn serialize_f32(self, v: f32) -> Result<Self::Ok, Self::Error> { unreachable!() }
+    fn serialize_f64(self, v: f64) -> Result<Self::Ok, Self::Error> { unreachable!() }
+
+    fn serialize_char(self, v: char) -> Result<Self::Ok, Self::Error> { unreachable!() }
+
+    fn serialize_bytes(self, v: &[u8]) -> Result<Self::Ok, Self::Error> {
+        todo!()
+    }
+
+    fn serialize_str(self, v: &str) -> Result<Self::Ok, Self::Error> {
+        todo!()
+    }
+
+    fn serialize_none(self) -> Result<Self::Ok, Self::Error> {
+        Ok(())
+    }
+
+    fn serialize_some<T: ?Sized>(self, value: &T) -> Result<Self::Ok, Self::Error>
+    where
+        T: serde::Serialize {
+
+        todo!()
+    }
+
+    fn serialize_unit(self) -> Result<Self::Ok, Self::Error> { unreachable!() }
+
+    fn serialize_seq(self, len: Option<usize>) -> Result<Self::SerializeSeq, Self::Error> { todo!() }
+    fn serialize_map(self, len: Option<usize>) -> Result<Self::SerializeMap, Self::Error> { todo!() }
+    fn serialize_tuple(self, len: usize) -> Result<Self::SerializeTuple, Self::Error> { todo!() }
+    fn serialize_struct(
+        self,
+        name: &'static str,
+        _len: usize,
+    ) -> Result<Self::SerializeStruct, Self::Error> {
+        log::trace!("name: {name}");
+
+        let new_prefix = 
+            // are we at the root?
+            if name == "Config" {
+                String::new()
+            }
+            else {
+                format!("{}{}.", self.prefix, name)
+            };
+
+        let subser = ConfigSerializer {
+            config: self.config,
+            qi: self.qi,
+            prefix: new_prefix,
+        };
+        
+        Ok(subser)
+    }
+
+    fn serialize_unit_struct(self, name: &'static str) -> Result<Self::Ok, Self::Error> {
+        todo!()
+    }
+
+    fn serialize_unit_variant(
+        self,
+        name: &'static str,
+        variant_index: u32,
+        variant: &'static str,
+    ) -> Result<Self::Ok, Self::Error> {
+        todo!()
+    }
+
+    fn serialize_tuple_struct(
+        self,
+        name: &'static str,
+        len: usize,
+    ) -> Result<Self::SerializeTupleStruct, Self::Error> {
+        todo!()
+    }
+
+    fn serialize_tuple_variant(
+        self,
+        name: &'static str,
+        variant_index: u32,
+        variant: &'static str,
+        len: usize,
+    ) -> Result<Self::SerializeTupleVariant, Self::Error> {
+        todo!()
+    }
+
+    fn serialize_newtype_struct<T: ?Sized>(
+        self,
+        name: &'static str,
+        value: &T,
+    ) -> Result<Self::Ok, Self::Error>
+    where
+        T: serde::Serialize {
+        todo!()
+    }
+
+    fn serialize_struct_variant(
+        self,
+        name: &'static str,
+        variant_index: u32,
+        variant: &'static str,
+        len: usize,
+    ) -> Result<Self::SerializeStructVariant, Self::Error> {
+        todo!()
+    }
+
+    fn serialize_newtype_variant<T: ?Sized>(
+        self,
+        name: &'static str,
+        variant_index: u32,
+        variant: &'static str,
+        value: &T,
+    ) -> Result<Self::Ok, Self::Error>
+    where
+        T: serde::Serialize {
+
+        todo!()
+    }
+}
+
+pub struct ConfigDeserializer<'de> {
+    pub config_map: &'de std::collections::HashMap<String, String>,
+    pub prefix: String,
+}
+
+#[derive(Debug)]
+pub enum ConfigError {
+    Missing(String),
+    InvalidType(String),
+    CustomError(String),
+}
+
+impl std::fmt::Display for ConfigError {
+    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+        match self {
+            Self::Missing(what) => f.write_fmt(format_args!(
+                "Missing required config entry: {}",
+                what.as_str()
+            )),
+            Self::InvalidType(what) => f.write_fmt(format_args!(
+                "Could not parse config entry '{}'",
+                what.as_str()
+            )),
+            Self::CustomError(what) => {
+                f.write_fmt(format_args!("Custom error '{}'", what.as_str()))
+            }
+        }
+    }
+}
+
+impl std::error::Error for ConfigError {}
+
+impl serde::ser::Error for ConfigError {
+    fn custom<T>(msg: T) -> Self
+    where
+        T: std::fmt::Display,
+    {
+        Self::CustomError(msg.to_string())
+    }
+}
+
+impl serde::de::Error for ConfigError {
+    fn custom<T>(msg: T) -> Self
+    where
+        T: std::fmt::Display,
+    {
+        Self::CustomError(msg.to_string())
+    }
+
+    fn invalid_type(_unexp: serde::de::Unexpected, _exp: &dyn serde::de::Expected) -> Self {
+        Self::InvalidType("".into())
+    }
+
+    fn missing_field(field: &'static str) -> Self {
+        Self::Missing(field.into())
+    }
+}
+
+impl<'de> serde::Deserializer<'de> for &'de mut ConfigDeserializer<'de> {
+    type Error = ConfigError;
+
+    fn deserialize_any<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
+    where
+        V: serde::de::Visitor<'de>,
+    {
+        unreachable!("deserialize_any needs context")
+    }
+
+    fn deserialize_struct<V>(
+        self,
+        _name: &'static str,
+        _fields: &'static [&'static str],
+        visitor: V,
+    ) -> Result<V::Value, Self::Error>
+    where
+        V: serde::de::Visitor<'de>,
+    {
+        self.deserialize_map(visitor)
+    }
+
+    fn deserialize_seq<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
+    where
+        V: serde::de::Visitor<'de>,
+    {
+        todo!("deserialize_seq")
+    }
+
+    fn deserialize_map<V>(mut self, visitor: V) -> Result<V::Value, Self::Error>
+    where
+        V: serde::de::Visitor<'de>,
+    {
+        let mut map_access = ConfigDeserializerIterator {
+            it: self
+                .config_map
+                .iter()
+                .filter(|e| {
+                    e.0.starts_with(&self.prefix) && !e.0[self.prefix.len()..].contains(".")
+                })
+                .peekable(),
+        };
+
+        visitor.visit_map(&mut map_access)
+    }
+
+    fn deserialize_enum<V>(
+        self,
+        _name: &'static str,
+        _variants: &'static [&'static str],
+        _visitor: V,
+    ) -> Result<V::Value, Self::Error>
+    where
+        V: serde::de::Visitor<'de>,
+    {
+        todo!("deserialize_enum")
+    }
+
+    fn deserialize_tuple<V>(self, _len: usize, _visitor: V) -> Result<V::Value, Self::Error>
+    where
+        V: serde::de::Visitor<'de>,
+    {
+        todo!("deserialize_tuple")
+    }
+
+    fn deserialize_tuple_struct<V>(
+        self,
+        _name: &'static str,
+        _len: usize,
+        _visitor: V,
+    ) -> Result<V::Value, Self::Error>
+    where
+        V: serde::de::Visitor<'de>,
+    {
+        todo!("deserialize_tuple_struct")
+    }
+
+    serde::forward_to_deserialize_any!(
+        i8 u8 i16 u16 i32 u32 i64 u64 i128 u128 str string bytes
+        bool f32 f64 char byte_buf option unit unit_struct
+        newtype_struct identifier ignored_any
+    );
+}
+
+struct AtomicForwarder<'de> {
+    to_fwd: &'de str,
+}
+
+impl<'de> serde::Deserializer<'de> for AtomicForwarder<'de> {
+    type Error = ConfigError;
+    fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
+    where
+        V: serde::de::Visitor<'de>,
+    {
+        unreachable!()
+    }
+
+    fn deserialize_u64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
+    where
+        V: serde::de::Visitor<'de> {
+
+        visitor.visit_u64(self.to_fwd.parse().map_err(|_| ConfigError::InvalidType(String::new()))?)
+    }
+
+    fn deserialize_str<V>(self, visitor: V) -> Result<V::Value, Self::Error>
+    where
+        V: serde::de::Visitor<'de>,
+    {
+        visitor.visit_str(self.to_fwd)
+    }
+
+    fn deserialize_string<V>(self, visitor: V) -> Result<V::Value, Self::Error>
+    where
+        V: serde::de::Visitor<'de>,
+    {
+        visitor.visit_str(self.to_fwd)
+    }
+
+    fn deserialize_identifier<V>(self, visitor: V) -> Result<V::Value, Self::Error>
+    where
+        V: serde::de::Visitor<'de>,
+    {
+        visitor.visit_str(self.to_fwd)
+    }
+
+    serde::forward_to_deserialize_any!(
+        i8 u8 i16 u16 i32 u32 i64 i128 u128 bytes
+        bool f32 f64 char byte_buf unit unit_struct option
+        newtype_struct ignored_any struct tuple tuple_struct
+        seq map enum
+    );
+}
+
+struct ConfigDeserializerIterator<'de, I: Iterator<Item = (&'de String, &'de String)>> {
+    it: std::iter::Peekable<I>,
+}
+
+impl<'de, I: Iterator<Item = (&'de String, &'de String)>> serde::de::MapAccess<'de>
+    for ConfigDeserializerIterator<'de, I>
+{
+    type Error = ConfigError;
+
+    fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>, Self::Error>
+    where
+        K: serde::de::DeserializeSeed<'de>,
+    {
+        if let Some(e) = self.it.peek() {
+            let de = AtomicForwarder {
+                to_fwd: e.0.as_str(),
+            };
+            Ok(seed.deserialize(de).ok())
+        } else {
+            Ok(None)
+        }
+    }
+
+    fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value, Self::Error>
+    where
+        V: serde::de::DeserializeSeed<'de>,
+    {
+        let value = self.it.next().unwrap();
+
+        let de = AtomicForwarder {
+            to_fwd: value.1.as_str(),
+        };
+
+        seed.deserialize(de)
+            .map_err(|e| ConfigError::InvalidType(e.to_string()))
+    }
+}

+ 1 - 1
src/jwt.rs

@@ -7,7 +7,7 @@ pub struct JWTData<'l> {
     pub exp: u64,
 
     #[serde(flatten)]
-    pub extras: std::collections::HashMap<String, serde_json::Value>,
+    pub extras: std::collections::HashMap<&'l str, serde_json::Value>,
 }
 
 impl<'l> Into<String> for JWTData<'l> {

+ 1 - 0
src/main.rs

@@ -6,6 +6,7 @@ mod jwt;
 mod login;
 mod schema;
 mod server;
+mod token;
 mod user;
 mod user_management;
 

+ 1 - 1
src/schema.rs

@@ -136,7 +136,7 @@ pub struct ScopeRole {
 #[derive(Debug, Entity, Serialize, Deserialize)]
 pub struct GroupRole {
     #[microrm_foreign]
-    pub scope: ScopeID,
+    pub group: GroupID,
     #[microrm_foreign]
     pub role: RoleID,
 }

+ 4 - 7
src/server.rs

@@ -1,14 +1,11 @@
-use crate::schema;
+use crate::{config,schema};
 use microrm::prelude::*;
 
-mod config;
 mod oidc;
 mod session;
 
-pub use config::ServerConfig;
-
 pub struct ServerState {
-    config: ServerConfig,
+    config: config::Config,
     pool: microrm::DBPool<'static>,
     templates: handlebars::Handlebars<'static>,
 }
@@ -53,7 +50,7 @@ async fn index(req: tide::Request<ServerStateWrapper>) -> tide::Result<tide::Res
     Ok(response)
 }
 
-pub async fn run_server(db: microrm::DB, config: ServerConfig, port: u16) {
+pub async fn run_server(db: microrm::DB, config: config::Config, port: u16) {
     let db_box = Box::new(db);
     let db: &'static mut microrm::DB = Box::leak(db_box);
     let pool = microrm::DBPool::new(db);
@@ -90,7 +87,7 @@ pub async fn run_server(db: microrm::DB, config: ServerConfig, port: u16) {
         .expect("Can't serve static files");
 
     session::session_v1_server(app.at("/:realm/v1/session/"));
-    oidc::oidc_v1_server(app.at("/:realm/v1/oidc/"));
+    oidc::oidc_server(app.at("/:realm/"));
 
     app.listen(("127.0.0.1", port)).await.expect("Can listen");
 }

+ 0 - 268
src/server/config.rs

@@ -1,268 +0,0 @@
-use crate::schema;
-use microrm::prelude::*;
-use serde::{Deserialize, Serialize};
-
-fn default_auth_token_expiry() -> u64 {
-    600
-}
-
-#[derive(Serialize, Deserialize)]
-pub struct ServerConfig {
-    pub base_url: String,
-
-    #[serde(default = "default_auth_token_expiry")]
-    pub auth_token_expiry: u64,
-}
-
-impl ServerConfig {
-    pub fn build_from(qi: &microrm::QueryInterface, cfile: Option<&str>) -> Self {
-        let mut config_map = std::collections::HashMap::<String, String>::new();
-        // load config keys from query interface
-        let db_pcs = qi
-            .get::<schema::PersistentConfig>()
-            .all()
-            .expect("couldn't get config keys from database");
-        config_map.extend(db_pcs.into_iter().map(|pc| {
-            let pc = pc.wrapped();
-            (pc.key, pc.value)
-        }));
-
-        if let Some(path) = cfile {
-            match std::fs::read(&path) {
-                Ok(data) => {
-                    log::info!("Loading config from {path}...");
-                    let toml_table: toml::Table = toml::from_str(
-                        std::str::from_utf8(data.as_slice())
-                            .expect("couldn't read config file contents as utf-8"),
-                    )
-                    .expect("couldn't parse config toml");
-                }
-                Err(e) => {
-                    log::error!("Could not open {path} for reading: {e}");
-                }
-            }
-        }
-
-        let mut deser = ConfigDeserializer {
-            config_map: &config_map,
-            prefix: "".to_string(),
-        };
-
-        let config = ServerConfig::deserialize(&mut deser).expect("couldn't load configuration");
-
-        config
-    }
-}
-
-struct ConfigDeserializer<'de> {
-    config_map: &'de std::collections::HashMap<String, String>,
-    prefix: String,
-}
-
-#[derive(Debug)]
-enum ConfigError {
-    Missing(String),
-    InvalidType(String),
-    CustomError(String),
-}
-
-impl std::fmt::Display for ConfigError {
-    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
-        match self {
-            Self::Missing(what) => f.write_fmt(format_args!(
-                "Missing required config entry: {}",
-                what.as_str()
-            )),
-            Self::InvalidType(what) => f.write_fmt(format_args!(
-                "Could not parse config entry '{}'",
-                what.as_str()
-            )),
-            Self::CustomError(what) => {
-                f.write_fmt(format_args!("Custom error '{}'", what.as_str()))
-            }
-        }
-    }
-}
-
-impl std::error::Error for ConfigError {}
-
-impl serde::de::Error for ConfigError {
-    fn custom<T>(msg: T) -> Self
-    where
-        T: std::fmt::Display,
-    {
-        Self::CustomError(msg.to_string())
-    }
-
-    fn invalid_type(_unexp: serde::de::Unexpected, _exp: &dyn serde::de::Expected) -> Self {
-        Self::InvalidType("".into())
-    }
-
-    fn missing_field(field: &'static str) -> Self {
-        Self::Missing(field.into())
-    }
-}
-
-impl<'de> serde::Deserializer<'de> for &'de mut ConfigDeserializer<'de> {
-    type Error = ConfigError;
-
-    fn deserialize_any<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
-    where
-        V: serde::de::Visitor<'de>,
-    {
-        unreachable!("deserialize_any needs context")
-    }
-
-    fn deserialize_struct<V>(
-        self,
-        _name: &'static str,
-        _fields: &'static [&'static str],
-        visitor: V,
-    ) -> Result<V::Value, Self::Error>
-    where
-        V: serde::de::Visitor<'de>,
-    {
-        self.deserialize_map(visitor)
-    }
-
-    fn deserialize_seq<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
-    where
-        V: serde::de::Visitor<'de>,
-    {
-        todo!("deserialize_seq")
-    }
-
-    fn deserialize_map<V>(mut self, visitor: V) -> Result<V::Value, Self::Error>
-    where
-        V: serde::de::Visitor<'de>,
-    {
-        let mut map_access = ConfigDeserializerIterator {
-            it: self
-                .config_map
-                .iter()
-                .filter(|e| {
-                    e.0.starts_with(&self.prefix) && !e.0[self.prefix.len()..].contains(".")
-                })
-                .peekable(),
-        };
-
-        visitor.visit_map(&mut map_access)
-    }
-
-    fn deserialize_enum<V>(
-        self,
-        _name: &'static str,
-        _variants: &'static [&'static str],
-        _visitor: V,
-    ) -> Result<V::Value, Self::Error>
-    where
-        V: serde::de::Visitor<'de>,
-    {
-        todo!("deserialize_enum")
-    }
-
-    fn deserialize_tuple<V>(self, _len: usize, _visitor: V) -> Result<V::Value, Self::Error>
-    where
-        V: serde::de::Visitor<'de>,
-    {
-        todo!("deserialize_tuple")
-    }
-
-    fn deserialize_tuple_struct<V>(
-        self,
-        _name: &'static str,
-        _len: usize,
-        _visitor: V,
-    ) -> Result<V::Value, Self::Error>
-    where
-        V: serde::de::Visitor<'de>,
-    {
-        todo!("deserialize_tuple_struct")
-    }
-
-    serde::forward_to_deserialize_any!(
-        i8 u8 i16 u16 i32 u32 i64 u64 i128 u128 str string bytes
-        bool f32 f64 char byte_buf option unit unit_struct
-        newtype_struct identifier ignored_any
-    );
-}
-
-struct AtomicForwarder<'de> {
-    to_fwd: &'de str,
-}
-
-impl<'de> serde::Deserializer<'de> for AtomicForwarder<'de> {
-    type Error = ConfigError;
-    fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
-    where
-        V: serde::de::Visitor<'de>,
-    {
-        unreachable!()
-    }
-
-    fn deserialize_str<V>(self, visitor: V) -> Result<V::Value, Self::Error>
-    where
-        V: serde::de::Visitor<'de>,
-    {
-        visitor.visit_str(self.to_fwd)
-    }
-
-    fn deserialize_string<V>(self, visitor: V) -> Result<V::Value, Self::Error>
-    where
-        V: serde::de::Visitor<'de>,
-    {
-        visitor.visit_str(self.to_fwd)
-    }
-
-    fn deserialize_identifier<V>(self, visitor: V) -> Result<V::Value, Self::Error>
-    where
-        V: serde::de::Visitor<'de>,
-    {
-        visitor.visit_str(self.to_fwd)
-    }
-
-    serde::forward_to_deserialize_any!(
-        i8 u8 i16 u16 i32 u32 i64 u64 i128 u128 bytes
-        bool f32 f64 char byte_buf unit unit_struct option
-        newtype_struct ignored_any struct tuple tuple_struct
-        seq map enum
-    );
-}
-
-struct ConfigDeserializerIterator<'de, I: Iterator<Item = (&'de String, &'de String)>> {
-    it: std::iter::Peekable<I>,
-}
-
-impl<'de, I: Iterator<Item = (&'de String, &'de String)>> serde::de::MapAccess<'de>
-    for ConfigDeserializerIterator<'de, I>
-{
-    type Error = ConfigError;
-
-    fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>, Self::Error>
-    where
-        K: serde::de::DeserializeSeed<'de>,
-    {
-        if let Some(e) = self.it.peek() {
-            let de = AtomicForwarder {
-                to_fwd: e.0.as_str(),
-            };
-            Ok(seed.deserialize(de).ok())
-        } else {
-            Ok(None)
-        }
-    }
-
-    fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value, Self::Error>
-    where
-        V: serde::de::DeserializeSeed<'de>,
-    {
-        let value = self.it.next().unwrap();
-
-        let de = AtomicForwarder {
-            to_fwd: value.1.as_str(),
-        };
-
-        seed.deserialize(de)
-            .map_err(|e| ConfigError::InvalidType(e.to_string()))
-    }
-}

+ 71 - 42
src/server/oidc.rs

@@ -1,5 +1,6 @@
-use crate::{jwt, schema};
+use crate::{schema, token};
 use microrm::prelude::*;
+use ring::signature::KeyPair;
 use serde::{Deserialize, Serialize};
 
 type Request = tide::Request<super::ServerStateWrapper>;
@@ -80,45 +81,23 @@ fn do_token_authorize(
         return Ok(tide::Redirect::new(login_url).into());
     }
     let sauth = sauth.unwrap();
-    let user = qi
-        .get()
-        .by_id(&sauth.user)
-        .one()
-        .expect("couldn't query db")
-        .ok_or_else(|| {
-            OIDCError(
-                OIDCErrorType::ServerError,
-                "user no longer exists".to_string(),
-                state,
-            )
-        })?;
-
-    let issuer = format!(
-        "{}/{}",
-        request.state().core.config.base_url,
-        request.param("realm").unwrap()
-    );
-
-    let iat = std::time::SystemTime::now();
-    let exp = iat + std::time::Duration::from_secs(request.state().core.config.auth_token_expiry);
-
-    let token = jwt::JWTData {
-        sub: user.username.as_str(),
-        iss: issuer.as_str(),
-        aud: client.shortname.as_str(),
-        iat: iat.duration_since(std::time::UNIX_EPOCH).unwrap().as_secs(),
-        exp: exp.duration_since(std::time::UNIX_EPOCH).unwrap().as_secs(),
-        extras: Default::default(),
-    };
-
-    let response_body = serde_json::json!({
-        "token": token,
-    });
 
-    Ok(tide::Response::builder(200)
-        .content_type(tide::http::mime::JSON)
-        .body(response_body)
-        .build())
+    let scopes = qp.scope.as_ref().map(|slist| slist.as_str()).unwrap_or("").split_whitespace();
+
+    let token = token::generate_auth_token(&request.state().core.config, &qi, shelper.get_realm().unwrap(), client.id(), sauth.user, scopes);
+
+    if let Some(token) = token {
+        let response_body = serde_json::json!({
+            "token": token,
+        });
+        Ok(tide::Response::builder(200)
+            .content_type(tide::http::mime::JSON)
+            .body(response_body)
+            .build())
+    }
+    else {
+        Err(OIDCError(OIDCErrorType::ServerError, "internal error while generating token".to_string(), state))
+    }
 }
 
 fn do_authorize(request: Request, state: Option<&str>) -> Result<tide::Response, OIDCError> {
@@ -202,7 +181,57 @@ async fn token(mut request: Request) -> tide::Result<tide::Response> {
     todo!()
 }
 
-pub(super) fn oidc_v1_server(mut route: tide::Route<super::ServerStateWrapper>) {
-    route.at("authorize").get(authorize).post(authorize);
-    route.at("token").post(token);
+const AUTHORIZE_PATH : &'static str = "oidc/authorize";
+const TOKEN_PATH : &'static str = "oidc/token";
+const JWKS_PATH : &'static str = "oidc/jwks";
+
+async fn jwks(request: Request) -> tide::Result<tide::Response> {
+    let qi = request.state().core.pool.query_interface();
+
+    let shelper = super::session::SessionHelper::new(&request);
+    let realm = shelper.get_realm()?;
+
+    let keyinfo = qi.get().by(schema::Key::Realm, &realm).all().expect("couldn't query db").into_iter().map(|key| {
+        let kpair = ring::signature::Ed25519KeyPair::from_pkcs8(&key.keydata).expect("couldn't parse keypair in db");
+        let pubkey_bytes = kpair.public_key().as_ref();
+        assert_eq!(pubkey_bytes.len(), 32);
+
+        serde_json::json!({
+            "kty": "OKP",
+            "crv": "Ed25519",
+            "x": base64::encode(pubkey_bytes),
+            "kid": key.key_id,
+        })
+    });
+
+    let jwks_response = serde_json::json!({
+        "keys": keyinfo.collect::<Vec<_>>(),
+    });
+
+    Ok(tide::Response::builder(200).body(jwks_response).build())
+}
+
+async fn discovery_config(request: Request) -> tide::Result<tide::Response> {
+    let server_config = &request.state().core.config;
+    let base_url = format!("{}/{}", server_config.base_url, request.param("realm").unwrap());
+
+    let config_response = serde_json::json!({
+        "issuer": base_url,
+        "authorization_endpoint": format!("{}/{}", base_url, AUTHORIZE_PATH),
+        "token_endpoint": format!("{}/{}", base_url, TOKEN_PATH),
+        "jwks_uri": format!("{}/{}", base_url, JWKS_PATH),
+        "token_endpoint_auth_signing_alg_values_supported": ["EdDSA"],
+        "response_types_supported": ["code", "id_token", "token id_token"],
+        "subject_types_supported": ["public"],
+        "id_token_signing_alg_values_supported": ["EdDSA"],
+    });
+
+    Ok(tide::Response::builder(200).body(config_response).build())
+}
+
+pub(super) fn oidc_server(mut route: tide::Route<super::ServerStateWrapper>) {
+    route.at(AUTHORIZE_PATH).get(authorize).post(authorize);
+    route.at(TOKEN_PATH).post(token);
+    route.at(JWKS_PATH).get(jwks);
+    route.at(".well-known/openid-configuration").get(discovery_config);
 }

+ 61 - 0
src/token.rs

@@ -0,0 +1,61 @@
+use crate::{schema, config, jwt};
+use microrm::{prelude::*, entity::EntityID};
+
+pub fn generate_auth_token<'a>(config: &config::Config, qi: &microrm::QueryInterface, realm: schema::RealmID, client: schema::ClientID, user: schema::UserID, scopes: impl Iterator<Item = &'a str>) -> Option<String> {
+    let realm = qi.get().by_id(&realm).one().ok()??;
+    let client = qi.get().by_id(&client).one().ok()??;
+    let user = qi.get().by_id(&user).one().ok()??;
+
+    let issuer = format!(
+        "{}/{}",
+        config.base_url,
+        realm.shortname,
+    );
+
+    let iat = std::time::SystemTime::now();
+    let exp = iat + std::time::Duration::from_secs(config.auth_token_expiry);
+
+    // find all roles the user can possibly have access to
+    let mut user_roles = vec![];
+    for group in qi.get().by(schema::GroupMembership::User, &user.id()).all().ok()? {
+        for group_role in qi.get().by(schema::GroupRole::Group, &group.id()).all().ok()? {
+            user_roles.push(group_role.role);
+        }
+    }
+
+    // find all roles requested by the scopes
+    let mut requested_roles = vec![];
+    for scope in scopes {
+        if let Some(scope) = qi.get().by(schema::Scope::Shortname, scope).one().ok()? {
+            for scope_role in qi.get().by(schema::ScopeRole::Scope, &scope.id()).all().ok()? {
+                requested_roles.push(scope_role.role);
+            }
+        }
+    }
+
+    user_roles.sort_by_key(|k| k.raw_id());
+    user_roles.dedup();
+    requested_roles.sort_by_key(|k| k.raw_id());
+    requested_roles.dedup();
+
+    // find the intersection between requested roles and the ones the user actually has
+    let resulting_roles = requested_roles.iter().filter(|req| user_roles.contains(req)).map(|role_id| {
+        serde_json::Value::String(qi.get().by_id(role_id).one().unwrap().unwrap().shortname.clone())
+    });
+
+    let token = jwt::JWTData {
+        sub: user.username.as_str(),
+        iss: issuer.as_str(),
+        aud: client.shortname.as_str(),
+        iat: iat.duration_since(std::time::UNIX_EPOCH).unwrap().as_secs(),
+        exp: exp.duration_since(std::time::UNIX_EPOCH).unwrap().as_secs(),
+        extras: [
+            ("roles", serde_json::Value::Array(resulting_roles.collect::<Vec<_>>())),
+        ].into(),
+    };
+
+    let key = qi.get().by(schema::Key::Realm, &realm.id()).one().ok()??;
+    let kpair = ring::signature::Ed25519KeyPair::from_pkcs8(key.keydata.as_slice()).ok()?;
+
+    Some(jwt::JWT::sign(&kpair, token).into_string())
+}