diff options
Diffstat (limited to 'crates/secd/src/client/mod.rs')
| -rw-r--r-- | crates/secd/src/client/mod.rs | 233 |
1 files changed, 223 insertions, 10 deletions
diff --git a/crates/secd/src/client/mod.rs b/crates/secd/src/client/mod.rs index 3925657..38426ef 100644 --- a/crates/secd/src/client/mod.rs +++ b/crates/secd/src/client/mod.rs @@ -1,13 +1,24 @@ -pub mod email; -pub mod sqldb; +pub(crate) mod email; +pub(crate) mod sqldb; +pub(crate) mod types; -use std::collections::HashMap; +use std::{collections::HashMap, str::FromStr}; use super::Identity; -use crate::{EmailValidation, Session, SessionSecret}; +use crate::{ + EmailValidation, OauthProvider, OauthProviderName, OauthResponseType, OauthValidation, Session, + SessionSecret, ValidationRequestId, ValidationType, +}; +use email_address::EmailAddress; use lazy_static::lazy_static; +use sqlx::{ + database::HasValueRef, sqlite::SqliteRow, ColumnIndex, Database, Decode, FromRow, Row, Sqlite, + Type, +}; use thiserror::Error; +use time::OffsetDateTime; +use url::Url; use uuid::Uuid; pub enum EmailType { @@ -36,13 +47,15 @@ pub trait EmailMessenger { #[derive(Error, Debug, derive_more::Display)] pub enum StoreError { SqlxError(#[from] sqlx::Error), - EmailAlreadyExists, CodeAppearsMoreThanOnce, CodeDoesNotExist(String), IdentityIdMustExistInvariant, - TooManyEmailValidations, + TooManyValidations, + TooManyIdentitiesFound, NoEmailValidationFound, - Unknown, + OauthProviderDoesNotExist(OauthProviderName), + OauthValidationDoesNotExist(ValidationRequestId), + Other(String), } const EMAIL_TEMPLATE_DEFAULT_LOGIN: &str = "You requested a login link. Please click the following link %secd_code% to login as %secd_email_address%"; @@ -56,6 +69,7 @@ const PGSQL: &str = "pgsql"; const WRITE_IDENTITY: &str = "write_identity"; const WRITE_EMAIL_VALIDATION: &str = "write_email_validation"; const FIND_EMAIL_VALIDATION: &str = "find_email_validation"; +const READ_VALIDATION_TYPE: &str = "read_validation_type"; const WRITE_EMAIL: &str = "write_email"; @@ -69,6 +83,11 @@ const READ_EMAIL_RAW_ID: &str = "read_email_raw_id"; const WRITE_SESSION: &str = "write_session"; const READ_SESSION: &str = "read_session"; +const WRITE_OAUTH_PROVIDER: &str = "write_oauth_provider"; +const READ_OAUTH_PROVIDER: &str = "read_oauth_provider"; +const WRITE_OAUTH_VALIDATION: &str = "write_oauth_validation"; +const READ_OAUTH_VALIDATION: &str = "read_oauth_validation"; + lazy_static! { static ref SQLS: HashMap<&'static str, HashMap<&'static str, &'static str>> = { let sqlite_sqls: HashMap<&'static str, &'static str> = [ @@ -116,6 +135,26 @@ lazy_static! { FIND_EMAIL_VALIDATION, include_str!("../../store/sqlite/sql/find_email_validation.sql"), ), + ( + WRITE_OAUTH_PROVIDER, + include_str!("../../store/sqlite/sql/write_oauth_provider.sql"), + ), + ( + READ_OAUTH_PROVIDER, + include_str!("../../store/sqlite/sql/read_oauth_provider.sql"), + ), + ( + READ_OAUTH_VALIDATION, + include_str!("../../store/sqlite/sql/read_oauth_validation.sql"), + ), + ( + WRITE_OAUTH_VALIDATION, + include_str!("../../store/sqlite/sql/write_oauth_validation.sql"), + ), + ( + READ_VALIDATION_TYPE, + include_str!("../../store/sqlite/sql/read_validation_type.sql"), + ), ] .iter() .cloned() @@ -166,6 +205,26 @@ lazy_static! { FIND_EMAIL_VALIDATION, include_str!("../../store/pg/sql/find_email_validation.sql"), ), + ( + WRITE_OAUTH_PROVIDER, + include_str!("../../store/pg/sql/write_oauth_provider.sql"), + ), + ( + READ_OAUTH_PROVIDER, + include_str!("../../store/pg/sql/read_oauth_provider.sql"), + ), + ( + READ_OAUTH_VALIDATION, + include_str!("../../store/pg/sql/read_oauth_validation.sql"), + ), + ( + WRITE_OAUTH_VALIDATION, + include_str!("../../store/pg/sql/write_oauth_validation.sql"), + ), + ( + READ_VALIDATION_TYPE, + include_str!("../../store/pg/sql/read_validation_type.sql"), + ), ] .iter() .cloned() @@ -180,9 +239,143 @@ lazy_static! { }; } +impl<'a, R: Row> FromRow<'a, R> for OauthValidation +where + &'a str: ColumnIndex<R>, + OauthProviderName: Decode<'a, R::Database> + Type<R::Database>, + OauthResponseType: Decode<'a, R::Database> + Type<R::Database>, + OffsetDateTime: Decode<'a, R::Database> + Type<R::Database>, + String: Decode<'a, R::Database> + Type<R::Database>, + Uuid: Decode<'a, R::Database> + Type<R::Database>, +{ + fn from_row(row: &'a R) -> Result<Self, sqlx::Error> { + let id: Option<Uuid> = row.try_get("oauth_validation_public_id")?; + let identity_id: Option<Uuid> = row.try_get("identity_public_id")?; + let access_token: Option<String> = row.try_get("access_token")?; + let raw_response: Option<String> = row.try_get("raw_response")?; + let created_at: Option<OffsetDateTime> = row.try_get("created_at")?; + let validated_at: Option<OffsetDateTime> = row.try_get("validated_at")?; + let revoked_at: Option<OffsetDateTime> = row.try_get("revoked_at")?; + let deleted_at: Option<OffsetDateTime> = row.try_get("deleted_at")?; + + let op_name: Option<OauthProviderName> = row.try_get("oauth_provider_name")?; + let op_flow: Option<String> = row.try_get("oauth_provider_flow")?; + let op_base_url: Option<String> = row.try_get("oauth_provider_base_url")?; + let op_response_type: Option<OauthResponseType> = + row.try_get("oauth_provider_response_type")?; + let op_default_scope: Option<String> = row.try_get("oauth_provider_default_scope")?; + let op_client_id: Option<String> = row.try_get("oauth_provider_client_id")?; + let op_client_secret: Option<String> = row.try_get("oauth_provider_client_secret")?; + let op_redirect_url: Option<String> = row.try_get("oauth_provider_redirect_url")?; + let op_created_at: Option<OffsetDateTime> = row.try_get("oauth_provider_created_at")?; + let op_deleted_at: Option<OffsetDateTime> = row.try_get("oauth_provider_deleted_at")?; + + let op_base_url = op_base_url + .map(|s| Url::from_str(&s).ok()) + .flatten() + .ok_or(sqlx::Error::ColumnDecode { + index: "oauth_provider_base_url".into(), + source: "secd".into(), + })?; + + let op_redirect_url = op_redirect_url + .map(|s| Url::from_str(&s).ok()) + .flatten() + .ok_or(sqlx::Error::ColumnDecode { + index: "oauth_provider_redirect_url".into(), + source: "secd".into(), + })?; + + Ok(OauthValidation { + id, + identity_id, + access_token, + raw_response, + created_at: created_at.ok_or(sqlx::Error::ColumnDecode { + index: "created_at".into(), + source: "secd".into(), + })?, + validated_at, + revoked_at, + deleted_at, + oauth_provider: OauthProvider { + name: op_name.unwrap(), + flow: op_flow, + base_url: op_base_url, + response: op_response_type.ok_or(sqlx::Error::ColumnDecode { + index: "oauth_provider_response_type".into(), + source: "secd".into(), + })?, + default_scope: op_default_scope.ok_or(sqlx::Error::ColumnDecode { + index: "oauth_provider_default_scope".into(), + source: "secd".into(), + })?, + client_id: op_client_id.ok_or(sqlx::Error::ColumnDecode { + index: "oauth_provider_client_id".into(), + source: "secd".into(), + })?, + client_secret: op_client_secret.ok_or(sqlx::Error::ColumnDecode { + index: "oauth_provider_client_secret".into(), + source: "secd".into(), + })?, + redirect_url: op_redirect_url, + created_at: op_created_at.ok_or(sqlx::Error::ColumnDecode { + index: "oauth_provider_created_at".into(), + source: "secd".into(), + })?, + deleted_at: op_deleted_at, + }, + }) + } +} + +impl<'a, D: Database> Decode<'a, D> for OauthProviderName +where + &'a str: Decode<'a, D>, +{ + fn decode( + value: <D as HasValueRef<'a>>::ValueRef, + ) -> Result<Self, Box<dyn ::std::error::Error + 'static + Send + Sync>> { + let v = <&str as Decode<D>>::decode(value)?; + <OauthProviderName as clap::ValueEnum>::from_str(v, true) + .map_err(|_| "OauthProviderName should exist and decode to a program value.".into()) + } +} + +impl<D: Database> Type<D> for OauthProviderName +where + str: Type<D>, +{ + fn type_info() -> D::TypeInfo { + <&str as Type<D>>::type_info() + } +} + +impl<'a, D: Database> Decode<'a, D> for OauthResponseType +where + &'a str: Decode<'a, D>, +{ + fn decode( + value: <D as HasValueRef<'a>>::ValueRef, + ) -> Result<Self, Box<dyn ::std::error::Error + 'static + Send + Sync>> { + let v = <&str as Decode<D>>::decode(value)?; + <OauthResponseType as clap::ValueEnum>::from_str(v, true) + .map_err(|_| "OauthResponseType should exist and decode to a program value.".into()) + } +} + +impl<D: Database> Type<D> for OauthResponseType +where + str: Type<D>, +{ + fn type_info() -> D::TypeInfo { + <&str as Type<D>>::type_info() + } +} + #[async_trait::async_trait] pub trait Store { - async fn write_email(&self, identity_id: Uuid, email_address: &str) -> Result<(), StoreError>; + async fn write_email(&self, email_address: &str) -> Result<(), StoreError>; async fn find_email_validation( &self, @@ -193,17 +386,37 @@ pub trait Store { &self, ev: &EmailValidation, // TODO: Make this write an EmailValidation - ) -> Result<Uuid, StoreError>; + ) -> anyhow::Result<Uuid>; async fn find_identity( &self, identity_id: Option<&Uuid>, email: Option<&str>, - ) -> Result<Option<Identity>, StoreError>; + ) -> anyhow::Result<Option<Identity>>; async fn find_identity_by_code(&self, code: &str) -> Result<Identity, StoreError>; async fn write_identity(&self, i: &Identity) -> Result<(), StoreError>; async fn read_identity(&self, identity_id: &Uuid) -> Result<Identity, StoreError>; async fn write_session(&self, session: &Session) -> Result<(), StoreError>; async fn read_session(&self, secret: &SessionSecret) -> Result<Session, StoreError>; + + async fn write_oauth_provider(&self, provider: &OauthProvider) -> Result<(), StoreError>; + async fn read_oauth_provider( + &self, + provider: &OauthProviderName, + flow: Option<String>, + ) -> Result<OauthProvider, StoreError>; + async fn write_oauth_validation( + &self, + validation: &OauthValidation, + ) -> anyhow::Result<ValidationRequestId>; + async fn read_oauth_validation( + &self, + validation_id: &ValidationRequestId, + ) -> anyhow::Result<OauthValidation>; + + async fn find_validation_type( + &self, + validation_id: &ValidationRequestId, + ) -> anyhow::Result<ValidationType>; } |
