oidc.rs 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. use crate::{key, schema, server::session::SessionHelper};
  2. use microrm::prelude::*;
  3. use serde::{Deserialize, Serialize};
  4. mod api;
  5. mod authorize;
  6. mod token;
  7. type Request = tide::Request<super::ServerStateWrapper>;
  8. const AUTHORIZE_PATH: &str = "oidc/authorize";
  9. const TOKEN_PATH: &str = "oidc/token";
  10. const JWKS_PATH: &str = "oidc/jwks";
  11. const DISCOVERY_PATH: &str = ".well-known/openid-configuration";
  12. #[derive(serde::Serialize)]
  13. pub enum OIDCErrorType {
  14. InvalidRequest,
  15. UnauthorizedClient,
  16. AccessDenied,
  17. UnsupportedResponseType,
  18. // InvalidScope,
  19. ServerError,
  20. // TemporarilyUnavailable,
  21. }
  22. pub enum OIDCErrorPayload<'a> {
  23. Borrowed(&'a str),
  24. Owned(String),
  25. }
  26. impl<'a> OIDCErrorPayload<'a> {
  27. fn as_str(&self) -> &str {
  28. match self {
  29. Self::Borrowed(s) => s,
  30. Self::Owned(s) => s.as_str(),
  31. }
  32. }
  33. }
  34. impl<'a> From<&'a str> for OIDCErrorPayload<'a> {
  35. fn from(value: &'a str) -> Self {
  36. Self::Borrowed(value)
  37. }
  38. }
  39. impl<'a> From<String> for OIDCErrorPayload<'a> {
  40. fn from(value: String) -> Self {
  41. Self::Owned(value)
  42. }
  43. }
  44. /// error type,
  45. pub struct OIDCError<'a>(OIDCErrorType, OIDCErrorPayload<'a>, Option<&'a str>);
  46. impl<'a> OIDCError<'a> {
  47. fn into_response(self) -> tide::Response {
  48. #[derive(Serialize)]
  49. struct ErrorOut<'a> {
  50. error: OIDCErrorType,
  51. error_description: &'a str,
  52. state: Option<&'a str>,
  53. }
  54. let eo = ErrorOut {
  55. error: self.0,
  56. error_description: self.1.as_str(),
  57. state: self.2,
  58. };
  59. tide::Response::builder(400)
  60. .body(serde_json::to_vec(&eo).unwrap())
  61. .build()
  62. }
  63. }
  64. impl<'a> From<microrm::Error> for OIDCError<'a> {
  65. fn from(value: microrm::Error) -> Self {
  66. Self(
  67. OIDCErrorType::ServerError,
  68. format!("Internal database error: {value}").into(),
  69. None,
  70. )
  71. }
  72. }
  73. async fn authorize(request: Request) -> tide::Result<tide::Response> {
  74. #[derive(Deserialize)]
  75. struct State {
  76. state: Option<String>,
  77. }
  78. let state: Option<String> = request.query::<State>().ok().and_then(|x| x.state);
  79. match authorize::do_authorize(request, state.as_deref()) {
  80. Ok(r) => Ok(r),
  81. Err(e) => Ok(e.into_response()),
  82. }
  83. }
  84. async fn token(request: Request) -> tide::Result<tide::Response> {
  85. match token::do_token(request).await {
  86. Ok(res) => Ok(res),
  87. Err(e) => Ok(e.into_response()),
  88. }
  89. }
  90. async fn jwks(request: Request) -> tide::Result<tide::Response> {
  91. let shelper = SessionHelper::new(&request);
  92. let realm = shelper.get_realm()?;
  93. // build JWK set
  94. let mut jwkset = jsonwebtoken::jwk::JwkSet { keys: vec![] };
  95. for key in realm.keys.get()?.into_iter() {
  96. if *key.key_state.as_ref() == schema::KeyState::Retired {
  97. continue;
  98. }
  99. // skip HMAC keys
  100. if let key::KeyType::HMac(_) = *key.key_type.as_ref() {
  101. continue;
  102. }
  103. jwkset.keys.push(key.wrapped().into_jwk());
  104. }
  105. Ok(tide::Response::builder(200)
  106. .header(tide::http::headers::ACCESS_CONTROL_ALLOW_ORIGIN, "*")
  107. .content_type(tide::http::mime::JSON)
  108. .body(serde_json::to_vec(&jwkset).unwrap())
  109. .build())
  110. }
  111. async fn discovery_config(request: Request) -> tide::Result<tide::Response> {
  112. let server_config = &request.state().core.config;
  113. let base_url = format!(
  114. "{}/{}",
  115. server_config.base_url,
  116. request.param("realm").unwrap()
  117. );
  118. let config_response = serde_json::json!({
  119. "issuer": base_url,
  120. "authorization_endpoint": format!("{}/{}", base_url, AUTHORIZE_PATH),
  121. "token_endpoint": format!("{}/{}", base_url, TOKEN_PATH),
  122. "jwks_uri": format!("{}/{}", base_url, JWKS_PATH),
  123. "token_endpoint_auth_signing_alg_values_supported": ["EdDSA", "RS256"],
  124. "response_types_supported": ["code", "id_token", "token id_token"],
  125. "subject_types_supported": ["public"],
  126. "id_token_signing_alg_values_supported": ["EdDSA", "RS256"],
  127. });
  128. Ok(tide::Response::builder(200)
  129. .header(tide::http::headers::ACCESS_CONTROL_ALLOW_ORIGIN, "*")
  130. .body(config_response)
  131. .build())
  132. }
  133. pub(super) fn oidc_server(mut route: tide::Route<super::ServerStateWrapper>) {
  134. route.at(AUTHORIZE_PATH).get(authorize).post(authorize);
  135. route.at(TOKEN_PATH).post(token);
  136. route.at(JWKS_PATH).get(jwks);
  137. route.at(DISCOVERY_PATH).get(discovery_config);
  138. }