Răsfoiți Sursa

Address clippy nits and follow microrm changes.

Kestrel 4 zile în urmă
părinte
comite
87a860ca2a
5 a modificat fișierele cu 132 adăugiri și 131 ștergeri
  1. 2 10
      src/cli.rs
  2. 117 113
      src/ext/github.rs
  3. 8 4
      src/realm.rs
  4. 4 3
      src/server/session.rs
  5. 1 1
      src/user.rs

+ 2 - 10
src/cli.rs

@@ -101,9 +101,6 @@ impl RootArgs {
             return self.init(db, &mut lease, config).await;
         }
 
-        // let db = UIDCDatabase::open_path(&config.db_path)
-        //    .map_err(|e| UIDCError::AbortString(format!("Error accessing database: {:?}", e)))?;
-
         let realm = db
             .realms
             .keyed(&self.realm)
@@ -153,18 +150,13 @@ impl RootArgs {
         }
     }
 
-    async fn init<'l>(
+    async fn init(
         &self,
         db: UIDCDatabase,
-        lease: &mut microrm::ConnectionLease<'l>,
+        lease: &mut microrm::ConnectionLease,
         config: Config,
     ) -> Result<(), UIDCError> {
-        // first check to see if the database is already vaguely set up
-        /*let db = UIDCDatabase::open_path(&config.db_path)
-        .map_err(|e| UIDCError::AbortString(format!("Error accessing database: {:?}", e)))?;*/
-
         log::info!("Initializing!");
-
         if db.realms.keyed("primary").get(lease)?.is_some() {
             log::warn!("Already initialized with primary realm!");
             return Ok(());

+ 117 - 113
src/ext/github.rs

@@ -8,6 +8,8 @@ use crate::{
     UIDCError,
 };
 
+use super::ExternalAuthenticator;
+
 #[derive(Debug, Clone, Deserialize)]
 #[serde(deny_unknown_fields)]
 pub struct GithubConfig {
@@ -35,7 +37,119 @@ pub struct GithubAuthenticator {
     config: GithubConfig,
 }
 
-impl super::ExternalAuthenticator for GithubAuthenticator {
+impl GithubAuthenticator {
+    async fn do_extract_login_state(&self, req: UIDCRequest) -> Result<tide::Response, UIDCError> {
+        let state = req.state();
+        let realm = req.param("realm").unwrap();
+
+        #[derive(Deserialize)]
+        struct Query {
+            code: String,
+            redirect: String,
+            mode: CallbackRequestType,
+        }
+        let Ok(query) = req.query::<Query>() else {
+            return Ok(tide::Response::builder(400)
+                .body("Query string invalid.")
+                .build());
+        };
+
+        #[derive(Deserialize)]
+        struct TokenResponse {
+            access_token: String,
+        }
+
+        let auth = surf::http::auth::BasicAuth::new(
+            self.config.client_id.as_str(),
+            self.config.client_secret.as_str(),
+        );
+
+        let resp: TokenResponse = match state
+            .client
+            .post(
+                tide::http::Url::parse_with_params(
+                    self.config
+                        .token_url
+                        .as_deref()
+                        .unwrap_or(DEFAULT_TOKEN_URL),
+                    &[
+                        ("client_id", self.config.client_id.as_str()),
+                        ("client_secret", self.config.client_secret.as_str()),
+                        ("code", query.code.as_str()),
+                    ],
+                )
+                .expect("couldn't generate token url for github"),
+            )
+            .header(auth.name(), auth.value())
+            .content_type(surf::http::mime::FORM)
+            .recv_form()
+            .await
+        {
+            Ok(resp) => resp,
+            Err(err) => {
+                return Err(UIDCError::AbortString(format!(
+                    "could not parse Github response for token: {err}"
+                )))
+            }
+        };
+
+        let atoken = resp.access_token;
+
+        #[derive(Deserialize)]
+        struct UserInfoResponse {
+            id: i64,
+        }
+
+        let resp: UserInfoResponse = match state
+            .client
+            .get(format!(
+                "{base}/user",
+                base = self.config.api_base.as_deref().unwrap_or(DEFAULT_API_BASE)
+            ))
+            .header("Authorization", format!("Bearer {atoken}"))
+            .content_type(surf::http::mime::JSON)
+            .recv_json()
+            .await
+        {
+            Ok(resp) => resp,
+            Err(err) => {
+                return Err(UIDCError::AbortString(format!(
+                    "could not parse Github response for token: {err}"
+                )))
+            }
+        };
+
+        let user_id = resp.id.to_string();
+
+        let external_auth_map = {
+            let mut lease = state.lease().await?;
+            let Some(realm) = state.db.realms.keyed(realm).get(&mut lease).ok().flatten() else {
+                return Ok(tide::Response::builder(404).body("no such realm").build());
+            };
+
+            realm
+                .external_auth
+                .keyed((
+                    &user_id,
+                    schema::ExternalAuthProvider::Github.into_serialized(),
+                ))
+                .get(&mut lease)
+                .ok()
+                .flatten()
+        };
+
+        match (query.mode, external_auth_map) {
+            (CallbackRequestType::Login, Some(map)) => {
+                self.handle_matching_login(req, map.internal_user_id, query.redirect.as_str())
+                    .await
+            }
+            (CallbackRequestType::Login, None) => self.handle_no_mapping(req, query.redirect).await,
+            (CallbackRequestType::Register, _) => self.handle_registration(req).await,
+        }
+    }
+}
+
+impl ExternalAuthenticator for GithubAuthenticator {
     fn build(_db: &UIDCDatabase, config: &Config) -> Option<Self> {
         config.github.as_ref().map(|ghc| Self {
             base_url: config.base_url.clone(),
@@ -92,117 +206,7 @@ impl super::ExternalAuthenticator for GithubAuthenticator {
     fn extract_login_state(
         &self,
         req: UIDCRequest,
-    ) -> impl smol::prelude::Future<Output = Result<tide::Response, UIDCError>> {
-        async move {
-            let state = req.state();
-            let realm = req.param("realm").unwrap();
-
-            #[derive(Deserialize)]
-            struct Query {
-                code: String,
-                redirect: String,
-                mode: CallbackRequestType,
-            }
-            let Ok(query) = req.query::<Query>() else {
-                return Ok(tide::Response::builder(400)
-                    .body("Query string invalid.")
-                    .build());
-            };
-
-            #[derive(Deserialize)]
-            struct TokenResponse {
-                access_token: String,
-            }
-
-            let auth = surf::http::auth::BasicAuth::new(
-                self.config.client_id.as_str(),
-                self.config.client_secret.as_str(),
-            );
-
-            let resp: TokenResponse = match state
-                .client
-                .post(
-                    tide::http::Url::parse_with_params(
-                        self.config
-                            .token_url
-                            .as_deref()
-                            .unwrap_or(DEFAULT_TOKEN_URL),
-                        &[
-                            ("client_id", self.config.client_id.as_str()),
-                            ("client_secret", self.config.client_secret.as_str()),
-                            ("code", query.code.as_str()),
-                        ],
-                    )
-                    .expect("couldn't generate token url for github"),
-                )
-                .header(auth.name(), auth.value())
-                .content_type(surf::http::mime::FORM)
-                .recv_form()
-                .await
-            {
-                Ok(resp) => resp,
-                Err(err) => {
-                    return Err(UIDCError::AbortString(format!(
-                        "could not parse Github response for token: {err}"
-                    )))
-                }
-            };
-
-            let atoken = resp.access_token;
-
-            #[derive(Deserialize)]
-            struct UserInfoResponse {
-                id: i64,
-            }
-
-            let resp: UserInfoResponse = match state
-                .client
-                .get(format!(
-                    "{base}/user",
-                    base = self.config.api_base.as_deref().unwrap_or(DEFAULT_API_BASE)
-                ))
-                .header("Authorization", format!("Bearer {atoken}"))
-                .content_type(surf::http::mime::JSON)
-                .recv_json()
-                .await
-            {
-                Ok(resp) => resp,
-                Err(err) => {
-                    return Err(UIDCError::AbortString(format!(
-                        "could not parse Github response for token: {err}"
-                    )))
-                }
-            };
-
-            let user_id = resp.id.to_string();
-
-            let mut lease = state.lease().await?;
-
-            let Some(realm) = state.db.realms.keyed(realm).get(&mut lease).ok().flatten() else {
-                return Ok(tide::Response::builder(404).body("no such realm").build());
-            };
-
-            let external_auth_map = realm
-                .external_auth
-                .keyed((
-                    &user_id,
-                    schema::ExternalAuthProvider::Github.into_serialized(),
-                ))
-                .get(&mut lease)
-                .ok()
-                .flatten();
-
-            drop(lease);
-            match (query.mode, external_auth_map) {
-                (CallbackRequestType::Login, Some(map)) => {
-                    self.handle_matching_login(req, map.internal_user_id, query.redirect.as_str())
-                        .await
-                }
-                (CallbackRequestType::Login, None) => {
-                    self.handle_no_mapping(req, query.redirect).await
-                }
-                (CallbackRequestType::Register, _) => self.handle_registration(req).await,
-            }
-        }
+    ) -> impl std::future::Future<Output = Result<tide::Response, UIDCError>> {
+        self.do_extract_login_state(req)
     }
 }

+ 8 - 4
src/realm.rs

@@ -301,6 +301,7 @@ impl RealmHelper {
         client: &microrm::Stored<schema::Client>,
         rtoken: &str,
     ) -> Result<(String, String), UIDCError> {
+        let mut txn = lease.guard("token trade")?;
         let header = jsonwebtoken::decode_header(rtoken)
             .map_err(|e| UIDCError::AbortString(format!("invalid JWT header: {e}")))?;
         let Some(kid) = header.kid else {
@@ -312,7 +313,7 @@ impl RealmHelper {
             .keys
             .with(schema::Key::KeyId, kid)
             .first()
-            .get(lease)?
+            .get(txn.as_mut())?
         else {
             return Err(UIDCError::Abort("no matching key"));
         };
@@ -338,14 +339,17 @@ impl RealmHelper {
             .realm
             .users
             .keyed((self.realm.id(), rt.claims.sub.as_str()))
-            .get(lease)?
+            .get(txn.as_mut())?
         else {
             return Err(UIDCError::Abort("user no longer exists or was renamed"));
         };
 
         let scopes = rt.claims.scopes.iter().map(String::as_str);
-        let access_token = self.generate_access_token(lease, client, &user, scopes.clone())?;
-        let refresh_token = self.generate_refresh_token(lease, client, &user, scopes)?;
+        let access_token =
+            self.generate_access_token(txn.as_mut(), client, &user, scopes.clone())?;
+        let refresh_token = self.generate_refresh_token(txn.as_mut(), client, &user, scopes)?;
+
+        txn.commit()?;
 
         Ok((access_token, refresh_token))
     }

+ 4 - 3
src/server/session.rs

@@ -174,7 +174,7 @@ impl<'l> SessionHelper<'l> {
                             "redirect": redirect,
                             "error_msg": error_msg.iter().collect::<Vec<_>>(),
                             "show_gh_login": gh.is_some(),
-                            "gh_login_url": gh.as_ref().map(|gh| gh.generate_login_url(self.realm_str, redirect.as_str())).unwrap_or(String::new()),
+                            "gh_login_url": gh.as_ref().map(|gh| gh.generate_login_url(self.realm_str, redirect.as_str())).unwrap_or_default(),
                         }
                     ),
                 )
@@ -244,9 +244,10 @@ async fn v1_onetime(req: Request) -> tide::Result<tide::Response> {
             response,
             otq.redirect,
             None,
-            Some(format!(
+            Some(
                 "Single-use authentication code does not exist or has already been used!"
-            )),
+                    .to_string(),
+            ),
         ));
     };
 

+ 1 - 1
src/user.rs

@@ -10,7 +10,7 @@ pub enum UserError {
     InvalidInput,
 }
 
-static PBKDF2_ROUNDS: std::num::NonZeroU32 = unsafe { std::num::NonZeroU32::new_unchecked(20000) };
+static PBKDF2_ROUNDS: std::num::NonZeroU32 = std::num::NonZeroU32::new(20000).unwrap();
 
 fn generate_totp_digits(secret: &[u8], time_offset: isize) -> Result<u32, UIDCError> {
     use hmac::Mac;