use std::{str::FromStr, sync::Arc}; use email_address::EmailAddress; use serde_json::value::RawValue; use sqlx::{ database::HasArguments, types::Json, ColumnIndex, Database, Decode, Encode, Executor, IntoArguments, Pool, Transaction, Type, }; use time::OffsetDateTime; use uuid::Uuid; use crate::{ Address, AddressType, AddressValidation, AddressValidationMethod, Identity, Session, SessionToken, }; use lazy_static::lazy_static; use sqlx::{Postgres, Sqlite}; use std::collections::HashMap; use super::{ AddressLens, AddressValidationLens, IdentityLens, SessionLens, Storable, Store, StoreError, StoreType, }; const SQLITE: &str = "sqlite"; const PGSQL: &str = "pgsql"; const WRITE_ADDRESS: &str = "write_address"; const FIND_ADDRESS: &str = "find_address"; const WRITE_ADDRESS_VALIDATION: &str = "write_address_validation"; const FIND_ADDRESS_VALIDATION: &str = "find_address_validation"; const WRITE_IDENTITY: &str = "write_identity"; const FIND_IDENTITY: &str = "find_identity"; const WRITE_SESSION: &str = "write_session"; const FIND_SESSION: &str = "find_session"; const ERR_MSG_MIGRATION_FAILED: &str = "Failed to apply secd migrations to a sql db. File a bug at https://www.github.com/branchcontrol/secdiam"; lazy_static! { static ref SQLS: HashMap<&'static str, HashMap<&'static str, &'static str>> = { let sqlite_sqls: HashMap<&'static str, &'static str> = [ ( WRITE_ADDRESS, include_str!("../../../store/sqlite/sql/write_address.sql"), ), ( FIND_ADDRESS, include_str!("../../../store/sqlite/sql/find_address.sql"), ), ( WRITE_ADDRESS_VALIDATION, include_str!("../../../store/sqlite/sql/write_address_validation.sql"), ), ( FIND_ADDRESS_VALIDATION, include_str!("../../../store/sqlite/sql/find_address_validation.sql"), ), ( WRITE_IDENTITY, include_str!("../../../store/sqlite/sql/write_identity.sql"), ), ( FIND_IDENTITY, include_str!("../../../store/sqlite/sql/find_identity.sql"), ), ( WRITE_SESSION, include_str!("../../../store/sqlite/sql/write_session.sql"), ), ( FIND_SESSION, include_str!("../../../store/sqlite/sql/find_session.sql"), ), ] .iter() .cloned() .collect(); let pg_sqls: HashMap<&'static str, &'static str> = [ ( WRITE_ADDRESS, include_str!("../../../store/pg/sql/write_address.sql"), ), ( FIND_ADDRESS, include_str!("../../../store/pg/sql/find_address.sql"), ), ( WRITE_ADDRESS_VALIDATION, include_str!("../../../store/pg/sql/write_address_validation.sql"), ), ( FIND_ADDRESS_VALIDATION, include_str!("../../../store/pg/sql/find_address_validation.sql"), ), ( WRITE_IDENTITY, include_str!("../../../store/pg/sql/write_identity.sql"), ), ( FIND_IDENTITY, include_str!("../../../store/pg/sql/find_identity.sql"), ), ( WRITE_SESSION, include_str!("../../../store/pg/sql/write_session.sql"), ), ( FIND_SESSION, include_str!("../../../store/pg/sql/find_session.sql"), ), ] .iter() .cloned() .collect(); let sqls: HashMap<&'static str, HashMap<&'static str, &'static str>> = [(SQLITE, sqlite_sqls), (PGSQL, pg_sqls)] .iter() .cloned() .collect(); sqls }; } pub trait SqlxResultExt { fn extend_err(self) -> Result; } impl SqlxResultExt for Result { fn extend_err(self) -> Result { if let Err(sqlx::Error::Database(dbe)) = &self { if dbe.code() == Some("23505".into()) { return Err(StoreError::IdempotentCheckAlreadyExists); } } self.map_err(|e| StoreError::SqlClientError(e)) } } pub struct SqlClient where D: sqlx::Database, { pool: sqlx::Pool, sqls_root: String, } pub struct PgClient { sql: Arc>, } impl Store for PgClient { fn get_type(&self) -> StoreType { StoreType::Postgres { c: self.sql.clone(), } } } impl PgClient { pub async fn new(pool: sqlx::Pool) -> Arc { sqlx::migrate!("store/pg/migrations") .run(&pool) .await .expect(ERR_MSG_MIGRATION_FAILED); Arc::new(PgClient { sql: Arc::new(SqlClient { pool, sqls_root: PGSQL.to_string(), }), }) } } pub struct SqliteClient { sql: Arc>, } impl Store for SqliteClient { fn get_type(&self) -> StoreType { StoreType::Sqlite { c: self.sql.clone(), } } } impl SqliteClient { pub async fn new(pool: sqlx::Pool) -> Arc { sqlx::migrate!("store/sqlite/migrations") .run(&pool) .await .expect(ERR_MSG_MIGRATION_FAILED); sqlx::query("pragma foreign_keys = on") .execute(&pool) .await .expect( "Failed to initialize FK pragma. File a bug at https://www.github.com/secd-lib", ); Arc::new(SqliteClient { sql: Arc::new(SqlClient { pool, sqls_root: SQLITE.to_string(), }), }) } } impl SqlClient where D: sqlx::Database, for<'c> &'c Pool: Executor<'c, Database = D>, for<'c> &'c mut Transaction<'c, D>: Executor<'c, Database = D>, for<'c> >::Arguments: IntoArguments<'c, D>, for<'c> bool: Decode<'c, D> + Type, for<'c> bool: Encode<'c, D> + Type, for<'c> Option: Decode<'c, D> + Type, for<'c> Option: Encode<'c, D> + Type, for<'c> i64: Decode<'c, D> + Type, for<'c> i64: Encode<'c, D> + Type, for<'c> i32: Decode<'c, D> + Type, for<'c> i32: Encode<'c, D> + Type, for<'c> OffsetDateTime: Decode<'c, D> + Type, for<'c> OffsetDateTime: Encode<'c, D> + Type, for<'c> &'c str: ColumnIndex<::Row>, for<'c> &'c str: Decode<'c, D> + Type, for<'c> &'c str: Encode<'c, D> + Type, for<'c> Option<&'c str>: Decode<'c, D> + Type, for<'c> Option<&'c str>: Encode<'c, D> + Type, for<'c> String: Decode<'c, D> + Type, for<'c> String: Encode<'c, D> + Type, for<'c> Option: Decode<'c, D> + Type, for<'c> Option: Encode<'c, D> + Type, for<'c> usize: ColumnIndex<::Row>, for<'c> Uuid: Decode<'c, D> + Type, for<'c> Uuid: Encode<'c, D> + Type, for<'c> &'c [u8]: Encode<'c, D> + Type, for<'c> Option<&'c Uuid>: Encode<'c, D> + Type, for<'c> Vec: Encode<'c, D> + Type, for<'c> Vec: Decode<'c, D> + Type, for<'c> Option>: Encode<'c, D> + Type, for<'c> Option>: Decode<'c, D> + Type, for<'c> Option: Decode<'c, D> + Type, for<'c> Option: Encode<'c, D> + Type, { pub async fn write_address(&self, a: &Address) -> Result<(), StoreError> { let sqls = get_sqls(&self.sqls_root, WRITE_ADDRESS); sqlx::query(&sqls[0]) .bind(a.id) .bind(a.t.to_string()) .bind(match &a.t { AddressType::Email { email_address } => { email_address.as_ref().map(ToString::to_string) } AddressType::Sms { phone_number } => phone_number.clone(), }) .bind(a.created_at) .execute(&self.pool) .await .extend_err()?; Ok(()) } pub async fn find_address( &self, id: Option<&Uuid>, typ: Option<&str>, val: Option<&str>, ) -> Result, StoreError> { let sqls = get_sqls(&self.sqls_root, FIND_ADDRESS); let res = sqlx::query_as::<_, (Uuid, String, String, OffsetDateTime)>(&sqls[0]) .bind(id) .bind(typ) .bind(val) .fetch_all(&self.pool) .await .extend_err()?; let mut addresses = vec![]; for (id, typ, val, created_at) in res.into_iter() { let t = match AddressType::from_str(&typ) .map_err(|_| StoreError::StoreValueCannotBeParsedInvariant)? { AddressType::Email { .. } => AddressType::Email { email_address: Some( EmailAddress::from_str(&val) .map_err(|_| StoreError::StoreValueCannotBeParsedInvariant)?, ), }, AddressType::Sms { .. } => AddressType::Sms { phone_number: Some(val.clone()), }, }; addresses.push(Address { id, t, created_at }); } Ok(addresses) } pub async fn write_address_validation(&self, v: &AddressValidation) -> Result<(), StoreError> { let sqls = get_sqls(&self.sqls_root, WRITE_ADDRESS_VALIDATION); sqlx::query(&sqls[0]) .bind(v.id) .bind(v.identity_id.as_ref()) .bind(v.address.id) .bind(v.method.to_string()) .bind(v.hashed_token.clone()) .bind(v.hashed_code.clone()) .bind(v.attempts) .bind(v.created_at) .bind(v.expires_at) .bind(v.revoked_at) .bind(v.validated_at) .execute(&self.pool) .await .extend_err()?; Ok(()) } pub async fn find_address_validation( &self, id: Option<&Uuid>, ) -> Result, StoreError> { let sqls = get_sqls(&self.sqls_root, FIND_ADDRESS_VALIDATION); let rs = sqlx::query_as::< _, ( Uuid, Option, Uuid, String, String, OffsetDateTime, String, Vec, Vec, i32, OffsetDateTime, OffsetDateTime, Option, Option, ), >(&sqls[0]) .bind(id) .fetch_all(&self.pool) .await .extend_err()?; let mut res = vec![]; for ( id, identity_id, address_id, address_typ, address_val, address_created_at, method, hashed_token, hashed_code, attempts, created_at, expires_at, revoked_at, validated_at, ) in rs.into_iter() { let t = match AddressType::from_str(&address_typ) .map_err(|_| StoreError::StoreValueCannotBeParsedInvariant)? { AddressType::Email { .. } => AddressType::Email { email_address: Some( EmailAddress::from_str(&address_val) .map_err(|_| StoreError::StoreValueCannotBeParsedInvariant)?, ), }, AddressType::Sms { .. } => AddressType::Sms { phone_number: Some(address_val.clone()), }, }; res.push(AddressValidation { id, identity_id, address: Address { id: address_id, t, created_at: address_created_at, }, method: AddressValidationMethod::from_str(&method) .map_err(|_| StoreError::StoreValueCannotBeParsedInvariant)?, created_at, expires_at, revoked_at, validated_at, attempts, hashed_token, hashed_code, }); } Ok(res) } pub async fn write_identity(&self, i: &Identity) -> Result<(), StoreError> { let sqls = get_sqls(&self.sqls_root, WRITE_IDENTITY); sqlx::query(&sqls[0]) .bind(i.id) // TODO: validate this is actually Json somewhere way up the chain (when being deserialized) .bind(i.metadata.clone().unwrap_or("{}".into())) .bind(i.created_at) .bind(OffsetDateTime::now_utc()) .bind(i.deleted_at) .execute(&self.pool) .await .extend_err()?; Ok(()) } pub async fn find_identity( &self, id: Option<&Uuid>, address_value: Option<&str>, address_is_validated: Option, session_token_hash: &Option>, ) -> Result, StoreError> { let sqls = get_sqls(&self.sqls_root, FIND_IDENTITY); let rs = sqlx::query_as::< _, ( Uuid, Option, OffsetDateTime, OffsetDateTime, Option, ), >(&sqls[0]) .bind(id) .bind(address_value) .bind(address_is_validated) .bind(session_token_hash) .fetch_all(&self.pool) .await .extend_err()?; let mut res = vec![]; for (id, metadata, created_at, updated_at, deleted_at) in rs.into_iter() { res.push(Identity { id, address_validations: vec![], credentials: vec![], rules: vec![], metadata, created_at, deleted_at, }) } Ok(res) } pub async fn write_session(&self, s: &Session, token_hash: &[u8]) -> Result<(), StoreError> { let sqls = get_sqls(&self.sqls_root, WRITE_SESSION); sqlx::query(&sqls[0]) .bind(s.identity_id) .bind(token_hash) .bind(s.created_at) .bind(s.expired_at) .bind(s.revoked_at) .execute(&self.pool) .await .extend_err()?; Ok(()) } pub async fn find_session( &self, token: Vec, identity_id: Option<&Uuid>, ) -> Result, StoreError> { let sqls = get_sqls(&self.sqls_root, FIND_SESSION); let rs = sqlx::query_as::<_, (Uuid, OffsetDateTime, OffsetDateTime, Option)>( &sqls[0], ) .bind(token) .bind(identity_id) .bind(OffsetDateTime::now_utc()) .bind(OffsetDateTime::now_utc()) .fetch_all(&self.pool) .await .extend_err()?; let mut res = vec![]; for (identity_id, created_at, expired_at, revoked_at) in rs.into_iter() { res.push(Session { identity_id, token: vec![], created_at, expired_at, revoked_at, }); } Ok(res) } } fn get_sqls(root: &str, file: &str) -> Vec { SQLS.get(root) .unwrap() .get(file) .unwrap() .split("--") .map(|p| p.to_string()) .collect() }