diff options
Diffstat (limited to '')
| -rw-r--r-- | crates/secd/src/client/sqldb.rs | 324 |
1 files changed, 266 insertions, 58 deletions
diff --git a/crates/secd/src/client/sqldb.rs b/crates/secd/src/client/sqldb.rs index 6048c48..15cc4b5 100644 --- a/crates/secd/src/client/sqldb.rs +++ b/crates/secd/src/client/sqldb.rs @@ -1,19 +1,23 @@ -use std::sync::Arc; +use std::{str::FromStr, sync::Arc}; use super::{ - EmailValidation, Identity, Session, SessionSecret, Store, StoreError, ERR_MSG_MIGRATION_FAILED, - FIND_EMAIL_VALIDATION, FIND_IDENTITY, FIND_IDENTITY_BY_CODE, PGSQL, READ_EMAIL_RAW_ID, - READ_IDENTITY_RAW_ID, READ_SESSION, SQLITE, SQLS, WRITE_EMAIL, WRITE_EMAIL_VALIDATION, - WRITE_IDENTITY, WRITE_SESSION, + EmailValidation, Identity, OauthProvider, OauthProviderName, OauthResponseType, Session, + SessionSecret, Store, StoreError, ERR_MSG_MIGRATION_FAILED, FIND_EMAIL_VALIDATION, + FIND_IDENTITY, FIND_IDENTITY_BY_CODE, PGSQL, READ_EMAIL_RAW_ID, READ_IDENTITY_RAW_ID, + READ_OAUTH_PROVIDER, READ_OAUTH_VALIDATION, READ_SESSION, READ_VALIDATION_TYPE, SQLITE, SQLS, + WRITE_EMAIL, WRITE_EMAIL_VALIDATION, WRITE_IDENTITY, WRITE_OAUTH_PROVIDER, + WRITE_OAUTH_VALIDATION, WRITE_SESSION, }; -use crate::util; -use log::error; +use crate::{util, OauthValidation, ValidationRequestId, ValidationType}; +use anyhow::bail; +use log::{debug, error}; use openssl::sha::Sha256; use sqlx::{ self, database::HasArguments, ColumnIndex, Database, Decode, Encode, Executor, IntoArguments, Pool, Postgres, Sqlite, Transaction, Type, }; use time::OffsetDateTime; +use url::Url; use uuid::Uuid; fn get_sqls(root: &str, file: &str) -> Vec<String> { @@ -97,6 +101,8 @@ where for<'c> String: Encode<'c, D> + Type<D>, for<'c> Option<String>: Decode<'c, D> + Type<D>, for<'c> Option<String>: Encode<'c, D> + Type<D>, + for<'c> OauthProviderName: Decode<'c, D> + Type<D>, + for<'c> OauthResponseType: Decode<'c, D> + Type<D>, for<'c> usize: ColumnIndex<<D as Database>::Row>, for<'c> Uuid: Decode<'c, D> + Type<D>, for<'c> Uuid: Encode<'c, D> + Type<D>, @@ -108,29 +114,11 @@ where for<'c> &'c Pool<D>: Executor<'c, Database = D>, for<'c> &'c mut Transaction<'c, D>: Executor<'c, Database = D>, { - async fn write_email(&self, identity_id: Uuid, email_address: &str) -> Result<(), StoreError> { + async fn write_email(&self, email_address: &str) -> Result<(), StoreError> { let sqls = get_sqls(&self.sqls_root, WRITE_EMAIL); - let identity_id = self.read_identity_raw_id(&identity_id).await?; - - let email_id: (i64,) = match sqlx::query_as(&sqls[0]) + sqlx::query(&sqls[0]) .bind(email_address) - .fetch_one(&self.pool) - .await - { - Ok(i) => i, - Err(sqlx::Error::RowNotFound) => sqlx::query_as::<_, (i64,)>(&sqls[1]) - .bind(email_address) - .fetch_one(&self.pool) - .await - .map_err(util::log_err_sqlx)?, - Err(e) => return Err(StoreError::SqlxError(e)), - }; - - sqlx::query(&sqls[2]) - .bind(identity_id) - .bind(email_id.0) - .bind(OffsetDateTime::now_utc()) .execute(&self.pool) .await .map_err(util::log_err_sqlx)?; @@ -154,57 +142,84 @@ where match rows.len() { 0 => Err(StoreError::NoEmailValidationFound), 1 => Ok(rows.swap_remove(0)), - _ => Err(StoreError::TooManyEmailValidations), + _ => Err(StoreError::TooManyValidations), } } - async fn write_email_validation(&self, ev: &EmailValidation) -> Result<Uuid, StoreError> { + async fn write_email_validation(&self, ev: &EmailValidation) -> anyhow::Result<Uuid> { let sqls = get_sqls(&self.sqls_root, WRITE_EMAIL_VALIDATION); - let identity_id = self - .read_identity_raw_id( - &ev.identity_id - .ok_or(StoreError::IdentityIdMustExistInvariant)?, - ) - .await?; let email_id = self.read_email_raw_id(&ev.email_address).await?; - - let new_id = Uuid::new_v4(); + let validation_id = ev.id.unwrap_or(Uuid::new_v4()); sqlx::query(&sqls[0]) - .bind(ev.id.unwrap_or(new_id)) - .bind(identity_id) + .bind(validation_id) .bind(email_id) - .bind(ev.attempts) .bind(&ev.code) - .bind(ev.is_validated) + .bind(ev.is_oauth_derived) .bind(ev.created_at) - .bind(ev.expires_at) + .bind(ev.validated_at) + .bind(ev.expired_at) .execute(&self.pool) .await .map_err(util::log_err_sqlx)?; - Ok(new_id) + if ev.identity_id.is_some() || ev.revoked_at.is_some() || ev.deleted_at.is_some() { + sqlx::query(&sqls[1]) + .bind(ev.identity_id.as_ref()) + .bind(validation_id) + .bind(ev.revoked_at) + .bind(ev.deleted_at) + .execute(&self.pool) + .await + .map_err(util::log_err_sqlx)?; + } + + Ok(validation_id) } async fn find_identity( &self, id: Option<&Uuid>, email: Option<&str>, - ) -> Result<Option<Identity>, StoreError> { + ) -> anyhow::Result<Option<Identity>> { let sqls = get_sqls(&self.sqls_root, FIND_IDENTITY); Ok( match sqlx::query_as::<_, Identity>(&sqls[0]) .bind(id) .bind(email) - .fetch_one(&self.pool) + .fetch_all(&self.pool) .await { - Ok(i) => Some(i), + Ok(mut is) => match is.len() { + // if only 1 found, then that's fine + // if multiple are fond, then if they all have the same id, that's okay + 1 => { + let i = is.swap_remove(0); + match i.deleted_at { + Some(t) if t > OffsetDateTime::now_utc() => Some(i), + None => Some(i), + _ => None, + } + } + 0 => None, + _ => { + match is + .iter() + .filter(|&i| i.id != is[0].id) + .collect::<Vec<&Identity>>() + .len() + { + 0 => Some(is.swap_remove(0)), + _ => bail!(StoreError::TooManyIdentitiesFound), + } + } + }, Err(sqlx::Error::RowNotFound) => None, - Err(e) => return Err(StoreError::SqlxError(e)), + Err(e) => bail!(StoreError::SqlxError(e)), }, ) } + async fn find_identity_by_code(&self, code: &str) -> Result<Identity, StoreError> { let sqls = get_sqls(&self.sqls_root, FIND_IDENTITY_BY_CODE); @@ -250,14 +265,16 @@ where Ok(()) } async fn read_identity(&self, id: &Uuid) -> Result<Identity, StoreError> { - Ok(sqlx::query_as::<_, Identity>( + let identity = sqlx::query_as::<_, Identity>( " select identity_public_id, data, created_at from identity where identity_public_id = ?", ) .bind(id) .fetch_one(&self.pool) .await - .map_err(util::log_err_sqlx)?) + .map_err(util::log_err_sqlx)?; + + Ok(identity) } async fn write_session(&self, session: &Session) -> Result<(), StoreError> { @@ -269,7 +286,6 @@ select identity_public_id, data, created_at from identity where identity_public_ .bind(&session.identity_id) .bind(secret_hash.as_ref()) .bind(session.created_at) - .bind(OffsetDateTime::now_utc()) .bind(session.expires_at) .bind(session.revoked_at) .execute(&self.pool) @@ -296,6 +312,142 @@ select identity_public_id, data, created_at from identity where identity_public_ Ok(session) } + + async fn write_oauth_provider(&self, provider: &OauthProvider) -> Result<(), StoreError> { + let sqls = get_sqls(&self.sqls_root, WRITE_OAUTH_PROVIDER); + sqlx::query(&sqls[0]) + .bind(&provider.name.to_string()) + .bind(&provider.flow) + .bind(&provider.base_url.to_string()) + .bind(&provider.response.to_string()) + .bind(&provider.default_scope) + .bind(&provider.client_id) + // TODO: encrypt secret before writing + .bind(&provider.client_secret) + .bind(&provider.redirect_url.to_string()) + .bind(provider.created_at) + .bind(provider.deleted_at) + .execute(&self.pool) + .await + .map_err(util::log_err_sqlx)?; + Ok(()) + } + + async fn read_oauth_provider( + &self, + provider: &OauthProviderName, + flow: Option<String>, + ) -> Result<OauthProvider, StoreError> { + let sqls = get_sqls(&self.sqls_root, READ_OAUTH_PROVIDER); + let flow = flow.unwrap_or("default".into()); + debug!("provider: {:?}, flow: {:?}", provider, flow); + // TODO: Write the generic FromRow impl for OauthProvider... + let res = sqlx::query_as::< + _, + ( + String, + String, + String, + String, + String, + String, + String, + OffsetDateTime, + Option<OffsetDateTime>, + ), + >(&sqls[0]) + .bind(&provider.to_string()) + .bind(&flow) + .fetch_one(&self.pool) + .await + .map_err(util::log_err_sqlx)?; + + debug!("res: {:?}", res); + + Ok(OauthProvider { + name: provider.clone(), + flow: Some(res.0), + base_url: Url::from_str(&res.1) + .map_err(|_| StoreError::OauthProviderDoesNotExist(*provider))?, + response: OauthResponseType::from_str(&res.2) + .map_err(|_| StoreError::OauthProviderDoesNotExist(*provider))?, + default_scope: res.3, + client_id: res.4, + client_secret: res.5, + redirect_url: Url::from_str(&res.6) + .map_err(|_| StoreError::OauthProviderDoesNotExist(*provider))?, + created_at: res.7, + deleted_at: res.8, + }) + } + async fn write_oauth_validation( + &self, + v: &OauthValidation, + ) -> anyhow::Result<ValidationRequestId> { + let sqls = get_sqls(&self.sqls_root, WRITE_OAUTH_VALIDATION); + + let validation_id = v.id.unwrap_or(Uuid::new_v4()); + sqlx::query(&sqls[0]) + .bind(validation_id) + .bind(v.oauth_provider.name.to_string()) + .bind(v.oauth_provider.flow.clone()) + .bind(v.access_token.clone()) + .bind(v.raw_response.clone()) + .bind(v.created_at) + .bind(v.validated_at) + .execute(&self.pool) + .await?; + + if v.identity_id.is_some() || v.revoked_at.is_some() || v.deleted_at.is_some() { + sqlx::query(&sqls[1]) + .bind(v.identity_id.as_ref()) + .bind(validation_id) + .bind(v.revoked_at) + .bind(v.deleted_at) + .execute(&self.pool) + .await?; + } + + Ok(validation_id) + } + async fn read_oauth_validation( + &self, + validation_id: &ValidationRequestId, + ) -> anyhow::Result<OauthValidation> { + let sqls = get_sqls(&self.sqls_root, READ_OAUTH_VALIDATION); + + let mut es = sqlx::query_as::<_, OauthValidation>(&sqls[0]) + .bind(validation_id) + .fetch_all(&self.pool) + .await?; + + if es.len() != 1 { + bail!(StoreError::OauthValidationDoesNotExist( + validation_id.clone() + )); + } + + Ok(es.swap_remove(0)) + } + async fn find_validation_type( + &self, + validation_id: &ValidationRequestId, + ) -> anyhow::Result<ValidationType> { + let sqls = get_sqls(&self.sqls_root, READ_VALIDATION_TYPE); + + let mut es = sqlx::query_as::<_, (String,)>(&sqls[0]) + .bind(validation_id) + .fetch_all(&self.pool) + .await + .map_err(util::log_err_sqlx)?; + + match es.len() { + 1 => Ok(ValidationType::from_str(&es.swap_remove(0).0)?), + _ => bail!(StoreError::Other( + "expected a single validation but recieved 0 or multiple validations".into() + )), + } + } } pub struct PgClient { @@ -320,8 +472,8 @@ impl PgClient { #[async_trait::async_trait] impl Store for PgClient { - async fn write_email(&self, identity_id: Uuid, email_address: &str) -> Result<(), StoreError> { - self.sql.write_email(identity_id, email_address).await + async fn write_email(&self, email_address: &str) -> Result<(), StoreError> { + self.sql.write_email(email_address).await } async fn find_email_validation( &self, @@ -330,14 +482,14 @@ impl Store for PgClient { ) -> Result<EmailValidation, StoreError> { self.sql.find_email_validation(validation_id, code).await } - async fn write_email_validation(&self, ev: &EmailValidation) -> Result<Uuid, StoreError> { + async fn write_email_validation(&self, ev: &EmailValidation) -> anyhow::Result<Uuid> { self.sql.write_email_validation(ev).await } async fn find_identity( &self, identity_id: Option<&Uuid>, email: Option<&str>, - ) -> Result<Option<Identity>, StoreError> { + ) -> anyhow::Result<Option<Identity>> { self.sql.find_identity(identity_id, email).await } async fn find_identity_by_code(&self, code: &str) -> Result<Identity, StoreError> { @@ -355,6 +507,34 @@ impl Store for PgClient { async fn read_session(&self, secret: &SessionSecret) -> Result<Session, StoreError> { self.sql.read_session(secret).await } + async fn write_oauth_provider(&self, provider: &OauthProvider) -> Result<(), StoreError> { + self.sql.write_oauth_provider(provider).await + } + async fn read_oauth_provider( + &self, + provider: &OauthProviderName, + flow: Option<String>, + ) -> Result<OauthProvider, StoreError> { + self.sql.read_oauth_provider(provider, flow).await + } + async fn write_oauth_validation( + &self, + validation: &OauthValidation, + ) -> anyhow::Result<ValidationRequestId> { + self.sql.write_oauth_validation(validation).await + } + async fn read_oauth_validation( + &self, + validation_id: &ValidationRequestId, + ) -> anyhow::Result<OauthValidation> { + self.sql.read_oauth_validation(validation_id).await + } + async fn find_validation_type( + &self, + validation_id: &ValidationRequestId, + ) -> anyhow::Result<ValidationType> { + self.sql.find_validation_type(validation_id).await + } } pub struct SqliteClient { @@ -386,8 +566,8 @@ impl SqliteClient { #[async_trait::async_trait] impl Store for SqliteClient { - async fn write_email(&self, identity_id: Uuid, email_address: &str) -> Result<(), StoreError> { - self.sql.write_email(identity_id, email_address).await + async fn write_email(&self, email_address: &str) -> Result<(), StoreError> { + self.sql.write_email(email_address).await } async fn find_email_validation( &self, @@ -396,14 +576,14 @@ impl Store for SqliteClient { ) -> Result<EmailValidation, StoreError> { self.sql.find_email_validation(validation_id, code).await } - async fn write_email_validation(&self, ev: &EmailValidation) -> Result<Uuid, StoreError> { + async fn write_email_validation(&self, ev: &EmailValidation) -> anyhow::Result<Uuid> { self.sql.write_email_validation(ev).await } async fn find_identity( &self, identity_id: Option<&Uuid>, email: Option<&str>, - ) -> Result<Option<Identity>, StoreError> { + ) -> anyhow::Result<Option<Identity>> { self.sql.find_identity(identity_id, email).await } async fn find_identity_by_code(&self, code: &str) -> Result<Identity, StoreError> { @@ -421,4 +601,32 @@ impl Store for SqliteClient { async fn read_session(&self, secret: &SessionSecret) -> Result<Session, StoreError> { self.sql.read_session(secret).await } + async fn write_oauth_provider(&self, provider: &OauthProvider) -> Result<(), StoreError> { + self.sql.write_oauth_provider(provider).await + } + async fn read_oauth_provider( + &self, + provider: &OauthProviderName, + flow: Option<String>, + ) -> Result<OauthProvider, StoreError> { + self.sql.read_oauth_provider(provider, flow).await + } + async fn write_oauth_validation( + &self, + validation: &OauthValidation, + ) -> anyhow::Result<ValidationRequestId> { + self.sql.write_oauth_validation(validation).await + } + async fn read_oauth_validation( + &self, + validation_id: &ValidationRequestId, + ) -> anyhow::Result<OauthValidation> { + self.sql.read_oauth_validation(validation_id).await + } + async fn find_validation_type( + &self, + validation_id: &ValidationRequestId, + ) -> anyhow::Result<ValidationType> { + self.sql.find_validation_type(validation_id).await + } } |
