aboutsummaryrefslogtreecommitdiff
path: root/crates/secd/src/client
diff options
context:
space:
mode:
authorbenj <benj@rse8.com>2022-12-24 00:43:38 -0800
committerbenj <benj@rse8.com>2022-12-24 00:43:38 -0800
commitc2268c285648ef02ece04de0d9df0813c6d70ff8 (patch)
treef84ec7ee42f97d78245f26d0c5a0c559cd35e89d /crates/secd/src/client
parentde6339da72af1d61ca5908b780977e2b037ce014 (diff)
downloadsecdiam-c2268c285648ef02ece04de0d9df0813c6d70ff8.tar
secdiam-c2268c285648ef02ece04de0d9df0813c6d70ff8.tar.gz
secdiam-c2268c285648ef02ece04de0d9df0813c6d70ff8.tar.bz2
secdiam-c2268c285648ef02ece04de0d9df0813c6d70ff8.tar.lz
secdiam-c2268c285648ef02ece04de0d9df0813c6d70ff8.tar.xz
secdiam-c2268c285648ef02ece04de0d9df0813c6d70ff8.tar.zst
secdiam-c2268c285648ef02ece04de0d9df0813c6d70ff8.zip
refactor everything with more abstraction and a nicer interface
Diffstat (limited to 'crates/secd/src/client')
-rw-r--r--crates/secd/src/client/email.rs67
-rw-r--r--crates/secd/src/client/email/mod.rs68
-rw-r--r--crates/secd/src/client/mod.rs422
-rw-r--r--crates/secd/src/client/sqldb.rs632
-rw-r--r--crates/secd/src/client/store/mod.rs190
-rw-r--r--crates/secd/src/client/store/sql_db.rs526
-rw-r--r--crates/secd/src/client/types.rs3
7 files changed, 785 insertions, 1123 deletions
diff --git a/crates/secd/src/client/email.rs b/crates/secd/src/client/email.rs
deleted file mode 100644
index 2712037..0000000
--- a/crates/secd/src/client/email.rs
+++ /dev/null
@@ -1,67 +0,0 @@
-use std::{path::PathBuf, str::FromStr};
-
-use email_address::EmailAddress;
-use time::OffsetDateTime;
-
-use super::{
- EmailMessenger, EmailMessengerError, EmailType, EMAIL_TEMPLATE_DEFAULT_LOGIN,
- EMAIL_TEMPLATE_DEFAULT_SIGNUP,
-};
-
-pub(crate) struct LocalEmailStubber {
- pub(crate) email_template_login: Option<String>,
- pub(crate) email_template_signup: Option<String>,
-}
-
-#[async_trait::async_trait]
-impl EmailMessenger for LocalEmailStubber {
- // TODO: this module really shouldn't be called client, it should be called services... the client is sqlx/mailgun/sns wrapper or whatever...
- async fn send_email(
- &self,
- email_address: &str,
- validation_id: &str,
- secret_code: &str,
- t: EmailType,
- ) -> Result<(), EmailMessengerError> {
- let login_template = self
- .email_template_login
- .clone()
- .unwrap_or(EMAIL_TEMPLATE_DEFAULT_LOGIN.to_string());
- let signup_template = self
- .email_template_signup
- .clone()
- .unwrap_or(EMAIL_TEMPLATE_DEFAULT_SIGNUP.to_string());
-
- let replace_template = |s: &str| {
- s.replace(
- "%secd_link%",
- &format!("{}?code={}", validation_id, secret_code),
- )
- .replace("%secd_email_address%", email_address)
- .replace("%secd_code%", secret_code)
- };
-
- if !EmailAddress::is_valid(email_address) {
- return Err(EmailMessengerError::InvalidEmailAddress);
- }
-
- let body = match t {
- EmailType::Login => replace_template(&login_template),
- EmailType::Signup => replace_template(&signup_template),
- };
-
- // TODO: write to the system mailbox instead?
- std::fs::write(
- PathBuf::from_str(&format!(
- "/tmp/{}_{}.localmail",
- OffsetDateTime::now_utc(),
- validation_id
- ))
- .map_err(|_| EmailMessengerError::Unknown)?,
- body,
- )
- .map_err(|_| EmailMessengerError::FailedToSendEmail)?;
-
- Ok(())
- }
-}
diff --git a/crates/secd/src/client/email/mod.rs b/crates/secd/src/client/email/mod.rs
new file mode 100644
index 0000000..915d18c
--- /dev/null
+++ b/crates/secd/src/client/email/mod.rs
@@ -0,0 +1,68 @@
+use email_address::EmailAddress;
+use lettre::Transport;
+use log::error;
+use std::collections::HashMap;
+
+#[derive(Debug, thiserror::Error, derive_more::Display)]
+pub enum EmailMessengerError {
+ FailedToSendEmail,
+}
+
+pub struct EmailValidationMessage {
+ pub recipient: EmailAddress,
+ pub subject: String,
+ pub body: String,
+}
+
+#[async_trait::async_trait]
+pub(crate) trait EmailMessenger {
+ async fn send_email(
+ &self,
+ email_address: &EmailAddress,
+ template: &str,
+ template_vars: HashMap<&str, &str>,
+ ) -> Result<(), EmailMessengerError>;
+}
+
+pub(crate) struct LocalMailer {}
+
+#[async_trait::async_trait]
+impl EmailMessenger for LocalMailer {
+ async fn send_email(
+ &self,
+ email_address: &EmailAddress,
+ template: &str,
+ template_vars: HashMap<&str, &str>,
+ ) -> Result<(), EmailMessengerError> {
+ todo!()
+ }
+}
+
+#[async_trait::async_trait]
+pub(crate) trait Sendable {
+ async fn send(&self) -> Result<(), EmailMessengerError>;
+}
+
+#[async_trait::async_trait]
+impl Sendable for EmailValidationMessage {
+ // TODO: We need to break this up as before, especially so we can feature
+ // gate unwanted things like Lettre...
+ async fn send(&self) -> Result<(), EmailMessengerError> {
+ // TODO: Get these things from the template...
+ let email = lettre::Message::builder()
+ .from("BranchControl <iam@branchcontrol.com>".parse().unwrap())
+ .reply_to("BranchControl <iam@branchcontrol.com>".parse().unwrap())
+ .to(self.recipient.to_string().parse().unwrap())
+ .subject(self.subject.clone())
+ .body(self.body.clone())
+ .unwrap();
+
+ let mailer = lettre::SmtpTransport::unencrypted_localhost();
+
+ mailer.send(&email).map_err(|e| {
+ error!("failed to send email {:?}", e);
+ EmailMessengerError::FailedToSendEmail
+ })?;
+ Ok(())
+ }
+}
diff --git a/crates/secd/src/client/mod.rs b/crates/secd/src/client/mod.rs
index 38426ef..e5272fd 100644
--- a/crates/secd/src/client/mod.rs
+++ b/crates/secd/src/client/mod.rs
@@ -1,422 +1,2 @@
pub(crate) mod email;
-pub(crate) mod sqldb;
-pub(crate) mod types;
-
-use std::{collections::HashMap, str::FromStr};
-
-use super::Identity;
-use crate::{
- EmailValidation, OauthProvider, OauthProviderName, OauthResponseType, OauthValidation, Session,
- SessionSecret, ValidationRequestId, ValidationType,
-};
-
-use email_address::EmailAddress;
-use lazy_static::lazy_static;
-use sqlx::{
- database::HasValueRef, sqlite::SqliteRow, ColumnIndex, Database, Decode, FromRow, Row, Sqlite,
- Type,
-};
-use thiserror::Error;
-use time::OffsetDateTime;
-use url::Url;
-use uuid::Uuid;
-
-pub enum EmailType {
- Login,
- Signup,
-}
-
-#[derive(Error, Debug, derive_more::Display)]
-pub enum EmailMessengerError {
- InvalidEmailAddress,
- FailedToSendEmail,
- Unknown,
-}
-
-#[async_trait::async_trait]
-pub trait EmailMessenger {
- async fn send_email(
- &self,
- email_address: &str,
- validation_id: &str,
- secret_code: &str,
- t: EmailType,
- ) -> Result<(), EmailMessengerError>;
-}
-
-#[derive(Error, Debug, derive_more::Display)]
-pub enum StoreError {
- SqlxError(#[from] sqlx::Error),
- CodeAppearsMoreThanOnce,
- CodeDoesNotExist(String),
- IdentityIdMustExistInvariant,
- TooManyValidations,
- TooManyIdentitiesFound,
- NoEmailValidationFound,
- OauthProviderDoesNotExist(OauthProviderName),
- OauthValidationDoesNotExist(ValidationRequestId),
- Other(String),
-}
-
-const EMAIL_TEMPLATE_DEFAULT_LOGIN: &str = "You requested a login link. Please click the following link %secd_code% to login as %secd_email_address%";
-const EMAIL_TEMPLATE_DEFAULT_SIGNUP: &str = "You requested a sign up. Please click the following link %secd_code% to complete your sign up and validate %secd_email_address%";
-
-const ERR_MSG_MIGRATION_FAILED: &str = "Failed to execute migrations. This appears to be a secd issue. File a bug at https://www.github.com/secd-lib";
-
-const SQLITE: &str = "sqlite";
-const PGSQL: &str = "pgsql";
-
-const WRITE_IDENTITY: &str = "write_identity";
-const WRITE_EMAIL_VALIDATION: &str = "write_email_validation";
-const FIND_EMAIL_VALIDATION: &str = "find_email_validation";
-const READ_VALIDATION_TYPE: &str = "read_validation_type";
-
-const WRITE_EMAIL: &str = "write_email";
-
-const READ_IDENTITY: &str = "read_identity";
-const FIND_IDENTITY: &str = "find_identity";
-const FIND_IDENTITY_BY_CODE: &str = "find_identity_by_code";
-
-const READ_IDENTITY_RAW_ID: &str = "read_identity_raw_id";
-const READ_EMAIL_RAW_ID: &str = "read_email_raw_id";
-
-const WRITE_SESSION: &str = "write_session";
-const READ_SESSION: &str = "read_session";
-
-const WRITE_OAUTH_PROVIDER: &str = "write_oauth_provider";
-const READ_OAUTH_PROVIDER: &str = "read_oauth_provider";
-const WRITE_OAUTH_VALIDATION: &str = "write_oauth_validation";
-const READ_OAUTH_VALIDATION: &str = "read_oauth_validation";
-
-lazy_static! {
- static ref SQLS: HashMap<&'static str, HashMap<&'static str, &'static str>> = {
- let sqlite_sqls: HashMap<&'static str, &'static str> = [
- (
- WRITE_IDENTITY,
- include_str!("../../store/sqlite/sql/write_identity.sql"),
- ),
- (
- WRITE_EMAIL_VALIDATION,
- include_str!("../../store/sqlite/sql/write_email_validation.sql"),
- ),
- (
- WRITE_EMAIL,
- include_str!("../../store/sqlite/sql/write_email.sql"),
- ),
- (
- READ_IDENTITY,
- include_str!("../../store/sqlite/sql/read_identity.sql"),
- ),
- (
- FIND_IDENTITY,
- include_str!("../../store/sqlite/sql/find_identity.sql"),
- ),
- (
- FIND_IDENTITY_BY_CODE,
- include_str!("../../store/sqlite/sql/find_identity_by_code.sql"),
- ),
- (
- READ_IDENTITY_RAW_ID,
- include_str!("../../store/sqlite/sql/read_identity_raw_id.sql"),
- ),
- (
- READ_EMAIL_RAW_ID,
- include_str!("../../store/sqlite/sql/read_email_raw_id.sql"),
- ),
- (
- WRITE_SESSION,
- include_str!("../../store/sqlite/sql/write_session.sql"),
- ),
- (
- READ_SESSION,
- include_str!("../../store/sqlite/sql/read_session.sql"),
- ),
- (
- FIND_EMAIL_VALIDATION,
- include_str!("../../store/sqlite/sql/find_email_validation.sql"),
- ),
- (
- WRITE_OAUTH_PROVIDER,
- include_str!("../../store/sqlite/sql/write_oauth_provider.sql"),
- ),
- (
- READ_OAUTH_PROVIDER,
- include_str!("../../store/sqlite/sql/read_oauth_provider.sql"),
- ),
- (
- READ_OAUTH_VALIDATION,
- include_str!("../../store/sqlite/sql/read_oauth_validation.sql"),
- ),
- (
- WRITE_OAUTH_VALIDATION,
- include_str!("../../store/sqlite/sql/write_oauth_validation.sql"),
- ),
- (
- READ_VALIDATION_TYPE,
- include_str!("../../store/sqlite/sql/read_validation_type.sql"),
- ),
- ]
- .iter()
- .cloned()
- .collect();
-
- let pg_sqls: HashMap<&'static str, &'static str> = [
- (
- WRITE_IDENTITY,
- include_str!("../../store/pg/sql/write_identity.sql"),
- ),
- (
- WRITE_EMAIL_VALIDATION,
- include_str!("../../store/pg/sql/write_email_validation.sql"),
- ),
- (
- WRITE_EMAIL,
- include_str!("../../store/pg/sql/write_email.sql"),
- ),
- (
- READ_IDENTITY,
- include_str!("../../store/pg/sql/read_identity.sql"),
- ),
- (
- FIND_IDENTITY,
- include_str!("../../store/pg/sql/find_identity.sql"),
- ),
- (
- FIND_IDENTITY_BY_CODE,
- include_str!("../../store/pg/sql/find_identity_by_code.sql"),
- ),
- (
- READ_IDENTITY_RAW_ID,
- include_str!("../../store/pg/sql/read_identity_raw_id.sql"),
- ),
- (
- READ_EMAIL_RAW_ID,
- include_str!("../../store/pg/sql/read_email_raw_id.sql"),
- ),
- (
- WRITE_SESSION,
- include_str!("../../store/pg/sql/write_session.sql"),
- ),
- (
- READ_SESSION,
- include_str!("../../store/pg/sql/read_session.sql"),
- ),
- (
- FIND_EMAIL_VALIDATION,
- include_str!("../../store/pg/sql/find_email_validation.sql"),
- ),
- (
- WRITE_OAUTH_PROVIDER,
- include_str!("../../store/pg/sql/write_oauth_provider.sql"),
- ),
- (
- READ_OAUTH_PROVIDER,
- include_str!("../../store/pg/sql/read_oauth_provider.sql"),
- ),
- (
- READ_OAUTH_VALIDATION,
- include_str!("../../store/pg/sql/read_oauth_validation.sql"),
- ),
- (
- WRITE_OAUTH_VALIDATION,
- include_str!("../../store/pg/sql/write_oauth_validation.sql"),
- ),
- (
- READ_VALIDATION_TYPE,
- include_str!("../../store/pg/sql/read_validation_type.sql"),
- ),
- ]
- .iter()
- .cloned()
- .collect();
-
- let sqls: HashMap<&'static str, HashMap<&'static str, &'static str>> =
- [(SQLITE, sqlite_sqls), (PGSQL, pg_sqls)]
- .iter()
- .cloned()
- .collect();
- sqls
- };
-}
-
-impl<'a, R: Row> FromRow<'a, R> for OauthValidation
-where
- &'a str: ColumnIndex<R>,
- OauthProviderName: Decode<'a, R::Database> + Type<R::Database>,
- OauthResponseType: Decode<'a, R::Database> + Type<R::Database>,
- OffsetDateTime: Decode<'a, R::Database> + Type<R::Database>,
- String: Decode<'a, R::Database> + Type<R::Database>,
- Uuid: Decode<'a, R::Database> + Type<R::Database>,
-{
- fn from_row(row: &'a R) -> Result<Self, sqlx::Error> {
- let id: Option<Uuid> = row.try_get("oauth_validation_public_id")?;
- let identity_id: Option<Uuid> = row.try_get("identity_public_id")?;
- let access_token: Option<String> = row.try_get("access_token")?;
- let raw_response: Option<String> = row.try_get("raw_response")?;
- let created_at: Option<OffsetDateTime> = row.try_get("created_at")?;
- let validated_at: Option<OffsetDateTime> = row.try_get("validated_at")?;
- let revoked_at: Option<OffsetDateTime> = row.try_get("revoked_at")?;
- let deleted_at: Option<OffsetDateTime> = row.try_get("deleted_at")?;
-
- let op_name: Option<OauthProviderName> = row.try_get("oauth_provider_name")?;
- let op_flow: Option<String> = row.try_get("oauth_provider_flow")?;
- let op_base_url: Option<String> = row.try_get("oauth_provider_base_url")?;
- let op_response_type: Option<OauthResponseType> =
- row.try_get("oauth_provider_response_type")?;
- let op_default_scope: Option<String> = row.try_get("oauth_provider_default_scope")?;
- let op_client_id: Option<String> = row.try_get("oauth_provider_client_id")?;
- let op_client_secret: Option<String> = row.try_get("oauth_provider_client_secret")?;
- let op_redirect_url: Option<String> = row.try_get("oauth_provider_redirect_url")?;
- let op_created_at: Option<OffsetDateTime> = row.try_get("oauth_provider_created_at")?;
- let op_deleted_at: Option<OffsetDateTime> = row.try_get("oauth_provider_deleted_at")?;
-
- let op_base_url = op_base_url
- .map(|s| Url::from_str(&s).ok())
- .flatten()
- .ok_or(sqlx::Error::ColumnDecode {
- index: "oauth_provider_base_url".into(),
- source: "secd".into(),
- })?;
-
- let op_redirect_url = op_redirect_url
- .map(|s| Url::from_str(&s).ok())
- .flatten()
- .ok_or(sqlx::Error::ColumnDecode {
- index: "oauth_provider_redirect_url".into(),
- source: "secd".into(),
- })?;
-
- Ok(OauthValidation {
- id,
- identity_id,
- access_token,
- raw_response,
- created_at: created_at.ok_or(sqlx::Error::ColumnDecode {
- index: "created_at".into(),
- source: "secd".into(),
- })?,
- validated_at,
- revoked_at,
- deleted_at,
- oauth_provider: OauthProvider {
- name: op_name.unwrap(),
- flow: op_flow,
- base_url: op_base_url,
- response: op_response_type.ok_or(sqlx::Error::ColumnDecode {
- index: "oauth_provider_response_type".into(),
- source: "secd".into(),
- })?,
- default_scope: op_default_scope.ok_or(sqlx::Error::ColumnDecode {
- index: "oauth_provider_default_scope".into(),
- source: "secd".into(),
- })?,
- client_id: op_client_id.ok_or(sqlx::Error::ColumnDecode {
- index: "oauth_provider_client_id".into(),
- source: "secd".into(),
- })?,
- client_secret: op_client_secret.ok_or(sqlx::Error::ColumnDecode {
- index: "oauth_provider_client_secret".into(),
- source: "secd".into(),
- })?,
- redirect_url: op_redirect_url,
- created_at: op_created_at.ok_or(sqlx::Error::ColumnDecode {
- index: "oauth_provider_created_at".into(),
- source: "secd".into(),
- })?,
- deleted_at: op_deleted_at,
- },
- })
- }
-}
-
-impl<'a, D: Database> Decode<'a, D> for OauthProviderName
-where
- &'a str: Decode<'a, D>,
-{
- fn decode(
- value: <D as HasValueRef<'a>>::ValueRef,
- ) -> Result<Self, Box<dyn ::std::error::Error + 'static + Send + Sync>> {
- let v = <&str as Decode<D>>::decode(value)?;
- <OauthProviderName as clap::ValueEnum>::from_str(v, true)
- .map_err(|_| "OauthProviderName should exist and decode to a program value.".into())
- }
-}
-
-impl<D: Database> Type<D> for OauthProviderName
-where
- str: Type<D>,
-{
- fn type_info() -> D::TypeInfo {
- <&str as Type<D>>::type_info()
- }
-}
-
-impl<'a, D: Database> Decode<'a, D> for OauthResponseType
-where
- &'a str: Decode<'a, D>,
-{
- fn decode(
- value: <D as HasValueRef<'a>>::ValueRef,
- ) -> Result<Self, Box<dyn ::std::error::Error + 'static + Send + Sync>> {
- let v = <&str as Decode<D>>::decode(value)?;
- <OauthResponseType as clap::ValueEnum>::from_str(v, true)
- .map_err(|_| "OauthResponseType should exist and decode to a program value.".into())
- }
-}
-
-impl<D: Database> Type<D> for OauthResponseType
-where
- str: Type<D>,
-{
- fn type_info() -> D::TypeInfo {
- <&str as Type<D>>::type_info()
- }
-}
-
-#[async_trait::async_trait]
-pub trait Store {
- async fn write_email(&self, email_address: &str) -> Result<(), StoreError>;
-
- async fn find_email_validation(
- &self,
- validation_id: Option<&Uuid>,
- code: Option<&str>,
- ) -> Result<EmailValidation, StoreError>;
- async fn write_email_validation(
- &self,
- ev: &EmailValidation,
- // TODO: Make this write an EmailValidation
- ) -> anyhow::Result<Uuid>;
-
- async fn find_identity(
- &self,
- identity_id: Option<&Uuid>,
- email: Option<&str>,
- ) -> anyhow::Result<Option<Identity>>;
- async fn find_identity_by_code(&self, code: &str) -> Result<Identity, StoreError>;
- async fn write_identity(&self, i: &Identity) -> Result<(), StoreError>;
- async fn read_identity(&self, identity_id: &Uuid) -> Result<Identity, StoreError>;
-
- async fn write_session(&self, session: &Session) -> Result<(), StoreError>;
- async fn read_session(&self, secret: &SessionSecret) -> Result<Session, StoreError>;
-
- async fn write_oauth_provider(&self, provider: &OauthProvider) -> Result<(), StoreError>;
- async fn read_oauth_provider(
- &self,
- provider: &OauthProviderName,
- flow: Option<String>,
- ) -> Result<OauthProvider, StoreError>;
- async fn write_oauth_validation(
- &self,
- validation: &OauthValidation,
- ) -> anyhow::Result<ValidationRequestId>;
- async fn read_oauth_validation(
- &self,
- validation_id: &ValidationRequestId,
- ) -> anyhow::Result<OauthValidation>;
-
- async fn find_validation_type(
- &self,
- validation_id: &ValidationRequestId,
- ) -> anyhow::Result<ValidationType>;
-}
+pub(crate) mod store;
diff --git a/crates/secd/src/client/sqldb.rs b/crates/secd/src/client/sqldb.rs
deleted file mode 100644
index 6751ef6..0000000
--- a/crates/secd/src/client/sqldb.rs
+++ /dev/null
@@ -1,632 +0,0 @@
-use std::{str::FromStr, sync::Arc};
-
-use super::{
- EmailValidation, Identity, OauthProvider, OauthProviderName, OauthResponseType, Session,
- SessionSecret, Store, StoreError, ERR_MSG_MIGRATION_FAILED, FIND_EMAIL_VALIDATION,
- FIND_IDENTITY, FIND_IDENTITY_BY_CODE, PGSQL, READ_EMAIL_RAW_ID, READ_IDENTITY_RAW_ID,
- READ_OAUTH_PROVIDER, READ_OAUTH_VALIDATION, READ_SESSION, READ_VALIDATION_TYPE, SQLITE, SQLS,
- WRITE_EMAIL, WRITE_EMAIL_VALIDATION, WRITE_IDENTITY, WRITE_OAUTH_PROVIDER,
- WRITE_OAUTH_VALIDATION, WRITE_SESSION,
-};
-use crate::{util, OauthValidation, ValidationRequestId, ValidationType};
-use anyhow::bail;
-use log::{debug, error};
-use openssl::sha::Sha256;
-use sqlx::{
- self, database::HasArguments, ColumnIndex, Database, Decode, Encode, Executor, IntoArguments,
- Pool, Postgres, Sqlite, Transaction, Type,
-};
-use time::OffsetDateTime;
-use url::Url;
-use uuid::Uuid;
-
-fn get_sqls(root: &str, file: &str) -> Vec<String> {
- SQLS.get(root)
- .unwrap()
- .get(file)
- .unwrap()
- .split("--")
- .map(|p| p.to_string())
- .collect()
-}
-
-fn hash_secret(secret: &str) -> Vec<u8> {
- let mut hasher = Sha256::new();
- hasher.update(secret.as_bytes());
- hasher.finish().to_vec()
-}
-
-struct SqlClient<D>
-where
- D: sqlx::Database,
-{
- pool: sqlx::Pool<D>,
- sqls_root: String,
-}
-
-impl<D> SqlClient<D>
-where
- D: sqlx::Database,
- for<'c> <D as HasArguments<'c>>::Arguments: IntoArguments<'c, D>,
- for<'c> i64: Decode<'c, D> + Type<D>,
- for<'c> &'c str: Decode<'c, D> + Type<D>,
- for<'c> &'c str: Encode<'c, D> + Type<D>,
- for<'c> usize: ColumnIndex<<D as Database>::Row>,
- for<'c> Uuid: Decode<'c, D> + Type<D>,
- for<'c> Uuid: Encode<'c, D> + Type<D>,
- for<'c> &'c Pool<D>: Executor<'c, Database = D>,
-{
- async fn read_identity_raw_id(&self, id: &Uuid) -> Result<i64, StoreError> {
- let sqls = get_sqls(&self.sqls_root, READ_IDENTITY_RAW_ID);
-
- Ok(sqlx::query_as::<_, (i64,)>(&sqls[0])
- .bind(id)
- .fetch_one(&self.pool)
- .await
- .map_err(util::log_err_sqlx)?
- .0)
- }
-
- async fn read_email_raw_id(&self, address: &str) -> Result<i64, StoreError> {
- let sqls = get_sqls(&self.sqls_root, READ_EMAIL_RAW_ID);
-
- Ok(sqlx::query_as::<_, (i64,)>(&sqls[0])
- .bind(address)
- .fetch_one(&self.pool)
- .await
- .map_err(util::log_err_sqlx)?
- .0)
- }
-}
-
-#[async_trait::async_trait]
-impl<D> Store for SqlClient<D>
-where
- D: sqlx::Database,
- for<'c> <D as HasArguments<'c>>::Arguments: IntoArguments<'c, D>,
- for<'c> bool: Decode<'c, D> + Type<D>,
- for<'c> bool: Encode<'c, D> + Type<D>,
- for<'c> i64: Decode<'c, D> + Type<D>,
- for<'c> i64: Encode<'c, D> + Type<D>,
- for<'c> i32: Decode<'c, D> + Type<D>,
- for<'c> i32: Encode<'c, D> + Type<D>,
- for<'c> OffsetDateTime: Decode<'c, D> + Type<D>,
- for<'c> OffsetDateTime: Encode<'c, D> + Type<D>,
- for<'c> &'c str: ColumnIndex<<D as Database>::Row>,
- for<'c> &'c str: Decode<'c, D> + Type<D>,
- for<'c> &'c str: Encode<'c, D> + Type<D>,
- for<'c> Option<&'c str>: Decode<'c, D> + Type<D>,
- for<'c> Option<&'c str>: Encode<'c, D> + Type<D>,
- for<'c> String: Decode<'c, D> + Type<D>,
- for<'c> String: Encode<'c, D> + Type<D>,
- for<'c> Option<String>: Decode<'c, D> + Type<D>,
- for<'c> Option<String>: Encode<'c, D> + Type<D>,
- for<'c> OauthProviderName: Decode<'c, D> + Type<D>,
- for<'c> OauthResponseType: Decode<'c, D> + Type<D>,
- for<'c> usize: ColumnIndex<<D as Database>::Row>,
- for<'c> Uuid: Decode<'c, D> + Type<D>,
- for<'c> Uuid: Encode<'c, D> + Type<D>,
- for<'c> &'c [u8]: Encode<'c, D> + Type<D>,
- for<'c> Option<&'c Uuid>: Encode<'c, D> + Type<D>,
- for<'c> Option<&'c Vec<u8>>: Encode<'c, D> + Type<D>,
- for<'c> Option<OffsetDateTime>: Decode<'c, D> + Type<D>,
- for<'c> Option<OffsetDateTime>: Encode<'c, D> + Type<D>,
- for<'c> &'c Pool<D>: Executor<'c, Database = D>,
- for<'c> &'c mut Transaction<'c, D>: Executor<'c, Database = D>,
-{
- async fn write_email(&self, email_address: &str) -> Result<(), StoreError> {
- let sqls = get_sqls(&self.sqls_root, WRITE_EMAIL);
-
- sqlx::query(&sqls[0])
- .bind(email_address)
- .execute(&self.pool)
- .await
- .map_err(util::log_err_sqlx)?;
-
- Ok(())
- }
-
- async fn find_email_validation(
- &self,
- validation_id: Option<&Uuid>,
- code: Option<&str>,
- ) -> Result<EmailValidation, StoreError> {
- let sqls = get_sqls(&self.sqls_root, FIND_EMAIL_VALIDATION);
- let mut rows = sqlx::query_as::<_, EmailValidation>(&sqls[0])
- .bind(validation_id)
- .bind(code)
- .fetch_all(&self.pool)
- .await
- .map_err(util::log_err_sqlx)?;
-
- match rows.len() {
- 0 => Err(StoreError::NoEmailValidationFound),
- 1 => Ok(rows.swap_remove(0)),
- _ => Err(StoreError::TooManyValidations),
- }
- }
-
- async fn write_email_validation(&self, ev: &EmailValidation) -> anyhow::Result<Uuid> {
- let sqls = get_sqls(&self.sqls_root, WRITE_EMAIL_VALIDATION);
-
- let email_id = self.read_email_raw_id(&ev.email_address).await?;
- let validation_id = ev.id.unwrap_or(Uuid::new_v4());
- sqlx::query(&sqls[0])
- .bind(validation_id)
- .bind(email_id)
- .bind(&ev.code)
- .bind(ev.is_oauth_derived)
- .bind(ev.created_at)
- .bind(ev.validated_at)
- .bind(ev.expired_at)
- .execute(&self.pool)
- .await
- .map_err(util::log_err_sqlx)?;
-
- if ev.identity_id.is_some() || ev.revoked_at.is_some() || ev.deleted_at.is_some() {
- sqlx::query(&sqls[1])
- .bind(ev.identity_id.as_ref())
- .bind(validation_id)
- .bind(ev.revoked_at)
- .bind(ev.deleted_at)
- .execute(&self.pool)
- .await
- .map_err(util::log_err_sqlx)?;
- }
-
- Ok(validation_id)
- }
-
- async fn find_identity(
- &self,
- id: Option<&Uuid>,
- email: Option<&str>,
- ) -> anyhow::Result<Option<Identity>> {
- let sqls = get_sqls(&self.sqls_root, FIND_IDENTITY);
- Ok(
- match sqlx::query_as::<_, Identity>(&sqls[0])
- .bind(id)
- .bind(email)
- .fetch_all(&self.pool)
- .await
- {
- Ok(mut is) => match is.len() {
- // if only 1 found, then that's fine
- // if multiple are fond, then if they all have the same id, that's okay
- 1 => {
- let i = is.swap_remove(0);
- match i.deleted_at {
- Some(t) if t > OffsetDateTime::now_utc() => Some(i),
- None => Some(i),
- _ => None,
- }
- }
- 0 => None,
- _ => {
- match is
- .iter()
- .filter(|&i| i.id != is[0].id)
- .collect::<Vec<&Identity>>()
- .len()
- {
- 0 => Some(is.swap_remove(0)),
- _ => bail!(StoreError::TooManyIdentitiesFound),
- }
- }
- },
- Err(sqlx::Error::RowNotFound) => None,
- Err(e) => bail!(StoreError::SqlxError(e)),
- },
- )
- }
-
- async fn find_identity_by_code(&self, code: &str) -> Result<Identity, StoreError> {
- let sqls = get_sqls(&self.sqls_root, FIND_IDENTITY_BY_CODE);
-
- let rows = sqlx::query_as::<_, (i32,)>(&sqls[0])
- .bind(code)
- .fetch_all(&self.pool)
- .await
- .map_err(util::log_err_sqlx)?;
-
- if rows.len() == 0 {
- return Err(StoreError::CodeDoesNotExist(code.to_string()));
- }
-
- if rows.len() != 1 {
- return Err(StoreError::CodeAppearsMoreThanOnce);
- }
-
- let identity_email_id = rows.get(0).unwrap().0;
-
- // TODO: IF we expand beyond email codes, then we'll need to join against a bunch of identity tables.
- // but since a single code was found, only one of them should pop...
- Ok(sqlx::query_as::<_, Identity>(&sqls[1])
- .bind(identity_email_id)
- .fetch_one(&self.pool)
- .await
- .map_err(util::log_err_sqlx)?)
- }
-
- async fn write_identity(&self, i: &Identity) -> Result<(), StoreError> {
- let sqls = get_sqls(&self.sqls_root, WRITE_IDENTITY);
- sqlx::query(&sqls[0])
- .bind(i.id)
- .bind(i.data.clone())
- .bind(i.created_at)
- .execute(&self.pool)
- .await
- .map_err(|e| {
- error!("write_identity_failure");
- error!("{:?}", e);
- e
- })?;
-
- Ok(())
- }
- async fn read_identity(&self, id: &Uuid) -> Result<Identity, StoreError> {
- let identity = sqlx::query_as::<_, Identity>(
- "
-select identity_public_id, data, created_at from identity where identity_public_id = ?",
- )
- .bind(id)
- .fetch_one(&self.pool)
- .await
- .map_err(util::log_err_sqlx)?;
-
- Ok(identity)
- }
-
- async fn write_session(&self, session: &Session) -> Result<(), StoreError> {
- let sqls = get_sqls(&self.sqls_root, WRITE_SESSION);
-
- let secret_hash = session.secret.as_ref().map(|s| hash_secret(s));
-
- sqlx::query(&sqls[0])
- .bind(&session.identity_id)
- .bind(secret_hash.as_ref())
- .bind(session.created_at)
- .bind(session.expired_at)
- .bind(session.revoked_at)
- .execute(&self.pool)
- .await
- .map_err(util::log_err_sqlx)?;
-
- Ok(())
- }
- async fn read_session(&self, secret: &SessionSecret) -> Result<Session, StoreError> {
- let sqls = get_sqls(&self.sqls_root, READ_SESSION);
-
- let secret_hash = hash_secret(secret);
- let mut session = sqlx::query_as::<_, Session>(&sqls[0])
- .bind(&secret_hash[..])
- .fetch_one(&self.pool)
- .await
- .map_err(util::log_err_sqlx)?;
-
- // This should do nothing other than updated touched_at, and then
- // clear the plaintext secret
- session.secret = Some(secret.to_string());
- self.write_session(&session).await?;
- session.secret = None;
-
- Ok(session)
- }
-
- async fn write_oauth_provider(&self, provider: &OauthProvider) -> Result<(), StoreError> {
- let sqls = get_sqls(&self.sqls_root, WRITE_OAUTH_PROVIDER);
- sqlx::query(&sqls[0])
- .bind(&provider.name.to_string())
- .bind(&provider.flow)
- .bind(&provider.base_url.to_string())
- .bind(&provider.response.to_string())
- .bind(&provider.default_scope)
- .bind(&provider.client_id)
- // TODO: encrypt secret before writing
- .bind(&provider.client_secret)
- .bind(&provider.redirect_url.to_string())
- .bind(provider.created_at)
- .bind(provider.deleted_at)
- .execute(&self.pool)
- .await
- .map_err(util::log_err_sqlx)?;
- Ok(())
- }
-
- async fn read_oauth_provider(
- &self,
- provider: &OauthProviderName,
- flow: Option<String>,
- ) -> Result<OauthProvider, StoreError> {
- let sqls = get_sqls(&self.sqls_root, READ_OAUTH_PROVIDER);
- let flow = flow.unwrap_or("default".into());
- debug!("provider: {:?}, flow: {:?}", provider, flow);
- // TODO: Write the generic FromRow impl for OauthProvider...
- let res = sqlx::query_as::<
- _,
- (
- String,
- String,
- String,
- String,
- String,
- String,
- String,
- OffsetDateTime,
- Option<OffsetDateTime>,
- ),
- >(&sqls[0])
- .bind(&provider.to_string())
- .bind(&flow)
- .fetch_one(&self.pool)
- .await
- .map_err(util::log_err_sqlx)?;
-
- debug!("res: {:?}", res);
-
- Ok(OauthProvider {
- name: provider.clone(),
- flow: Some(res.0),
- base_url: Url::from_str(&res.1)
- .map_err(|_| StoreError::OauthProviderDoesNotExist(*provider))?,
- response: OauthResponseType::from_str(&res.2)
- .map_err(|_| StoreError::OauthProviderDoesNotExist(*provider))?,
- default_scope: res.3,
- client_id: res.4,
- client_secret: res.5,
- redirect_url: Url::from_str(&res.6)
- .map_err(|_| StoreError::OauthProviderDoesNotExist(*provider))?,
- created_at: res.7,
- deleted_at: res.8,
- })
- }
- async fn write_oauth_validation(
- &self,
- v: &OauthValidation,
- ) -> anyhow::Result<ValidationRequestId> {
- let sqls = get_sqls(&self.sqls_root, WRITE_OAUTH_VALIDATION);
-
- let validation_id = v.id.unwrap_or(Uuid::new_v4());
- sqlx::query(&sqls[0])
- .bind(validation_id)
- .bind(v.oauth_provider.name.to_string())
- .bind(v.oauth_provider.flow.clone())
- .bind(v.access_token.clone())
- .bind(v.raw_response.clone())
- .bind(v.created_at)
- .bind(v.validated_at)
- .execute(&self.pool)
- .await?;
-
- if v.identity_id.is_some() || v.revoked_at.is_some() || v.deleted_at.is_some() {
- sqlx::query(&sqls[1])
- .bind(v.identity_id.as_ref())
- .bind(validation_id)
- .bind(v.revoked_at)
- .bind(v.deleted_at)
- .execute(&self.pool)
- .await?;
- }
-
- Ok(validation_id)
- }
- async fn read_oauth_validation(
- &self,
- validation_id: &ValidationRequestId,
- ) -> anyhow::Result<OauthValidation> {
- let sqls = get_sqls(&self.sqls_root, READ_OAUTH_VALIDATION);
-
- let mut es = sqlx::query_as::<_, OauthValidation>(&sqls[0])
- .bind(validation_id)
- .fetch_all(&self.pool)
- .await?;
-
- if es.len() != 1 {
- bail!(StoreError::OauthValidationDoesNotExist(
- validation_id.clone()
- ));
- }
-
- Ok(es.swap_remove(0))
- }
- async fn find_validation_type(
- &self,
- validation_id: &ValidationRequestId,
- ) -> anyhow::Result<ValidationType> {
- let sqls = get_sqls(&self.sqls_root, READ_VALIDATION_TYPE);
-
- let mut es = sqlx::query_as::<_, (String,)>(&sqls[0])
- .bind(validation_id)
- .fetch_all(&self.pool)
- .await
- .map_err(util::log_err_sqlx)?;
-
- match es.len() {
- 1 => Ok(ValidationType::from_str(&es.swap_remove(0).0)?),
- _ => bail!(StoreError::Other(
- "expected a single validation but recieved 0 or multiple validations".into()
- )),
- }
- }
-}
-
-pub struct PgClient {
- sql: SqlClient<Postgres>,
-}
-
-impl PgClient {
- pub async fn new(pool: sqlx::Pool<Postgres>) -> Arc<dyn Store + Send + Sync + 'static> {
- sqlx::migrate!("store/pg/migrations")
- .run(&pool)
- .await
- .expect(ERR_MSG_MIGRATION_FAILED);
-
- Arc::new(PgClient {
- sql: SqlClient {
- pool,
- sqls_root: PGSQL.to_string(),
- },
- })
- }
-}
-
-#[async_trait::async_trait]
-impl Store for PgClient {
- async fn write_email(&self, email_address: &str) -> Result<(), StoreError> {
- self.sql.write_email(email_address).await
- }
- async fn find_email_validation(
- &self,
- validation_id: Option<&Uuid>,
- code: Option<&str>,
- ) -> Result<EmailValidation, StoreError> {
- self.sql.find_email_validation(validation_id, code).await
- }
- async fn write_email_validation(&self, ev: &EmailValidation) -> anyhow::Result<Uuid> {
- self.sql.write_email_validation(ev).await
- }
- async fn find_identity(
- &self,
- identity_id: Option<&Uuid>,
- email: Option<&str>,
- ) -> anyhow::Result<Option<Identity>> {
- self.sql.find_identity(identity_id, email).await
- }
- async fn find_identity_by_code(&self, code: &str) -> Result<Identity, StoreError> {
- self.sql.find_identity_by_code(code).await
- }
- async fn write_identity(&self, i: &Identity) -> Result<(), StoreError> {
- self.sql.write_identity(i).await
- }
- async fn read_identity(&self, identity_id: &Uuid) -> Result<Identity, StoreError> {
- self.sql.read_identity(identity_id).await
- }
- async fn write_session(&self, session: &Session) -> Result<(), StoreError> {
- self.sql.write_session(session).await
- }
- async fn read_session(&self, secret: &SessionSecret) -> Result<Session, StoreError> {
- self.sql.read_session(secret).await
- }
- async fn write_oauth_provider(&self, provider: &OauthProvider) -> Result<(), StoreError> {
- self.sql.write_oauth_provider(provider).await
- }
- async fn read_oauth_provider(
- &self,
- provider: &OauthProviderName,
- flow: Option<String>,
- ) -> Result<OauthProvider, StoreError> {
- self.sql.read_oauth_provider(provider, flow).await
- }
- async fn write_oauth_validation(
- &self,
- validation: &OauthValidation,
- ) -> anyhow::Result<ValidationRequestId> {
- self.sql.write_oauth_validation(validation).await
- }
- async fn read_oauth_validation(
- &self,
- validation_id: &ValidationRequestId,
- ) -> anyhow::Result<OauthValidation> {
- self.sql.read_oauth_validation(validation_id).await
- }
- async fn find_validation_type(
- &self,
- validation_id: &ValidationRequestId,
- ) -> anyhow::Result<ValidationType> {
- self.sql.find_validation_type(validation_id).await
- }
-}
-
-pub struct SqliteClient {
- sql: SqlClient<Sqlite>,
-}
-
-impl SqliteClient {
- pub async fn new(pool: sqlx::Pool<Sqlite>) -> Arc<dyn Store + Send + Sync + 'static> {
- sqlx::migrate!("store/sqlite/migrations")
- .run(&pool)
- .await
- .expect(ERR_MSG_MIGRATION_FAILED);
-
- sqlx::query("pragma foreign_keys = on")
- .execute(&pool)
- .await
- .expect(
- "Failed to initialize FK pragma. File a bug at https://www.github.com/secd-lib",
- );
-
- Arc::new(SqliteClient {
- sql: SqlClient {
- pool,
- sqls_root: SQLITE.to_string(),
- },
- })
- }
-}
-
-#[async_trait::async_trait]
-impl Store for SqliteClient {
- async fn write_email(&self, email_address: &str) -> Result<(), StoreError> {
- self.sql.write_email(email_address).await
- }
- async fn find_email_validation(
- &self,
- validation_id: Option<&Uuid>,
- code: Option<&str>,
- ) -> Result<EmailValidation, StoreError> {
- self.sql.find_email_validation(validation_id, code).await
- }
- async fn write_email_validation(&self, ev: &EmailValidation) -> anyhow::Result<Uuid> {
- self.sql.write_email_validation(ev).await
- }
- async fn find_identity(
- &self,
- identity_id: Option<&Uuid>,
- email: Option<&str>,
- ) -> anyhow::Result<Option<Identity>> {
- self.sql.find_identity(identity_id, email).await
- }
- async fn find_identity_by_code(&self, code: &str) -> Result<Identity, StoreError> {
- self.sql.find_identity_by_code(code).await
- }
- async fn write_identity(&self, i: &Identity) -> Result<(), StoreError> {
- self.sql.write_identity(i).await
- }
- async fn read_identity(&self, identity_id: &Uuid) -> Result<Identity, StoreError> {
- self.sql.read_identity(identity_id).await
- }
- async fn write_session(&self, session: &Session) -> Result<(), StoreError> {
- self.sql.write_session(session).await
- }
- async fn read_session(&self, secret: &SessionSecret) -> Result<Session, StoreError> {
- self.sql.read_session(secret).await
- }
- async fn write_oauth_provider(&self, provider: &OauthProvider) -> Result<(), StoreError> {
- self.sql.write_oauth_provider(provider).await
- }
- async fn read_oauth_provider(
- &self,
- provider: &OauthProviderName,
- flow: Option<String>,
- ) -> Result<OauthProvider, StoreError> {
- self.sql.read_oauth_provider(provider, flow).await
- }
- async fn write_oauth_validation(
- &self,
- validation: &OauthValidation,
- ) -> anyhow::Result<ValidationRequestId> {
- self.sql.write_oauth_validation(validation).await
- }
- async fn read_oauth_validation(
- &self,
- validation_id: &ValidationRequestId,
- ) -> anyhow::Result<OauthValidation> {
- self.sql.read_oauth_validation(validation_id).await
- }
- async fn find_validation_type(
- &self,
- validation_id: &ValidationRequestId,
- ) -> anyhow::Result<ValidationType> {
- self.sql.find_validation_type(validation_id).await
- }
-}
diff --git a/crates/secd/src/client/store/mod.rs b/crates/secd/src/client/store/mod.rs
new file mode 100644
index 0000000..b93fd84
--- /dev/null
+++ b/crates/secd/src/client/store/mod.rs
@@ -0,0 +1,190 @@
+pub(crate) mod sql_db;
+
+use email_address::EmailAddress;
+use sqlx::{Postgres, Sqlite};
+use std::sync::Arc;
+use time::OffsetDateTime;
+use uuid::Uuid;
+
+use crate::{
+ util, Address, AddressType, AddressValidation, Identity, IdentityId, PhoneNumber, Session,
+ SessionToken,
+};
+
+use self::sql_db::SqlClient;
+
+#[derive(Debug, thiserror::Error, derive_more::Display)]
+pub enum StoreError {
+ SqlClientError(#[from] sqlx::Error),
+ StoreValueCannotBeParsedInvariant,
+ IdempotentCheckAlreadyExists,
+}
+
+#[async_trait::async_trait(?Send)]
+pub trait Store {
+ fn get_type(&self) -> StoreType;
+}
+
+pub enum StoreType {
+ Postgres { c: Arc<SqlClient<Postgres>> },
+ Sqlite { c: Arc<SqlClient<Sqlite>> },
+}
+
+#[async_trait::async_trait(?Send)]
+pub(crate) trait Storable<'a> {
+ type Item;
+ type Lens;
+
+ async fn write(&self, store: Arc<dyn Store>) -> Result<(), StoreError>;
+ async fn find(
+ store: Arc<dyn Store>,
+ lens: &'a Self::Lens,
+ ) -> Result<Vec<Self::Item>, StoreError>;
+}
+
+pub(crate) trait Lens {}
+
+pub(crate) struct AddressLens<'a> {
+ pub id: Option<&'a Uuid>,
+ pub t: Option<&'a AddressType>,
+}
+impl<'a> Lens for AddressLens<'a> {}
+
+pub(crate) struct AddressValidationLens<'a> {
+ pub id: Option<&'a Uuid>,
+}
+impl<'a> Lens for AddressValidationLens<'a> {}
+
+pub(crate) struct IdentityLens<'a> {
+ pub id: Option<&'a Uuid>,
+ pub address_type: Option<&'a AddressType>,
+ pub validated_address: Option<bool>,
+ pub session_token_hash: Option<Vec<u8>>,
+}
+impl<'a> Lens for IdentityLens<'a> {}
+
+pub(crate) struct SessionLens<'a> {
+ pub token_hash: Option<&'a Vec<u8>>,
+ pub identity_id: Option<&'a IdentityId>,
+}
+impl<'a> Lens for SessionLens<'a> {}
+
+#[async_trait::async_trait(?Send)]
+impl<'a> Storable<'a> for Address {
+ type Item = Address;
+ type Lens = AddressLens<'a>;
+
+ async fn write(&self, store: Arc<dyn Store>) -> Result<(), StoreError> {
+ match store.get_type() {
+ StoreType::Postgres { c } => c.write_address(self).await?,
+ StoreType::Sqlite { c } => c.write_address(self).await?,
+ }
+ Ok(())
+ }
+ async fn find(
+ store: Arc<dyn Store>,
+ lens: &'a Self::Lens,
+ ) -> Result<Vec<Self::Item>, StoreError> {
+ let typ = lens.t.map(|at| at.to_string().clone());
+ let typ = typ.as_deref();
+
+ let val = lens.t.and_then(|at| at.get_value());
+ let val = val.as_deref();
+
+ Ok(match store.get_type() {
+ StoreType::Postgres { c } => c.find_address(lens.id, typ, val).await?,
+ StoreType::Sqlite { c } => c.find_address(lens.id, typ, val).await?,
+ })
+ }
+}
+
+#[async_trait::async_trait(?Send)]
+impl<'a> Storable<'a> for AddressValidation {
+ type Item = AddressValidation;
+ type Lens = AddressValidationLens<'a>;
+
+ async fn write(&self, store: Arc<dyn Store>) -> Result<(), StoreError> {
+ match store.get_type() {
+ StoreType::Sqlite { c } => c.write_address_validation(self).await?,
+ StoreType::Postgres { c } => c.write_address_validation(self).await?,
+ }
+ Ok(())
+ }
+ async fn find(
+ store: Arc<dyn Store>,
+ lens: &'a Self::Lens,
+ ) -> Result<Vec<Self::Item>, StoreError> {
+ Ok(match store.get_type() {
+ StoreType::Postgres { c } => c.find_address_validation(lens.id).await?,
+ StoreType::Sqlite { c } => c.find_address_validation(lens.id).await?,
+ })
+ }
+}
+
+#[async_trait::async_trait(?Send)]
+impl<'a> Storable<'a> for Identity {
+ type Item = Identity;
+ type Lens = IdentityLens<'a>;
+
+ async fn write(&self, store: Arc<dyn Store>) -> Result<(), StoreError> {
+ match store.get_type() {
+ StoreType::Postgres { c } => c.write_identity(self).await?,
+ StoreType::Sqlite { c } => c.write_identity(self).await?,
+ }
+ Ok(())
+ }
+ async fn find(
+ store: Arc<dyn Store>,
+ lens: &'a Self::Lens,
+ ) -> Result<Vec<Self::Item>, StoreError> {
+ let val = lens.address_type.and_then(|at| at.get_value());
+ let val = val.as_deref();
+
+ Ok(match store.get_type() {
+ StoreType::Postgres { c } => {
+ c.find_identity(
+ lens.id,
+ val,
+ lens.validated_address,
+ &lens.session_token_hash,
+ )
+ .await?
+ }
+ StoreType::Sqlite { c } => {
+ c.find_identity(
+ lens.id,
+ val,
+ lens.validated_address,
+ &lens.session_token_hash,
+ )
+ .await?
+ }
+ })
+ }
+}
+
+#[async_trait::async_trait(?Send)]
+impl<'a> Storable<'a> for Session {
+ type Item = Session;
+ type Lens = SessionLens<'a>;
+
+ async fn write(&self, store: Arc<dyn Store>) -> Result<(), StoreError> {
+ let token_hash = util::hash(&self.token);
+ match store.get_type() {
+ StoreType::Postgres { c } => c.write_session(self, &token_hash).await?,
+ StoreType::Sqlite { c } => c.write_session(self, &token_hash).await?,
+ }
+ Ok(())
+ }
+ async fn find(
+ store: Arc<dyn Store>,
+ lens: &'a Self::Lens,
+ ) -> Result<Vec<Self::Item>, StoreError> {
+ let token = lens.token_hash.map(|t| t.clone()).unwrap_or(vec![]);
+
+ Ok(match store.get_type() {
+ StoreType::Postgres { c } => c.find_session(token, lens.identity_id).await?,
+ StoreType::Sqlite { c } => c.find_session(token, lens.identity_id).await?,
+ })
+ }
+}
diff --git a/crates/secd/src/client/store/sql_db.rs b/crates/secd/src/client/store/sql_db.rs
new file mode 100644
index 0000000..6d84301
--- /dev/null
+++ b/crates/secd/src/client/store/sql_db.rs
@@ -0,0 +1,526 @@
+use std::{str::FromStr, sync::Arc};
+
+use email_address::EmailAddress;
+use serde_json::value::RawValue;
+use sqlx::{
+ database::HasArguments, types::Json, ColumnIndex, Database, Decode, Encode, Executor,
+ IntoArguments, Pool, Transaction, Type,
+};
+use time::OffsetDateTime;
+use uuid::Uuid;
+
+use crate::{
+ Address, AddressType, AddressValidation, AddressValidationMethod, Identity, Session,
+ SessionToken,
+};
+
+use lazy_static::lazy_static;
+use sqlx::{Postgres, Sqlite};
+use std::collections::HashMap;
+
+use super::{
+ AddressLens, AddressValidationLens, IdentityLens, SessionLens, Storable, Store, StoreError,
+ StoreType,
+};
+
+const SQLITE: &str = "sqlite";
+const PGSQL: &str = "pgsql";
+
+const WRITE_ADDRESS: &str = "write_address";
+const FIND_ADDRESS: &str = "find_address";
+const WRITE_ADDRESS_VALIDATION: &str = "write_address_validation";
+const FIND_ADDRESS_VALIDATION: &str = "find_address_validation";
+const WRITE_IDENTITY: &str = "write_identity";
+const FIND_IDENTITY: &str = "find_identity";
+const WRITE_SESSION: &str = "write_session";
+const FIND_SESSION: &str = "find_session";
+
+const ERR_MSG_MIGRATION_FAILED: &str = "Failed to apply secd migrations to a sql db. File a bug at https://www.github.com/branchcontrol/secdiam";
+
+lazy_static! {
+ static ref SQLS: HashMap<&'static str, HashMap<&'static str, &'static str>> = {
+ let sqlite_sqls: HashMap<&'static str, &'static str> = [
+ (
+ WRITE_ADDRESS,
+ include_str!("../../../store/sqlite/sql/write_address.sql"),
+ ),
+ (
+ FIND_ADDRESS,
+ include_str!("../../../store/sqlite/sql/find_address.sql"),
+ ),
+ (
+ WRITE_ADDRESS_VALIDATION,
+ include_str!("../../../store/sqlite/sql/write_address_validation.sql"),
+ ),
+ (
+ FIND_ADDRESS_VALIDATION,
+ include_str!("../../../store/sqlite/sql/find_address_validation.sql"),
+ ),
+ (
+ WRITE_IDENTITY,
+ include_str!("../../../store/sqlite/sql/write_identity.sql"),
+ ),
+ (
+ FIND_IDENTITY,
+ include_str!("../../../store/sqlite/sql/find_identity.sql"),
+ ),
+ (
+ WRITE_SESSION,
+ include_str!("../../../store/sqlite/sql/write_session.sql"),
+ ),
+ (
+ FIND_SESSION,
+ include_str!("../../../store/sqlite/sql/find_session.sql"),
+ ),
+ ]
+ .iter()
+ .cloned()
+ .collect();
+
+ let pg_sqls: HashMap<&'static str, &'static str> = [
+ (
+ WRITE_ADDRESS,
+ include_str!("../../../store/pg/sql/write_address.sql"),
+ ),
+ (
+ FIND_ADDRESS,
+ include_str!("../../../store/pg/sql/find_address.sql"),
+ ),
+ (
+ WRITE_ADDRESS_VALIDATION,
+ include_str!("../../../store/pg/sql/write_address_validation.sql"),
+ ),
+ (
+ FIND_ADDRESS_VALIDATION,
+ include_str!("../../../store/pg/sql/find_address_validation.sql"),
+ ),
+ (
+ WRITE_IDENTITY,
+ include_str!("../../../store/pg/sql/write_identity.sql"),
+ ),
+ (
+ FIND_IDENTITY,
+ include_str!("../../../store/pg/sql/find_identity.sql"),
+ ),
+ (
+ WRITE_SESSION,
+ include_str!("../../../store/pg/sql/write_session.sql"),
+ ),
+ (
+ FIND_SESSION,
+ include_str!("../../../store/pg/sql/find_session.sql"),
+ ),
+ ]
+ .iter()
+ .cloned()
+ .collect();
+
+ let sqls: HashMap<&'static str, HashMap<&'static str, &'static str>> =
+ [(SQLITE, sqlite_sqls), (PGSQL, pg_sqls)]
+ .iter()
+ .cloned()
+ .collect();
+ sqls
+ };
+}
+
+pub trait SqlxResultExt<T> {
+ fn extend_err(self) -> Result<T, StoreError>;
+}
+
+impl<T> SqlxResultExt<T> for Result<T, sqlx::Error> {
+ fn extend_err(self) -> Result<T, StoreError> {
+ if let Err(sqlx::Error::Database(dbe)) = &self {
+ if dbe.code() == Some("23505".into()) {
+ return Err(StoreError::IdempotentCheckAlreadyExists);
+ }
+ }
+ self.map_err(|e| StoreError::SqlClientError(e))
+ }
+}
+
+pub struct SqlClient<D>
+where
+ D: sqlx::Database,
+{
+ pool: sqlx::Pool<D>,
+ sqls_root: String,
+}
+
+pub struct PgClient {
+ sql: Arc<SqlClient<Postgres>>,
+}
+impl Store for PgClient {
+ fn get_type(&self) -> StoreType {
+ StoreType::Postgres {
+ c: self.sql.clone(),
+ }
+ }
+}
+
+impl PgClient {
+ pub async fn new(pool: sqlx::Pool<Postgres>) -> Arc<dyn Store + Send + Sync + 'static> {
+ sqlx::migrate!("store/pg/migrations")
+ .run(&pool)
+ .await
+ .expect(ERR_MSG_MIGRATION_FAILED);
+
+ Arc::new(PgClient {
+ sql: Arc::new(SqlClient {
+ pool,
+ sqls_root: PGSQL.to_string(),
+ }),
+ })
+ }
+}
+
+pub struct SqliteClient {
+ sql: Arc<SqlClient<Sqlite>>,
+}
+impl Store for SqliteClient {
+ fn get_type(&self) -> StoreType {
+ StoreType::Sqlite {
+ c: self.sql.clone(),
+ }
+ }
+}
+
+impl SqliteClient {
+ pub async fn new(pool: sqlx::Pool<Sqlite>) -> Arc<dyn Store + Send + Sync + 'static> {
+ sqlx::migrate!("store/sqlite/migrations")
+ .run(&pool)
+ .await
+ .expect(ERR_MSG_MIGRATION_FAILED);
+
+ sqlx::query("pragma foreign_keys = on")
+ .execute(&pool)
+ .await
+ .expect(
+ "Failed to initialize FK pragma. File a bug at https://www.github.com/secd-lib",
+ );
+
+ Arc::new(SqliteClient {
+ sql: Arc::new(SqlClient {
+ pool,
+ sqls_root: SQLITE.to_string(),
+ }),
+ })
+ }
+}
+
+impl<D> SqlClient<D>
+where
+ D: sqlx::Database,
+ for<'c> &'c Pool<D>: Executor<'c, Database = D>,
+ for<'c> &'c mut Transaction<'c, D>: Executor<'c, Database = D>,
+ for<'c> <D as HasArguments<'c>>::Arguments: IntoArguments<'c, D>,
+ for<'c> bool: Decode<'c, D> + Type<D>,
+ for<'c> bool: Encode<'c, D> + Type<D>,
+ for<'c> Option<bool>: Decode<'c, D> + Type<D>,
+ for<'c> Option<bool>: Encode<'c, D> + Type<D>,
+ for<'c> i64: Decode<'c, D> + Type<D>,
+ for<'c> i64: Encode<'c, D> + Type<D>,
+ for<'c> i32: Decode<'c, D> + Type<D>,
+ for<'c> i32: Encode<'c, D> + Type<D>,
+ for<'c> OffsetDateTime: Decode<'c, D> + Type<D>,
+ for<'c> OffsetDateTime: Encode<'c, D> + Type<D>,
+ for<'c> &'c str: ColumnIndex<<D as Database>::Row>,
+ for<'c> &'c str: Decode<'c, D> + Type<D>,
+ for<'c> &'c str: Encode<'c, D> + Type<D>,
+ for<'c> Option<&'c str>: Decode<'c, D> + Type<D>,
+ for<'c> Option<&'c str>: Encode<'c, D> + Type<D>,
+ for<'c> String: Decode<'c, D> + Type<D>,
+ for<'c> String: Encode<'c, D> + Type<D>,
+ for<'c> Option<String>: Decode<'c, D> + Type<D>,
+ for<'c> Option<String>: Encode<'c, D> + Type<D>,
+ for<'c> usize: ColumnIndex<<D as Database>::Row>,
+ for<'c> Uuid: Decode<'c, D> + Type<D>,
+ for<'c> Uuid: Encode<'c, D> + Type<D>,
+ for<'c> &'c [u8]: Encode<'c, D> + Type<D>,
+ for<'c> Option<&'c Uuid>: Encode<'c, D> + Type<D>,
+ for<'c> Vec<u8>: Encode<'c, D> + Type<D>,
+ for<'c> Vec<u8>: Decode<'c, D> + Type<D>,
+ for<'c> Option<Vec<u8>>: Encode<'c, D> + Type<D>,
+ for<'c> Option<Vec<u8>>: Decode<'c, D> + Type<D>,
+ for<'c> Option<OffsetDateTime>: Decode<'c, D> + Type<D>,
+ for<'c> Option<OffsetDateTime>: Encode<'c, D> + Type<D>,
+{
+ pub async fn write_address(&self, a: &Address) -> Result<(), StoreError> {
+ let sqls = get_sqls(&self.sqls_root, WRITE_ADDRESS);
+ sqlx::query(&sqls[0])
+ .bind(a.id)
+ .bind(a.t.to_string())
+ .bind(match &a.t {
+ AddressType::Email { email_address } => {
+ email_address.as_ref().map(ToString::to_string)
+ }
+ AddressType::Sms { phone_number } => phone_number.clone(),
+ })
+ .bind(a.created_at)
+ .execute(&self.pool)
+ .await
+ .extend_err()?;
+
+ Ok(())
+ }
+
+ pub async fn find_address(
+ &self,
+ id: Option<&Uuid>,
+ typ: Option<&str>,
+ val: Option<&str>,
+ ) -> Result<Vec<Address>, StoreError> {
+ let sqls = get_sqls(&self.sqls_root, FIND_ADDRESS);
+ let res = sqlx::query_as::<_, (Uuid, String, String, OffsetDateTime)>(&sqls[0])
+ .bind(id)
+ .bind(typ)
+ .bind(val)
+ .fetch_all(&self.pool)
+ .await
+ .extend_err()?;
+
+ let mut addresses = vec![];
+ for (id, typ, val, created_at) in res.into_iter() {
+ let t = match AddressType::from_str(&typ)
+ .map_err(|_| StoreError::StoreValueCannotBeParsedInvariant)?
+ {
+ AddressType::Email { .. } => AddressType::Email {
+ email_address: Some(
+ EmailAddress::from_str(&val)
+ .map_err(|_| StoreError::StoreValueCannotBeParsedInvariant)?,
+ ),
+ },
+ AddressType::Sms { .. } => AddressType::Sms {
+ phone_number: Some(val.clone()),
+ },
+ };
+
+ addresses.push(Address { id, t, created_at });
+ }
+
+ Ok(addresses)
+ }
+
+ pub async fn write_address_validation(&self, v: &AddressValidation) -> Result<(), StoreError> {
+ let sqls = get_sqls(&self.sqls_root, WRITE_ADDRESS_VALIDATION);
+ sqlx::query(&sqls[0])
+ .bind(v.id)
+ .bind(v.identity_id.as_ref())
+ .bind(v.address.id)
+ .bind(v.method.to_string())
+ .bind(v.hashed_token.clone())
+ .bind(v.hashed_code.clone())
+ .bind(v.attempts)
+ .bind(v.created_at)
+ .bind(v.expires_at)
+ .bind(v.revoked_at)
+ .bind(v.validated_at)
+ .execute(&self.pool)
+ .await
+ .extend_err()?;
+
+ Ok(())
+ }
+
+ pub async fn find_address_validation(
+ &self,
+ id: Option<&Uuid>,
+ ) -> Result<Vec<AddressValidation>, StoreError> {
+ let sqls = get_sqls(&self.sqls_root, FIND_ADDRESS_VALIDATION);
+ let rs = sqlx::query_as::<
+ _,
+ (
+ Uuid,
+ Option<Uuid>,
+ Uuid,
+ String,
+ String,
+ OffsetDateTime,
+ String,
+ Vec<u8>,
+ Vec<u8>,
+ i32,
+ OffsetDateTime,
+ OffsetDateTime,
+ Option<OffsetDateTime>,
+ Option<OffsetDateTime>,
+ ),
+ >(&sqls[0])
+ .bind(id)
+ .fetch_all(&self.pool)
+ .await
+ .extend_err()?;
+
+ let mut res = vec![];
+ for (
+ id,
+ identity_id,
+ address_id,
+ address_typ,
+ address_val,
+ address_created_at,
+ method,
+ hashed_token,
+ hashed_code,
+ attempts,
+ created_at,
+ expires_at,
+ revoked_at,
+ validated_at,
+ ) in rs.into_iter()
+ {
+ let t = match AddressType::from_str(&address_typ)
+ .map_err(|_| StoreError::StoreValueCannotBeParsedInvariant)?
+ {
+ AddressType::Email { .. } => AddressType::Email {
+ email_address: Some(
+ EmailAddress::from_str(&address_val)
+ .map_err(|_| StoreError::StoreValueCannotBeParsedInvariant)?,
+ ),
+ },
+ AddressType::Sms { .. } => AddressType::Sms {
+ phone_number: Some(address_val.clone()),
+ },
+ };
+
+ res.push(AddressValidation {
+ id,
+ identity_id,
+ address: Address {
+ id: address_id,
+ t,
+ created_at: address_created_at,
+ },
+ method: AddressValidationMethod::from_str(&method)
+ .map_err(|_| StoreError::StoreValueCannotBeParsedInvariant)?,
+ created_at,
+ expires_at,
+ revoked_at,
+ validated_at,
+ attempts,
+ hashed_token,
+ hashed_code,
+ });
+ }
+
+ Ok(res)
+ }
+
+ pub async fn write_identity(&self, i: &Identity) -> Result<(), StoreError> {
+ let sqls = get_sqls(&self.sqls_root, WRITE_IDENTITY);
+ sqlx::query(&sqls[0])
+ .bind(i.id)
+ // TODO: validate this is actually Json somewhere way up the chain (when being deserialized)
+ .bind(i.metadata.clone().unwrap_or("{}".into()))
+ .bind(i.created_at)
+ .bind(OffsetDateTime::now_utc())
+ .bind(i.deleted_at)
+ .execute(&self.pool)
+ .await
+ .extend_err()?;
+
+ Ok(())
+ }
+
+ pub async fn find_identity(
+ &self,
+ id: Option<&Uuid>,
+ address_value: Option<&str>,
+ address_is_validated: Option<bool>,
+ session_token_hash: &Option<Vec<u8>>,
+ ) -> Result<Vec<Identity>, StoreError> {
+ let sqls = get_sqls(&self.sqls_root, FIND_IDENTITY);
+ println!("{:?}", id);
+ println!("{:?}", address_value);
+ println!("{:?}", address_is_validated);
+ println!("{:?}", session_token_hash);
+ let rs = sqlx::query_as::<
+ _,
+ (
+ Uuid,
+ Option<String>,
+ OffsetDateTime,
+ OffsetDateTime,
+ Option<OffsetDateTime>,
+ ),
+ >(&sqls[0])
+ .bind(id)
+ .bind(address_value)
+ .bind(address_is_validated)
+ .bind(session_token_hash)
+ .fetch_all(&self.pool)
+ .await
+ .extend_err()?;
+
+ let mut res = vec![];
+ for (id, metadata, created_at, updated_at, deleted_at) in rs.into_iter() {
+ res.push(Identity {
+ id,
+ address_validations: vec![],
+ credentials: vec![],
+ rules: vec![],
+ metadata,
+ created_at,
+ deleted_at,
+ })
+ }
+
+ Ok(res)
+ }
+
+ pub async fn write_session(&self, s: &Session, token_hash: &[u8]) -> Result<(), StoreError> {
+ let sqls = get_sqls(&self.sqls_root, WRITE_SESSION);
+ sqlx::query(&sqls[0])
+ .bind(s.identity_id)
+ .bind(token_hash)
+ .bind(s.created_at)
+ .bind(s.expired_at)
+ .bind(s.revoked_at)
+ .execute(&self.pool)
+ .await
+ .extend_err()?;
+
+ Ok(())
+ }
+
+ pub async fn find_session(
+ &self,
+ token: Vec<u8>,
+ identity_id: Option<&Uuid>,
+ ) -> Result<Vec<Session>, StoreError> {
+ let sqls = get_sqls(&self.sqls_root, FIND_SESSION);
+ let rs =
+ sqlx::query_as::<_, (Uuid, OffsetDateTime, OffsetDateTime, Option<OffsetDateTime>)>(
+ &sqls[0],
+ )
+ .bind(token)
+ .bind(identity_id)
+ .bind(OffsetDateTime::now_utc())
+ .bind(OffsetDateTime::now_utc())
+ .fetch_all(&self.pool)
+ .await
+ .extend_err()?;
+
+ let mut res = vec![];
+ for (identity_id, created_at, expired_at, revoked_at) in rs.into_iter() {
+ res.push(Session {
+ identity_id,
+ token: vec![],
+ created_at,
+ expired_at,
+ revoked_at,
+ });
+ }
+ Ok(res)
+ }
+}
+
+fn get_sqls(root: &str, file: &str) -> Vec<String> {
+ SQLS.get(root)
+ .unwrap()
+ .get(file)
+ .unwrap()
+ .split("--")
+ .map(|p| p.to_string())
+ .collect()
+}
diff --git a/crates/secd/src/client/types.rs b/crates/secd/src/client/types.rs
deleted file mode 100644
index bacade4..0000000
--- a/crates/secd/src/client/types.rs
+++ /dev/null
@@ -1,3 +0,0 @@
-pub(crate) struct Email {
- address: String,
-}