diff options
| author | benj <benj@rse8.com> | 2022-12-24 00:43:38 -0800 |
|---|---|---|
| committer | benj <benj@rse8.com> | 2022-12-24 00:43:38 -0800 |
| commit | c2268c285648ef02ece04de0d9df0813c6d70ff8 (patch) | |
| tree | f84ec7ee42f97d78245f26d0c5a0c559cd35e89d /crates/secd/src/client | |
| parent | de6339da72af1d61ca5908b780977e2b037ce014 (diff) | |
| download | secdiam-c2268c285648ef02ece04de0d9df0813c6d70ff8.tar secdiam-c2268c285648ef02ece04de0d9df0813c6d70ff8.tar.gz secdiam-c2268c285648ef02ece04de0d9df0813c6d70ff8.tar.bz2 secdiam-c2268c285648ef02ece04de0d9df0813c6d70ff8.tar.lz secdiam-c2268c285648ef02ece04de0d9df0813c6d70ff8.tar.xz secdiam-c2268c285648ef02ece04de0d9df0813c6d70ff8.tar.zst secdiam-c2268c285648ef02ece04de0d9df0813c6d70ff8.zip | |
refactor everything with more abstraction and a nicer interface
Diffstat (limited to 'crates/secd/src/client')
| -rw-r--r-- | crates/secd/src/client/email.rs | 67 | ||||
| -rw-r--r-- | crates/secd/src/client/email/mod.rs | 68 | ||||
| -rw-r--r-- | crates/secd/src/client/mod.rs | 422 | ||||
| -rw-r--r-- | crates/secd/src/client/sqldb.rs | 632 | ||||
| -rw-r--r-- | crates/secd/src/client/store/mod.rs | 190 | ||||
| -rw-r--r-- | crates/secd/src/client/store/sql_db.rs | 526 | ||||
| -rw-r--r-- | crates/secd/src/client/types.rs | 3 |
7 files changed, 785 insertions, 1123 deletions
diff --git a/crates/secd/src/client/email.rs b/crates/secd/src/client/email.rs deleted file mode 100644 index 2712037..0000000 --- a/crates/secd/src/client/email.rs +++ /dev/null @@ -1,67 +0,0 @@ -use std::{path::PathBuf, str::FromStr}; - -use email_address::EmailAddress; -use time::OffsetDateTime; - -use super::{ - EmailMessenger, EmailMessengerError, EmailType, EMAIL_TEMPLATE_DEFAULT_LOGIN, - EMAIL_TEMPLATE_DEFAULT_SIGNUP, -}; - -pub(crate) struct LocalEmailStubber { - pub(crate) email_template_login: Option<String>, - pub(crate) email_template_signup: Option<String>, -} - -#[async_trait::async_trait] -impl EmailMessenger for LocalEmailStubber { - // TODO: this module really shouldn't be called client, it should be called services... the client is sqlx/mailgun/sns wrapper or whatever... - async fn send_email( - &self, - email_address: &str, - validation_id: &str, - secret_code: &str, - t: EmailType, - ) -> Result<(), EmailMessengerError> { - let login_template = self - .email_template_login - .clone() - .unwrap_or(EMAIL_TEMPLATE_DEFAULT_LOGIN.to_string()); - let signup_template = self - .email_template_signup - .clone() - .unwrap_or(EMAIL_TEMPLATE_DEFAULT_SIGNUP.to_string()); - - let replace_template = |s: &str| { - s.replace( - "%secd_link%", - &format!("{}?code={}", validation_id, secret_code), - ) - .replace("%secd_email_address%", email_address) - .replace("%secd_code%", secret_code) - }; - - if !EmailAddress::is_valid(email_address) { - return Err(EmailMessengerError::InvalidEmailAddress); - } - - let body = match t { - EmailType::Login => replace_template(&login_template), - EmailType::Signup => replace_template(&signup_template), - }; - - // TODO: write to the system mailbox instead? - std::fs::write( - PathBuf::from_str(&format!( - "/tmp/{}_{}.localmail", - OffsetDateTime::now_utc(), - validation_id - )) - .map_err(|_| EmailMessengerError::Unknown)?, - body, - ) - .map_err(|_| EmailMessengerError::FailedToSendEmail)?; - - Ok(()) - } -} diff --git a/crates/secd/src/client/email/mod.rs b/crates/secd/src/client/email/mod.rs new file mode 100644 index 0000000..915d18c --- /dev/null +++ b/crates/secd/src/client/email/mod.rs @@ -0,0 +1,68 @@ +use email_address::EmailAddress; +use lettre::Transport; +use log::error; +use std::collections::HashMap; + +#[derive(Debug, thiserror::Error, derive_more::Display)] +pub enum EmailMessengerError { + FailedToSendEmail, +} + +pub struct EmailValidationMessage { + pub recipient: EmailAddress, + pub subject: String, + pub body: String, +} + +#[async_trait::async_trait] +pub(crate) trait EmailMessenger { + async fn send_email( + &self, + email_address: &EmailAddress, + template: &str, + template_vars: HashMap<&str, &str>, + ) -> Result<(), EmailMessengerError>; +} + +pub(crate) struct LocalMailer {} + +#[async_trait::async_trait] +impl EmailMessenger for LocalMailer { + async fn send_email( + &self, + email_address: &EmailAddress, + template: &str, + template_vars: HashMap<&str, &str>, + ) -> Result<(), EmailMessengerError> { + todo!() + } +} + +#[async_trait::async_trait] +pub(crate) trait Sendable { + async fn send(&self) -> Result<(), EmailMessengerError>; +} + +#[async_trait::async_trait] +impl Sendable for EmailValidationMessage { + // TODO: We need to break this up as before, especially so we can feature + // gate unwanted things like Lettre... + async fn send(&self) -> Result<(), EmailMessengerError> { + // TODO: Get these things from the template... + let email = lettre::Message::builder() + .from("BranchControl <iam@branchcontrol.com>".parse().unwrap()) + .reply_to("BranchControl <iam@branchcontrol.com>".parse().unwrap()) + .to(self.recipient.to_string().parse().unwrap()) + .subject(self.subject.clone()) + .body(self.body.clone()) + .unwrap(); + + let mailer = lettre::SmtpTransport::unencrypted_localhost(); + + mailer.send(&email).map_err(|e| { + error!("failed to send email {:?}", e); + EmailMessengerError::FailedToSendEmail + })?; + Ok(()) + } +} diff --git a/crates/secd/src/client/mod.rs b/crates/secd/src/client/mod.rs index 38426ef..e5272fd 100644 --- a/crates/secd/src/client/mod.rs +++ b/crates/secd/src/client/mod.rs @@ -1,422 +1,2 @@ pub(crate) mod email; -pub(crate) mod sqldb; -pub(crate) mod types; - -use std::{collections::HashMap, str::FromStr}; - -use super::Identity; -use crate::{ - EmailValidation, OauthProvider, OauthProviderName, OauthResponseType, OauthValidation, Session, - SessionSecret, ValidationRequestId, ValidationType, -}; - -use email_address::EmailAddress; -use lazy_static::lazy_static; -use sqlx::{ - database::HasValueRef, sqlite::SqliteRow, ColumnIndex, Database, Decode, FromRow, Row, Sqlite, - Type, -}; -use thiserror::Error; -use time::OffsetDateTime; -use url::Url; -use uuid::Uuid; - -pub enum EmailType { - Login, - Signup, -} - -#[derive(Error, Debug, derive_more::Display)] -pub enum EmailMessengerError { - InvalidEmailAddress, - FailedToSendEmail, - Unknown, -} - -#[async_trait::async_trait] -pub trait EmailMessenger { - async fn send_email( - &self, - email_address: &str, - validation_id: &str, - secret_code: &str, - t: EmailType, - ) -> Result<(), EmailMessengerError>; -} - -#[derive(Error, Debug, derive_more::Display)] -pub enum StoreError { - SqlxError(#[from] sqlx::Error), - CodeAppearsMoreThanOnce, - CodeDoesNotExist(String), - IdentityIdMustExistInvariant, - TooManyValidations, - TooManyIdentitiesFound, - NoEmailValidationFound, - OauthProviderDoesNotExist(OauthProviderName), - OauthValidationDoesNotExist(ValidationRequestId), - Other(String), -} - -const EMAIL_TEMPLATE_DEFAULT_LOGIN: &str = "You requested a login link. Please click the following link %secd_code% to login as %secd_email_address%"; -const EMAIL_TEMPLATE_DEFAULT_SIGNUP: &str = "You requested a sign up. Please click the following link %secd_code% to complete your sign up and validate %secd_email_address%"; - -const ERR_MSG_MIGRATION_FAILED: &str = "Failed to execute migrations. This appears to be a secd issue. File a bug at https://www.github.com/secd-lib"; - -const SQLITE: &str = "sqlite"; -const PGSQL: &str = "pgsql"; - -const WRITE_IDENTITY: &str = "write_identity"; -const WRITE_EMAIL_VALIDATION: &str = "write_email_validation"; -const FIND_EMAIL_VALIDATION: &str = "find_email_validation"; -const READ_VALIDATION_TYPE: &str = "read_validation_type"; - -const WRITE_EMAIL: &str = "write_email"; - -const READ_IDENTITY: &str = "read_identity"; -const FIND_IDENTITY: &str = "find_identity"; -const FIND_IDENTITY_BY_CODE: &str = "find_identity_by_code"; - -const READ_IDENTITY_RAW_ID: &str = "read_identity_raw_id"; -const READ_EMAIL_RAW_ID: &str = "read_email_raw_id"; - -const WRITE_SESSION: &str = "write_session"; -const READ_SESSION: &str = "read_session"; - -const WRITE_OAUTH_PROVIDER: &str = "write_oauth_provider"; -const READ_OAUTH_PROVIDER: &str = "read_oauth_provider"; -const WRITE_OAUTH_VALIDATION: &str = "write_oauth_validation"; -const READ_OAUTH_VALIDATION: &str = "read_oauth_validation"; - -lazy_static! { - static ref SQLS: HashMap<&'static str, HashMap<&'static str, &'static str>> = { - let sqlite_sqls: HashMap<&'static str, &'static str> = [ - ( - WRITE_IDENTITY, - include_str!("../../store/sqlite/sql/write_identity.sql"), - ), - ( - WRITE_EMAIL_VALIDATION, - include_str!("../../store/sqlite/sql/write_email_validation.sql"), - ), - ( - WRITE_EMAIL, - include_str!("../../store/sqlite/sql/write_email.sql"), - ), - ( - READ_IDENTITY, - include_str!("../../store/sqlite/sql/read_identity.sql"), - ), - ( - FIND_IDENTITY, - include_str!("../../store/sqlite/sql/find_identity.sql"), - ), - ( - FIND_IDENTITY_BY_CODE, - include_str!("../../store/sqlite/sql/find_identity_by_code.sql"), - ), - ( - READ_IDENTITY_RAW_ID, - include_str!("../../store/sqlite/sql/read_identity_raw_id.sql"), - ), - ( - READ_EMAIL_RAW_ID, - include_str!("../../store/sqlite/sql/read_email_raw_id.sql"), - ), - ( - WRITE_SESSION, - include_str!("../../store/sqlite/sql/write_session.sql"), - ), - ( - READ_SESSION, - include_str!("../../store/sqlite/sql/read_session.sql"), - ), - ( - FIND_EMAIL_VALIDATION, - include_str!("../../store/sqlite/sql/find_email_validation.sql"), - ), - ( - WRITE_OAUTH_PROVIDER, - include_str!("../../store/sqlite/sql/write_oauth_provider.sql"), - ), - ( - READ_OAUTH_PROVIDER, - include_str!("../../store/sqlite/sql/read_oauth_provider.sql"), - ), - ( - READ_OAUTH_VALIDATION, - include_str!("../../store/sqlite/sql/read_oauth_validation.sql"), - ), - ( - WRITE_OAUTH_VALIDATION, - include_str!("../../store/sqlite/sql/write_oauth_validation.sql"), - ), - ( - READ_VALIDATION_TYPE, - include_str!("../../store/sqlite/sql/read_validation_type.sql"), - ), - ] - .iter() - .cloned() - .collect(); - - let pg_sqls: HashMap<&'static str, &'static str> = [ - ( - WRITE_IDENTITY, - include_str!("../../store/pg/sql/write_identity.sql"), - ), - ( - WRITE_EMAIL_VALIDATION, - include_str!("../../store/pg/sql/write_email_validation.sql"), - ), - ( - WRITE_EMAIL, - include_str!("../../store/pg/sql/write_email.sql"), - ), - ( - READ_IDENTITY, - include_str!("../../store/pg/sql/read_identity.sql"), - ), - ( - FIND_IDENTITY, - include_str!("../../store/pg/sql/find_identity.sql"), - ), - ( - FIND_IDENTITY_BY_CODE, - include_str!("../../store/pg/sql/find_identity_by_code.sql"), - ), - ( - READ_IDENTITY_RAW_ID, - include_str!("../../store/pg/sql/read_identity_raw_id.sql"), - ), - ( - READ_EMAIL_RAW_ID, - include_str!("../../store/pg/sql/read_email_raw_id.sql"), - ), - ( - WRITE_SESSION, - include_str!("../../store/pg/sql/write_session.sql"), - ), - ( - READ_SESSION, - include_str!("../../store/pg/sql/read_session.sql"), - ), - ( - FIND_EMAIL_VALIDATION, - include_str!("../../store/pg/sql/find_email_validation.sql"), - ), - ( - WRITE_OAUTH_PROVIDER, - include_str!("../../store/pg/sql/write_oauth_provider.sql"), - ), - ( - READ_OAUTH_PROVIDER, - include_str!("../../store/pg/sql/read_oauth_provider.sql"), - ), - ( - READ_OAUTH_VALIDATION, - include_str!("../../store/pg/sql/read_oauth_validation.sql"), - ), - ( - WRITE_OAUTH_VALIDATION, - include_str!("../../store/pg/sql/write_oauth_validation.sql"), - ), - ( - READ_VALIDATION_TYPE, - include_str!("../../store/pg/sql/read_validation_type.sql"), - ), - ] - .iter() - .cloned() - .collect(); - - let sqls: HashMap<&'static str, HashMap<&'static str, &'static str>> = - [(SQLITE, sqlite_sqls), (PGSQL, pg_sqls)] - .iter() - .cloned() - .collect(); - sqls - }; -} - -impl<'a, R: Row> FromRow<'a, R> for OauthValidation -where - &'a str: ColumnIndex<R>, - OauthProviderName: Decode<'a, R::Database> + Type<R::Database>, - OauthResponseType: Decode<'a, R::Database> + Type<R::Database>, - OffsetDateTime: Decode<'a, R::Database> + Type<R::Database>, - String: Decode<'a, R::Database> + Type<R::Database>, - Uuid: Decode<'a, R::Database> + Type<R::Database>, -{ - fn from_row(row: &'a R) -> Result<Self, sqlx::Error> { - let id: Option<Uuid> = row.try_get("oauth_validation_public_id")?; - let identity_id: Option<Uuid> = row.try_get("identity_public_id")?; - let access_token: Option<String> = row.try_get("access_token")?; - let raw_response: Option<String> = row.try_get("raw_response")?; - let created_at: Option<OffsetDateTime> = row.try_get("created_at")?; - let validated_at: Option<OffsetDateTime> = row.try_get("validated_at")?; - let revoked_at: Option<OffsetDateTime> = row.try_get("revoked_at")?; - let deleted_at: Option<OffsetDateTime> = row.try_get("deleted_at")?; - - let op_name: Option<OauthProviderName> = row.try_get("oauth_provider_name")?; - let op_flow: Option<String> = row.try_get("oauth_provider_flow")?; - let op_base_url: Option<String> = row.try_get("oauth_provider_base_url")?; - let op_response_type: Option<OauthResponseType> = - row.try_get("oauth_provider_response_type")?; - let op_default_scope: Option<String> = row.try_get("oauth_provider_default_scope")?; - let op_client_id: Option<String> = row.try_get("oauth_provider_client_id")?; - let op_client_secret: Option<String> = row.try_get("oauth_provider_client_secret")?; - let op_redirect_url: Option<String> = row.try_get("oauth_provider_redirect_url")?; - let op_created_at: Option<OffsetDateTime> = row.try_get("oauth_provider_created_at")?; - let op_deleted_at: Option<OffsetDateTime> = row.try_get("oauth_provider_deleted_at")?; - - let op_base_url = op_base_url - .map(|s| Url::from_str(&s).ok()) - .flatten() - .ok_or(sqlx::Error::ColumnDecode { - index: "oauth_provider_base_url".into(), - source: "secd".into(), - })?; - - let op_redirect_url = op_redirect_url - .map(|s| Url::from_str(&s).ok()) - .flatten() - .ok_or(sqlx::Error::ColumnDecode { - index: "oauth_provider_redirect_url".into(), - source: "secd".into(), - })?; - - Ok(OauthValidation { - id, - identity_id, - access_token, - raw_response, - created_at: created_at.ok_or(sqlx::Error::ColumnDecode { - index: "created_at".into(), - source: "secd".into(), - })?, - validated_at, - revoked_at, - deleted_at, - oauth_provider: OauthProvider { - name: op_name.unwrap(), - flow: op_flow, - base_url: op_base_url, - response: op_response_type.ok_or(sqlx::Error::ColumnDecode { - index: "oauth_provider_response_type".into(), - source: "secd".into(), - })?, - default_scope: op_default_scope.ok_or(sqlx::Error::ColumnDecode { - index: "oauth_provider_default_scope".into(), - source: "secd".into(), - })?, - client_id: op_client_id.ok_or(sqlx::Error::ColumnDecode { - index: "oauth_provider_client_id".into(), - source: "secd".into(), - })?, - client_secret: op_client_secret.ok_or(sqlx::Error::ColumnDecode { - index: "oauth_provider_client_secret".into(), - source: "secd".into(), - })?, - redirect_url: op_redirect_url, - created_at: op_created_at.ok_or(sqlx::Error::ColumnDecode { - index: "oauth_provider_created_at".into(), - source: "secd".into(), - })?, - deleted_at: op_deleted_at, - }, - }) - } -} - -impl<'a, D: Database> Decode<'a, D> for OauthProviderName -where - &'a str: Decode<'a, D>, -{ - fn decode( - value: <D as HasValueRef<'a>>::ValueRef, - ) -> Result<Self, Box<dyn ::std::error::Error + 'static + Send + Sync>> { - let v = <&str as Decode<D>>::decode(value)?; - <OauthProviderName as clap::ValueEnum>::from_str(v, true) - .map_err(|_| "OauthProviderName should exist and decode to a program value.".into()) - } -} - -impl<D: Database> Type<D> for OauthProviderName -where - str: Type<D>, -{ - fn type_info() -> D::TypeInfo { - <&str as Type<D>>::type_info() - } -} - -impl<'a, D: Database> Decode<'a, D> for OauthResponseType -where - &'a str: Decode<'a, D>, -{ - fn decode( - value: <D as HasValueRef<'a>>::ValueRef, - ) -> Result<Self, Box<dyn ::std::error::Error + 'static + Send + Sync>> { - let v = <&str as Decode<D>>::decode(value)?; - <OauthResponseType as clap::ValueEnum>::from_str(v, true) - .map_err(|_| "OauthResponseType should exist and decode to a program value.".into()) - } -} - -impl<D: Database> Type<D> for OauthResponseType -where - str: Type<D>, -{ - fn type_info() -> D::TypeInfo { - <&str as Type<D>>::type_info() - } -} - -#[async_trait::async_trait] -pub trait Store { - async fn write_email(&self, email_address: &str) -> Result<(), StoreError>; - - async fn find_email_validation( - &self, - validation_id: Option<&Uuid>, - code: Option<&str>, - ) -> Result<EmailValidation, StoreError>; - async fn write_email_validation( - &self, - ev: &EmailValidation, - // TODO: Make this write an EmailValidation - ) -> anyhow::Result<Uuid>; - - async fn find_identity( - &self, - identity_id: Option<&Uuid>, - email: Option<&str>, - ) -> anyhow::Result<Option<Identity>>; - async fn find_identity_by_code(&self, code: &str) -> Result<Identity, StoreError>; - async fn write_identity(&self, i: &Identity) -> Result<(), StoreError>; - async fn read_identity(&self, identity_id: &Uuid) -> Result<Identity, StoreError>; - - async fn write_session(&self, session: &Session) -> Result<(), StoreError>; - async fn read_session(&self, secret: &SessionSecret) -> Result<Session, StoreError>; - - async fn write_oauth_provider(&self, provider: &OauthProvider) -> Result<(), StoreError>; - async fn read_oauth_provider( - &self, - provider: &OauthProviderName, - flow: Option<String>, - ) -> Result<OauthProvider, StoreError>; - async fn write_oauth_validation( - &self, - validation: &OauthValidation, - ) -> anyhow::Result<ValidationRequestId>; - async fn read_oauth_validation( - &self, - validation_id: &ValidationRequestId, - ) -> anyhow::Result<OauthValidation>; - - async fn find_validation_type( - &self, - validation_id: &ValidationRequestId, - ) -> anyhow::Result<ValidationType>; -} +pub(crate) mod store; diff --git a/crates/secd/src/client/sqldb.rs b/crates/secd/src/client/sqldb.rs deleted file mode 100644 index 6751ef6..0000000 --- a/crates/secd/src/client/sqldb.rs +++ /dev/null @@ -1,632 +0,0 @@ -use std::{str::FromStr, sync::Arc}; - -use super::{ - EmailValidation, Identity, OauthProvider, OauthProviderName, OauthResponseType, Session, - SessionSecret, Store, StoreError, ERR_MSG_MIGRATION_FAILED, FIND_EMAIL_VALIDATION, - FIND_IDENTITY, FIND_IDENTITY_BY_CODE, PGSQL, READ_EMAIL_RAW_ID, READ_IDENTITY_RAW_ID, - READ_OAUTH_PROVIDER, READ_OAUTH_VALIDATION, READ_SESSION, READ_VALIDATION_TYPE, SQLITE, SQLS, - WRITE_EMAIL, WRITE_EMAIL_VALIDATION, WRITE_IDENTITY, WRITE_OAUTH_PROVIDER, - WRITE_OAUTH_VALIDATION, WRITE_SESSION, -}; -use crate::{util, OauthValidation, ValidationRequestId, ValidationType}; -use anyhow::bail; -use log::{debug, error}; -use openssl::sha::Sha256; -use sqlx::{ - self, database::HasArguments, ColumnIndex, Database, Decode, Encode, Executor, IntoArguments, - Pool, Postgres, Sqlite, Transaction, Type, -}; -use time::OffsetDateTime; -use url::Url; -use uuid::Uuid; - -fn get_sqls(root: &str, file: &str) -> Vec<String> { - SQLS.get(root) - .unwrap() - .get(file) - .unwrap() - .split("--") - .map(|p| p.to_string()) - .collect() -} - -fn hash_secret(secret: &str) -> Vec<u8> { - let mut hasher = Sha256::new(); - hasher.update(secret.as_bytes()); - hasher.finish().to_vec() -} - -struct SqlClient<D> -where - D: sqlx::Database, -{ - pool: sqlx::Pool<D>, - sqls_root: String, -} - -impl<D> SqlClient<D> -where - D: sqlx::Database, - for<'c> <D as HasArguments<'c>>::Arguments: IntoArguments<'c, D>, - for<'c> i64: Decode<'c, D> + Type<D>, - for<'c> &'c str: Decode<'c, D> + Type<D>, - for<'c> &'c str: Encode<'c, D> + Type<D>, - for<'c> usize: ColumnIndex<<D as Database>::Row>, - for<'c> Uuid: Decode<'c, D> + Type<D>, - for<'c> Uuid: Encode<'c, D> + Type<D>, - for<'c> &'c Pool<D>: Executor<'c, Database = D>, -{ - async fn read_identity_raw_id(&self, id: &Uuid) -> Result<i64, StoreError> { - let sqls = get_sqls(&self.sqls_root, READ_IDENTITY_RAW_ID); - - Ok(sqlx::query_as::<_, (i64,)>(&sqls[0]) - .bind(id) - .fetch_one(&self.pool) - .await - .map_err(util::log_err_sqlx)? - .0) - } - - async fn read_email_raw_id(&self, address: &str) -> Result<i64, StoreError> { - let sqls = get_sqls(&self.sqls_root, READ_EMAIL_RAW_ID); - - Ok(sqlx::query_as::<_, (i64,)>(&sqls[0]) - .bind(address) - .fetch_one(&self.pool) - .await - .map_err(util::log_err_sqlx)? - .0) - } -} - -#[async_trait::async_trait] -impl<D> Store for SqlClient<D> -where - D: sqlx::Database, - for<'c> <D as HasArguments<'c>>::Arguments: IntoArguments<'c, D>, - for<'c> bool: Decode<'c, D> + Type<D>, - for<'c> bool: Encode<'c, D> + Type<D>, - for<'c> i64: Decode<'c, D> + Type<D>, - for<'c> i64: Encode<'c, D> + Type<D>, - for<'c> i32: Decode<'c, D> + Type<D>, - for<'c> i32: Encode<'c, D> + Type<D>, - for<'c> OffsetDateTime: Decode<'c, D> + Type<D>, - for<'c> OffsetDateTime: Encode<'c, D> + Type<D>, - for<'c> &'c str: ColumnIndex<<D as Database>::Row>, - for<'c> &'c str: Decode<'c, D> + Type<D>, - for<'c> &'c str: Encode<'c, D> + Type<D>, - for<'c> Option<&'c str>: Decode<'c, D> + Type<D>, - for<'c> Option<&'c str>: Encode<'c, D> + Type<D>, - for<'c> String: Decode<'c, D> + Type<D>, - for<'c> String: Encode<'c, D> + Type<D>, - for<'c> Option<String>: Decode<'c, D> + Type<D>, - for<'c> Option<String>: Encode<'c, D> + Type<D>, - for<'c> OauthProviderName: Decode<'c, D> + Type<D>, - for<'c> OauthResponseType: Decode<'c, D> + Type<D>, - for<'c> usize: ColumnIndex<<D as Database>::Row>, - for<'c> Uuid: Decode<'c, D> + Type<D>, - for<'c> Uuid: Encode<'c, D> + Type<D>, - for<'c> &'c [u8]: Encode<'c, D> + Type<D>, - for<'c> Option<&'c Uuid>: Encode<'c, D> + Type<D>, - for<'c> Option<&'c Vec<u8>>: Encode<'c, D> + Type<D>, - for<'c> Option<OffsetDateTime>: Decode<'c, D> + Type<D>, - for<'c> Option<OffsetDateTime>: Encode<'c, D> + Type<D>, - for<'c> &'c Pool<D>: Executor<'c, Database = D>, - for<'c> &'c mut Transaction<'c, D>: Executor<'c, Database = D>, -{ - async fn write_email(&self, email_address: &str) -> Result<(), StoreError> { - let sqls = get_sqls(&self.sqls_root, WRITE_EMAIL); - - sqlx::query(&sqls[0]) - .bind(email_address) - .execute(&self.pool) - .await - .map_err(util::log_err_sqlx)?; - - Ok(()) - } - - async fn find_email_validation( - &self, - validation_id: Option<&Uuid>, - code: Option<&str>, - ) -> Result<EmailValidation, StoreError> { - let sqls = get_sqls(&self.sqls_root, FIND_EMAIL_VALIDATION); - let mut rows = sqlx::query_as::<_, EmailValidation>(&sqls[0]) - .bind(validation_id) - .bind(code) - .fetch_all(&self.pool) - .await - .map_err(util::log_err_sqlx)?; - - match rows.len() { - 0 => Err(StoreError::NoEmailValidationFound), - 1 => Ok(rows.swap_remove(0)), - _ => Err(StoreError::TooManyValidations), - } - } - - async fn write_email_validation(&self, ev: &EmailValidation) -> anyhow::Result<Uuid> { - let sqls = get_sqls(&self.sqls_root, WRITE_EMAIL_VALIDATION); - - let email_id = self.read_email_raw_id(&ev.email_address).await?; - let validation_id = ev.id.unwrap_or(Uuid::new_v4()); - sqlx::query(&sqls[0]) - .bind(validation_id) - .bind(email_id) - .bind(&ev.code) - .bind(ev.is_oauth_derived) - .bind(ev.created_at) - .bind(ev.validated_at) - .bind(ev.expired_at) - .execute(&self.pool) - .await - .map_err(util::log_err_sqlx)?; - - if ev.identity_id.is_some() || ev.revoked_at.is_some() || ev.deleted_at.is_some() { - sqlx::query(&sqls[1]) - .bind(ev.identity_id.as_ref()) - .bind(validation_id) - .bind(ev.revoked_at) - .bind(ev.deleted_at) - .execute(&self.pool) - .await - .map_err(util::log_err_sqlx)?; - } - - Ok(validation_id) - } - - async fn find_identity( - &self, - id: Option<&Uuid>, - email: Option<&str>, - ) -> anyhow::Result<Option<Identity>> { - let sqls = get_sqls(&self.sqls_root, FIND_IDENTITY); - Ok( - match sqlx::query_as::<_, Identity>(&sqls[0]) - .bind(id) - .bind(email) - .fetch_all(&self.pool) - .await - { - Ok(mut is) => match is.len() { - // if only 1 found, then that's fine - // if multiple are fond, then if they all have the same id, that's okay - 1 => { - let i = is.swap_remove(0); - match i.deleted_at { - Some(t) if t > OffsetDateTime::now_utc() => Some(i), - None => Some(i), - _ => None, - } - } - 0 => None, - _ => { - match is - .iter() - .filter(|&i| i.id != is[0].id) - .collect::<Vec<&Identity>>() - .len() - { - 0 => Some(is.swap_remove(0)), - _ => bail!(StoreError::TooManyIdentitiesFound), - } - } - }, - Err(sqlx::Error::RowNotFound) => None, - Err(e) => bail!(StoreError::SqlxError(e)), - }, - ) - } - - async fn find_identity_by_code(&self, code: &str) -> Result<Identity, StoreError> { - let sqls = get_sqls(&self.sqls_root, FIND_IDENTITY_BY_CODE); - - let rows = sqlx::query_as::<_, (i32,)>(&sqls[0]) - .bind(code) - .fetch_all(&self.pool) - .await - .map_err(util::log_err_sqlx)?; - - if rows.len() == 0 { - return Err(StoreError::CodeDoesNotExist(code.to_string())); - } - - if rows.len() != 1 { - return Err(StoreError::CodeAppearsMoreThanOnce); - } - - let identity_email_id = rows.get(0).unwrap().0; - - // TODO: IF we expand beyond email codes, then we'll need to join against a bunch of identity tables. - // but since a single code was found, only one of them should pop... - Ok(sqlx::query_as::<_, Identity>(&sqls[1]) - .bind(identity_email_id) - .fetch_one(&self.pool) - .await - .map_err(util::log_err_sqlx)?) - } - - async fn write_identity(&self, i: &Identity) -> Result<(), StoreError> { - let sqls = get_sqls(&self.sqls_root, WRITE_IDENTITY); - sqlx::query(&sqls[0]) - .bind(i.id) - .bind(i.data.clone()) - .bind(i.created_at) - .execute(&self.pool) - .await - .map_err(|e| { - error!("write_identity_failure"); - error!("{:?}", e); - e - })?; - - Ok(()) - } - async fn read_identity(&self, id: &Uuid) -> Result<Identity, StoreError> { - let identity = sqlx::query_as::<_, Identity>( - " -select identity_public_id, data, created_at from identity where identity_public_id = ?", - ) - .bind(id) - .fetch_one(&self.pool) - .await - .map_err(util::log_err_sqlx)?; - - Ok(identity) - } - - async fn write_session(&self, session: &Session) -> Result<(), StoreError> { - let sqls = get_sqls(&self.sqls_root, WRITE_SESSION); - - let secret_hash = session.secret.as_ref().map(|s| hash_secret(s)); - - sqlx::query(&sqls[0]) - .bind(&session.identity_id) - .bind(secret_hash.as_ref()) - .bind(session.created_at) - .bind(session.expired_at) - .bind(session.revoked_at) - .execute(&self.pool) - .await - .map_err(util::log_err_sqlx)?; - - Ok(()) - } - async fn read_session(&self, secret: &SessionSecret) -> Result<Session, StoreError> { - let sqls = get_sqls(&self.sqls_root, READ_SESSION); - - let secret_hash = hash_secret(secret); - let mut session = sqlx::query_as::<_, Session>(&sqls[0]) - .bind(&secret_hash[..]) - .fetch_one(&self.pool) - .await - .map_err(util::log_err_sqlx)?; - - // This should do nothing other than updated touched_at, and then - // clear the plaintext secret - session.secret = Some(secret.to_string()); - self.write_session(&session).await?; - session.secret = None; - - Ok(session) - } - - async fn write_oauth_provider(&self, provider: &OauthProvider) -> Result<(), StoreError> { - let sqls = get_sqls(&self.sqls_root, WRITE_OAUTH_PROVIDER); - sqlx::query(&sqls[0]) - .bind(&provider.name.to_string()) - .bind(&provider.flow) - .bind(&provider.base_url.to_string()) - .bind(&provider.response.to_string()) - .bind(&provider.default_scope) - .bind(&provider.client_id) - // TODO: encrypt secret before writing - .bind(&provider.client_secret) - .bind(&provider.redirect_url.to_string()) - .bind(provider.created_at) - .bind(provider.deleted_at) - .execute(&self.pool) - .await - .map_err(util::log_err_sqlx)?; - Ok(()) - } - - async fn read_oauth_provider( - &self, - provider: &OauthProviderName, - flow: Option<String>, - ) -> Result<OauthProvider, StoreError> { - let sqls = get_sqls(&self.sqls_root, READ_OAUTH_PROVIDER); - let flow = flow.unwrap_or("default".into()); - debug!("provider: {:?}, flow: {:?}", provider, flow); - // TODO: Write the generic FromRow impl for OauthProvider... - let res = sqlx::query_as::< - _, - ( - String, - String, - String, - String, - String, - String, - String, - OffsetDateTime, - Option<OffsetDateTime>, - ), - >(&sqls[0]) - .bind(&provider.to_string()) - .bind(&flow) - .fetch_one(&self.pool) - .await - .map_err(util::log_err_sqlx)?; - - debug!("res: {:?}", res); - - Ok(OauthProvider { - name: provider.clone(), - flow: Some(res.0), - base_url: Url::from_str(&res.1) - .map_err(|_| StoreError::OauthProviderDoesNotExist(*provider))?, - response: OauthResponseType::from_str(&res.2) - .map_err(|_| StoreError::OauthProviderDoesNotExist(*provider))?, - default_scope: res.3, - client_id: res.4, - client_secret: res.5, - redirect_url: Url::from_str(&res.6) - .map_err(|_| StoreError::OauthProviderDoesNotExist(*provider))?, - created_at: res.7, - deleted_at: res.8, - }) - } - async fn write_oauth_validation( - &self, - v: &OauthValidation, - ) -> anyhow::Result<ValidationRequestId> { - let sqls = get_sqls(&self.sqls_root, WRITE_OAUTH_VALIDATION); - - let validation_id = v.id.unwrap_or(Uuid::new_v4()); - sqlx::query(&sqls[0]) - .bind(validation_id) - .bind(v.oauth_provider.name.to_string()) - .bind(v.oauth_provider.flow.clone()) - .bind(v.access_token.clone()) - .bind(v.raw_response.clone()) - .bind(v.created_at) - .bind(v.validated_at) - .execute(&self.pool) - .await?; - - if v.identity_id.is_some() || v.revoked_at.is_some() || v.deleted_at.is_some() { - sqlx::query(&sqls[1]) - .bind(v.identity_id.as_ref()) - .bind(validation_id) - .bind(v.revoked_at) - .bind(v.deleted_at) - .execute(&self.pool) - .await?; - } - - Ok(validation_id) - } - async fn read_oauth_validation( - &self, - validation_id: &ValidationRequestId, - ) -> anyhow::Result<OauthValidation> { - let sqls = get_sqls(&self.sqls_root, READ_OAUTH_VALIDATION); - - let mut es = sqlx::query_as::<_, OauthValidation>(&sqls[0]) - .bind(validation_id) - .fetch_all(&self.pool) - .await?; - - if es.len() != 1 { - bail!(StoreError::OauthValidationDoesNotExist( - validation_id.clone() - )); - } - - Ok(es.swap_remove(0)) - } - async fn find_validation_type( - &self, - validation_id: &ValidationRequestId, - ) -> anyhow::Result<ValidationType> { - let sqls = get_sqls(&self.sqls_root, READ_VALIDATION_TYPE); - - let mut es = sqlx::query_as::<_, (String,)>(&sqls[0]) - .bind(validation_id) - .fetch_all(&self.pool) - .await - .map_err(util::log_err_sqlx)?; - - match es.len() { - 1 => Ok(ValidationType::from_str(&es.swap_remove(0).0)?), - _ => bail!(StoreError::Other( - "expected a single validation but recieved 0 or multiple validations".into() - )), - } - } -} - -pub struct PgClient { - sql: SqlClient<Postgres>, -} - -impl PgClient { - pub async fn new(pool: sqlx::Pool<Postgres>) -> Arc<dyn Store + Send + Sync + 'static> { - sqlx::migrate!("store/pg/migrations") - .run(&pool) - .await - .expect(ERR_MSG_MIGRATION_FAILED); - - Arc::new(PgClient { - sql: SqlClient { - pool, - sqls_root: PGSQL.to_string(), - }, - }) - } -} - -#[async_trait::async_trait] -impl Store for PgClient { - async fn write_email(&self, email_address: &str) -> Result<(), StoreError> { - self.sql.write_email(email_address).await - } - async fn find_email_validation( - &self, - validation_id: Option<&Uuid>, - code: Option<&str>, - ) -> Result<EmailValidation, StoreError> { - self.sql.find_email_validation(validation_id, code).await - } - async fn write_email_validation(&self, ev: &EmailValidation) -> anyhow::Result<Uuid> { - self.sql.write_email_validation(ev).await - } - async fn find_identity( - &self, - identity_id: Option<&Uuid>, - email: Option<&str>, - ) -> anyhow::Result<Option<Identity>> { - self.sql.find_identity(identity_id, email).await - } - async fn find_identity_by_code(&self, code: &str) -> Result<Identity, StoreError> { - self.sql.find_identity_by_code(code).await - } - async fn write_identity(&self, i: &Identity) -> Result<(), StoreError> { - self.sql.write_identity(i).await - } - async fn read_identity(&self, identity_id: &Uuid) -> Result<Identity, StoreError> { - self.sql.read_identity(identity_id).await - } - async fn write_session(&self, session: &Session) -> Result<(), StoreError> { - self.sql.write_session(session).await - } - async fn read_session(&self, secret: &SessionSecret) -> Result<Session, StoreError> { - self.sql.read_session(secret).await - } - async fn write_oauth_provider(&self, provider: &OauthProvider) -> Result<(), StoreError> { - self.sql.write_oauth_provider(provider).await - } - async fn read_oauth_provider( - &self, - provider: &OauthProviderName, - flow: Option<String>, - ) -> Result<OauthProvider, StoreError> { - self.sql.read_oauth_provider(provider, flow).await - } - async fn write_oauth_validation( - &self, - validation: &OauthValidation, - ) -> anyhow::Result<ValidationRequestId> { - self.sql.write_oauth_validation(validation).await - } - async fn read_oauth_validation( - &self, - validation_id: &ValidationRequestId, - ) -> anyhow::Result<OauthValidation> { - self.sql.read_oauth_validation(validation_id).await - } - async fn find_validation_type( - &self, - validation_id: &ValidationRequestId, - ) -> anyhow::Result<ValidationType> { - self.sql.find_validation_type(validation_id).await - } -} - -pub struct SqliteClient { - sql: SqlClient<Sqlite>, -} - -impl SqliteClient { - pub async fn new(pool: sqlx::Pool<Sqlite>) -> Arc<dyn Store + Send + Sync + 'static> { - sqlx::migrate!("store/sqlite/migrations") - .run(&pool) - .await - .expect(ERR_MSG_MIGRATION_FAILED); - - sqlx::query("pragma foreign_keys = on") - .execute(&pool) - .await - .expect( - "Failed to initialize FK pragma. File a bug at https://www.github.com/secd-lib", - ); - - Arc::new(SqliteClient { - sql: SqlClient { - pool, - sqls_root: SQLITE.to_string(), - }, - }) - } -} - -#[async_trait::async_trait] -impl Store for SqliteClient { - async fn write_email(&self, email_address: &str) -> Result<(), StoreError> { - self.sql.write_email(email_address).await - } - async fn find_email_validation( - &self, - validation_id: Option<&Uuid>, - code: Option<&str>, - ) -> Result<EmailValidation, StoreError> { - self.sql.find_email_validation(validation_id, code).await - } - async fn write_email_validation(&self, ev: &EmailValidation) -> anyhow::Result<Uuid> { - self.sql.write_email_validation(ev).await - } - async fn find_identity( - &self, - identity_id: Option<&Uuid>, - email: Option<&str>, - ) -> anyhow::Result<Option<Identity>> { - self.sql.find_identity(identity_id, email).await - } - async fn find_identity_by_code(&self, code: &str) -> Result<Identity, StoreError> { - self.sql.find_identity_by_code(code).await - } - async fn write_identity(&self, i: &Identity) -> Result<(), StoreError> { - self.sql.write_identity(i).await - } - async fn read_identity(&self, identity_id: &Uuid) -> Result<Identity, StoreError> { - self.sql.read_identity(identity_id).await - } - async fn write_session(&self, session: &Session) -> Result<(), StoreError> { - self.sql.write_session(session).await - } - async fn read_session(&self, secret: &SessionSecret) -> Result<Session, StoreError> { - self.sql.read_session(secret).await - } - async fn write_oauth_provider(&self, provider: &OauthProvider) -> Result<(), StoreError> { - self.sql.write_oauth_provider(provider).await - } - async fn read_oauth_provider( - &self, - provider: &OauthProviderName, - flow: Option<String>, - ) -> Result<OauthProvider, StoreError> { - self.sql.read_oauth_provider(provider, flow).await - } - async fn write_oauth_validation( - &self, - validation: &OauthValidation, - ) -> anyhow::Result<ValidationRequestId> { - self.sql.write_oauth_validation(validation).await - } - async fn read_oauth_validation( - &self, - validation_id: &ValidationRequestId, - ) -> anyhow::Result<OauthValidation> { - self.sql.read_oauth_validation(validation_id).await - } - async fn find_validation_type( - &self, - validation_id: &ValidationRequestId, - ) -> anyhow::Result<ValidationType> { - self.sql.find_validation_type(validation_id).await - } -} diff --git a/crates/secd/src/client/store/mod.rs b/crates/secd/src/client/store/mod.rs new file mode 100644 index 0000000..b93fd84 --- /dev/null +++ b/crates/secd/src/client/store/mod.rs @@ -0,0 +1,190 @@ +pub(crate) mod sql_db; + +use email_address::EmailAddress; +use sqlx::{Postgres, Sqlite}; +use std::sync::Arc; +use time::OffsetDateTime; +use uuid::Uuid; + +use crate::{ + util, Address, AddressType, AddressValidation, Identity, IdentityId, PhoneNumber, Session, + SessionToken, +}; + +use self::sql_db::SqlClient; + +#[derive(Debug, thiserror::Error, derive_more::Display)] +pub enum StoreError { + SqlClientError(#[from] sqlx::Error), + StoreValueCannotBeParsedInvariant, + IdempotentCheckAlreadyExists, +} + +#[async_trait::async_trait(?Send)] +pub trait Store { + fn get_type(&self) -> StoreType; +} + +pub enum StoreType { + Postgres { c: Arc<SqlClient<Postgres>> }, + Sqlite { c: Arc<SqlClient<Sqlite>> }, +} + +#[async_trait::async_trait(?Send)] +pub(crate) trait Storable<'a> { + type Item; + type Lens; + + async fn write(&self, store: Arc<dyn Store>) -> Result<(), StoreError>; + async fn find( + store: Arc<dyn Store>, + lens: &'a Self::Lens, + ) -> Result<Vec<Self::Item>, StoreError>; +} + +pub(crate) trait Lens {} + +pub(crate) struct AddressLens<'a> { + pub id: Option<&'a Uuid>, + pub t: Option<&'a AddressType>, +} +impl<'a> Lens for AddressLens<'a> {} + +pub(crate) struct AddressValidationLens<'a> { + pub id: Option<&'a Uuid>, +} +impl<'a> Lens for AddressValidationLens<'a> {} + +pub(crate) struct IdentityLens<'a> { + pub id: Option<&'a Uuid>, + pub address_type: Option<&'a AddressType>, + pub validated_address: Option<bool>, + pub session_token_hash: Option<Vec<u8>>, +} +impl<'a> Lens for IdentityLens<'a> {} + +pub(crate) struct SessionLens<'a> { + pub token_hash: Option<&'a Vec<u8>>, + pub identity_id: Option<&'a IdentityId>, +} +impl<'a> Lens for SessionLens<'a> {} + +#[async_trait::async_trait(?Send)] +impl<'a> Storable<'a> for Address { + type Item = Address; + type Lens = AddressLens<'a>; + + async fn write(&self, store: Arc<dyn Store>) -> Result<(), StoreError> { + match store.get_type() { + StoreType::Postgres { c } => c.write_address(self).await?, + StoreType::Sqlite { c } => c.write_address(self).await?, + } + Ok(()) + } + async fn find( + store: Arc<dyn Store>, + lens: &'a Self::Lens, + ) -> Result<Vec<Self::Item>, StoreError> { + let typ = lens.t.map(|at| at.to_string().clone()); + let typ = typ.as_deref(); + + let val = lens.t.and_then(|at| at.get_value()); + let val = val.as_deref(); + + Ok(match store.get_type() { + StoreType::Postgres { c } => c.find_address(lens.id, typ, val).await?, + StoreType::Sqlite { c } => c.find_address(lens.id, typ, val).await?, + }) + } +} + +#[async_trait::async_trait(?Send)] +impl<'a> Storable<'a> for AddressValidation { + type Item = AddressValidation; + type Lens = AddressValidationLens<'a>; + + async fn write(&self, store: Arc<dyn Store>) -> Result<(), StoreError> { + match store.get_type() { + StoreType::Sqlite { c } => c.write_address_validation(self).await?, + StoreType::Postgres { c } => c.write_address_validation(self).await?, + } + Ok(()) + } + async fn find( + store: Arc<dyn Store>, + lens: &'a Self::Lens, + ) -> Result<Vec<Self::Item>, StoreError> { + Ok(match store.get_type() { + StoreType::Postgres { c } => c.find_address_validation(lens.id).await?, + StoreType::Sqlite { c } => c.find_address_validation(lens.id).await?, + }) + } +} + +#[async_trait::async_trait(?Send)] +impl<'a> Storable<'a> for Identity { + type Item = Identity; + type Lens = IdentityLens<'a>; + + async fn write(&self, store: Arc<dyn Store>) -> Result<(), StoreError> { + match store.get_type() { + StoreType::Postgres { c } => c.write_identity(self).await?, + StoreType::Sqlite { c } => c.write_identity(self).await?, + } + Ok(()) + } + async fn find( + store: Arc<dyn Store>, + lens: &'a Self::Lens, + ) -> Result<Vec<Self::Item>, StoreError> { + let val = lens.address_type.and_then(|at| at.get_value()); + let val = val.as_deref(); + + Ok(match store.get_type() { + StoreType::Postgres { c } => { + c.find_identity( + lens.id, + val, + lens.validated_address, + &lens.session_token_hash, + ) + .await? + } + StoreType::Sqlite { c } => { + c.find_identity( + lens.id, + val, + lens.validated_address, + &lens.session_token_hash, + ) + .await? + } + }) + } +} + +#[async_trait::async_trait(?Send)] +impl<'a> Storable<'a> for Session { + type Item = Session; + type Lens = SessionLens<'a>; + + async fn write(&self, store: Arc<dyn Store>) -> Result<(), StoreError> { + let token_hash = util::hash(&self.token); + match store.get_type() { + StoreType::Postgres { c } => c.write_session(self, &token_hash).await?, + StoreType::Sqlite { c } => c.write_session(self, &token_hash).await?, + } + Ok(()) + } + async fn find( + store: Arc<dyn Store>, + lens: &'a Self::Lens, + ) -> Result<Vec<Self::Item>, StoreError> { + let token = lens.token_hash.map(|t| t.clone()).unwrap_or(vec![]); + + Ok(match store.get_type() { + StoreType::Postgres { c } => c.find_session(token, lens.identity_id).await?, + StoreType::Sqlite { c } => c.find_session(token, lens.identity_id).await?, + }) + } +} diff --git a/crates/secd/src/client/store/sql_db.rs b/crates/secd/src/client/store/sql_db.rs new file mode 100644 index 0000000..6d84301 --- /dev/null +++ b/crates/secd/src/client/store/sql_db.rs @@ -0,0 +1,526 @@ +use std::{str::FromStr, sync::Arc}; + +use email_address::EmailAddress; +use serde_json::value::RawValue; +use sqlx::{ + database::HasArguments, types::Json, ColumnIndex, Database, Decode, Encode, Executor, + IntoArguments, Pool, Transaction, Type, +}; +use time::OffsetDateTime; +use uuid::Uuid; + +use crate::{ + Address, AddressType, AddressValidation, AddressValidationMethod, Identity, Session, + SessionToken, +}; + +use lazy_static::lazy_static; +use sqlx::{Postgres, Sqlite}; +use std::collections::HashMap; + +use super::{ + AddressLens, AddressValidationLens, IdentityLens, SessionLens, Storable, Store, StoreError, + StoreType, +}; + +const SQLITE: &str = "sqlite"; +const PGSQL: &str = "pgsql"; + +const WRITE_ADDRESS: &str = "write_address"; +const FIND_ADDRESS: &str = "find_address"; +const WRITE_ADDRESS_VALIDATION: &str = "write_address_validation"; +const FIND_ADDRESS_VALIDATION: &str = "find_address_validation"; +const WRITE_IDENTITY: &str = "write_identity"; +const FIND_IDENTITY: &str = "find_identity"; +const WRITE_SESSION: &str = "write_session"; +const FIND_SESSION: &str = "find_session"; + +const ERR_MSG_MIGRATION_FAILED: &str = "Failed to apply secd migrations to a sql db. File a bug at https://www.github.com/branchcontrol/secdiam"; + +lazy_static! { + static ref SQLS: HashMap<&'static str, HashMap<&'static str, &'static str>> = { + let sqlite_sqls: HashMap<&'static str, &'static str> = [ + ( + WRITE_ADDRESS, + include_str!("../../../store/sqlite/sql/write_address.sql"), + ), + ( + FIND_ADDRESS, + include_str!("../../../store/sqlite/sql/find_address.sql"), + ), + ( + WRITE_ADDRESS_VALIDATION, + include_str!("../../../store/sqlite/sql/write_address_validation.sql"), + ), + ( + FIND_ADDRESS_VALIDATION, + include_str!("../../../store/sqlite/sql/find_address_validation.sql"), + ), + ( + WRITE_IDENTITY, + include_str!("../../../store/sqlite/sql/write_identity.sql"), + ), + ( + FIND_IDENTITY, + include_str!("../../../store/sqlite/sql/find_identity.sql"), + ), + ( + WRITE_SESSION, + include_str!("../../../store/sqlite/sql/write_session.sql"), + ), + ( + FIND_SESSION, + include_str!("../../../store/sqlite/sql/find_session.sql"), + ), + ] + .iter() + .cloned() + .collect(); + + let pg_sqls: HashMap<&'static str, &'static str> = [ + ( + WRITE_ADDRESS, + include_str!("../../../store/pg/sql/write_address.sql"), + ), + ( + FIND_ADDRESS, + include_str!("../../../store/pg/sql/find_address.sql"), + ), + ( + WRITE_ADDRESS_VALIDATION, + include_str!("../../../store/pg/sql/write_address_validation.sql"), + ), + ( + FIND_ADDRESS_VALIDATION, + include_str!("../../../store/pg/sql/find_address_validation.sql"), + ), + ( + WRITE_IDENTITY, + include_str!("../../../store/pg/sql/write_identity.sql"), + ), + ( + FIND_IDENTITY, + include_str!("../../../store/pg/sql/find_identity.sql"), + ), + ( + WRITE_SESSION, + include_str!("../../../store/pg/sql/write_session.sql"), + ), + ( + FIND_SESSION, + include_str!("../../../store/pg/sql/find_session.sql"), + ), + ] + .iter() + .cloned() + .collect(); + + let sqls: HashMap<&'static str, HashMap<&'static str, &'static str>> = + [(SQLITE, sqlite_sqls), (PGSQL, pg_sqls)] + .iter() + .cloned() + .collect(); + sqls + }; +} + +pub trait SqlxResultExt<T> { + fn extend_err(self) -> Result<T, StoreError>; +} + +impl<T> SqlxResultExt<T> for Result<T, sqlx::Error> { + fn extend_err(self) -> Result<T, StoreError> { + if let Err(sqlx::Error::Database(dbe)) = &self { + if dbe.code() == Some("23505".into()) { + return Err(StoreError::IdempotentCheckAlreadyExists); + } + } + self.map_err(|e| StoreError::SqlClientError(e)) + } +} + +pub struct SqlClient<D> +where + D: sqlx::Database, +{ + pool: sqlx::Pool<D>, + sqls_root: String, +} + +pub struct PgClient { + sql: Arc<SqlClient<Postgres>>, +} +impl Store for PgClient { + fn get_type(&self) -> StoreType { + StoreType::Postgres { + c: self.sql.clone(), + } + } +} + +impl PgClient { + pub async fn new(pool: sqlx::Pool<Postgres>) -> Arc<dyn Store + Send + Sync + 'static> { + sqlx::migrate!("store/pg/migrations") + .run(&pool) + .await + .expect(ERR_MSG_MIGRATION_FAILED); + + Arc::new(PgClient { + sql: Arc::new(SqlClient { + pool, + sqls_root: PGSQL.to_string(), + }), + }) + } +} + +pub struct SqliteClient { + sql: Arc<SqlClient<Sqlite>>, +} +impl Store for SqliteClient { + fn get_type(&self) -> StoreType { + StoreType::Sqlite { + c: self.sql.clone(), + } + } +} + +impl SqliteClient { + pub async fn new(pool: sqlx::Pool<Sqlite>) -> Arc<dyn Store + Send + Sync + 'static> { + sqlx::migrate!("store/sqlite/migrations") + .run(&pool) + .await + .expect(ERR_MSG_MIGRATION_FAILED); + + sqlx::query("pragma foreign_keys = on") + .execute(&pool) + .await + .expect( + "Failed to initialize FK pragma. File a bug at https://www.github.com/secd-lib", + ); + + Arc::new(SqliteClient { + sql: Arc::new(SqlClient { + pool, + sqls_root: SQLITE.to_string(), + }), + }) + } +} + +impl<D> SqlClient<D> +where + D: sqlx::Database, + for<'c> &'c Pool<D>: Executor<'c, Database = D>, + for<'c> &'c mut Transaction<'c, D>: Executor<'c, Database = D>, + for<'c> <D as HasArguments<'c>>::Arguments: IntoArguments<'c, D>, + for<'c> bool: Decode<'c, D> + Type<D>, + for<'c> bool: Encode<'c, D> + Type<D>, + for<'c> Option<bool>: Decode<'c, D> + Type<D>, + for<'c> Option<bool>: Encode<'c, D> + Type<D>, + for<'c> i64: Decode<'c, D> + Type<D>, + for<'c> i64: Encode<'c, D> + Type<D>, + for<'c> i32: Decode<'c, D> + Type<D>, + for<'c> i32: Encode<'c, D> + Type<D>, + for<'c> OffsetDateTime: Decode<'c, D> + Type<D>, + for<'c> OffsetDateTime: Encode<'c, D> + Type<D>, + for<'c> &'c str: ColumnIndex<<D as Database>::Row>, + for<'c> &'c str: Decode<'c, D> + Type<D>, + for<'c> &'c str: Encode<'c, D> + Type<D>, + for<'c> Option<&'c str>: Decode<'c, D> + Type<D>, + for<'c> Option<&'c str>: Encode<'c, D> + Type<D>, + for<'c> String: Decode<'c, D> + Type<D>, + for<'c> String: Encode<'c, D> + Type<D>, + for<'c> Option<String>: Decode<'c, D> + Type<D>, + for<'c> Option<String>: Encode<'c, D> + Type<D>, + for<'c> usize: ColumnIndex<<D as Database>::Row>, + for<'c> Uuid: Decode<'c, D> + Type<D>, + for<'c> Uuid: Encode<'c, D> + Type<D>, + for<'c> &'c [u8]: Encode<'c, D> + Type<D>, + for<'c> Option<&'c Uuid>: Encode<'c, D> + Type<D>, + for<'c> Vec<u8>: Encode<'c, D> + Type<D>, + for<'c> Vec<u8>: Decode<'c, D> + Type<D>, + for<'c> Option<Vec<u8>>: Encode<'c, D> + Type<D>, + for<'c> Option<Vec<u8>>: Decode<'c, D> + Type<D>, + for<'c> Option<OffsetDateTime>: Decode<'c, D> + Type<D>, + for<'c> Option<OffsetDateTime>: Encode<'c, D> + Type<D>, +{ + pub async fn write_address(&self, a: &Address) -> Result<(), StoreError> { + let sqls = get_sqls(&self.sqls_root, WRITE_ADDRESS); + sqlx::query(&sqls[0]) + .bind(a.id) + .bind(a.t.to_string()) + .bind(match &a.t { + AddressType::Email { email_address } => { + email_address.as_ref().map(ToString::to_string) + } + AddressType::Sms { phone_number } => phone_number.clone(), + }) + .bind(a.created_at) + .execute(&self.pool) + .await + .extend_err()?; + + Ok(()) + } + + pub async fn find_address( + &self, + id: Option<&Uuid>, + typ: Option<&str>, + val: Option<&str>, + ) -> Result<Vec<Address>, StoreError> { + let sqls = get_sqls(&self.sqls_root, FIND_ADDRESS); + let res = sqlx::query_as::<_, (Uuid, String, String, OffsetDateTime)>(&sqls[0]) + .bind(id) + .bind(typ) + .bind(val) + .fetch_all(&self.pool) + .await + .extend_err()?; + + let mut addresses = vec![]; + for (id, typ, val, created_at) in res.into_iter() { + let t = match AddressType::from_str(&typ) + .map_err(|_| StoreError::StoreValueCannotBeParsedInvariant)? + { + AddressType::Email { .. } => AddressType::Email { + email_address: Some( + EmailAddress::from_str(&val) + .map_err(|_| StoreError::StoreValueCannotBeParsedInvariant)?, + ), + }, + AddressType::Sms { .. } => AddressType::Sms { + phone_number: Some(val.clone()), + }, + }; + + addresses.push(Address { id, t, created_at }); + } + + Ok(addresses) + } + + pub async fn write_address_validation(&self, v: &AddressValidation) -> Result<(), StoreError> { + let sqls = get_sqls(&self.sqls_root, WRITE_ADDRESS_VALIDATION); + sqlx::query(&sqls[0]) + .bind(v.id) + .bind(v.identity_id.as_ref()) + .bind(v.address.id) + .bind(v.method.to_string()) + .bind(v.hashed_token.clone()) + .bind(v.hashed_code.clone()) + .bind(v.attempts) + .bind(v.created_at) + .bind(v.expires_at) + .bind(v.revoked_at) + .bind(v.validated_at) + .execute(&self.pool) + .await + .extend_err()?; + + Ok(()) + } + + pub async fn find_address_validation( + &self, + id: Option<&Uuid>, + ) -> Result<Vec<AddressValidation>, StoreError> { + let sqls = get_sqls(&self.sqls_root, FIND_ADDRESS_VALIDATION); + let rs = sqlx::query_as::< + _, + ( + Uuid, + Option<Uuid>, + Uuid, + String, + String, + OffsetDateTime, + String, + Vec<u8>, + Vec<u8>, + i32, + OffsetDateTime, + OffsetDateTime, + Option<OffsetDateTime>, + Option<OffsetDateTime>, + ), + >(&sqls[0]) + .bind(id) + .fetch_all(&self.pool) + .await + .extend_err()?; + + let mut res = vec![]; + for ( + id, + identity_id, + address_id, + address_typ, + address_val, + address_created_at, + method, + hashed_token, + hashed_code, + attempts, + created_at, + expires_at, + revoked_at, + validated_at, + ) in rs.into_iter() + { + let t = match AddressType::from_str(&address_typ) + .map_err(|_| StoreError::StoreValueCannotBeParsedInvariant)? + { + AddressType::Email { .. } => AddressType::Email { + email_address: Some( + EmailAddress::from_str(&address_val) + .map_err(|_| StoreError::StoreValueCannotBeParsedInvariant)?, + ), + }, + AddressType::Sms { .. } => AddressType::Sms { + phone_number: Some(address_val.clone()), + }, + }; + + res.push(AddressValidation { + id, + identity_id, + address: Address { + id: address_id, + t, + created_at: address_created_at, + }, + method: AddressValidationMethod::from_str(&method) + .map_err(|_| StoreError::StoreValueCannotBeParsedInvariant)?, + created_at, + expires_at, + revoked_at, + validated_at, + attempts, + hashed_token, + hashed_code, + }); + } + + Ok(res) + } + + pub async fn write_identity(&self, i: &Identity) -> Result<(), StoreError> { + let sqls = get_sqls(&self.sqls_root, WRITE_IDENTITY); + sqlx::query(&sqls[0]) + .bind(i.id) + // TODO: validate this is actually Json somewhere way up the chain (when being deserialized) + .bind(i.metadata.clone().unwrap_or("{}".into())) + .bind(i.created_at) + .bind(OffsetDateTime::now_utc()) + .bind(i.deleted_at) + .execute(&self.pool) + .await + .extend_err()?; + + Ok(()) + } + + pub async fn find_identity( + &self, + id: Option<&Uuid>, + address_value: Option<&str>, + address_is_validated: Option<bool>, + session_token_hash: &Option<Vec<u8>>, + ) -> Result<Vec<Identity>, StoreError> { + let sqls = get_sqls(&self.sqls_root, FIND_IDENTITY); + println!("{:?}", id); + println!("{:?}", address_value); + println!("{:?}", address_is_validated); + println!("{:?}", session_token_hash); + let rs = sqlx::query_as::< + _, + ( + Uuid, + Option<String>, + OffsetDateTime, + OffsetDateTime, + Option<OffsetDateTime>, + ), + >(&sqls[0]) + .bind(id) + .bind(address_value) + .bind(address_is_validated) + .bind(session_token_hash) + .fetch_all(&self.pool) + .await + .extend_err()?; + + let mut res = vec![]; + for (id, metadata, created_at, updated_at, deleted_at) in rs.into_iter() { + res.push(Identity { + id, + address_validations: vec![], + credentials: vec![], + rules: vec![], + metadata, + created_at, + deleted_at, + }) + } + + Ok(res) + } + + pub async fn write_session(&self, s: &Session, token_hash: &[u8]) -> Result<(), StoreError> { + let sqls = get_sqls(&self.sqls_root, WRITE_SESSION); + sqlx::query(&sqls[0]) + .bind(s.identity_id) + .bind(token_hash) + .bind(s.created_at) + .bind(s.expired_at) + .bind(s.revoked_at) + .execute(&self.pool) + .await + .extend_err()?; + + Ok(()) + } + + pub async fn find_session( + &self, + token: Vec<u8>, + identity_id: Option<&Uuid>, + ) -> Result<Vec<Session>, StoreError> { + let sqls = get_sqls(&self.sqls_root, FIND_SESSION); + let rs = + sqlx::query_as::<_, (Uuid, OffsetDateTime, OffsetDateTime, Option<OffsetDateTime>)>( + &sqls[0], + ) + .bind(token) + .bind(identity_id) + .bind(OffsetDateTime::now_utc()) + .bind(OffsetDateTime::now_utc()) + .fetch_all(&self.pool) + .await + .extend_err()?; + + let mut res = vec![]; + for (identity_id, created_at, expired_at, revoked_at) in rs.into_iter() { + res.push(Session { + identity_id, + token: vec![], + created_at, + expired_at, + revoked_at, + }); + } + Ok(res) + } +} + +fn get_sqls(root: &str, file: &str) -> Vec<String> { + SQLS.get(root) + .unwrap() + .get(file) + .unwrap() + .split("--") + .map(|p| p.to_string()) + .collect() +} diff --git a/crates/secd/src/client/types.rs b/crates/secd/src/client/types.rs deleted file mode 100644 index bacade4..0000000 --- a/crates/secd/src/client/types.rs +++ /dev/null @@ -1,3 +0,0 @@ -pub(crate) struct Email { - address: String, -} |
