diff options
| author | benj <benj@rse8.com> | 2022-12-12 17:06:57 -0800 |
|---|---|---|
| committer | benj <benj@rse8.com> | 2022-12-12 17:06:57 -0800 |
| commit | 0920c4d4f30a3345870d385d5c6f3e0919228b56 (patch) | |
| tree | f54668d91db469b7304758893a51b590c8f9b0de /crates/secd/src/client | |
| parent | 3a4de13528fc85dcbe6bc9055d97ba5cc87f5712 (diff) | |
| download | secdiam-0920c4d4f30a3345870d385d5c6f3e0919228b56.tar secdiam-0920c4d4f30a3345870d385d5c6f3e0919228b56.tar.gz secdiam-0920c4d4f30a3345870d385d5c6f3e0919228b56.tar.bz2 secdiam-0920c4d4f30a3345870d385d5c6f3e0919228b56.tar.lz secdiam-0920c4d4f30a3345870d385d5c6f3e0919228b56.tar.xz secdiam-0920c4d4f30a3345870d385d5c6f3e0919228b56.tar.zst secdiam-0920c4d4f30a3345870d385d5c6f3e0919228b56.zip | |
(oauth2 + email added): a mess that may or may not really work and needs to be refactored...
Diffstat (limited to '')
| -rw-r--r-- | crates/secd/src/client/mod.rs | 233 | ||||
| -rw-r--r-- | crates/secd/src/client/sqldb.rs | 324 | ||||
| -rw-r--r-- | crates/secd/src/client/types.rs | 3 |
3 files changed, 492 insertions, 68 deletions
diff --git a/crates/secd/src/client/mod.rs b/crates/secd/src/client/mod.rs index 3925657..38426ef 100644 --- a/crates/secd/src/client/mod.rs +++ b/crates/secd/src/client/mod.rs @@ -1,13 +1,24 @@ -pub mod email; -pub mod sqldb; +pub(crate) mod email; +pub(crate) mod sqldb; +pub(crate) mod types; -use std::collections::HashMap; +use std::{collections::HashMap, str::FromStr}; use super::Identity; -use crate::{EmailValidation, Session, SessionSecret}; +use crate::{ + EmailValidation, OauthProvider, OauthProviderName, OauthResponseType, OauthValidation, Session, + SessionSecret, ValidationRequestId, ValidationType, +}; +use email_address::EmailAddress; use lazy_static::lazy_static; +use sqlx::{ + database::HasValueRef, sqlite::SqliteRow, ColumnIndex, Database, Decode, FromRow, Row, Sqlite, + Type, +}; use thiserror::Error; +use time::OffsetDateTime; +use url::Url; use uuid::Uuid; pub enum EmailType { @@ -36,13 +47,15 @@ pub trait EmailMessenger { #[derive(Error, Debug, derive_more::Display)] pub enum StoreError { SqlxError(#[from] sqlx::Error), - EmailAlreadyExists, CodeAppearsMoreThanOnce, CodeDoesNotExist(String), IdentityIdMustExistInvariant, - TooManyEmailValidations, + TooManyValidations, + TooManyIdentitiesFound, NoEmailValidationFound, - Unknown, + OauthProviderDoesNotExist(OauthProviderName), + OauthValidationDoesNotExist(ValidationRequestId), + Other(String), } const EMAIL_TEMPLATE_DEFAULT_LOGIN: &str = "You requested a login link. Please click the following link %secd_code% to login as %secd_email_address%"; @@ -56,6 +69,7 @@ const PGSQL: &str = "pgsql"; const WRITE_IDENTITY: &str = "write_identity"; const WRITE_EMAIL_VALIDATION: &str = "write_email_validation"; const FIND_EMAIL_VALIDATION: &str = "find_email_validation"; +const READ_VALIDATION_TYPE: &str = "read_validation_type"; const WRITE_EMAIL: &str = "write_email"; @@ -69,6 +83,11 @@ const READ_EMAIL_RAW_ID: &str = "read_email_raw_id"; const WRITE_SESSION: &str = "write_session"; const READ_SESSION: &str = "read_session"; +const WRITE_OAUTH_PROVIDER: &str = "write_oauth_provider"; +const READ_OAUTH_PROVIDER: &str = "read_oauth_provider"; +const WRITE_OAUTH_VALIDATION: &str = "write_oauth_validation"; +const READ_OAUTH_VALIDATION: &str = "read_oauth_validation"; + lazy_static! { static ref SQLS: HashMap<&'static str, HashMap<&'static str, &'static str>> = { let sqlite_sqls: HashMap<&'static str, &'static str> = [ @@ -116,6 +135,26 @@ lazy_static! { FIND_EMAIL_VALIDATION, include_str!("../../store/sqlite/sql/find_email_validation.sql"), ), + ( + WRITE_OAUTH_PROVIDER, + include_str!("../../store/sqlite/sql/write_oauth_provider.sql"), + ), + ( + READ_OAUTH_PROVIDER, + include_str!("../../store/sqlite/sql/read_oauth_provider.sql"), + ), + ( + READ_OAUTH_VALIDATION, + include_str!("../../store/sqlite/sql/read_oauth_validation.sql"), + ), + ( + WRITE_OAUTH_VALIDATION, + include_str!("../../store/sqlite/sql/write_oauth_validation.sql"), + ), + ( + READ_VALIDATION_TYPE, + include_str!("../../store/sqlite/sql/read_validation_type.sql"), + ), ] .iter() .cloned() @@ -166,6 +205,26 @@ lazy_static! { FIND_EMAIL_VALIDATION, include_str!("../../store/pg/sql/find_email_validation.sql"), ), + ( + WRITE_OAUTH_PROVIDER, + include_str!("../../store/pg/sql/write_oauth_provider.sql"), + ), + ( + READ_OAUTH_PROVIDER, + include_str!("../../store/pg/sql/read_oauth_provider.sql"), + ), + ( + READ_OAUTH_VALIDATION, + include_str!("../../store/pg/sql/read_oauth_validation.sql"), + ), + ( + WRITE_OAUTH_VALIDATION, + include_str!("../../store/pg/sql/write_oauth_validation.sql"), + ), + ( + READ_VALIDATION_TYPE, + include_str!("../../store/pg/sql/read_validation_type.sql"), + ), ] .iter() .cloned() @@ -180,9 +239,143 @@ lazy_static! { }; } +impl<'a, R: Row> FromRow<'a, R> for OauthValidation +where + &'a str: ColumnIndex<R>, + OauthProviderName: Decode<'a, R::Database> + Type<R::Database>, + OauthResponseType: Decode<'a, R::Database> + Type<R::Database>, + OffsetDateTime: Decode<'a, R::Database> + Type<R::Database>, + String: Decode<'a, R::Database> + Type<R::Database>, + Uuid: Decode<'a, R::Database> + Type<R::Database>, +{ + fn from_row(row: &'a R) -> Result<Self, sqlx::Error> { + let id: Option<Uuid> = row.try_get("oauth_validation_public_id")?; + let identity_id: Option<Uuid> = row.try_get("identity_public_id")?; + let access_token: Option<String> = row.try_get("access_token")?; + let raw_response: Option<String> = row.try_get("raw_response")?; + let created_at: Option<OffsetDateTime> = row.try_get("created_at")?; + let validated_at: Option<OffsetDateTime> = row.try_get("validated_at")?; + let revoked_at: Option<OffsetDateTime> = row.try_get("revoked_at")?; + let deleted_at: Option<OffsetDateTime> = row.try_get("deleted_at")?; + + let op_name: Option<OauthProviderName> = row.try_get("oauth_provider_name")?; + let op_flow: Option<String> = row.try_get("oauth_provider_flow")?; + let op_base_url: Option<String> = row.try_get("oauth_provider_base_url")?; + let op_response_type: Option<OauthResponseType> = + row.try_get("oauth_provider_response_type")?; + let op_default_scope: Option<String> = row.try_get("oauth_provider_default_scope")?; + let op_client_id: Option<String> = row.try_get("oauth_provider_client_id")?; + let op_client_secret: Option<String> = row.try_get("oauth_provider_client_secret")?; + let op_redirect_url: Option<String> = row.try_get("oauth_provider_redirect_url")?; + let op_created_at: Option<OffsetDateTime> = row.try_get("oauth_provider_created_at")?; + let op_deleted_at: Option<OffsetDateTime> = row.try_get("oauth_provider_deleted_at")?; + + let op_base_url = op_base_url + .map(|s| Url::from_str(&s).ok()) + .flatten() + .ok_or(sqlx::Error::ColumnDecode { + index: "oauth_provider_base_url".into(), + source: "secd".into(), + })?; + + let op_redirect_url = op_redirect_url + .map(|s| Url::from_str(&s).ok()) + .flatten() + .ok_or(sqlx::Error::ColumnDecode { + index: "oauth_provider_redirect_url".into(), + source: "secd".into(), + })?; + + Ok(OauthValidation { + id, + identity_id, + access_token, + raw_response, + created_at: created_at.ok_or(sqlx::Error::ColumnDecode { + index: "created_at".into(), + source: "secd".into(), + })?, + validated_at, + revoked_at, + deleted_at, + oauth_provider: OauthProvider { + name: op_name.unwrap(), + flow: op_flow, + base_url: op_base_url, + response: op_response_type.ok_or(sqlx::Error::ColumnDecode { + index: "oauth_provider_response_type".into(), + source: "secd".into(), + })?, + default_scope: op_default_scope.ok_or(sqlx::Error::ColumnDecode { + index: "oauth_provider_default_scope".into(), + source: "secd".into(), + })?, + client_id: op_client_id.ok_or(sqlx::Error::ColumnDecode { + index: "oauth_provider_client_id".into(), + source: "secd".into(), + })?, + client_secret: op_client_secret.ok_or(sqlx::Error::ColumnDecode { + index: "oauth_provider_client_secret".into(), + source: "secd".into(), + })?, + redirect_url: op_redirect_url, + created_at: op_created_at.ok_or(sqlx::Error::ColumnDecode { + index: "oauth_provider_created_at".into(), + source: "secd".into(), + })?, + deleted_at: op_deleted_at, + }, + }) + } +} + +impl<'a, D: Database> Decode<'a, D> for OauthProviderName +where + &'a str: Decode<'a, D>, +{ + fn decode( + value: <D as HasValueRef<'a>>::ValueRef, + ) -> Result<Self, Box<dyn ::std::error::Error + 'static + Send + Sync>> { + let v = <&str as Decode<D>>::decode(value)?; + <OauthProviderName as clap::ValueEnum>::from_str(v, true) + .map_err(|_| "OauthProviderName should exist and decode to a program value.".into()) + } +} + +impl<D: Database> Type<D> for OauthProviderName +where + str: Type<D>, +{ + fn type_info() -> D::TypeInfo { + <&str as Type<D>>::type_info() + } +} + +impl<'a, D: Database> Decode<'a, D> for OauthResponseType +where + &'a str: Decode<'a, D>, +{ + fn decode( + value: <D as HasValueRef<'a>>::ValueRef, + ) -> Result<Self, Box<dyn ::std::error::Error + 'static + Send + Sync>> { + let v = <&str as Decode<D>>::decode(value)?; + <OauthResponseType as clap::ValueEnum>::from_str(v, true) + .map_err(|_| "OauthResponseType should exist and decode to a program value.".into()) + } +} + +impl<D: Database> Type<D> for OauthResponseType +where + str: Type<D>, +{ + fn type_info() -> D::TypeInfo { + <&str as Type<D>>::type_info() + } +} + #[async_trait::async_trait] pub trait Store { - async fn write_email(&self, identity_id: Uuid, email_address: &str) -> Result<(), StoreError>; + async fn write_email(&self, email_address: &str) -> Result<(), StoreError>; async fn find_email_validation( &self, @@ -193,17 +386,37 @@ pub trait Store { &self, ev: &EmailValidation, // TODO: Make this write an EmailValidation - ) -> Result<Uuid, StoreError>; + ) -> anyhow::Result<Uuid>; async fn find_identity( &self, identity_id: Option<&Uuid>, email: Option<&str>, - ) -> Result<Option<Identity>, StoreError>; + ) -> anyhow::Result<Option<Identity>>; async fn find_identity_by_code(&self, code: &str) -> Result<Identity, StoreError>; async fn write_identity(&self, i: &Identity) -> Result<(), StoreError>; async fn read_identity(&self, identity_id: &Uuid) -> Result<Identity, StoreError>; async fn write_session(&self, session: &Session) -> Result<(), StoreError>; async fn read_session(&self, secret: &SessionSecret) -> Result<Session, StoreError>; + + async fn write_oauth_provider(&self, provider: &OauthProvider) -> Result<(), StoreError>; + async fn read_oauth_provider( + &self, + provider: &OauthProviderName, + flow: Option<String>, + ) -> Result<OauthProvider, StoreError>; + async fn write_oauth_validation( + &self, + validation: &OauthValidation, + ) -> anyhow::Result<ValidationRequestId>; + async fn read_oauth_validation( + &self, + validation_id: &ValidationRequestId, + ) -> anyhow::Result<OauthValidation>; + + async fn find_validation_type( + &self, + validation_id: &ValidationRequestId, + ) -> anyhow::Result<ValidationType>; } diff --git a/crates/secd/src/client/sqldb.rs b/crates/secd/src/client/sqldb.rs index 6048c48..15cc4b5 100644 --- a/crates/secd/src/client/sqldb.rs +++ b/crates/secd/src/client/sqldb.rs @@ -1,19 +1,23 @@ -use std::sync::Arc; +use std::{str::FromStr, sync::Arc}; use super::{ - EmailValidation, Identity, Session, SessionSecret, Store, StoreError, ERR_MSG_MIGRATION_FAILED, - FIND_EMAIL_VALIDATION, FIND_IDENTITY, FIND_IDENTITY_BY_CODE, PGSQL, READ_EMAIL_RAW_ID, - READ_IDENTITY_RAW_ID, READ_SESSION, SQLITE, SQLS, WRITE_EMAIL, WRITE_EMAIL_VALIDATION, - WRITE_IDENTITY, WRITE_SESSION, + EmailValidation, Identity, OauthProvider, OauthProviderName, OauthResponseType, Session, + SessionSecret, Store, StoreError, ERR_MSG_MIGRATION_FAILED, FIND_EMAIL_VALIDATION, + FIND_IDENTITY, FIND_IDENTITY_BY_CODE, PGSQL, READ_EMAIL_RAW_ID, READ_IDENTITY_RAW_ID, + READ_OAUTH_PROVIDER, READ_OAUTH_VALIDATION, READ_SESSION, READ_VALIDATION_TYPE, SQLITE, SQLS, + WRITE_EMAIL, WRITE_EMAIL_VALIDATION, WRITE_IDENTITY, WRITE_OAUTH_PROVIDER, + WRITE_OAUTH_VALIDATION, WRITE_SESSION, }; -use crate::util; -use log::error; +use crate::{util, OauthValidation, ValidationRequestId, ValidationType}; +use anyhow::bail; +use log::{debug, error}; use openssl::sha::Sha256; use sqlx::{ self, database::HasArguments, ColumnIndex, Database, Decode, Encode, Executor, IntoArguments, Pool, Postgres, Sqlite, Transaction, Type, }; use time::OffsetDateTime; +use url::Url; use uuid::Uuid; fn get_sqls(root: &str, file: &str) -> Vec<String> { @@ -97,6 +101,8 @@ where for<'c> String: Encode<'c, D> + Type<D>, for<'c> Option<String>: Decode<'c, D> + Type<D>, for<'c> Option<String>: Encode<'c, D> + Type<D>, + for<'c> OauthProviderName: Decode<'c, D> + Type<D>, + for<'c> OauthResponseType: Decode<'c, D> + Type<D>, for<'c> usize: ColumnIndex<<D as Database>::Row>, for<'c> Uuid: Decode<'c, D> + Type<D>, for<'c> Uuid: Encode<'c, D> + Type<D>, @@ -108,29 +114,11 @@ where for<'c> &'c Pool<D>: Executor<'c, Database = D>, for<'c> &'c mut Transaction<'c, D>: Executor<'c, Database = D>, { - async fn write_email(&self, identity_id: Uuid, email_address: &str) -> Result<(), StoreError> { + async fn write_email(&self, email_address: &str) -> Result<(), StoreError> { let sqls = get_sqls(&self.sqls_root, WRITE_EMAIL); - let identity_id = self.read_identity_raw_id(&identity_id).await?; - - let email_id: (i64,) = match sqlx::query_as(&sqls[0]) + sqlx::query(&sqls[0]) .bind(email_address) - .fetch_one(&self.pool) - .await - { - Ok(i) => i, - Err(sqlx::Error::RowNotFound) => sqlx::query_as::<_, (i64,)>(&sqls[1]) - .bind(email_address) - .fetch_one(&self.pool) - .await - .map_err(util::log_err_sqlx)?, - Err(e) => return Err(StoreError::SqlxError(e)), - }; - - sqlx::query(&sqls[2]) - .bind(identity_id) - .bind(email_id.0) - .bind(OffsetDateTime::now_utc()) .execute(&self.pool) .await .map_err(util::log_err_sqlx)?; @@ -154,57 +142,84 @@ where match rows.len() { 0 => Err(StoreError::NoEmailValidationFound), 1 => Ok(rows.swap_remove(0)), - _ => Err(StoreError::TooManyEmailValidations), + _ => Err(StoreError::TooManyValidations), } } - async fn write_email_validation(&self, ev: &EmailValidation) -> Result<Uuid, StoreError> { + async fn write_email_validation(&self, ev: &EmailValidation) -> anyhow::Result<Uuid> { let sqls = get_sqls(&self.sqls_root, WRITE_EMAIL_VALIDATION); - let identity_id = self - .read_identity_raw_id( - &ev.identity_id - .ok_or(StoreError::IdentityIdMustExistInvariant)?, - ) - .await?; let email_id = self.read_email_raw_id(&ev.email_address).await?; - - let new_id = Uuid::new_v4(); + let validation_id = ev.id.unwrap_or(Uuid::new_v4()); sqlx::query(&sqls[0]) - .bind(ev.id.unwrap_or(new_id)) - .bind(identity_id) + .bind(validation_id) .bind(email_id) - .bind(ev.attempts) .bind(&ev.code) - .bind(ev.is_validated) + .bind(ev.is_oauth_derived) .bind(ev.created_at) - .bind(ev.expires_at) + .bind(ev.validated_at) + .bind(ev.expired_at) .execute(&self.pool) .await .map_err(util::log_err_sqlx)?; - Ok(new_id) + if ev.identity_id.is_some() || ev.revoked_at.is_some() || ev.deleted_at.is_some() { + sqlx::query(&sqls[1]) + .bind(ev.identity_id.as_ref()) + .bind(validation_id) + .bind(ev.revoked_at) + .bind(ev.deleted_at) + .execute(&self.pool) + .await + .map_err(util::log_err_sqlx)?; + } + + Ok(validation_id) } async fn find_identity( &self, id: Option<&Uuid>, email: Option<&str>, - ) -> Result<Option<Identity>, StoreError> { + ) -> anyhow::Result<Option<Identity>> { let sqls = get_sqls(&self.sqls_root, FIND_IDENTITY); Ok( match sqlx::query_as::<_, Identity>(&sqls[0]) .bind(id) .bind(email) - .fetch_one(&self.pool) + .fetch_all(&self.pool) .await { - Ok(i) => Some(i), + Ok(mut is) => match is.len() { + // if only 1 found, then that's fine + // if multiple are fond, then if they all have the same id, that's okay + 1 => { + let i = is.swap_remove(0); + match i.deleted_at { + Some(t) if t > OffsetDateTime::now_utc() => Some(i), + None => Some(i), + _ => None, + } + } + 0 => None, + _ => { + match is + .iter() + .filter(|&i| i.id != is[0].id) + .collect::<Vec<&Identity>>() + .len() + { + 0 => Some(is.swap_remove(0)), + _ => bail!(StoreError::TooManyIdentitiesFound), + } + } + }, Err(sqlx::Error::RowNotFound) => None, - Err(e) => return Err(StoreError::SqlxError(e)), + Err(e) => bail!(StoreError::SqlxError(e)), }, ) } + async fn find_identity_by_code(&self, code: &str) -> Result<Identity, StoreError> { let sqls = get_sqls(&self.sqls_root, FIND_IDENTITY_BY_CODE); @@ -250,14 +265,16 @@ where Ok(()) } async fn read_identity(&self, id: &Uuid) -> Result<Identity, StoreError> { - Ok(sqlx::query_as::<_, Identity>( + let identity = sqlx::query_as::<_, Identity>( " select identity_public_id, data, created_at from identity where identity_public_id = ?", ) .bind(id) .fetch_one(&self.pool) .await - .map_err(util::log_err_sqlx)?) + .map_err(util::log_err_sqlx)?; + + Ok(identity) } async fn write_session(&self, session: &Session) -> Result<(), StoreError> { @@ -269,7 +286,6 @@ select identity_public_id, data, created_at from identity where identity_public_ .bind(&session.identity_id) .bind(secret_hash.as_ref()) .bind(session.created_at) - .bind(OffsetDateTime::now_utc()) .bind(session.expires_at) .bind(session.revoked_at) .execute(&self.pool) @@ -296,6 +312,142 @@ select identity_public_id, data, created_at from identity where identity_public_ Ok(session) } + + async fn write_oauth_provider(&self, provider: &OauthProvider) -> Result<(), StoreError> { + let sqls = get_sqls(&self.sqls_root, WRITE_OAUTH_PROVIDER); + sqlx::query(&sqls[0]) + .bind(&provider.name.to_string()) + .bind(&provider.flow) + .bind(&provider.base_url.to_string()) + .bind(&provider.response.to_string()) + .bind(&provider.default_scope) + .bind(&provider.client_id) + // TODO: encrypt secret before writing + .bind(&provider.client_secret) + .bind(&provider.redirect_url.to_string()) + .bind(provider.created_at) + .bind(provider.deleted_at) + .execute(&self.pool) + .await + .map_err(util::log_err_sqlx)?; + Ok(()) + } + + async fn read_oauth_provider( + &self, + provider: &OauthProviderName, + flow: Option<String>, + ) -> Result<OauthProvider, StoreError> { + let sqls = get_sqls(&self.sqls_root, READ_OAUTH_PROVIDER); + let flow = flow.unwrap_or("default".into()); + debug!("provider: {:?}, flow: {:?}", provider, flow); + // TODO: Write the generic FromRow impl for OauthProvider... + let res = sqlx::query_as::< + _, + ( + String, + String, + String, + String, + String, + String, + String, + OffsetDateTime, + Option<OffsetDateTime>, + ), + >(&sqls[0]) + .bind(&provider.to_string()) + .bind(&flow) + .fetch_one(&self.pool) + .await + .map_err(util::log_err_sqlx)?; + + debug!("res: {:?}", res); + + Ok(OauthProvider { + name: provider.clone(), + flow: Some(res.0), + base_url: Url::from_str(&res.1) + .map_err(|_| StoreError::OauthProviderDoesNotExist(*provider))?, + response: OauthResponseType::from_str(&res.2) + .map_err(|_| StoreError::OauthProviderDoesNotExist(*provider))?, + default_scope: res.3, + client_id: res.4, + client_secret: res.5, + redirect_url: Url::from_str(&res.6) + .map_err(|_| StoreError::OauthProviderDoesNotExist(*provider))?, + created_at: res.7, + deleted_at: res.8, + }) + } + async fn write_oauth_validation( + &self, + v: &OauthValidation, + ) -> anyhow::Result<ValidationRequestId> { + let sqls = get_sqls(&self.sqls_root, WRITE_OAUTH_VALIDATION); + + let validation_id = v.id.unwrap_or(Uuid::new_v4()); + sqlx::query(&sqls[0]) + .bind(validation_id) + .bind(v.oauth_provider.name.to_string()) + .bind(v.oauth_provider.flow.clone()) + .bind(v.access_token.clone()) + .bind(v.raw_response.clone()) + .bind(v.created_at) + .bind(v.validated_at) + .execute(&self.pool) + .await?; + + if v.identity_id.is_some() || v.revoked_at.is_some() || v.deleted_at.is_some() { + sqlx::query(&sqls[1]) + .bind(v.identity_id.as_ref()) + .bind(validation_id) + .bind(v.revoked_at) + .bind(v.deleted_at) + .execute(&self.pool) + .await?; + } + + Ok(validation_id) + } + async fn read_oauth_validation( + &self, + validation_id: &ValidationRequestId, + ) -> anyhow::Result<OauthValidation> { + let sqls = get_sqls(&self.sqls_root, READ_OAUTH_VALIDATION); + + let mut es = sqlx::query_as::<_, OauthValidation>(&sqls[0]) + .bind(validation_id) + .fetch_all(&self.pool) + .await?; + + if es.len() != 1 { + bail!(StoreError::OauthValidationDoesNotExist( + validation_id.clone() + )); + } + + Ok(es.swap_remove(0)) + } + async fn find_validation_type( + &self, + validation_id: &ValidationRequestId, + ) -> anyhow::Result<ValidationType> { + let sqls = get_sqls(&self.sqls_root, READ_VALIDATION_TYPE); + + let mut es = sqlx::query_as::<_, (String,)>(&sqls[0]) + .bind(validation_id) + .fetch_all(&self.pool) + .await + .map_err(util::log_err_sqlx)?; + + match es.len() { + 1 => Ok(ValidationType::from_str(&es.swap_remove(0).0)?), + _ => bail!(StoreError::Other( + "expected a single validation but recieved 0 or multiple validations".into() + )), + } + } } pub struct PgClient { @@ -320,8 +472,8 @@ impl PgClient { #[async_trait::async_trait] impl Store for PgClient { - async fn write_email(&self, identity_id: Uuid, email_address: &str) -> Result<(), StoreError> { - self.sql.write_email(identity_id, email_address).await + async fn write_email(&self, email_address: &str) -> Result<(), StoreError> { + self.sql.write_email(email_address).await } async fn find_email_validation( &self, @@ -330,14 +482,14 @@ impl Store for PgClient { ) -> Result<EmailValidation, StoreError> { self.sql.find_email_validation(validation_id, code).await } - async fn write_email_validation(&self, ev: &EmailValidation) -> Result<Uuid, StoreError> { + async fn write_email_validation(&self, ev: &EmailValidation) -> anyhow::Result<Uuid> { self.sql.write_email_validation(ev).await } async fn find_identity( &self, identity_id: Option<&Uuid>, email: Option<&str>, - ) -> Result<Option<Identity>, StoreError> { + ) -> anyhow::Result<Option<Identity>> { self.sql.find_identity(identity_id, email).await } async fn find_identity_by_code(&self, code: &str) -> Result<Identity, StoreError> { @@ -355,6 +507,34 @@ impl Store for PgClient { async fn read_session(&self, secret: &SessionSecret) -> Result<Session, StoreError> { self.sql.read_session(secret).await } + async fn write_oauth_provider(&self, provider: &OauthProvider) -> Result<(), StoreError> { + self.sql.write_oauth_provider(provider).await + } + async fn read_oauth_provider( + &self, + provider: &OauthProviderName, + flow: Option<String>, + ) -> Result<OauthProvider, StoreError> { + self.sql.read_oauth_provider(provider, flow).await + } + async fn write_oauth_validation( + &self, + validation: &OauthValidation, + ) -> anyhow::Result<ValidationRequestId> { + self.sql.write_oauth_validation(validation).await + } + async fn read_oauth_validation( + &self, + validation_id: &ValidationRequestId, + ) -> anyhow::Result<OauthValidation> { + self.sql.read_oauth_validation(validation_id).await + } + async fn find_validation_type( + &self, + validation_id: &ValidationRequestId, + ) -> anyhow::Result<ValidationType> { + self.sql.find_validation_type(validation_id).await + } } pub struct SqliteClient { @@ -386,8 +566,8 @@ impl SqliteClient { #[async_trait::async_trait] impl Store for SqliteClient { - async fn write_email(&self, identity_id: Uuid, email_address: &str) -> Result<(), StoreError> { - self.sql.write_email(identity_id, email_address).await + async fn write_email(&self, email_address: &str) -> Result<(), StoreError> { + self.sql.write_email(email_address).await } async fn find_email_validation( &self, @@ -396,14 +576,14 @@ impl Store for SqliteClient { ) -> Result<EmailValidation, StoreError> { self.sql.find_email_validation(validation_id, code).await } - async fn write_email_validation(&self, ev: &EmailValidation) -> Result<Uuid, StoreError> { + async fn write_email_validation(&self, ev: &EmailValidation) -> anyhow::Result<Uuid> { self.sql.write_email_validation(ev).await } async fn find_identity( &self, identity_id: Option<&Uuid>, email: Option<&str>, - ) -> Result<Option<Identity>, StoreError> { + ) -> anyhow::Result<Option<Identity>> { self.sql.find_identity(identity_id, email).await } async fn find_identity_by_code(&self, code: &str) -> Result<Identity, StoreError> { @@ -421,4 +601,32 @@ impl Store for SqliteClient { async fn read_session(&self, secret: &SessionSecret) -> Result<Session, StoreError> { self.sql.read_session(secret).await } + async fn write_oauth_provider(&self, provider: &OauthProvider) -> Result<(), StoreError> { + self.sql.write_oauth_provider(provider).await + } + async fn read_oauth_provider( + &self, + provider: &OauthProviderName, + flow: Option<String>, + ) -> Result<OauthProvider, StoreError> { + self.sql.read_oauth_provider(provider, flow).await + } + async fn write_oauth_validation( + &self, + validation: &OauthValidation, + ) -> anyhow::Result<ValidationRequestId> { + self.sql.write_oauth_validation(validation).await + } + async fn read_oauth_validation( + &self, + validation_id: &ValidationRequestId, + ) -> anyhow::Result<OauthValidation> { + self.sql.read_oauth_validation(validation_id).await + } + async fn find_validation_type( + &self, + validation_id: &ValidationRequestId, + ) -> anyhow::Result<ValidationType> { + self.sql.find_validation_type(validation_id).await + } } diff --git a/crates/secd/src/client/types.rs b/crates/secd/src/client/types.rs new file mode 100644 index 0000000..bacade4 --- /dev/null +++ b/crates/secd/src/client/types.rs @@ -0,0 +1,3 @@ +pub(crate) struct Email { + address: String, +} |
