use super::{Store, StoreError, StoreType}; use crate::Impersonator; use crate::{ util::ErrorContext, Address, AddressType, AddressValidation, AddressValidationMethod, Credential, CredentialId, CredentialType, Identity, IdentityId, }; use email_address::EmailAddress; use lazy_static::lazy_static; use sqlx::{ database::HasArguments, ColumnIndex, Database, Decode, Encode, Executor, IntoArguments, Pool, Transaction, Type, }; use sqlx::{Postgres, Sqlite}; use std::collections::HashMap; use std::{str::FromStr, sync::Arc}; use time::OffsetDateTime; use uuid::Uuid; const SQLITE: &str = "sqlite"; const PGSQL: &str = "pgsql"; const WRITE_ADDRESS: &str = "write_address"; const FIND_ADDRESS: &str = "find_address"; const WRITE_ADDRESS_VALIDATION: &str = "write_address_validation"; const FIND_ADDRESS_VALIDATION: &str = "find_address_validation"; const WRITE_CREDENTIAL: &str = "write_credential"; const FIND_CREDENTIAL: &str = "find_credential"; const WRITE_IDENTITY: &str = "write_identity"; const FIND_IDENTITY: &str = "find_identity"; const WRITE_IMPERSONATOR: &str = "write_impersonator"; const FIND_IMPERSONATOR: &str = "find_impersonator"; const ERR_MSG_MIGRATION_FAILED: &str = "Failed to apply secd migrations to a sql db. File a bug at https://www.github.com/branchcontrol/secdiam"; lazy_static! { static ref SQLS: HashMap<&'static str, HashMap<&'static str, &'static str>> = { let sqlite_sqls: HashMap<&'static str, &'static str> = [ ( WRITE_ADDRESS, include_str!("../../../store/sqlite/sql/write_address.sql"), ), ( FIND_ADDRESS, include_str!("../../../store/sqlite/sql/find_address.sql"), ), ( WRITE_ADDRESS_VALIDATION, include_str!("../../../store/sqlite/sql/write_address_validation.sql"), ), ( FIND_ADDRESS_VALIDATION, include_str!("../../../store/sqlite/sql/find_address_validation.sql"), ), ( WRITE_IDENTITY, include_str!("../../../store/sqlite/sql/write_identity.sql"), ), ( FIND_IDENTITY, include_str!("../../../store/sqlite/sql/find_identity.sql"), ), ( WRITE_CREDENTIAL, include_str!("../../../store/sqlite/sql/write_credential.sql"), ), ( FIND_CREDENTIAL, include_str!("../../../store/sqlite/sql/find_credential.sql"), ), ( WRITE_IMPERSONATOR, include_str!("../../../store/sqlite/sql/write_impersonator.sql"), ), ( FIND_IMPERSONATOR, include_str!("../../../store/sqlite/sql/find_impersonator.sql"), ), ] .iter() .cloned() .collect(); let pg_sqls: HashMap<&'static str, &'static str> = [ ( WRITE_ADDRESS, include_str!("../../../store/pg/sql/write_address.sql"), ), ( FIND_ADDRESS, include_str!("../../../store/pg/sql/find_address.sql"), ), ( WRITE_ADDRESS_VALIDATION, include_str!("../../../store/pg/sql/write_address_validation.sql"), ), ( FIND_ADDRESS_VALIDATION, include_str!("../../../store/pg/sql/find_address_validation.sql"), ), ( WRITE_IDENTITY, include_str!("../../../store/pg/sql/write_identity.sql"), ), ( FIND_IDENTITY, include_str!("../../../store/pg/sql/find_identity.sql"), ), ( WRITE_CREDENTIAL, include_str!("../../../store/pg/sql/write_credential.sql"), ), ( FIND_CREDENTIAL, include_str!("../../../store/pg/sql/find_credential.sql"), ), ( WRITE_IMPERSONATOR, include_str!("../../../store/pg/sql/write_impersonator.sql"), ), ( FIND_IMPERSONATOR, include_str!("../../../store/pg/sql/find_impersonator.sql"), ), ] .iter() .cloned() .collect(); let sqls: HashMap<&'static str, HashMap<&'static str, &'static str>> = [(SQLITE, sqlite_sqls), (PGSQL, pg_sqls)] .iter() .cloned() .collect(); sqls }; } pub trait SqlxResultExt { fn extend_err(self) -> Result; } impl SqlxResultExt for Result { fn extend_err(self) -> Result { if let Err(sqlx::Error::Database(dbe)) = &self { if dbe.code() == Some("23505".into()) || dbe.code() == Some("2067".into()) { return Err(StoreError::IdempotentCheckAlreadyExists); } } self.map_err(StoreError::SqlClientError) } } pub struct SqlClient where D: sqlx::Database, { pool: sqlx::Pool, sqls_root: String, } pub struct PgClient { sql: Arc>, } impl Store for PgClient { fn get_type(&self) -> StoreType { StoreType::Postgres { c: self.sql.clone(), } } } impl PgClient { pub async fn new_ref(pool: sqlx::Pool) -> Arc { sqlx::migrate!("store/pg/migrations", "secd") .run(&pool) .await .expect(ERR_MSG_MIGRATION_FAILED); Arc::new(PgClient { sql: Arc::new(SqlClient { pool, sqls_root: PGSQL.to_string(), }), }) } } pub struct SqliteClient { sql: Arc>, } impl Store for SqliteClient { fn get_type(&self) -> StoreType { StoreType::Sqlite { c: self.sql.clone(), } } } impl SqliteClient { pub async fn new_ref(pool: sqlx::Pool) -> Arc { sqlx::migrate!("store/sqlite/migrations", "secd") .run(&pool) .await .expect(ERR_MSG_MIGRATION_FAILED); sqlx::query("pragma foreign_keys = on") .execute(&pool) .await .expect( "Failed to initialize FK pragma. File a bug at https://www.github.com/secd-lib", ); Arc::new(SqliteClient { sql: Arc::new(SqlClient { pool, sqls_root: SQLITE.to_string(), }), }) } } impl SqlClient where D: sqlx::Database, for<'c> &'c Pool: Executor<'c, Database = D>, for<'c> &'c mut Transaction<'c, D>: Executor<'c, Database = D>, for<'c> >::Arguments: IntoArguments<'c, D>, for<'c> bool: Decode<'c, D> + Type, for<'c> bool: Encode<'c, D> + Type, for<'c> Option: Decode<'c, D> + Type, for<'c> Option: Encode<'c, D> + Type, for<'c> i64: Decode<'c, D> + Type, for<'c> i64: Encode<'c, D> + Type, for<'c> i32: Decode<'c, D> + Type, for<'c> i32: Encode<'c, D> + Type, for<'c> OffsetDateTime: Decode<'c, D> + Type, for<'c> OffsetDateTime: Encode<'c, D> + Type, for<'c> &'c str: ColumnIndex<::Row>, for<'c> &'c str: Decode<'c, D> + Type, for<'c> &'c str: Encode<'c, D> + Type, for<'c> Option<&'c str>: Decode<'c, D> + Type, for<'c> Option<&'c str>: Encode<'c, D> + Type, for<'c> String: Decode<'c, D> + Type, for<'c> String: Encode<'c, D> + Type, for<'c> Option: Decode<'c, D> + Type, for<'c> Option: Encode<'c, D> + Type, for<'c> usize: ColumnIndex<::Row>, for<'c> Uuid: Decode<'c, D> + Type, for<'c> Uuid: Encode<'c, D> + Type, for<'c> &'c [u8]: Encode<'c, D> + Type, for<'c> Option<&'c Uuid>: Encode<'c, D> + Type, for<'c> Vec: Encode<'c, D> + Type, for<'c> Vec: Decode<'c, D> + Type, for<'c> Option>: Encode<'c, D> + Type, for<'c> Option>: Decode<'c, D> + Type, for<'c> Option: Decode<'c, D> + Type, for<'c> Option: Encode<'c, D> + Type, { pub async fn write_address(&self, a: &Address) -> Result<(), StoreError> { let sqls = get_sqls(&self.sqls_root, WRITE_ADDRESS); sqlx::query(&sqls[0]) .bind(a.id) .bind(a.t.to_string()) .bind(match &a.t { AddressType::Email { email_address } => { email_address.as_ref().map(ToString::to_string) } AddressType::Sms { phone_number } => phone_number.clone(), }) .bind(a.created_at) .execute(&self.pool) .await .extend_err()?; Ok(()) } pub async fn find_address( &self, id: Option<&Uuid>, typ: Option<&str>, val: Option<&str>, ) -> Result, StoreError> { let sqls = get_sqls(&self.sqls_root, FIND_ADDRESS); let res = sqlx::query_as::<_, (Uuid, String, String, OffsetDateTime)>(&sqls[0]) .bind(id) .bind(typ) .bind(val) .fetch_all(&self.pool) .await .extend_err()?; let mut addresses = vec![]; for (id, typ, val, created_at) in res.into_iter() { let t = match AddressType::from_str(&typ) .map_err(|_| StoreError::StoreValueCannotBeParsedInvariant)? { AddressType::Email { .. } => AddressType::Email { email_address: Some( EmailAddress::from_str(&val) .map_err(|_| StoreError::StoreValueCannotBeParsedInvariant)?, ), }, AddressType::Sms { .. } => AddressType::Sms { phone_number: Some(val.clone()), }, }; addresses.push(Address { id, t, created_at }); } Ok(addresses) } pub async fn write_address_validation(&self, v: &AddressValidation) -> Result<(), StoreError> { let sqls = get_sqls(&self.sqls_root, WRITE_ADDRESS_VALIDATION); sqlx::query(&sqls[0]) .bind(v.id) .bind(v.identity_id.as_ref()) .bind(v.address.id) .bind(v.method.to_string()) .bind(v.hashed_token.clone()) .bind(v.hashed_code.clone()) .bind(v.attempts) .bind(v.created_at) .bind(v.expires_at) .bind(v.revoked_at) .bind(v.validated_at) .execute(&self.pool) .await .extend_err()?; Ok(()) } pub async fn find_address_validation( &self, id: Option<&Uuid>, ) -> Result, StoreError> { let sqls = get_sqls(&self.sqls_root, FIND_ADDRESS_VALIDATION); let rs = sqlx::query_as::< _, ( Uuid, Option, Uuid, String, String, OffsetDateTime, String, Vec, Vec, i32, OffsetDateTime, OffsetDateTime, Option, Option, ), >(&sqls[0]) .bind(id) .fetch_all(&self.pool) .await .extend_err()?; let mut res = vec![]; for ( id, identity_id, address_id, address_typ, address_val, address_created_at, method, hashed_token, hashed_code, attempts, created_at, expires_at, revoked_at, validated_at, ) in rs.into_iter() { let t = match AddressType::from_str(&address_typ) .map_err(|_| StoreError::StoreValueCannotBeParsedInvariant)? { AddressType::Email { .. } => AddressType::Email { email_address: Some( EmailAddress::from_str(&address_val) .map_err(|_| StoreError::StoreValueCannotBeParsedInvariant)?, ), }, AddressType::Sms { .. } => AddressType::Sms { phone_number: Some(address_val.clone()), }, }; res.push(AddressValidation { id, identity_id, address: Address { id: address_id, t, created_at: address_created_at, }, method: AddressValidationMethod::from_str(&method) .map_err(|_| StoreError::StoreValueCannotBeParsedInvariant)?, created_at, expires_at, revoked_at, validated_at, attempts, hashed_token, hashed_code, }); } Ok(res) } pub async fn write_identity(&self, i: &Identity) -> Result<(), StoreError> { let sqls = get_sqls(&self.sqls_root, WRITE_IDENTITY); sqlx::query(&sqls[0]) .bind(i.id) .bind(i.metadata.clone()) .bind(i.created_at) .bind(OffsetDateTime::now_utc()) .bind(i.deleted_at) .execute(&self.pool) .await .extend_err()?; Ok(()) } pub async fn find_identity( &self, id: Option<&Uuid>, address_value: Option<&str>, address_is_validated: Option, ) -> Result, StoreError> { let sqls = get_sqls(&self.sqls_root, FIND_IDENTITY); let rs = sqlx::query_as::< _, ( Uuid, Option, OffsetDateTime, OffsetDateTime, Option, ), >(&sqls[0]) .bind(id) .bind(address_value) .bind(address_is_validated) .fetch_all(&self.pool) .await .extend_err()?; let mut res = vec![]; for (id, metadata, created_at, _, deleted_at) in rs.into_iter() { res.push(Identity { id, address_validations: vec![], credentials: self.find_credential(None, Some(id), None).await?, new_credentials: vec![], rules: vec![], metadata, created_at, deleted_at, }) } Ok(res) } pub async fn write_credential(&self, c: &Credential) -> Result<(), StoreError> { let sqls = get_sqls(&self.sqls_root, WRITE_CREDENTIAL); let partial_key = match &c.t { CredentialType::Passphrase { key, .. } => Some(key.clone()), CredentialType::ApiToken { public, .. } => Some(public.clone()), CredentialType::Session { key, .. } => Some(key.clone()), }; sqlx::query(&sqls[0]) .bind(c.id) .bind(c.identity_id) .bind(partial_key) .bind(c.t.to_string()) .bind(serde_json::to_string(&c.t)?) .bind(c.created_at) .bind(c.revoked_at) .bind(c.deleted_at) .execute(&self.pool) .await .extend_err()?; Ok(()) } pub async fn find_credential( &self, id: Option, identity_id: Option, t: Option<&CredentialType>, ) -> Result, StoreError> { let sqls = get_sqls(&self.sqls_root, FIND_CREDENTIAL); let key = t.map(|i| match i { CredentialType::Passphrase { key, .. } => key.clone(), CredentialType::ApiToken { public, .. } => public.clone(), CredentialType::Session { key, .. } => key.clone(), }); let rs = sqlx::query_as::< _, ( CredentialId, IdentityId, String, OffsetDateTime, Option, Option, ), >(&sqls[0]) .bind(id.as_ref()) .bind(identity_id.as_ref()) .bind(t.map(|i| i.to_string())) .bind(key) .fetch_all(&self.pool) .await .extend_err()?; let mut res = vec![]; for (id, identity_id, data, created_at, revoked_at, deleted_at) in rs.into_iter() { let t: CredentialType = serde_json::from_str(&data).ctx("error while deserializing credential_type")?; res.push(Credential { id, identity_id, t, created_at, revoked_at, deleted_at, }) } Ok(res) } pub async fn write_impersonator(&self, i: &Impersonator) -> Result<(), StoreError> { let sqls = get_sqls(&self.sqls_root, WRITE_IMPERSONATOR); sqlx::query(&sqls[0]) .bind(i.impersonator.id) .bind(i.target.id) .bind(i.target.new_credentials.get(0).map(|e| &e.id)) .bind(i.created_at) .fetch_all(&self.pool) .await .extend_err()?; Ok(()) } pub async fn find_impersonator( &self, impersonator_id: Option<&Uuid>, target_id: Option<&Uuid>, ) -> Result, StoreError> { let sqls = get_sqls(&self.sqls_root, FIND_IMPERSONATOR); let rs = sqlx::query_as::<_, (Uuid, Uuid, OffsetDateTime)>(&sqls[0]) .bind(impersonator_id) .bind(target_id) .bind(OffsetDateTime::now_utc()) .fetch_all(&self.pool) .await .extend_err()?; let mut res = vec![]; for (impersonator_id, target_id, created_at) in rs.into_iter() { let impersonator = self .find_identity(Some(&impersonator_id), None, None) .await? .into_iter() .next() .ok_or(StoreError::ExpectedEntity)?; let target = self .find_identity(Some(&target_id), None, None) .await? .into_iter() .next() .ok_or(StoreError::ExpectedEntity)?; res.push(Impersonator { impersonator, target, created_at, }) } Ok(res) } } fn get_sqls(root: &str, file: &str) -> Vec { SQLS.get(root) .unwrap() .get(file) .unwrap() .split("--") .map(|p| p.to_string()) .collect() }