Переглянути джерело

Began adding config system and added login redirects.

Kestrel 1 рік тому
батько
коміт
73004fbd5f
12 змінених файлів з 656 додано та 164 видалено
  1. 77 3
      Cargo.lock
  2. 12 10
      Cargo.toml
  3. 36 2
      src/cli.rs
  4. 22 0
      src/client_management.rs
  5. 17 0
      src/jwt.rs
  6. 2 0
      src/main.rs
  7. 15 0
      src/schema.rs
  8. 37 13
      src/server.rs
  9. 262 0
      src/server/config.rs
  10. 63 21
      src/server/oidc.rs
  11. 112 115
      src/server/session.rs
  12. 1 0
      tmpl/id_v1_login.tmpl

+ 77 - 3
Cargo.lock

@@ -494,7 +494,7 @@ dependencies = [
  "bitflags 1.3.2",
  "clap_derive",
  "clap_lex",
- "indexmap",
+ "indexmap 1.9.3",
  "once_cell",
  "strsim",
  "termcolor",
@@ -661,6 +661,12 @@ version = "1.0.4"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "212d0f5754cb6769937f4501cc0e67f4f4483c8d2c3e1e922ee9edbe4ab4c7c0"
 
+[[package]]
+name = "equivalent"
+version = "1.0.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5"
+
 [[package]]
 name = "erased-serde"
 version = "0.3.31"
@@ -879,6 +885,12 @@ version = "0.12.3"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888"
 
+[[package]]
+name = "hashbrown"
+version = "0.14.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "7dfda62a12f55daeae5015f81b0baea145391cb4520f86c248fc615d72640d12"
+
 [[package]]
 name = "heck"
 version = "0.4.1"
@@ -1010,7 +1022,17 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "bd070e393353796e801d209ad339e89596eb4c8d430d18ede6a1cced8fafbd99"
 dependencies = [
  "autocfg",
- "hashbrown",
+ "hashbrown 0.12.3",
+]
+
+[[package]]
+name = "indexmap"
+version = "2.0.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "8adf3ddd720272c6ea8bf59463c04e0f93d0bbf7c5439b691bca2987e0270897"
+dependencies = [
+ "equivalent",
+ "hashbrown 0.14.1",
 ]
 
 [[package]]
@@ -1582,6 +1604,15 @@ dependencies = [
  "thiserror",
 ]
 
+[[package]]
+name = "serde_spanned"
+version = "0.6.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "96426c9936fd7a0124915f9185ea1d20aa9445cc9821142f0a73bc9207a2e186"
+dependencies = [
+ "serde",
+]
+
 [[package]]
 name = "serde_urlencoded"
 version = "0.7.1"
@@ -2017,6 +2048,40 @@ version = "0.1.1"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20"
 
+[[package]]
+name = "toml"
+version = "0.8.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "185d8ab0dfbb35cf1399a6344d8484209c088f75f8f68230da55d48d95d43e3d"
+dependencies = [
+ "serde",
+ "serde_spanned",
+ "toml_datetime",
+ "toml_edit",
+]
+
+[[package]]
+name = "toml_datetime"
+version = "0.6.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "7cda73e2f1397b1262d6dfdcef8aafae14d1de7748d66822d3bfeeb6d03e5e4b"
+dependencies = [
+ "serde",
+]
+
+[[package]]
+name = "toml_edit"
+version = "0.20.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "396e4d48bbb2b7554c944bde63101b5ae446cff6ec4a24227428f15eb72ef338"
+dependencies = [
+ "indexmap 2.0.2",
+ "serde",
+ "serde_spanned",
+ "toml_datetime",
+ "winnow",
+]
+
 [[package]]
 name = "tracing"
 version = "0.1.39"
@@ -2043,7 +2108,6 @@ checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825"
 name = "uauth2"
 version = "0.1.0"
 dependencies = [
- "anyhow",
  "base64 0.13.1",
  "clap",
  "handlebars",
@@ -2059,6 +2123,7 @@ dependencies = [
  "smol",
  "stderrlog",
  "tide",
+ "toml",
 ]
 
 [[package]]
@@ -2369,3 +2434,12 @@ name = "windows_x86_64_msvc"
 version = "0.48.5"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538"
+
+[[package]]
+name = "winnow"
+version = "0.5.17"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "a3b801d0e0a6726477cc207f60162da452f3a95adb368399bef20a946e06f65c"
+dependencies = [
+ "memchr",
+]

+ 12 - 10
Cargo.toml

@@ -7,27 +7,29 @@ edition = "2021"
 
 [dependencies]
 # Core dependencies
-# async-std = "1.11.0"
 smol = "1.3"
+log = "0.4"
+serde = { version =  "1.0", features = ["derive"] }
+lazy_static = "1.4.0"
 
+# crypto
 ring = { version = "0.16.20", features = ["std"] }
-serde = { version =  "1.0", features = ["derive"] }
-serde_bytes = { version = "0.11.6" }
-serde_json = "1.0"
 sha2 = { version = "0.10.2" }
 base64 = { version = "0.13.0" }
-log = "0.4"
-stderrlog = "0.5"
-handlebars = { version = "4.3", features = ["dir_source"] }
-lazy_static = "1.4.0"
+
+# configuration
+toml = "0.8.2"
 
 # Data storage dependencies
 microrm = { version = "0.3.9" }
+serde_bytes = { version = "0.11.6" }
 
-# Public API dependencies
+# Public API/server dependencies
 tide = { version = "0.16.0" }
-anyhow = { version = "1.0" }
+handlebars = { version = "4.3", features = ["dir_source"] }
+serde_json = "1.0"
 
 # CLI dependencies
 clap = { version = "3.1.15", features = ["derive"] }
 rpassword = "6.0"
+stderrlog = "0.5"

+ 36 - 2
src/cli.rs

@@ -1,4 +1,4 @@
-use crate::{cert, schema, server, user_management};
+use crate::{cert, client_management, schema, server, user_management};
 use clap::{Parser, Subcommand};
 
 #[derive(Debug, Parser)]
@@ -20,6 +20,7 @@ struct RootArgs {
 enum Command {
     Init,
     Cert(CertArgs),
+    Client(ClientArgs),
     Group(GroupArgs),
     Server(ServerArgs),
     Token(TokenArgs),
@@ -43,6 +44,7 @@ impl RootArgs {
         match &self.command {
             Command::Init => unreachable!(),
             Command::Cert(v) => v.run(&self, storage).await,
+            Command::Client(v) => v.run(&self, storage).await,
             Command::Group(v) => v.run(&self, storage).await,
             Command::Server(v) => v.run(&self, storage).await,
             Command::Token(v) => v.run(&self, storage).await,
@@ -102,6 +104,33 @@ impl CertArgs {
     }
 }
 
+#[derive(Debug, Subcommand)]
+enum ClientCommand {
+    Create { name: String },
+    List,
+    Inspect { name: String },
+}
+
+#[derive(Debug, Parser)]
+struct ClientArgs {
+    #[clap(subcommand)]
+    command: ClientCommand,
+}
+
+impl ClientArgs {
+    async fn run(&self, root: &RootArgs, db: microrm::DB) {
+        match &self.command {
+            ClientCommand::Create { name } => {
+                client_management::create(&db, root.realm.as_str(), name);
+            }
+            ClientCommand::List => {}
+            ClientCommand::Inspect { name } => {
+                client_management::inspect(&db, name);
+            }
+        }
+    }
+}
+
 #[derive(Debug, Subcommand)]
 enum GroupCommand {}
 
@@ -118,11 +147,16 @@ impl GroupArgs {
 #[derive(Debug, Parser)]
 struct ServerArgs {
     port: Option<u16>,
+    config_path: Option<String>,
 }
 
 impl ServerArgs {
     async fn run(&self, root: &RootArgs, db: microrm::DB) {
-        server::run_server(db, self.port.unwrap_or(2114)).await
+        let config = server::ServerConfig::build_from(
+            &db.query_interface(),
+            self.config_path.as_ref().map(|x| x.as_str()),
+        );
+        server::run_server(db, config, self.port.unwrap_or(2114)).await
     }
 }
 

+ 22 - 0
src/client_management.rs

@@ -0,0 +1,22 @@
+use crate::schema;
+use microrm::prelude::*;
+
+pub fn create(db: &microrm::DB, realm: &str, name: &str) {
+    let qi = db.query_interface();
+
+    let realm = qi
+        .get()
+        .by(schema::Realm::Shortname, realm)
+        .one()
+        .expect("couldn't query db")
+        .expect("no such realm");
+
+    qi.add(&schema::Client {
+        realm: realm.id(),
+        shortname: name.into(),
+        secret: "".into(),
+    })
+    .expect("couldn't add client");
+}
+
+pub fn inspect(db: &microrm::DB, name: &str) {}

+ 17 - 0
src/jwt.rs

@@ -0,0 +1,17 @@
+#[derive(serde::Serialize, serde::Deserialize)]
+pub struct JWTData<'l> {
+    pub sub: &'l str,
+    pub iss: &'l str,
+    pub aud: &'l str,
+    pub iat: u64,
+    pub exp: u64,
+
+    #[serde(flatten)]
+    pub extras: std::collections::HashMap<String, serde_json::Value>,
+}
+
+impl<'l> Into<String> for JWTData<'l> {
+    fn into(self) -> String {
+        serde_json::to_string(&self).unwrap()
+    }
+}

+ 2 - 0
src/main.rs

@@ -1,6 +1,8 @@
 mod cert;
 mod cli;
+mod client_management;
 mod config;
+mod jwt;
 mod login;
 mod schema;
 mod server;

+ 15 - 0
src/schema.rs

@@ -1,6 +1,13 @@
 pub use microrm::{Entity, Modelable, Schema};
 use serde::{Deserialize, Serialize};
 
+/// Simple key-value store for persistent configuration
+#[derive(Debug, Entity, Serialize, Deserialize)]
+pub struct PersistentConfig {
+    pub key: String,
+    pub value: String,
+}
+
 #[derive(Debug, Entity, Serialize, Deserialize)]
 pub struct Session {
     pub key: String,
@@ -76,7 +83,9 @@ pub struct Group {
 /// User membership in group
 #[derive(Debug, Entity, Serialize, Deserialize)]
 pub struct GroupMembership {
+    #[microrm_foreign]
     pub group: GroupID,
+    #[microrm_foreign]
     pub user: UserID,
 }
 
@@ -117,25 +126,31 @@ pub struct Role {
 /// Role membership in scope
 #[derive(Debug, Entity, Serialize, Deserialize)]
 pub struct ScopeRole {
+    #[microrm_foreign]
     pub scope: ScopeID,
+    #[microrm_foreign]
     pub role: RoleID,
 }
 
 /// Assigned permissions in group
 #[derive(Debug, Entity, Serialize, Deserialize)]
 pub struct GroupRole {
+    #[microrm_foreign]
     pub scope: ScopeID,
+    #[microrm_foreign]
     pub role: RoleID,
 }
 
 #[derive(Debug, Entity, Serialize, Deserialize)]
 pub struct RevokedToken {
+    #[microrm_foreign]
     pub user: UserID,
     pub nonce: String,
 }
 
 pub fn schema() -> Schema {
     Schema::new()
+        .entity::<PersistentConfig>()
         .entity::<Session>()
         .index::<SessionKeyIndex>()
         .entity::<SessionAuthentication>()

+ 37 - 13
src/server.rs

@@ -1,20 +1,24 @@
 use crate::schema;
 use microrm::prelude::*;
 
+mod config;
 mod oidc;
 mod session;
 
-pub struct ServerCoreState {
+pub use config::ServerConfig;
+
+pub struct ServerState {
+    config: ServerConfig,
     pool: microrm::DBPool<'static>,
     templates: handlebars::Handlebars<'static>,
 }
 
 #[derive(Clone)]
-struct ServerState {
-    core: &'static ServerCoreState,
+struct ServerStateWrapper {
+    core: &'static ServerState,
 }
 
-fn is_auth_valid<T>(core: &'static ServerCoreState, of: &tide::Request<T>) -> Option<bool> {
+fn is_auth_valid<T>(core: &'static ServerState, of: &tide::Request<T>) -> Option<bool> {
     let cookie = of.cookie("vogt_session")?;
     let session_id = cookie.value();
 
@@ -29,18 +33,38 @@ fn is_auth_valid<T>(core: &'static ServerCoreState, of: &tide::Request<T>) -> Op
     )
 }
 
-pub async fn run_server(db: microrm::DB, port: u16) {
+async fn index(req: tide::Request<ServerStateWrapper>) -> tide::Result<tide::Response> {
+    let shelper = session::SessionHelper::new(&req);
+
+    let realm = shelper.get_realm()?;
+    let sid = shelper.get_session(&req);
+    let auth = sid.and_then(|sid| shelper.get_auth_for_session(realm, sid));
+
+    let response = tide::Response::builder(200)
+        .content_type(tide::http::mime::PLAIN)
+        .body(format!(
+            r#"
+            realm: {realm:?}
+            session: {sid:?}
+            auth: {auth:?}
+        "#
+        ))
+        .build();
+    Ok(response)
+}
+
+pub async fn run_server(db: microrm::DB, config: ServerConfig, port: u16) {
     let db_box = Box::new(db);
     let db: &'static mut microrm::DB = Box::leak(db_box);
     let pool = microrm::DBPool::new(db);
 
-    let core_state = Box::leak(Box::new(ServerCoreState {
+    let core_state = Box::leak(Box::new(ServerState {
+        config,
         pool,
         templates: handlebars::Handlebars::new(),
     }));
 
-    // XXX: for development only
-    // core_state.templates.write().unwrap().set_dev_mode(true);
+    core_state.templates.set_dev_mode(true);
 
     core_state
         .templates
@@ -53,20 +77,20 @@ pub async fn run_server(db: microrm::DB, port: u16) {
 
     core_state.templates.render("id_v1_login", &()).unwrap();
 
-    let state = ServerState { core: core_state };
+    let state = ServerStateWrapper { core: core_state };
 
     let mut app = tide::with_state(state);
 
     app.with(tide::log::LogMiddleware::new());
 
+    app.at("/:realm/").get(index);
+
     app.at("/static")
         .serve_dir("static/")
         .expect("Can't serve static files");
 
-    app.at("/:realm/v1/session/")
-        .nest(session::session_v1_server(core_state));
-    app.at("/:realm/v1/oidc/")
-        .nest(oidc::oidc_v1_server(core_state));
+    session::session_v1_server(app.at("/:realm/v1/session/"));
+    oidc::oidc_v1_server(app.at("/:realm/v1/oidc/"));
 
     app.listen(("127.0.0.1", port)).await.expect("Can listen");
 }

+ 262 - 0
src/server/config.rs

@@ -0,0 +1,262 @@
+use crate::schema;
+use microrm::prelude::*;
+use serde::{Deserialize, Serialize};
+
+#[derive(Serialize, Deserialize)]
+pub struct ServerConfig {
+    pub base_url: String,
+    // #[serde(defaul
+}
+
+impl ServerConfig {
+    pub fn build_from(qi: &microrm::QueryInterface, cfile: Option<&str>) -> Self {
+        let mut config_map = std::collections::HashMap::<String, String>::new();
+        // load config keys from query interface
+        let db_pcs = qi
+            .get::<schema::PersistentConfig>()
+            .all()
+            .expect("couldn't get config keys from database");
+        config_map.extend(db_pcs.into_iter().map(|pc| {
+            let pc = pc.wrapped();
+            (pc.key, pc.value)
+        }));
+
+        if let Some(path) = cfile {
+            match std::fs::read(&path) {
+                Ok(data) => {
+                    log::info!("Loading config from {path}...");
+                    let toml_table: toml::Table = toml::from_str(
+                        std::str::from_utf8(data.as_slice())
+                            .expect("couldn't read config file contents as utf-8"),
+                    )
+                    .expect("couldn't parse config toml");
+                }
+                Err(e) => {
+                    log::error!("Could not open {path} for reading: {e}");
+                }
+            }
+        }
+
+        let mut deser = ConfigDeserializer {
+            config_map: &config_map,
+            prefix: "".to_string(),
+        };
+
+        let config = ServerConfig::deserialize(&mut deser).expect("couldn't load configuration");
+
+        config
+    }
+}
+
+struct ConfigDeserializer<'de> {
+    config_map: &'de std::collections::HashMap<String, String>,
+    prefix: String,
+}
+
+#[derive(Debug)]
+enum ConfigError {
+    Missing(String),
+    InvalidType(String),
+    CustomError(String),
+}
+
+impl std::fmt::Display for ConfigError {
+    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+        match self {
+            Self::Missing(what) => f.write_fmt(format_args!(
+                "Missing required config entry: {}",
+                what.as_str()
+            )),
+            Self::InvalidType(what) => f.write_fmt(format_args!(
+                "Could not parse config entry '{}'",
+                what.as_str()
+            )),
+            Self::CustomError(what) => {
+                f.write_fmt(format_args!("Custom error '{}'", what.as_str()))
+            }
+        }
+    }
+}
+
+impl std::error::Error for ConfigError {}
+
+impl serde::de::Error for ConfigError {
+    fn custom<T>(msg: T) -> Self
+    where
+        T: std::fmt::Display,
+    {
+        Self::CustomError(msg.to_string())
+    }
+
+    fn invalid_type(_unexp: serde::de::Unexpected, _exp: &dyn serde::de::Expected) -> Self {
+        Self::InvalidType("".into())
+    }
+
+    fn missing_field(field: &'static str) -> Self {
+        Self::Missing(field.into())
+    }
+}
+
+impl<'de> serde::Deserializer<'de> for &'de mut ConfigDeserializer<'de> {
+    type Error = ConfigError;
+
+    fn deserialize_any<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
+    where
+        V: serde::de::Visitor<'de>,
+    {
+        unreachable!("deserialize_any needs context")
+    }
+
+    fn deserialize_struct<V>(
+        self,
+        _name: &'static str,
+        _fields: &'static [&'static str],
+        visitor: V,
+    ) -> Result<V::Value, Self::Error>
+    where
+        V: serde::de::Visitor<'de>,
+    {
+        self.deserialize_map(visitor)
+    }
+
+    fn deserialize_seq<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
+    where
+        V: serde::de::Visitor<'de>,
+    {
+        todo!("deserialize_seq")
+    }
+
+    fn deserialize_map<V>(mut self, visitor: V) -> Result<V::Value, Self::Error>
+    where
+        V: serde::de::Visitor<'de>,
+    {
+        let mut map_access = ConfigDeserializerIterator {
+            it: self
+                .config_map
+                .iter()
+                .filter(|e| {
+                    e.0.starts_with(&self.prefix) && !e.0[self.prefix.len()..].contains(".")
+                })
+                .peekable(),
+        };
+
+        visitor.visit_map(&mut map_access)
+    }
+
+    fn deserialize_enum<V>(
+        self,
+        _name: &'static str,
+        _variants: &'static [&'static str],
+        _visitor: V,
+    ) -> Result<V::Value, Self::Error>
+    where
+        V: serde::de::Visitor<'de>,
+    {
+        todo!("deserialize_enum")
+    }
+
+    fn deserialize_tuple<V>(self, _len: usize, _visitor: V) -> Result<V::Value, Self::Error>
+    where
+        V: serde::de::Visitor<'de>,
+    {
+        todo!("deserialize_tuple")
+    }
+
+    fn deserialize_tuple_struct<V>(
+        self,
+        _name: &'static str,
+        _len: usize,
+        _visitor: V,
+    ) -> Result<V::Value, Self::Error>
+    where
+        V: serde::de::Visitor<'de>,
+    {
+        todo!("deserialize_tuple_struct")
+    }
+
+    serde::forward_to_deserialize_any!(
+        i8 u8 i16 u16 i32 u32 i64 u64 i128 u128 str string bytes
+        bool f32 f64 char byte_buf option unit unit_struct
+        newtype_struct identifier ignored_any
+    );
+}
+
+struct AtomicForwarder<'de> {
+    to_fwd: &'de str,
+}
+
+impl<'de> serde::Deserializer<'de> for AtomicForwarder<'de> {
+    type Error = ConfigError;
+    fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
+    where
+        V: serde::de::Visitor<'de>,
+    {
+        unreachable!()
+    }
+
+    fn deserialize_str<V>(self, visitor: V) -> Result<V::Value, Self::Error>
+    where
+        V: serde::de::Visitor<'de>,
+    {
+        visitor.visit_str(self.to_fwd)
+    }
+
+    fn deserialize_string<V>(self, visitor: V) -> Result<V::Value, Self::Error>
+    where
+        V: serde::de::Visitor<'de>,
+    {
+        visitor.visit_str(self.to_fwd)
+    }
+
+    fn deserialize_identifier<V>(self, visitor: V) -> Result<V::Value, Self::Error>
+    where
+        V: serde::de::Visitor<'de>,
+    {
+        visitor.visit_str(self.to_fwd)
+    }
+
+    serde::forward_to_deserialize_any!(
+        i8 u8 i16 u16 i32 u32 i64 u64 i128 u128 bytes
+        bool f32 f64 char byte_buf unit unit_struct option
+        newtype_struct ignored_any struct tuple tuple_struct
+        seq map enum
+    );
+}
+
+struct ConfigDeserializerIterator<'de, I: Iterator<Item = (&'de String, &'de String)>> {
+    it: std::iter::Peekable<I>,
+}
+
+impl<'de, I: Iterator<Item = (&'de String, &'de String)>> serde::de::MapAccess<'de>
+    for ConfigDeserializerIterator<'de, I>
+{
+    type Error = ConfigError;
+
+    fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>, Self::Error>
+    where
+        K: serde::de::DeserializeSeed<'de>,
+    {
+        if let Some(e) = self.it.peek() {
+            let de = AtomicForwarder {
+                to_fwd: e.0.as_str(),
+            };
+            Ok(seed.deserialize(de).ok())
+        } else {
+            Ok(None)
+        }
+    }
+
+    fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value, Self::Error>
+    where
+        V: serde::de::DeserializeSeed<'de>,
+    {
+        let value = self.it.next().unwrap();
+
+        let de = AtomicForwarder {
+            to_fwd: value.1.as_str(),
+        };
+
+        seed.deserialize(de)
+            .map_err(|e| ConfigError::InvalidType(e.to_string()))
+    }
+}

+ 63 - 21
src/server/oidc.rs

@@ -1,13 +1,8 @@
-use crate::schema;
+use crate::{jwt, schema};
 use microrm::prelude::*;
 use serde::{Deserialize, Serialize};
 
-#[derive(Clone)]
-pub struct ServerState {
-    core: &'static super::ServerCoreState,
-}
-
-type Request = tide::Request<ServerState>;
+type Request = tide::Request<super::ServerStateWrapper>;
 
 #[derive(serde::Serialize)]
 pub enum OIDCErrorType {
@@ -20,6 +15,7 @@ pub enum OIDCErrorType {
     TemporarilyUnavailable,
 }
 
+/// error type,
 pub struct OIDCError<'a>(OIDCErrorType, String, Option<&'a str>);
 
 impl<'a> OIDCError<'a> {
@@ -66,7 +62,60 @@ fn do_token_authorize(
     state: Option<&str>,
     client: microrm::WithID<schema::Client>,
 ) -> Result<tide::Response, OIDCError> {
-    todo!()
+    let qi = request.state().core.pool.query_interface();
+
+    let shelper = super::session::SessionHelper::new(&request);
+
+    let sauth = shelper
+        .get_session(&request)
+        .and_then(|sid| shelper.get_auth_for_session(shelper.get_realm().unwrap(), sid));
+
+    if sauth.is_none() {
+        // if we don't have any relevant auth info, redirect to login
+        let mut login_url = request.url().join("../session/login").unwrap();
+        login_url
+            .query_pairs_mut()
+            .clear()
+            .append_pair("redirect", request.url().as_str());
+        return Ok(tide::Redirect::new(login_url).into());
+    }
+    let sauth = sauth.unwrap();
+    let user = qi
+        .get()
+        .by_id(&sauth.user)
+        .one()
+        .expect("couldn't query db")
+        .ok_or_else(|| {
+            OIDCError(
+                OIDCErrorType::ServerError,
+                "user no longer exists".to_string(),
+                state,
+            )
+        })?;
+
+    let issuer = format!(
+        "{}/{}",
+        request.state().core.config.base_url,
+        request.param("realm").unwrap()
+    );
+
+    let token = jwt::JWTData {
+        sub: user.username.as_str(),
+        iss: issuer.as_str(),
+        aud: client.shortname.as_str(),
+        iat: 123,
+        exp: 123,
+        extras: Default::default(),
+    };
+
+    let response_body = serde_json::json!({
+        "token": token,
+    });
+
+    Ok(tide::Response::builder(200)
+        .content_type(tide::http::mime::JSON)
+        .body(response_body)
+        .build())
 }
 
 fn do_authorize(request: Request, state: Option<&str>) -> Result<tide::Response, OIDCError> {
@@ -124,13 +173,10 @@ async fn authorize(request: Request) -> tide::Result<tide::Response> {
     }
     let state: Option<String> = request.query::<State>().ok().map(|x| x.state).flatten();
 
-    let result = do_authorize(request, state.as_ref().map(|x| x.as_str()));
-
-    if let Err(e) = result {
-        todo!()
+    match do_authorize(request, state.as_ref().map(|x| x.as_str())) {
+        Ok(r) => Ok(r),
+        Err(e) => Ok(e.to_response()),
     }
-
-    todo!()
 }
 
 #[derive(Deserialize)]
@@ -153,11 +199,7 @@ async fn token(mut request: Request) -> tide::Result<tide::Response> {
     todo!()
 }
 
-pub fn oidc_v1_server(core: &'static super::ServerCoreState) -> tide::Server<ServerState> {
-    let mut srv = tide::with_state(ServerState { core });
-
-    srv.at("authorize").get(authorize).post(authorize);
-    srv.at("token").post(token);
-
-    srv
+pub(super) fn oidc_v1_server(mut route: tide::Route<super::ServerStateWrapper>) {
+    route.at("authorize").get(authorize).post(authorize);
+    route.at("token").post(token);
 }

+ 112 - 115
src/server/session.rs

@@ -3,50 +3,37 @@ use microrm::prelude::*;
 use serde::Deserialize;
 use tide::http::Cookie;
 
-#[derive(Clone)]
-pub struct ServerState {
-    core: &'static super::ServerCoreState,
-    realm_cache:
-        std::sync::Arc<std::sync::RwLock<std::collections::HashMap<String, schema::RealmID>>>,
+pub(super) struct SessionHelper<'l> {
+    qi: &'l microrm::QueryInterface<'static>,
+    tmpl: &'l handlebars::Handlebars<'l>,
+    realm_str: &'l str,
 }
 
-type Request = tide::Request<ServerState>;
+type Request = tide::Request<super::ServerStateWrapper>;
 
-impl ServerState {
-    pub fn get_realm(&self, req: &Request) -> Option<schema::RealmID> {
-        let realm_str = req
-            .param("realm")
-            .expect("get_realm called with no :realm param");
-        let cache = self.realm_cache.read().unwrap();
-        let cache_lookup = cache.get(realm_str);
+const SESSION_COOKIE_NAME: &'static str = "uauth_session";
 
-        // expected case
-        if cache_lookup.is_some() {
-            return cache_lookup.map(|x| *x);
+impl<'l> SessionHelper<'l> {
+    pub fn new(req: &'l Request) -> Self {
+        Self {
+            qi: req.state().core.pool.query_interface(),
+            tmpl: &req.state().core.templates,
+            realm_str: req.param("realm").expect("no realm param?"),
         }
-        drop(cache);
-
-        // unexpected case, but maybe we haven't filled that cache entry yet
+    }
 
-        let qi = self.core.pool.query_interface();
-        let realm = qi
+    pub fn get_realm(&self) -> tide::Result<schema::RealmID> {
+        self.qi
             .get()
-            .by(schema::Realm::Shortname, realm_str)
+            .by(schema::Realm::Shortname, self.realm_str)
             .one()
-            .expect("couldn't query db");
-
-        if let Some(with_id) = realm {
-            let mut cache = self.realm_cache.write().unwrap();
-            cache.insert(realm_str.to_owned(), with_id.id());
-            return Some(with_id.id());
-        }
-
-        // other expected case, is bogus realm
-        return None;
+            .expect("couldn't query db")
+            .map(|r| r.id())
+            .ok_or(tide::Error::from_str(404, "No such realm"))
     }
 
     fn build_session(
-        qi: &microrm::QueryInterface,
+        &self,
     ) -> tide::Result<(schema::SessionID, Option<tide::http::Cookie<'static>>)> {
         let rng = ring::rand::SystemRandom::new();
         let session_id: [u8; 32] = ring::rand::generate(&rng)
@@ -54,22 +41,25 @@ impl ServerState {
             .expose();
         let session_id = base64::encode_config(session_id, base64::URL_SAFE_NO_PAD);
 
-        let maybe_id = qi.add(&schema::Session {
+        let maybe_id = self.qi.add(&schema::Session {
             key: session_id.clone(),
         });
+        let session_cookie = Cookie::build(SESSION_COOKIE_NAME, session_id)
+            .path("/")
+            .finish();
         Ok((
             maybe_id.ok().ok_or(tide::Error::from_str(
                 500,
                 "Failed to store session in database",
             ))?,
-            Some(Cookie::new("vogt_session", session_id)),
+            Some(session_cookie),
         ))
     }
 
     pub fn verify_session(&self, req: &Request) -> Option<(schema::RealmID, schema::UserID)> {
         self.get_or_build_session(req)
             .ok()
-            .zip(self.get_realm(req))
+            .zip(self.get_realm().ok())
             .and_then(|((sid, _cookie), realm)| {
                 self.get_auth_for_session(realm, sid).and_then(|auth| {
                     if auth.challenges_left.len() == 0 {
@@ -81,23 +71,26 @@ impl ServerState {
             })
     }
 
+    pub fn get_session(&self, req: &Request) -> Option<schema::SessionID> {
+        req.cookie(SESSION_COOKIE_NAME)
+            .and_then(|sid| {
+                self.qi
+                    .get()
+                    .by(schema::Session::Key, sid.value())
+                    .one()
+                    .expect("couldn't query db")
+            })
+            .and_then(|session| Some(session.id()))
+    }
+
     pub fn get_or_build_session(
         &self,
         req: &Request,
     ) -> tide::Result<(schema::SessionID, Option<tide::http::Cookie<'static>>)> {
-        let qi = self.core.pool.query_interface();
-        if let Some(sid) = req.cookie("vogt_session") {
-            let existing = qi
-                .get()
-                .by(schema::Session::Key, sid.value())
-                .one()
-                .expect("couldn't query db");
-
-            if existing.is_some() {
-                return Ok((existing.unwrap().id(), None));
-            }
+        match self.get_session(&req) {
+            Some(sid) => Ok((sid, None)),
+            None => self.build_session(),
         }
-        Self::build_session(qi)
     }
 
     pub fn get_auth_for_session(
@@ -105,10 +98,9 @@ impl ServerState {
         realm: schema::RealmID,
         session: schema::SessionID,
     ) -> Option<microrm::WithID<schema::SessionAuthentication>> {
-        let qi = self.core.pool.query_interface();
-
         use schema::SessionAuthentication as SAC;
-        qi.get()
+        self.qi
+            .get()
             .by(SAC::Realm, &realm)
             .by(SAC::Session, &session)
             .one()
@@ -116,10 +108,9 @@ impl ServerState {
     }
 
     pub fn destroy_auth(&self, realm: schema::RealmID, session: schema::SessionID) {
-        let qi = self.core.pool.query_interface();
-
         use schema::SessionAuthentication as SAC;
-        qi.delete()
+        self.qi
+            .delete()
             .by(SAC::Realm, &realm)
             .by(SAC::Session, &session)
             .exec()
@@ -127,10 +118,11 @@ impl ServerState {
     }
 }
 
-impl ServerState {
+impl<'l> SessionHelper<'l> {
     fn render_login_from_auth(
         &self,
         mut response: tide::Response,
+        redirect: String,
         auth: Option<schema::SessionAuthentication>,
         error_msg: Option<String>,
     ) -> tide::Response {
@@ -143,37 +135,38 @@ impl ServerState {
 
         if to_present.is_none() {
             response.set_status(302);
-            tide::Redirect::new("/").into()
+            tide::Redirect::new("../..").into()
         } else {
-            self.render_login_page(response, to_present.unwrap(), error_msg)
+            self.render_login_page(response, redirect, to_present.unwrap(), error_msg)
         }
     }
 
     fn render_login_page(
         &self,
         mut response: tide::Response,
+        redirect: String,
         to_present: schema::AuthChallengeType,
         error_msg: Option<String>,
     ) -> tide::Response {
-        let tmpl = &self.core.templates;
-
         let do_challenge = |ty, ch| {
-            tmpl.render(
-                "id_v1_login",
-                &serde_json::json!(
-                    {
-                        "challenge":
-                            format!(r#"
+            self.tmpl
+                .render(
+                    "id_v1_login",
+                    &serde_json::json!(
+                        {
+                            "challenge":
+                                format!(r#"
                             <input type="hidden" name="challenge_type" value="{:?}" />
                             <div class="challenge-type">{}</div>
                             <div class="challenge-content">{}</div>
                             "#,
-                                to_present, ty, ch),
-                        "error_msg": error_msg.iter().collect::<Vec<_>>()
-                    }
-                ),
-            )
-            .unwrap()
+                                    to_present, ty, ch),
+                            "redirect": redirect,
+                            "error_msg": error_msg.iter().collect::<Vec<_>>()
+                        }
+                    ),
+                )
+                .unwrap()
         };
 
         response.set_content_type("text/html");
@@ -198,48 +191,57 @@ impl ServerState {
     }
 }
 
-async fn v1_login(req: tide::Request<ServerState>) -> tide::Result<tide::Response> {
+async fn v1_login(req: Request) -> tide::Result<tide::Response> {
     let mut response = tide::Response::builder(200).build();
 
-    let realm = req
-        .state()
-        .get_realm(&req)
-        .ok_or(tide::Error::from_str(404, "No such realm"))?;
-    let (session_id, cookie) = req.state().get_or_build_session(&req)?;
+    let shelper = SessionHelper::new(&req);
+
+    let realm = shelper.get_realm()?;
+    let (session_id, cookie) = shelper.get_or_build_session(&req)?;
     cookie.map(|c| response.insert_cookie(c));
 
-    let auth = req.state().get_auth_for_session(realm, session_id);
+    let auth = shelper.get_auth_for_session(realm, session_id);
 
-    Ok(req
-        .state()
-        .render_login_from_auth(response, auth.map(|a| a.wrapped()), None))
-}
+    #[derive(serde::Deserialize)]
+    struct LoginQuery {
+        redirect: Option<String>,
+    }
 
-async fn v1_login_post(mut req: tide::Request<ServerState>) -> tide::Result<tide::Response> {
-    let mut response = tide::Response::builder(200).build();
+    let query: LoginQuery = req.query().unwrap();
 
-    let realm = req
-        .state()
-        .get_realm(&req)
-        .ok_or(tide::Error::from_str(404, "No such realm"))?;
-    let (session_id, cookie) = req.state().get_or_build_session(&req)?;
-    cookie.map(|c| response.insert_cookie(c));
+    Ok(shelper.render_login_from_auth(
+        response,
+        query.redirect.unwrap_or_else(|| "../..".to_string()),
+        auth.map(|a| a.wrapped()),
+        None,
+    ))
+}
 
-    let mut auth = req.state().get_auth_for_session(realm, session_id);
+async fn v1_login_post(mut req: Request) -> tide::Result<tide::Response> {
+    let mut response = tide::Response::builder(200).build();
 
     #[derive(Deserialize)]
     struct ResponseBody {
         challenge_type: String,
         challenge: String,
         reset: Option<String>,
+        redirect: String,
     }
 
     let body: ResponseBody = req.body_form().await?;
 
+    let shelper = SessionHelper::new(&req);
+
+    let realm = shelper.get_realm()?;
+    let (session_id, cookie) = shelper.get_or_build_session(&req)?;
+    cookie.map(|c| response.insert_cookie(c));
+
+    let mut auth = shelper.get_auth_for_session(realm, session_id);
+
     // check if a login reset was requested; if so, we start again from the top
     if body.reset.is_some() {
         if let Some(_) = auth {
-            req.state().destroy_auth(realm, session_id);
+            shelper.destroy_auth(realm, session_id);
             response.set_status(302);
             return Ok(tide::Redirect::new("login").into());
         }
@@ -268,7 +270,7 @@ async fn v1_login_post(mut req: tide::Request<ServerState>) -> tide::Result<tide
     match challenge {
         ChallengeType::Username => {
             let qi = req.state().core.pool.query_interface();
-            req.state().destroy_auth(realm, session_id);
+            shelper.destroy_auth(realm, session_id);
 
             let user = qi
                 .get()
@@ -330,32 +332,27 @@ async fn v1_login_post(mut req: tide::Request<ServerState>) -> tide::Result<tide
         }
     };
 
-    Ok(req
-        .state()
-        .render_login_from_auth(response, auth.map(|a| a.wrapped()), error))
+    Ok(shelper.render_login_from_auth(response, body.redirect, auth.map(|a| a.wrapped()), error))
 }
 
-async fn v1_logout(req: tide::Request<ServerState>) -> tide::Result<tide::Response> {
-    let realm = req
-        .state()
-        .get_realm(&req)
-        .ok_or(tide::Error::from_str(404, "No such realm"))?;
-    let (session_id, _) = req.state().get_or_build_session(&req)?;
+async fn v1_logout(req: Request) -> tide::Result<tide::Response> {
+    let shelper = SessionHelper::new(&req);
 
-    req.state().destroy_auth(realm, session_id);
-    Ok(tide::Redirect::new("/").into())
-}
-
-pub fn session_v1_server(core: &'static super::ServerCoreState) -> tide::Server<ServerState> {
-    let mut srv = tide::with_state(ServerState {
-        core,
-        realm_cache: std::sync::Arc::new(std::sync::RwLock::new(std::collections::HashMap::new())),
-    });
+    #[derive(serde::Deserialize)]
+    struct LogoutQuery {
+        redirect: Option<String>,
+    }
 
-    srv.with(tide::log::LogMiddleware::new());
+    let query: LogoutQuery = req.query().unwrap();
 
-    srv.at("login").get(v1_login).post(v1_login_post);
-    srv.at("logout").get(v1_logout);
+    let realm = shelper.get_realm()?;
+    shelper
+        .get_session(&req)
+        .map(|sid| shelper.destroy_auth(realm, sid));
+    Ok(tide::Redirect::new(query.redirect.unwrap_or_else(|| "../..".into())).into())
+}
 
-    srv
+pub(super) fn session_v1_server(mut route: tide::Route<super::ServerStateWrapper>) {
+    route.at("login").get(v1_login).post(v1_login_post);
+    route.at("logout").get(v1_logout);
 }

+ 1 - 0
tmpl/id_v1_login.tmpl

@@ -22,6 +22,7 @@
                             <input type="submit" value=">" />
                         </div>
                         <div class="spacer">&nbsp;</div>
+                        <input type="hidden" name="redirect" value="{{ redirect }}" />
                         <input type="submit" name="reset" value="Start over" />
                     </div>
                 </form>