diff options
| author | benj <benj@rse8.com> | 2022-12-24 00:43:38 -0800 |
|---|---|---|
| committer | benj <benj@rse8.com> | 2022-12-24 00:43:38 -0800 |
| commit | c2268c285648ef02ece04de0d9df0813c6d70ff8 (patch) | |
| tree | f84ec7ee42f97d78245f26d0c5a0c559cd35e89d /crates/secd/src/client/store/sql_db.rs | |
| parent | de6339da72af1d61ca5908b780977e2b037ce014 (diff) | |
| download | secdiam-c2268c285648ef02ece04de0d9df0813c6d70ff8.tar secdiam-c2268c285648ef02ece04de0d9df0813c6d70ff8.tar.gz secdiam-c2268c285648ef02ece04de0d9df0813c6d70ff8.tar.bz2 secdiam-c2268c285648ef02ece04de0d9df0813c6d70ff8.tar.lz secdiam-c2268c285648ef02ece04de0d9df0813c6d70ff8.tar.xz secdiam-c2268c285648ef02ece04de0d9df0813c6d70ff8.tar.zst secdiam-c2268c285648ef02ece04de0d9df0813c6d70ff8.zip | |
refactor everything with more abstraction and a nicer interface
Diffstat (limited to '')
| -rw-r--r-- | crates/secd/src/client/store/sql_db.rs | 526 |
1 files changed, 526 insertions, 0 deletions
diff --git a/crates/secd/src/client/store/sql_db.rs b/crates/secd/src/client/store/sql_db.rs new file mode 100644 index 0000000..6d84301 --- /dev/null +++ b/crates/secd/src/client/store/sql_db.rs @@ -0,0 +1,526 @@ +use std::{str::FromStr, sync::Arc}; + +use email_address::EmailAddress; +use serde_json::value::RawValue; +use sqlx::{ + database::HasArguments, types::Json, ColumnIndex, Database, Decode, Encode, Executor, + IntoArguments, Pool, Transaction, Type, +}; +use time::OffsetDateTime; +use uuid::Uuid; + +use crate::{ + Address, AddressType, AddressValidation, AddressValidationMethod, Identity, Session, + SessionToken, +}; + +use lazy_static::lazy_static; +use sqlx::{Postgres, Sqlite}; +use std::collections::HashMap; + +use super::{ + AddressLens, AddressValidationLens, IdentityLens, SessionLens, Storable, Store, StoreError, + StoreType, +}; + +const SQLITE: &str = "sqlite"; +const PGSQL: &str = "pgsql"; + +const WRITE_ADDRESS: &str = "write_address"; +const FIND_ADDRESS: &str = "find_address"; +const WRITE_ADDRESS_VALIDATION: &str = "write_address_validation"; +const FIND_ADDRESS_VALIDATION: &str = "find_address_validation"; +const WRITE_IDENTITY: &str = "write_identity"; +const FIND_IDENTITY: &str = "find_identity"; +const WRITE_SESSION: &str = "write_session"; +const FIND_SESSION: &str = "find_session"; + +const ERR_MSG_MIGRATION_FAILED: &str = "Failed to apply secd migrations to a sql db. File a bug at https://www.github.com/branchcontrol/secdiam"; + +lazy_static! { + static ref SQLS: HashMap<&'static str, HashMap<&'static str, &'static str>> = { + let sqlite_sqls: HashMap<&'static str, &'static str> = [ + ( + WRITE_ADDRESS, + include_str!("../../../store/sqlite/sql/write_address.sql"), + ), + ( + FIND_ADDRESS, + include_str!("../../../store/sqlite/sql/find_address.sql"), + ), + ( + WRITE_ADDRESS_VALIDATION, + include_str!("../../../store/sqlite/sql/write_address_validation.sql"), + ), + ( + FIND_ADDRESS_VALIDATION, + include_str!("../../../store/sqlite/sql/find_address_validation.sql"), + ), + ( + WRITE_IDENTITY, + include_str!("../../../store/sqlite/sql/write_identity.sql"), + ), + ( + FIND_IDENTITY, + include_str!("../../../store/sqlite/sql/find_identity.sql"), + ), + ( + WRITE_SESSION, + include_str!("../../../store/sqlite/sql/write_session.sql"), + ), + ( + FIND_SESSION, + include_str!("../../../store/sqlite/sql/find_session.sql"), + ), + ] + .iter() + .cloned() + .collect(); + + let pg_sqls: HashMap<&'static str, &'static str> = [ + ( + WRITE_ADDRESS, + include_str!("../../../store/pg/sql/write_address.sql"), + ), + ( + FIND_ADDRESS, + include_str!("../../../store/pg/sql/find_address.sql"), + ), + ( + WRITE_ADDRESS_VALIDATION, + include_str!("../../../store/pg/sql/write_address_validation.sql"), + ), + ( + FIND_ADDRESS_VALIDATION, + include_str!("../../../store/pg/sql/find_address_validation.sql"), + ), + ( + WRITE_IDENTITY, + include_str!("../../../store/pg/sql/write_identity.sql"), + ), + ( + FIND_IDENTITY, + include_str!("../../../store/pg/sql/find_identity.sql"), + ), + ( + WRITE_SESSION, + include_str!("../../../store/pg/sql/write_session.sql"), + ), + ( + FIND_SESSION, + include_str!("../../../store/pg/sql/find_session.sql"), + ), + ] + .iter() + .cloned() + .collect(); + + let sqls: HashMap<&'static str, HashMap<&'static str, &'static str>> = + [(SQLITE, sqlite_sqls), (PGSQL, pg_sqls)] + .iter() + .cloned() + .collect(); + sqls + }; +} + +pub trait SqlxResultExt<T> { + fn extend_err(self) -> Result<T, StoreError>; +} + +impl<T> SqlxResultExt<T> for Result<T, sqlx::Error> { + fn extend_err(self) -> Result<T, StoreError> { + if let Err(sqlx::Error::Database(dbe)) = &self { + if dbe.code() == Some("23505".into()) { + return Err(StoreError::IdempotentCheckAlreadyExists); + } + } + self.map_err(|e| StoreError::SqlClientError(e)) + } +} + +pub struct SqlClient<D> +where + D: sqlx::Database, +{ + pool: sqlx::Pool<D>, + sqls_root: String, +} + +pub struct PgClient { + sql: Arc<SqlClient<Postgres>>, +} +impl Store for PgClient { + fn get_type(&self) -> StoreType { + StoreType::Postgres { + c: self.sql.clone(), + } + } +} + +impl PgClient { + pub async fn new(pool: sqlx::Pool<Postgres>) -> Arc<dyn Store + Send + Sync + 'static> { + sqlx::migrate!("store/pg/migrations") + .run(&pool) + .await + .expect(ERR_MSG_MIGRATION_FAILED); + + Arc::new(PgClient { + sql: Arc::new(SqlClient { + pool, + sqls_root: PGSQL.to_string(), + }), + }) + } +} + +pub struct SqliteClient { + sql: Arc<SqlClient<Sqlite>>, +} +impl Store for SqliteClient { + fn get_type(&self) -> StoreType { + StoreType::Sqlite { + c: self.sql.clone(), + } + } +} + +impl SqliteClient { + pub async fn new(pool: sqlx::Pool<Sqlite>) -> Arc<dyn Store + Send + Sync + 'static> { + sqlx::migrate!("store/sqlite/migrations") + .run(&pool) + .await + .expect(ERR_MSG_MIGRATION_FAILED); + + sqlx::query("pragma foreign_keys = on") + .execute(&pool) + .await + .expect( + "Failed to initialize FK pragma. File a bug at https://www.github.com/secd-lib", + ); + + Arc::new(SqliteClient { + sql: Arc::new(SqlClient { + pool, + sqls_root: SQLITE.to_string(), + }), + }) + } +} + +impl<D> SqlClient<D> +where + D: sqlx::Database, + for<'c> &'c Pool<D>: Executor<'c, Database = D>, + for<'c> &'c mut Transaction<'c, D>: Executor<'c, Database = D>, + for<'c> <D as HasArguments<'c>>::Arguments: IntoArguments<'c, D>, + for<'c> bool: Decode<'c, D> + Type<D>, + for<'c> bool: Encode<'c, D> + Type<D>, + for<'c> Option<bool>: Decode<'c, D> + Type<D>, + for<'c> Option<bool>: Encode<'c, D> + Type<D>, + for<'c> i64: Decode<'c, D> + Type<D>, + for<'c> i64: Encode<'c, D> + Type<D>, + for<'c> i32: Decode<'c, D> + Type<D>, + for<'c> i32: Encode<'c, D> + Type<D>, + for<'c> OffsetDateTime: Decode<'c, D> + Type<D>, + for<'c> OffsetDateTime: Encode<'c, D> + Type<D>, + for<'c> &'c str: ColumnIndex<<D as Database>::Row>, + for<'c> &'c str: Decode<'c, D> + Type<D>, + for<'c> &'c str: Encode<'c, D> + Type<D>, + for<'c> Option<&'c str>: Decode<'c, D> + Type<D>, + for<'c> Option<&'c str>: Encode<'c, D> + Type<D>, + for<'c> String: Decode<'c, D> + Type<D>, + for<'c> String: Encode<'c, D> + Type<D>, + for<'c> Option<String>: Decode<'c, D> + Type<D>, + for<'c> Option<String>: Encode<'c, D> + Type<D>, + for<'c> usize: ColumnIndex<<D as Database>::Row>, + for<'c> Uuid: Decode<'c, D> + Type<D>, + for<'c> Uuid: Encode<'c, D> + Type<D>, + for<'c> &'c [u8]: Encode<'c, D> + Type<D>, + for<'c> Option<&'c Uuid>: Encode<'c, D> + Type<D>, + for<'c> Vec<u8>: Encode<'c, D> + Type<D>, + for<'c> Vec<u8>: Decode<'c, D> + Type<D>, + for<'c> Option<Vec<u8>>: Encode<'c, D> + Type<D>, + for<'c> Option<Vec<u8>>: Decode<'c, D> + Type<D>, + for<'c> Option<OffsetDateTime>: Decode<'c, D> + Type<D>, + for<'c> Option<OffsetDateTime>: Encode<'c, D> + Type<D>, +{ + pub async fn write_address(&self, a: &Address) -> Result<(), StoreError> { + let sqls = get_sqls(&self.sqls_root, WRITE_ADDRESS); + sqlx::query(&sqls[0]) + .bind(a.id) + .bind(a.t.to_string()) + .bind(match &a.t { + AddressType::Email { email_address } => { + email_address.as_ref().map(ToString::to_string) + } + AddressType::Sms { phone_number } => phone_number.clone(), + }) + .bind(a.created_at) + .execute(&self.pool) + .await + .extend_err()?; + + Ok(()) + } + + pub async fn find_address( + &self, + id: Option<&Uuid>, + typ: Option<&str>, + val: Option<&str>, + ) -> Result<Vec<Address>, StoreError> { + let sqls = get_sqls(&self.sqls_root, FIND_ADDRESS); + let res = sqlx::query_as::<_, (Uuid, String, String, OffsetDateTime)>(&sqls[0]) + .bind(id) + .bind(typ) + .bind(val) + .fetch_all(&self.pool) + .await + .extend_err()?; + + let mut addresses = vec![]; + for (id, typ, val, created_at) in res.into_iter() { + let t = match AddressType::from_str(&typ) + .map_err(|_| StoreError::StoreValueCannotBeParsedInvariant)? + { + AddressType::Email { .. } => AddressType::Email { + email_address: Some( + EmailAddress::from_str(&val) + .map_err(|_| StoreError::StoreValueCannotBeParsedInvariant)?, + ), + }, + AddressType::Sms { .. } => AddressType::Sms { + phone_number: Some(val.clone()), + }, + }; + + addresses.push(Address { id, t, created_at }); + } + + Ok(addresses) + } + + pub async fn write_address_validation(&self, v: &AddressValidation) -> Result<(), StoreError> { + let sqls = get_sqls(&self.sqls_root, WRITE_ADDRESS_VALIDATION); + sqlx::query(&sqls[0]) + .bind(v.id) + .bind(v.identity_id.as_ref()) + .bind(v.address.id) + .bind(v.method.to_string()) + .bind(v.hashed_token.clone()) + .bind(v.hashed_code.clone()) + .bind(v.attempts) + .bind(v.created_at) + .bind(v.expires_at) + .bind(v.revoked_at) + .bind(v.validated_at) + .execute(&self.pool) + .await + .extend_err()?; + + Ok(()) + } + + pub async fn find_address_validation( + &self, + id: Option<&Uuid>, + ) -> Result<Vec<AddressValidation>, StoreError> { + let sqls = get_sqls(&self.sqls_root, FIND_ADDRESS_VALIDATION); + let rs = sqlx::query_as::< + _, + ( + Uuid, + Option<Uuid>, + Uuid, + String, + String, + OffsetDateTime, + String, + Vec<u8>, + Vec<u8>, + i32, + OffsetDateTime, + OffsetDateTime, + Option<OffsetDateTime>, + Option<OffsetDateTime>, + ), + >(&sqls[0]) + .bind(id) + .fetch_all(&self.pool) + .await + .extend_err()?; + + let mut res = vec![]; + for ( + id, + identity_id, + address_id, + address_typ, + address_val, + address_created_at, + method, + hashed_token, + hashed_code, + attempts, + created_at, + expires_at, + revoked_at, + validated_at, + ) in rs.into_iter() + { + let t = match AddressType::from_str(&address_typ) + .map_err(|_| StoreError::StoreValueCannotBeParsedInvariant)? + { + AddressType::Email { .. } => AddressType::Email { + email_address: Some( + EmailAddress::from_str(&address_val) + .map_err(|_| StoreError::StoreValueCannotBeParsedInvariant)?, + ), + }, + AddressType::Sms { .. } => AddressType::Sms { + phone_number: Some(address_val.clone()), + }, + }; + + res.push(AddressValidation { + id, + identity_id, + address: Address { + id: address_id, + t, + created_at: address_created_at, + }, + method: AddressValidationMethod::from_str(&method) + .map_err(|_| StoreError::StoreValueCannotBeParsedInvariant)?, + created_at, + expires_at, + revoked_at, + validated_at, + attempts, + hashed_token, + hashed_code, + }); + } + + Ok(res) + } + + pub async fn write_identity(&self, i: &Identity) -> Result<(), StoreError> { + let sqls = get_sqls(&self.sqls_root, WRITE_IDENTITY); + sqlx::query(&sqls[0]) + .bind(i.id) + // TODO: validate this is actually Json somewhere way up the chain (when being deserialized) + .bind(i.metadata.clone().unwrap_or("{}".into())) + .bind(i.created_at) + .bind(OffsetDateTime::now_utc()) + .bind(i.deleted_at) + .execute(&self.pool) + .await + .extend_err()?; + + Ok(()) + } + + pub async fn find_identity( + &self, + id: Option<&Uuid>, + address_value: Option<&str>, + address_is_validated: Option<bool>, + session_token_hash: &Option<Vec<u8>>, + ) -> Result<Vec<Identity>, StoreError> { + let sqls = get_sqls(&self.sqls_root, FIND_IDENTITY); + println!("{:?}", id); + println!("{:?}", address_value); + println!("{:?}", address_is_validated); + println!("{:?}", session_token_hash); + let rs = sqlx::query_as::< + _, + ( + Uuid, + Option<String>, + OffsetDateTime, + OffsetDateTime, + Option<OffsetDateTime>, + ), + >(&sqls[0]) + .bind(id) + .bind(address_value) + .bind(address_is_validated) + .bind(session_token_hash) + .fetch_all(&self.pool) + .await + .extend_err()?; + + let mut res = vec![]; + for (id, metadata, created_at, updated_at, deleted_at) in rs.into_iter() { + res.push(Identity { + id, + address_validations: vec![], + credentials: vec![], + rules: vec![], + metadata, + created_at, + deleted_at, + }) + } + + Ok(res) + } + + pub async fn write_session(&self, s: &Session, token_hash: &[u8]) -> Result<(), StoreError> { + let sqls = get_sqls(&self.sqls_root, WRITE_SESSION); + sqlx::query(&sqls[0]) + .bind(s.identity_id) + .bind(token_hash) + .bind(s.created_at) + .bind(s.expired_at) + .bind(s.revoked_at) + .execute(&self.pool) + .await + .extend_err()?; + + Ok(()) + } + + pub async fn find_session( + &self, + token: Vec<u8>, + identity_id: Option<&Uuid>, + ) -> Result<Vec<Session>, StoreError> { + let sqls = get_sqls(&self.sqls_root, FIND_SESSION); + let rs = + sqlx::query_as::<_, (Uuid, OffsetDateTime, OffsetDateTime, Option<OffsetDateTime>)>( + &sqls[0], + ) + .bind(token) + .bind(identity_id) + .bind(OffsetDateTime::now_utc()) + .bind(OffsetDateTime::now_utc()) + .fetch_all(&self.pool) + .await + .extend_err()?; + + let mut res = vec![]; + for (identity_id, created_at, expired_at, revoked_at) in rs.into_iter() { + res.push(Session { + identity_id, + token: vec![], + created_at, + expired_at, + revoked_at, + }); + } + Ok(res) + } +} + +fn get_sqls(root: &str, file: &str) -> Vec<String> { + SQLS.get(root) + .unwrap() + .get(file) + .unwrap() + .split("--") + .map(|p| p.to_string()) + .collect() +} |
