aboutsummaryrefslogtreecommitdiff
path: root/crates/secd/src/client/store/sql_db.rs
diff options
context:
space:
mode:
authorbenj <benj@rse8.com>2022-12-24 00:43:38 -0800
committerbenj <benj@rse8.com>2022-12-24 00:43:38 -0800
commitc2268c285648ef02ece04de0d9df0813c6d70ff8 (patch)
treef84ec7ee42f97d78245f26d0c5a0c559cd35e89d /crates/secd/src/client/store/sql_db.rs
parentde6339da72af1d61ca5908b780977e2b037ce014 (diff)
downloadsecdiam-c2268c285648ef02ece04de0d9df0813c6d70ff8.tar
secdiam-c2268c285648ef02ece04de0d9df0813c6d70ff8.tar.gz
secdiam-c2268c285648ef02ece04de0d9df0813c6d70ff8.tar.bz2
secdiam-c2268c285648ef02ece04de0d9df0813c6d70ff8.tar.lz
secdiam-c2268c285648ef02ece04de0d9df0813c6d70ff8.tar.xz
secdiam-c2268c285648ef02ece04de0d9df0813c6d70ff8.tar.zst
secdiam-c2268c285648ef02ece04de0d9df0813c6d70ff8.zip
refactor everything with more abstraction and a nicer interface
Diffstat (limited to 'crates/secd/src/client/store/sql_db.rs')
-rw-r--r--crates/secd/src/client/store/sql_db.rs526
1 files changed, 526 insertions, 0 deletions
diff --git a/crates/secd/src/client/store/sql_db.rs b/crates/secd/src/client/store/sql_db.rs
new file mode 100644
index 0000000..6d84301
--- /dev/null
+++ b/crates/secd/src/client/store/sql_db.rs
@@ -0,0 +1,526 @@
+use std::{str::FromStr, sync::Arc};
+
+use email_address::EmailAddress;
+use serde_json::value::RawValue;
+use sqlx::{
+ database::HasArguments, types::Json, ColumnIndex, Database, Decode, Encode, Executor,
+ IntoArguments, Pool, Transaction, Type,
+};
+use time::OffsetDateTime;
+use uuid::Uuid;
+
+use crate::{
+ Address, AddressType, AddressValidation, AddressValidationMethod, Identity, Session,
+ SessionToken,
+};
+
+use lazy_static::lazy_static;
+use sqlx::{Postgres, Sqlite};
+use std::collections::HashMap;
+
+use super::{
+ AddressLens, AddressValidationLens, IdentityLens, SessionLens, Storable, Store, StoreError,
+ StoreType,
+};
+
+const SQLITE: &str = "sqlite";
+const PGSQL: &str = "pgsql";
+
+const WRITE_ADDRESS: &str = "write_address";
+const FIND_ADDRESS: &str = "find_address";
+const WRITE_ADDRESS_VALIDATION: &str = "write_address_validation";
+const FIND_ADDRESS_VALIDATION: &str = "find_address_validation";
+const WRITE_IDENTITY: &str = "write_identity";
+const FIND_IDENTITY: &str = "find_identity";
+const WRITE_SESSION: &str = "write_session";
+const FIND_SESSION: &str = "find_session";
+
+const ERR_MSG_MIGRATION_FAILED: &str = "Failed to apply secd migrations to a sql db. File a bug at https://www.github.com/branchcontrol/secdiam";
+
+lazy_static! {
+ static ref SQLS: HashMap<&'static str, HashMap<&'static str, &'static str>> = {
+ let sqlite_sqls: HashMap<&'static str, &'static str> = [
+ (
+ WRITE_ADDRESS,
+ include_str!("../../../store/sqlite/sql/write_address.sql"),
+ ),
+ (
+ FIND_ADDRESS,
+ include_str!("../../../store/sqlite/sql/find_address.sql"),
+ ),
+ (
+ WRITE_ADDRESS_VALIDATION,
+ include_str!("../../../store/sqlite/sql/write_address_validation.sql"),
+ ),
+ (
+ FIND_ADDRESS_VALIDATION,
+ include_str!("../../../store/sqlite/sql/find_address_validation.sql"),
+ ),
+ (
+ WRITE_IDENTITY,
+ include_str!("../../../store/sqlite/sql/write_identity.sql"),
+ ),
+ (
+ FIND_IDENTITY,
+ include_str!("../../../store/sqlite/sql/find_identity.sql"),
+ ),
+ (
+ WRITE_SESSION,
+ include_str!("../../../store/sqlite/sql/write_session.sql"),
+ ),
+ (
+ FIND_SESSION,
+ include_str!("../../../store/sqlite/sql/find_session.sql"),
+ ),
+ ]
+ .iter()
+ .cloned()
+ .collect();
+
+ let pg_sqls: HashMap<&'static str, &'static str> = [
+ (
+ WRITE_ADDRESS,
+ include_str!("../../../store/pg/sql/write_address.sql"),
+ ),
+ (
+ FIND_ADDRESS,
+ include_str!("../../../store/pg/sql/find_address.sql"),
+ ),
+ (
+ WRITE_ADDRESS_VALIDATION,
+ include_str!("../../../store/pg/sql/write_address_validation.sql"),
+ ),
+ (
+ FIND_ADDRESS_VALIDATION,
+ include_str!("../../../store/pg/sql/find_address_validation.sql"),
+ ),
+ (
+ WRITE_IDENTITY,
+ include_str!("../../../store/pg/sql/write_identity.sql"),
+ ),
+ (
+ FIND_IDENTITY,
+ include_str!("../../../store/pg/sql/find_identity.sql"),
+ ),
+ (
+ WRITE_SESSION,
+ include_str!("../../../store/pg/sql/write_session.sql"),
+ ),
+ (
+ FIND_SESSION,
+ include_str!("../../../store/pg/sql/find_session.sql"),
+ ),
+ ]
+ .iter()
+ .cloned()
+ .collect();
+
+ let sqls: HashMap<&'static str, HashMap<&'static str, &'static str>> =
+ [(SQLITE, sqlite_sqls), (PGSQL, pg_sqls)]
+ .iter()
+ .cloned()
+ .collect();
+ sqls
+ };
+}
+
+pub trait SqlxResultExt<T> {
+ fn extend_err(self) -> Result<T, StoreError>;
+}
+
+impl<T> SqlxResultExt<T> for Result<T, sqlx::Error> {
+ fn extend_err(self) -> Result<T, StoreError> {
+ if let Err(sqlx::Error::Database(dbe)) = &self {
+ if dbe.code() == Some("23505".into()) {
+ return Err(StoreError::IdempotentCheckAlreadyExists);
+ }
+ }
+ self.map_err(|e| StoreError::SqlClientError(e))
+ }
+}
+
+pub struct SqlClient<D>
+where
+ D: sqlx::Database,
+{
+ pool: sqlx::Pool<D>,
+ sqls_root: String,
+}
+
+pub struct PgClient {
+ sql: Arc<SqlClient<Postgres>>,
+}
+impl Store for PgClient {
+ fn get_type(&self) -> StoreType {
+ StoreType::Postgres {
+ c: self.sql.clone(),
+ }
+ }
+}
+
+impl PgClient {
+ pub async fn new(pool: sqlx::Pool<Postgres>) -> Arc<dyn Store + Send + Sync + 'static> {
+ sqlx::migrate!("store/pg/migrations")
+ .run(&pool)
+ .await
+ .expect(ERR_MSG_MIGRATION_FAILED);
+
+ Arc::new(PgClient {
+ sql: Arc::new(SqlClient {
+ pool,
+ sqls_root: PGSQL.to_string(),
+ }),
+ })
+ }
+}
+
+pub struct SqliteClient {
+ sql: Arc<SqlClient<Sqlite>>,
+}
+impl Store for SqliteClient {
+ fn get_type(&self) -> StoreType {
+ StoreType::Sqlite {
+ c: self.sql.clone(),
+ }
+ }
+}
+
+impl SqliteClient {
+ pub async fn new(pool: sqlx::Pool<Sqlite>) -> Arc<dyn Store + Send + Sync + 'static> {
+ sqlx::migrate!("store/sqlite/migrations")
+ .run(&pool)
+ .await
+ .expect(ERR_MSG_MIGRATION_FAILED);
+
+ sqlx::query("pragma foreign_keys = on")
+ .execute(&pool)
+ .await
+ .expect(
+ "Failed to initialize FK pragma. File a bug at https://www.github.com/secd-lib",
+ );
+
+ Arc::new(SqliteClient {
+ sql: Arc::new(SqlClient {
+ pool,
+ sqls_root: SQLITE.to_string(),
+ }),
+ })
+ }
+}
+
+impl<D> SqlClient<D>
+where
+ D: sqlx::Database,
+ for<'c> &'c Pool<D>: Executor<'c, Database = D>,
+ for<'c> &'c mut Transaction<'c, D>: Executor<'c, Database = D>,
+ for<'c> <D as HasArguments<'c>>::Arguments: IntoArguments<'c, D>,
+ for<'c> bool: Decode<'c, D> + Type<D>,
+ for<'c> bool: Encode<'c, D> + Type<D>,
+ for<'c> Option<bool>: Decode<'c, D> + Type<D>,
+ for<'c> Option<bool>: Encode<'c, D> + Type<D>,
+ for<'c> i64: Decode<'c, D> + Type<D>,
+ for<'c> i64: Encode<'c, D> + Type<D>,
+ for<'c> i32: Decode<'c, D> + Type<D>,
+ for<'c> i32: Encode<'c, D> + Type<D>,
+ for<'c> OffsetDateTime: Decode<'c, D> + Type<D>,
+ for<'c> OffsetDateTime: Encode<'c, D> + Type<D>,
+ for<'c> &'c str: ColumnIndex<<D as Database>::Row>,
+ for<'c> &'c str: Decode<'c, D> + Type<D>,
+ for<'c> &'c str: Encode<'c, D> + Type<D>,
+ for<'c> Option<&'c str>: Decode<'c, D> + Type<D>,
+ for<'c> Option<&'c str>: Encode<'c, D> + Type<D>,
+ for<'c> String: Decode<'c, D> + Type<D>,
+ for<'c> String: Encode<'c, D> + Type<D>,
+ for<'c> Option<String>: Decode<'c, D> + Type<D>,
+ for<'c> Option<String>: Encode<'c, D> + Type<D>,
+ for<'c> usize: ColumnIndex<<D as Database>::Row>,
+ for<'c> Uuid: Decode<'c, D> + Type<D>,
+ for<'c> Uuid: Encode<'c, D> + Type<D>,
+ for<'c> &'c [u8]: Encode<'c, D> + Type<D>,
+ for<'c> Option<&'c Uuid>: Encode<'c, D> + Type<D>,
+ for<'c> Vec<u8>: Encode<'c, D> + Type<D>,
+ for<'c> Vec<u8>: Decode<'c, D> + Type<D>,
+ for<'c> Option<Vec<u8>>: Encode<'c, D> + Type<D>,
+ for<'c> Option<Vec<u8>>: Decode<'c, D> + Type<D>,
+ for<'c> Option<OffsetDateTime>: Decode<'c, D> + Type<D>,
+ for<'c> Option<OffsetDateTime>: Encode<'c, D> + Type<D>,
+{
+ pub async fn write_address(&self, a: &Address) -> Result<(), StoreError> {
+ let sqls = get_sqls(&self.sqls_root, WRITE_ADDRESS);
+ sqlx::query(&sqls[0])
+ .bind(a.id)
+ .bind(a.t.to_string())
+ .bind(match &a.t {
+ AddressType::Email { email_address } => {
+ email_address.as_ref().map(ToString::to_string)
+ }
+ AddressType::Sms { phone_number } => phone_number.clone(),
+ })
+ .bind(a.created_at)
+ .execute(&self.pool)
+ .await
+ .extend_err()?;
+
+ Ok(())
+ }
+
+ pub async fn find_address(
+ &self,
+ id: Option<&Uuid>,
+ typ: Option<&str>,
+ val: Option<&str>,
+ ) -> Result<Vec<Address>, StoreError> {
+ let sqls = get_sqls(&self.sqls_root, FIND_ADDRESS);
+ let res = sqlx::query_as::<_, (Uuid, String, String, OffsetDateTime)>(&sqls[0])
+ .bind(id)
+ .bind(typ)
+ .bind(val)
+ .fetch_all(&self.pool)
+ .await
+ .extend_err()?;
+
+ let mut addresses = vec![];
+ for (id, typ, val, created_at) in res.into_iter() {
+ let t = match AddressType::from_str(&typ)
+ .map_err(|_| StoreError::StoreValueCannotBeParsedInvariant)?
+ {
+ AddressType::Email { .. } => AddressType::Email {
+ email_address: Some(
+ EmailAddress::from_str(&val)
+ .map_err(|_| StoreError::StoreValueCannotBeParsedInvariant)?,
+ ),
+ },
+ AddressType::Sms { .. } => AddressType::Sms {
+ phone_number: Some(val.clone()),
+ },
+ };
+
+ addresses.push(Address { id, t, created_at });
+ }
+
+ Ok(addresses)
+ }
+
+ pub async fn write_address_validation(&self, v: &AddressValidation) -> Result<(), StoreError> {
+ let sqls = get_sqls(&self.sqls_root, WRITE_ADDRESS_VALIDATION);
+ sqlx::query(&sqls[0])
+ .bind(v.id)
+ .bind(v.identity_id.as_ref())
+ .bind(v.address.id)
+ .bind(v.method.to_string())
+ .bind(v.hashed_token.clone())
+ .bind(v.hashed_code.clone())
+ .bind(v.attempts)
+ .bind(v.created_at)
+ .bind(v.expires_at)
+ .bind(v.revoked_at)
+ .bind(v.validated_at)
+ .execute(&self.pool)
+ .await
+ .extend_err()?;
+
+ Ok(())
+ }
+
+ pub async fn find_address_validation(
+ &self,
+ id: Option<&Uuid>,
+ ) -> Result<Vec<AddressValidation>, StoreError> {
+ let sqls = get_sqls(&self.sqls_root, FIND_ADDRESS_VALIDATION);
+ let rs = sqlx::query_as::<
+ _,
+ (
+ Uuid,
+ Option<Uuid>,
+ Uuid,
+ String,
+ String,
+ OffsetDateTime,
+ String,
+ Vec<u8>,
+ Vec<u8>,
+ i32,
+ OffsetDateTime,
+ OffsetDateTime,
+ Option<OffsetDateTime>,
+ Option<OffsetDateTime>,
+ ),
+ >(&sqls[0])
+ .bind(id)
+ .fetch_all(&self.pool)
+ .await
+ .extend_err()?;
+
+ let mut res = vec![];
+ for (
+ id,
+ identity_id,
+ address_id,
+ address_typ,
+ address_val,
+ address_created_at,
+ method,
+ hashed_token,
+ hashed_code,
+ attempts,
+ created_at,
+ expires_at,
+ revoked_at,
+ validated_at,
+ ) in rs.into_iter()
+ {
+ let t = match AddressType::from_str(&address_typ)
+ .map_err(|_| StoreError::StoreValueCannotBeParsedInvariant)?
+ {
+ AddressType::Email { .. } => AddressType::Email {
+ email_address: Some(
+ EmailAddress::from_str(&address_val)
+ .map_err(|_| StoreError::StoreValueCannotBeParsedInvariant)?,
+ ),
+ },
+ AddressType::Sms { .. } => AddressType::Sms {
+ phone_number: Some(address_val.clone()),
+ },
+ };
+
+ res.push(AddressValidation {
+ id,
+ identity_id,
+ address: Address {
+ id: address_id,
+ t,
+ created_at: address_created_at,
+ },
+ method: AddressValidationMethod::from_str(&method)
+ .map_err(|_| StoreError::StoreValueCannotBeParsedInvariant)?,
+ created_at,
+ expires_at,
+ revoked_at,
+ validated_at,
+ attempts,
+ hashed_token,
+ hashed_code,
+ });
+ }
+
+ Ok(res)
+ }
+
+ pub async fn write_identity(&self, i: &Identity) -> Result<(), StoreError> {
+ let sqls = get_sqls(&self.sqls_root, WRITE_IDENTITY);
+ sqlx::query(&sqls[0])
+ .bind(i.id)
+ // TODO: validate this is actually Json somewhere way up the chain (when being deserialized)
+ .bind(i.metadata.clone().unwrap_or("{}".into()))
+ .bind(i.created_at)
+ .bind(OffsetDateTime::now_utc())
+ .bind(i.deleted_at)
+ .execute(&self.pool)
+ .await
+ .extend_err()?;
+
+ Ok(())
+ }
+
+ pub async fn find_identity(
+ &self,
+ id: Option<&Uuid>,
+ address_value: Option<&str>,
+ address_is_validated: Option<bool>,
+ session_token_hash: &Option<Vec<u8>>,
+ ) -> Result<Vec<Identity>, StoreError> {
+ let sqls = get_sqls(&self.sqls_root, FIND_IDENTITY);
+ println!("{:?}", id);
+ println!("{:?}", address_value);
+ println!("{:?}", address_is_validated);
+ println!("{:?}", session_token_hash);
+ let rs = sqlx::query_as::<
+ _,
+ (
+ Uuid,
+ Option<String>,
+ OffsetDateTime,
+ OffsetDateTime,
+ Option<OffsetDateTime>,
+ ),
+ >(&sqls[0])
+ .bind(id)
+ .bind(address_value)
+ .bind(address_is_validated)
+ .bind(session_token_hash)
+ .fetch_all(&self.pool)
+ .await
+ .extend_err()?;
+
+ let mut res = vec![];
+ for (id, metadata, created_at, updated_at, deleted_at) in rs.into_iter() {
+ res.push(Identity {
+ id,
+ address_validations: vec![],
+ credentials: vec![],
+ rules: vec![],
+ metadata,
+ created_at,
+ deleted_at,
+ })
+ }
+
+ Ok(res)
+ }
+
+ pub async fn write_session(&self, s: &Session, token_hash: &[u8]) -> Result<(), StoreError> {
+ let sqls = get_sqls(&self.sqls_root, WRITE_SESSION);
+ sqlx::query(&sqls[0])
+ .bind(s.identity_id)
+ .bind(token_hash)
+ .bind(s.created_at)
+ .bind(s.expired_at)
+ .bind(s.revoked_at)
+ .execute(&self.pool)
+ .await
+ .extend_err()?;
+
+ Ok(())
+ }
+
+ pub async fn find_session(
+ &self,
+ token: Vec<u8>,
+ identity_id: Option<&Uuid>,
+ ) -> Result<Vec<Session>, StoreError> {
+ let sqls = get_sqls(&self.sqls_root, FIND_SESSION);
+ let rs =
+ sqlx::query_as::<_, (Uuid, OffsetDateTime, OffsetDateTime, Option<OffsetDateTime>)>(
+ &sqls[0],
+ )
+ .bind(token)
+ .bind(identity_id)
+ .bind(OffsetDateTime::now_utc())
+ .bind(OffsetDateTime::now_utc())
+ .fetch_all(&self.pool)
+ .await
+ .extend_err()?;
+
+ let mut res = vec![];
+ for (identity_id, created_at, expired_at, revoked_at) in rs.into_iter() {
+ res.push(Session {
+ identity_id,
+ token: vec![],
+ created_at,
+ expired_at,
+ revoked_at,
+ });
+ }
+ Ok(res)
+ }
+}
+
+fn get_sqls(root: &str, file: &str) -> Vec<String> {
+ SQLS.get(root)
+ .unwrap()
+ .get(file)
+ .unwrap()
+ .split("--")
+ .map(|p| p.to_string())
+ .collect()
+}