diff options
| author | benj <benj@rse8.com> | 2022-12-24 00:43:38 -0800 |
|---|---|---|
| committer | benj <benj@rse8.com> | 2022-12-24 00:43:38 -0800 |
| commit | c2268c285648ef02ece04de0d9df0813c6d70ff8 (patch) | |
| tree | f84ec7ee42f97d78245f26d0c5a0c559cd35e89d /crates/secd/src/client/store | |
| parent | de6339da72af1d61ca5908b780977e2b037ce014 (diff) | |
| download | secdiam-c2268c285648ef02ece04de0d9df0813c6d70ff8.tar secdiam-c2268c285648ef02ece04de0d9df0813c6d70ff8.tar.gz secdiam-c2268c285648ef02ece04de0d9df0813c6d70ff8.tar.bz2 secdiam-c2268c285648ef02ece04de0d9df0813c6d70ff8.tar.lz secdiam-c2268c285648ef02ece04de0d9df0813c6d70ff8.tar.xz secdiam-c2268c285648ef02ece04de0d9df0813c6d70ff8.tar.zst secdiam-c2268c285648ef02ece04de0d9df0813c6d70ff8.zip | |
refactor everything with more abstraction and a nicer interface
Diffstat (limited to '')
| -rw-r--r-- | crates/secd/src/client/store/mod.rs | 190 | ||||
| -rw-r--r-- | crates/secd/src/client/store/sql_db.rs | 526 |
2 files changed, 716 insertions, 0 deletions
diff --git a/crates/secd/src/client/store/mod.rs b/crates/secd/src/client/store/mod.rs new file mode 100644 index 0000000..b93fd84 --- /dev/null +++ b/crates/secd/src/client/store/mod.rs @@ -0,0 +1,190 @@ +pub(crate) mod sql_db; + +use email_address::EmailAddress; +use sqlx::{Postgres, Sqlite}; +use std::sync::Arc; +use time::OffsetDateTime; +use uuid::Uuid; + +use crate::{ + util, Address, AddressType, AddressValidation, Identity, IdentityId, PhoneNumber, Session, + SessionToken, +}; + +use self::sql_db::SqlClient; + +#[derive(Debug, thiserror::Error, derive_more::Display)] +pub enum StoreError { + SqlClientError(#[from] sqlx::Error), + StoreValueCannotBeParsedInvariant, + IdempotentCheckAlreadyExists, +} + +#[async_trait::async_trait(?Send)] +pub trait Store { + fn get_type(&self) -> StoreType; +} + +pub enum StoreType { + Postgres { c: Arc<SqlClient<Postgres>> }, + Sqlite { c: Arc<SqlClient<Sqlite>> }, +} + +#[async_trait::async_trait(?Send)] +pub(crate) trait Storable<'a> { + type Item; + type Lens; + + async fn write(&self, store: Arc<dyn Store>) -> Result<(), StoreError>; + async fn find( + store: Arc<dyn Store>, + lens: &'a Self::Lens, + ) -> Result<Vec<Self::Item>, StoreError>; +} + +pub(crate) trait Lens {} + +pub(crate) struct AddressLens<'a> { + pub id: Option<&'a Uuid>, + pub t: Option<&'a AddressType>, +} +impl<'a> Lens for AddressLens<'a> {} + +pub(crate) struct AddressValidationLens<'a> { + pub id: Option<&'a Uuid>, +} +impl<'a> Lens for AddressValidationLens<'a> {} + +pub(crate) struct IdentityLens<'a> { + pub id: Option<&'a Uuid>, + pub address_type: Option<&'a AddressType>, + pub validated_address: Option<bool>, + pub session_token_hash: Option<Vec<u8>>, +} +impl<'a> Lens for IdentityLens<'a> {} + +pub(crate) struct SessionLens<'a> { + pub token_hash: Option<&'a Vec<u8>>, + pub identity_id: Option<&'a IdentityId>, +} +impl<'a> Lens for SessionLens<'a> {} + +#[async_trait::async_trait(?Send)] +impl<'a> Storable<'a> for Address { + type Item = Address; + type Lens = AddressLens<'a>; + + async fn write(&self, store: Arc<dyn Store>) -> Result<(), StoreError> { + match store.get_type() { + StoreType::Postgres { c } => c.write_address(self).await?, + StoreType::Sqlite { c } => c.write_address(self).await?, + } + Ok(()) + } + async fn find( + store: Arc<dyn Store>, + lens: &'a Self::Lens, + ) -> Result<Vec<Self::Item>, StoreError> { + let typ = lens.t.map(|at| at.to_string().clone()); + let typ = typ.as_deref(); + + let val = lens.t.and_then(|at| at.get_value()); + let val = val.as_deref(); + + Ok(match store.get_type() { + StoreType::Postgres { c } => c.find_address(lens.id, typ, val).await?, + StoreType::Sqlite { c } => c.find_address(lens.id, typ, val).await?, + }) + } +} + +#[async_trait::async_trait(?Send)] +impl<'a> Storable<'a> for AddressValidation { + type Item = AddressValidation; + type Lens = AddressValidationLens<'a>; + + async fn write(&self, store: Arc<dyn Store>) -> Result<(), StoreError> { + match store.get_type() { + StoreType::Sqlite { c } => c.write_address_validation(self).await?, + StoreType::Postgres { c } => c.write_address_validation(self).await?, + } + Ok(()) + } + async fn find( + store: Arc<dyn Store>, + lens: &'a Self::Lens, + ) -> Result<Vec<Self::Item>, StoreError> { + Ok(match store.get_type() { + StoreType::Postgres { c } => c.find_address_validation(lens.id).await?, + StoreType::Sqlite { c } => c.find_address_validation(lens.id).await?, + }) + } +} + +#[async_trait::async_trait(?Send)] +impl<'a> Storable<'a> for Identity { + type Item = Identity; + type Lens = IdentityLens<'a>; + + async fn write(&self, store: Arc<dyn Store>) -> Result<(), StoreError> { + match store.get_type() { + StoreType::Postgres { c } => c.write_identity(self).await?, + StoreType::Sqlite { c } => c.write_identity(self).await?, + } + Ok(()) + } + async fn find( + store: Arc<dyn Store>, + lens: &'a Self::Lens, + ) -> Result<Vec<Self::Item>, StoreError> { + let val = lens.address_type.and_then(|at| at.get_value()); + let val = val.as_deref(); + + Ok(match store.get_type() { + StoreType::Postgres { c } => { + c.find_identity( + lens.id, + val, + lens.validated_address, + &lens.session_token_hash, + ) + .await? + } + StoreType::Sqlite { c } => { + c.find_identity( + lens.id, + val, + lens.validated_address, + &lens.session_token_hash, + ) + .await? + } + }) + } +} + +#[async_trait::async_trait(?Send)] +impl<'a> Storable<'a> for Session { + type Item = Session; + type Lens = SessionLens<'a>; + + async fn write(&self, store: Arc<dyn Store>) -> Result<(), StoreError> { + let token_hash = util::hash(&self.token); + match store.get_type() { + StoreType::Postgres { c } => c.write_session(self, &token_hash).await?, + StoreType::Sqlite { c } => c.write_session(self, &token_hash).await?, + } + Ok(()) + } + async fn find( + store: Arc<dyn Store>, + lens: &'a Self::Lens, + ) -> Result<Vec<Self::Item>, StoreError> { + let token = lens.token_hash.map(|t| t.clone()).unwrap_or(vec![]); + + Ok(match store.get_type() { + StoreType::Postgres { c } => c.find_session(token, lens.identity_id).await?, + StoreType::Sqlite { c } => c.find_session(token, lens.identity_id).await?, + }) + } +} diff --git a/crates/secd/src/client/store/sql_db.rs b/crates/secd/src/client/store/sql_db.rs new file mode 100644 index 0000000..6d84301 --- /dev/null +++ b/crates/secd/src/client/store/sql_db.rs @@ -0,0 +1,526 @@ +use std::{str::FromStr, sync::Arc}; + +use email_address::EmailAddress; +use serde_json::value::RawValue; +use sqlx::{ + database::HasArguments, types::Json, ColumnIndex, Database, Decode, Encode, Executor, + IntoArguments, Pool, Transaction, Type, +}; +use time::OffsetDateTime; +use uuid::Uuid; + +use crate::{ + Address, AddressType, AddressValidation, AddressValidationMethod, Identity, Session, + SessionToken, +}; + +use lazy_static::lazy_static; +use sqlx::{Postgres, Sqlite}; +use std::collections::HashMap; + +use super::{ + AddressLens, AddressValidationLens, IdentityLens, SessionLens, Storable, Store, StoreError, + StoreType, +}; + +const SQLITE: &str = "sqlite"; +const PGSQL: &str = "pgsql"; + +const WRITE_ADDRESS: &str = "write_address"; +const FIND_ADDRESS: &str = "find_address"; +const WRITE_ADDRESS_VALIDATION: &str = "write_address_validation"; +const FIND_ADDRESS_VALIDATION: &str = "find_address_validation"; +const WRITE_IDENTITY: &str = "write_identity"; +const FIND_IDENTITY: &str = "find_identity"; +const WRITE_SESSION: &str = "write_session"; +const FIND_SESSION: &str = "find_session"; + +const ERR_MSG_MIGRATION_FAILED: &str = "Failed to apply secd migrations to a sql db. File a bug at https://www.github.com/branchcontrol/secdiam"; + +lazy_static! { + static ref SQLS: HashMap<&'static str, HashMap<&'static str, &'static str>> = { + let sqlite_sqls: HashMap<&'static str, &'static str> = [ + ( + WRITE_ADDRESS, + include_str!("../../../store/sqlite/sql/write_address.sql"), + ), + ( + FIND_ADDRESS, + include_str!("../../../store/sqlite/sql/find_address.sql"), + ), + ( + WRITE_ADDRESS_VALIDATION, + include_str!("../../../store/sqlite/sql/write_address_validation.sql"), + ), + ( + FIND_ADDRESS_VALIDATION, + include_str!("../../../store/sqlite/sql/find_address_validation.sql"), + ), + ( + WRITE_IDENTITY, + include_str!("../../../store/sqlite/sql/write_identity.sql"), + ), + ( + FIND_IDENTITY, + include_str!("../../../store/sqlite/sql/find_identity.sql"), + ), + ( + WRITE_SESSION, + include_str!("../../../store/sqlite/sql/write_session.sql"), + ), + ( + FIND_SESSION, + include_str!("../../../store/sqlite/sql/find_session.sql"), + ), + ] + .iter() + .cloned() + .collect(); + + let pg_sqls: HashMap<&'static str, &'static str> = [ + ( + WRITE_ADDRESS, + include_str!("../../../store/pg/sql/write_address.sql"), + ), + ( + FIND_ADDRESS, + include_str!("../../../store/pg/sql/find_address.sql"), + ), + ( + WRITE_ADDRESS_VALIDATION, + include_str!("../../../store/pg/sql/write_address_validation.sql"), + ), + ( + FIND_ADDRESS_VALIDATION, + include_str!("../../../store/pg/sql/find_address_validation.sql"), + ), + ( + WRITE_IDENTITY, + include_str!("../../../store/pg/sql/write_identity.sql"), + ), + ( + FIND_IDENTITY, + include_str!("../../../store/pg/sql/find_identity.sql"), + ), + ( + WRITE_SESSION, + include_str!("../../../store/pg/sql/write_session.sql"), + ), + ( + FIND_SESSION, + include_str!("../../../store/pg/sql/find_session.sql"), + ), + ] + .iter() + .cloned() + .collect(); + + let sqls: HashMap<&'static str, HashMap<&'static str, &'static str>> = + [(SQLITE, sqlite_sqls), (PGSQL, pg_sqls)] + .iter() + .cloned() + .collect(); + sqls + }; +} + +pub trait SqlxResultExt<T> { + fn extend_err(self) -> Result<T, StoreError>; +} + +impl<T> SqlxResultExt<T> for Result<T, sqlx::Error> { + fn extend_err(self) -> Result<T, StoreError> { + if let Err(sqlx::Error::Database(dbe)) = &self { + if dbe.code() == Some("23505".into()) { + return Err(StoreError::IdempotentCheckAlreadyExists); + } + } + self.map_err(|e| StoreError::SqlClientError(e)) + } +} + +pub struct SqlClient<D> +where + D: sqlx::Database, +{ + pool: sqlx::Pool<D>, + sqls_root: String, +} + +pub struct PgClient { + sql: Arc<SqlClient<Postgres>>, +} +impl Store for PgClient { + fn get_type(&self) -> StoreType { + StoreType::Postgres { + c: self.sql.clone(), + } + } +} + +impl PgClient { + pub async fn new(pool: sqlx::Pool<Postgres>) -> Arc<dyn Store + Send + Sync + 'static> { + sqlx::migrate!("store/pg/migrations") + .run(&pool) + .await + .expect(ERR_MSG_MIGRATION_FAILED); + + Arc::new(PgClient { + sql: Arc::new(SqlClient { + pool, + sqls_root: PGSQL.to_string(), + }), + }) + } +} + +pub struct SqliteClient { + sql: Arc<SqlClient<Sqlite>>, +} +impl Store for SqliteClient { + fn get_type(&self) -> StoreType { + StoreType::Sqlite { + c: self.sql.clone(), + } + } +} + +impl SqliteClient { + pub async fn new(pool: sqlx::Pool<Sqlite>) -> Arc<dyn Store + Send + Sync + 'static> { + sqlx::migrate!("store/sqlite/migrations") + .run(&pool) + .await + .expect(ERR_MSG_MIGRATION_FAILED); + + sqlx::query("pragma foreign_keys = on") + .execute(&pool) + .await + .expect( + "Failed to initialize FK pragma. File a bug at https://www.github.com/secd-lib", + ); + + Arc::new(SqliteClient { + sql: Arc::new(SqlClient { + pool, + sqls_root: SQLITE.to_string(), + }), + }) + } +} + +impl<D> SqlClient<D> +where + D: sqlx::Database, + for<'c> &'c Pool<D>: Executor<'c, Database = D>, + for<'c> &'c mut Transaction<'c, D>: Executor<'c, Database = D>, + for<'c> <D as HasArguments<'c>>::Arguments: IntoArguments<'c, D>, + for<'c> bool: Decode<'c, D> + Type<D>, + for<'c> bool: Encode<'c, D> + Type<D>, + for<'c> Option<bool>: Decode<'c, D> + Type<D>, + for<'c> Option<bool>: Encode<'c, D> + Type<D>, + for<'c> i64: Decode<'c, D> + Type<D>, + for<'c> i64: Encode<'c, D> + Type<D>, + for<'c> i32: Decode<'c, D> + Type<D>, + for<'c> i32: Encode<'c, D> + Type<D>, + for<'c> OffsetDateTime: Decode<'c, D> + Type<D>, + for<'c> OffsetDateTime: Encode<'c, D> + Type<D>, + for<'c> &'c str: ColumnIndex<<D as Database>::Row>, + for<'c> &'c str: Decode<'c, D> + Type<D>, + for<'c> &'c str: Encode<'c, D> + Type<D>, + for<'c> Option<&'c str>: Decode<'c, D> + Type<D>, + for<'c> Option<&'c str>: Encode<'c, D> + Type<D>, + for<'c> String: Decode<'c, D> + Type<D>, + for<'c> String: Encode<'c, D> + Type<D>, + for<'c> Option<String>: Decode<'c, D> + Type<D>, + for<'c> Option<String>: Encode<'c, D> + Type<D>, + for<'c> usize: ColumnIndex<<D as Database>::Row>, + for<'c> Uuid: Decode<'c, D> + Type<D>, + for<'c> Uuid: Encode<'c, D> + Type<D>, + for<'c> &'c [u8]: Encode<'c, D> + Type<D>, + for<'c> Option<&'c Uuid>: Encode<'c, D> + Type<D>, + for<'c> Vec<u8>: Encode<'c, D> + Type<D>, + for<'c> Vec<u8>: Decode<'c, D> + Type<D>, + for<'c> Option<Vec<u8>>: Encode<'c, D> + Type<D>, + for<'c> Option<Vec<u8>>: Decode<'c, D> + Type<D>, + for<'c> Option<OffsetDateTime>: Decode<'c, D> + Type<D>, + for<'c> Option<OffsetDateTime>: Encode<'c, D> + Type<D>, +{ + pub async fn write_address(&self, a: &Address) -> Result<(), StoreError> { + let sqls = get_sqls(&self.sqls_root, WRITE_ADDRESS); + sqlx::query(&sqls[0]) + .bind(a.id) + .bind(a.t.to_string()) + .bind(match &a.t { + AddressType::Email { email_address } => { + email_address.as_ref().map(ToString::to_string) + } + AddressType::Sms { phone_number } => phone_number.clone(), + }) + .bind(a.created_at) + .execute(&self.pool) + .await + .extend_err()?; + + Ok(()) + } + + pub async fn find_address( + &self, + id: Option<&Uuid>, + typ: Option<&str>, + val: Option<&str>, + ) -> Result<Vec<Address>, StoreError> { + let sqls = get_sqls(&self.sqls_root, FIND_ADDRESS); + let res = sqlx::query_as::<_, (Uuid, String, String, OffsetDateTime)>(&sqls[0]) + .bind(id) + .bind(typ) + .bind(val) + .fetch_all(&self.pool) + .await + .extend_err()?; + + let mut addresses = vec![]; + for (id, typ, val, created_at) in res.into_iter() { + let t = match AddressType::from_str(&typ) + .map_err(|_| StoreError::StoreValueCannotBeParsedInvariant)? + { + AddressType::Email { .. } => AddressType::Email { + email_address: Some( + EmailAddress::from_str(&val) + .map_err(|_| StoreError::StoreValueCannotBeParsedInvariant)?, + ), + }, + AddressType::Sms { .. } => AddressType::Sms { + phone_number: Some(val.clone()), + }, + }; + + addresses.push(Address { id, t, created_at }); + } + + Ok(addresses) + } + + pub async fn write_address_validation(&self, v: &AddressValidation) -> Result<(), StoreError> { + let sqls = get_sqls(&self.sqls_root, WRITE_ADDRESS_VALIDATION); + sqlx::query(&sqls[0]) + .bind(v.id) + .bind(v.identity_id.as_ref()) + .bind(v.address.id) + .bind(v.method.to_string()) + .bind(v.hashed_token.clone()) + .bind(v.hashed_code.clone()) + .bind(v.attempts) + .bind(v.created_at) + .bind(v.expires_at) + .bind(v.revoked_at) + .bind(v.validated_at) + .execute(&self.pool) + .await + .extend_err()?; + + Ok(()) + } + + pub async fn find_address_validation( + &self, + id: Option<&Uuid>, + ) -> Result<Vec<AddressValidation>, StoreError> { + let sqls = get_sqls(&self.sqls_root, FIND_ADDRESS_VALIDATION); + let rs = sqlx::query_as::< + _, + ( + Uuid, + Option<Uuid>, + Uuid, + String, + String, + OffsetDateTime, + String, + Vec<u8>, + Vec<u8>, + i32, + OffsetDateTime, + OffsetDateTime, + Option<OffsetDateTime>, + Option<OffsetDateTime>, + ), + >(&sqls[0]) + .bind(id) + .fetch_all(&self.pool) + .await + .extend_err()?; + + let mut res = vec![]; + for ( + id, + identity_id, + address_id, + address_typ, + address_val, + address_created_at, + method, + hashed_token, + hashed_code, + attempts, + created_at, + expires_at, + revoked_at, + validated_at, + ) in rs.into_iter() + { + let t = match AddressType::from_str(&address_typ) + .map_err(|_| StoreError::StoreValueCannotBeParsedInvariant)? + { + AddressType::Email { .. } => AddressType::Email { + email_address: Some( + EmailAddress::from_str(&address_val) + .map_err(|_| StoreError::StoreValueCannotBeParsedInvariant)?, + ), + }, + AddressType::Sms { .. } => AddressType::Sms { + phone_number: Some(address_val.clone()), + }, + }; + + res.push(AddressValidation { + id, + identity_id, + address: Address { + id: address_id, + t, + created_at: address_created_at, + }, + method: AddressValidationMethod::from_str(&method) + .map_err(|_| StoreError::StoreValueCannotBeParsedInvariant)?, + created_at, + expires_at, + revoked_at, + validated_at, + attempts, + hashed_token, + hashed_code, + }); + } + + Ok(res) + } + + pub async fn write_identity(&self, i: &Identity) -> Result<(), StoreError> { + let sqls = get_sqls(&self.sqls_root, WRITE_IDENTITY); + sqlx::query(&sqls[0]) + .bind(i.id) + // TODO: validate this is actually Json somewhere way up the chain (when being deserialized) + .bind(i.metadata.clone().unwrap_or("{}".into())) + .bind(i.created_at) + .bind(OffsetDateTime::now_utc()) + .bind(i.deleted_at) + .execute(&self.pool) + .await + .extend_err()?; + + Ok(()) + } + + pub async fn find_identity( + &self, + id: Option<&Uuid>, + address_value: Option<&str>, + address_is_validated: Option<bool>, + session_token_hash: &Option<Vec<u8>>, + ) -> Result<Vec<Identity>, StoreError> { + let sqls = get_sqls(&self.sqls_root, FIND_IDENTITY); + println!("{:?}", id); + println!("{:?}", address_value); + println!("{:?}", address_is_validated); + println!("{:?}", session_token_hash); + let rs = sqlx::query_as::< + _, + ( + Uuid, + Option<String>, + OffsetDateTime, + OffsetDateTime, + Option<OffsetDateTime>, + ), + >(&sqls[0]) + .bind(id) + .bind(address_value) + .bind(address_is_validated) + .bind(session_token_hash) + .fetch_all(&self.pool) + .await + .extend_err()?; + + let mut res = vec![]; + for (id, metadata, created_at, updated_at, deleted_at) in rs.into_iter() { + res.push(Identity { + id, + address_validations: vec![], + credentials: vec![], + rules: vec![], + metadata, + created_at, + deleted_at, + }) + } + + Ok(res) + } + + pub async fn write_session(&self, s: &Session, token_hash: &[u8]) -> Result<(), StoreError> { + let sqls = get_sqls(&self.sqls_root, WRITE_SESSION); + sqlx::query(&sqls[0]) + .bind(s.identity_id) + .bind(token_hash) + .bind(s.created_at) + .bind(s.expired_at) + .bind(s.revoked_at) + .execute(&self.pool) + .await + .extend_err()?; + + Ok(()) + } + + pub async fn find_session( + &self, + token: Vec<u8>, + identity_id: Option<&Uuid>, + ) -> Result<Vec<Session>, StoreError> { + let sqls = get_sqls(&self.sqls_root, FIND_SESSION); + let rs = + sqlx::query_as::<_, (Uuid, OffsetDateTime, OffsetDateTime, Option<OffsetDateTime>)>( + &sqls[0], + ) + .bind(token) + .bind(identity_id) + .bind(OffsetDateTime::now_utc()) + .bind(OffsetDateTime::now_utc()) + .fetch_all(&self.pool) + .await + .extend_err()?; + + let mut res = vec![]; + for (identity_id, created_at, expired_at, revoked_at) in rs.into_iter() { + res.push(Session { + identity_id, + token: vec![], + created_at, + expired_at, + revoked_at, + }); + } + Ok(res) + } +} + +fn get_sqls(root: &str, file: &str) -> Vec<String> { + SQLS.get(root) + .unwrap() + .get(file) + .unwrap() + .split("--") + .map(|p| p.to_string()) + .collect() +} |
