use std::{str::FromStr, sync::Arc}; use super::{ EmailValidation, Identity, OauthProvider, OauthProviderName, OauthResponseType, Session, SessionSecret, Store, StoreError, ERR_MSG_MIGRATION_FAILED, FIND_EMAIL_VALIDATION, FIND_IDENTITY, FIND_IDENTITY_BY_CODE, PGSQL, READ_EMAIL_RAW_ID, READ_IDENTITY_RAW_ID, READ_OAUTH_PROVIDER, READ_OAUTH_VALIDATION, READ_SESSION, READ_VALIDATION_TYPE, SQLITE, SQLS, WRITE_EMAIL, WRITE_EMAIL_VALIDATION, WRITE_IDENTITY, WRITE_OAUTH_PROVIDER, WRITE_OAUTH_VALIDATION, WRITE_SESSION, }; use crate::{util, OauthValidation, ValidationRequestId, ValidationType}; use anyhow::bail; use log::{debug, error}; use openssl::sha::Sha256; use sqlx::{ self, database::HasArguments, ColumnIndex, Database, Decode, Encode, Executor, IntoArguments, Pool, Postgres, Sqlite, Transaction, Type, }; use time::OffsetDateTime; use url::Url; use uuid::Uuid; fn get_sqls(root: &str, file: &str) -> Vec { SQLS.get(root) .unwrap() .get(file) .unwrap() .split("--") .map(|p| p.to_string()) .collect() } fn hash_secret(secret: &str) -> Vec { let mut hasher = Sha256::new(); hasher.update(secret.as_bytes()); hasher.finish().to_vec() } struct SqlClient where D: sqlx::Database, { pool: sqlx::Pool, sqls_root: String, } impl SqlClient where D: sqlx::Database, for<'c> >::Arguments: IntoArguments<'c, D>, for<'c> i64: Decode<'c, D> + Type, for<'c> &'c str: Decode<'c, D> + Type, for<'c> &'c str: Encode<'c, D> + Type, for<'c> usize: ColumnIndex<::Row>, for<'c> Uuid: Decode<'c, D> + Type, for<'c> Uuid: Encode<'c, D> + Type, for<'c> &'c Pool: Executor<'c, Database = D>, { async fn read_identity_raw_id(&self, id: &Uuid) -> Result { let sqls = get_sqls(&self.sqls_root, READ_IDENTITY_RAW_ID); Ok(sqlx::query_as::<_, (i64,)>(&sqls[0]) .bind(id) .fetch_one(&self.pool) .await .map_err(util::log_err_sqlx)? .0) } async fn read_email_raw_id(&self, address: &str) -> Result { let sqls = get_sqls(&self.sqls_root, READ_EMAIL_RAW_ID); Ok(sqlx::query_as::<_, (i64,)>(&sqls[0]) .bind(address) .fetch_one(&self.pool) .await .map_err(util::log_err_sqlx)? .0) } } #[async_trait::async_trait] impl Store for SqlClient where D: sqlx::Database, for<'c> >::Arguments: IntoArguments<'c, D>, for<'c> bool: Decode<'c, D> + Type, for<'c> bool: Encode<'c, D> + Type, for<'c> i64: Decode<'c, D> + Type, for<'c> i64: Encode<'c, D> + Type, for<'c> i32: Decode<'c, D> + Type, for<'c> i32: Encode<'c, D> + Type, for<'c> OffsetDateTime: Decode<'c, D> + Type, for<'c> OffsetDateTime: Encode<'c, D> + Type, for<'c> &'c str: ColumnIndex<::Row>, for<'c> &'c str: Decode<'c, D> + Type, for<'c> &'c str: Encode<'c, D> + Type, for<'c> Option<&'c str>: Decode<'c, D> + Type, for<'c> Option<&'c str>: Encode<'c, D> + Type, for<'c> String: Decode<'c, D> + Type, for<'c> String: Encode<'c, D> + Type, for<'c> Option: Decode<'c, D> + Type, for<'c> Option: Encode<'c, D> + Type, for<'c> OauthProviderName: Decode<'c, D> + Type, for<'c> OauthResponseType: Decode<'c, D> + Type, for<'c> usize: ColumnIndex<::Row>, for<'c> Uuid: Decode<'c, D> + Type, for<'c> Uuid: Encode<'c, D> + Type, for<'c> &'c [u8]: Encode<'c, D> + Type, for<'c> Option<&'c Uuid>: Encode<'c, D> + Type, for<'c> Option<&'c Vec>: Encode<'c, D> + Type, for<'c> Option: Decode<'c, D> + Type, for<'c> Option: Encode<'c, D> + Type, for<'c> &'c Pool: Executor<'c, Database = D>, for<'c> &'c mut Transaction<'c, D>: Executor<'c, Database = D>, { async fn write_email(&self, email_address: &str) -> Result<(), StoreError> { let sqls = get_sqls(&self.sqls_root, WRITE_EMAIL); sqlx::query(&sqls[0]) .bind(email_address) .execute(&self.pool) .await .map_err(util::log_err_sqlx)?; Ok(()) } async fn find_email_validation( &self, validation_id: Option<&Uuid>, code: Option<&str>, ) -> Result { let sqls = get_sqls(&self.sqls_root, FIND_EMAIL_VALIDATION); let mut rows = sqlx::query_as::<_, EmailValidation>(&sqls[0]) .bind(validation_id) .bind(code) .fetch_all(&self.pool) .await .map_err(util::log_err_sqlx)?; match rows.len() { 0 => Err(StoreError::NoEmailValidationFound), 1 => Ok(rows.swap_remove(0)), _ => Err(StoreError::TooManyValidations), } } async fn write_email_validation(&self, ev: &EmailValidation) -> anyhow::Result { let sqls = get_sqls(&self.sqls_root, WRITE_EMAIL_VALIDATION); let email_id = self.read_email_raw_id(&ev.email_address).await?; let validation_id = ev.id.unwrap_or(Uuid::new_v4()); sqlx::query(&sqls[0]) .bind(validation_id) .bind(email_id) .bind(&ev.code) .bind(ev.is_oauth_derived) .bind(ev.created_at) .bind(ev.validated_at) .bind(ev.expired_at) .execute(&self.pool) .await .map_err(util::log_err_sqlx)?; if ev.identity_id.is_some() || ev.revoked_at.is_some() || ev.deleted_at.is_some() { sqlx::query(&sqls[1]) .bind(ev.identity_id.as_ref()) .bind(validation_id) .bind(ev.revoked_at) .bind(ev.deleted_at) .execute(&self.pool) .await .map_err(util::log_err_sqlx)?; } Ok(validation_id) } async fn find_identity( &self, id: Option<&Uuid>, email: Option<&str>, ) -> anyhow::Result> { let sqls = get_sqls(&self.sqls_root, FIND_IDENTITY); Ok( match sqlx::query_as::<_, Identity>(&sqls[0]) .bind(id) .bind(email) .fetch_all(&self.pool) .await { Ok(mut is) => match is.len() { // if only 1 found, then that's fine // if multiple are fond, then if they all have the same id, that's okay 1 => { let i = is.swap_remove(0); match i.deleted_at { Some(t) if t > OffsetDateTime::now_utc() => Some(i), None => Some(i), _ => None, } } 0 => None, _ => { match is .iter() .filter(|&i| i.id != is[0].id) .collect::>() .len() { 0 => Some(is.swap_remove(0)), _ => bail!(StoreError::TooManyIdentitiesFound), } } }, Err(sqlx::Error::RowNotFound) => None, Err(e) => bail!(StoreError::SqlxError(e)), }, ) } async fn find_identity_by_code(&self, code: &str) -> Result { let sqls = get_sqls(&self.sqls_root, FIND_IDENTITY_BY_CODE); let rows = sqlx::query_as::<_, (i32,)>(&sqls[0]) .bind(code) .fetch_all(&self.pool) .await .map_err(util::log_err_sqlx)?; if rows.len() == 0 { return Err(StoreError::CodeDoesNotExist(code.to_string())); } if rows.len() != 1 { return Err(StoreError::CodeAppearsMoreThanOnce); } let identity_email_id = rows.get(0).unwrap().0; // TODO: IF we expand beyond email codes, then we'll need to join against a bunch of identity tables. // but since a single code was found, only one of them should pop... Ok(sqlx::query_as::<_, Identity>(&sqls[1]) .bind(identity_email_id) .fetch_one(&self.pool) .await .map_err(util::log_err_sqlx)?) } async fn write_identity(&self, i: &Identity) -> Result<(), StoreError> { let sqls = get_sqls(&self.sqls_root, WRITE_IDENTITY); sqlx::query(&sqls[0]) .bind(i.id) .bind(i.data.clone()) .bind(i.created_at) .execute(&self.pool) .await .map_err(|e| { error!("write_identity_failure"); error!("{:?}", e); e })?; Ok(()) } async fn read_identity(&self, id: &Uuid) -> Result { let identity = sqlx::query_as::<_, Identity>( " select identity_public_id, data, created_at from identity where identity_public_id = ?", ) .bind(id) .fetch_one(&self.pool) .await .map_err(util::log_err_sqlx)?; Ok(identity) } async fn write_session(&self, session: &Session) -> Result<(), StoreError> { let sqls = get_sqls(&self.sqls_root, WRITE_SESSION); let secret_hash = session.secret.as_ref().map(|s| hash_secret(s)); sqlx::query(&sqls[0]) .bind(&session.identity_id) .bind(secret_hash.as_ref()) .bind(session.created_at) .bind(session.expired_at) .bind(session.revoked_at) .execute(&self.pool) .await .map_err(util::log_err_sqlx)?; Ok(()) } async fn read_session(&self, secret: &SessionSecret) -> Result { let sqls = get_sqls(&self.sqls_root, READ_SESSION); let secret_hash = hash_secret(secret); let mut session = sqlx::query_as::<_, Session>(&sqls[0]) .bind(&secret_hash[..]) .fetch_one(&self.pool) .await .map_err(util::log_err_sqlx)?; // This should do nothing other than updated touched_at, and then // clear the plaintext secret session.secret = Some(secret.to_string()); self.write_session(&session).await?; session.secret = None; Ok(session) } async fn write_oauth_provider(&self, provider: &OauthProvider) -> Result<(), StoreError> { let sqls = get_sqls(&self.sqls_root, WRITE_OAUTH_PROVIDER); sqlx::query(&sqls[0]) .bind(&provider.name.to_string()) .bind(&provider.flow) .bind(&provider.base_url.to_string()) .bind(&provider.response.to_string()) .bind(&provider.default_scope) .bind(&provider.client_id) // TODO: encrypt secret before writing .bind(&provider.client_secret) .bind(&provider.redirect_url.to_string()) .bind(provider.created_at) .bind(provider.deleted_at) .execute(&self.pool) .await .map_err(util::log_err_sqlx)?; Ok(()) } async fn read_oauth_provider( &self, provider: &OauthProviderName, flow: Option, ) -> Result { let sqls = get_sqls(&self.sqls_root, READ_OAUTH_PROVIDER); let flow = flow.unwrap_or("default".into()); debug!("provider: {:?}, flow: {:?}", provider, flow); // TODO: Write the generic FromRow impl for OauthProvider... let res = sqlx::query_as::< _, ( String, String, String, String, String, String, String, OffsetDateTime, Option, ), >(&sqls[0]) .bind(&provider.to_string()) .bind(&flow) .fetch_one(&self.pool) .await .map_err(util::log_err_sqlx)?; debug!("res: {:?}", res); Ok(OauthProvider { name: provider.clone(), flow: Some(res.0), base_url: Url::from_str(&res.1) .map_err(|_| StoreError::OauthProviderDoesNotExist(*provider))?, response: OauthResponseType::from_str(&res.2) .map_err(|_| StoreError::OauthProviderDoesNotExist(*provider))?, default_scope: res.3, client_id: res.4, client_secret: res.5, redirect_url: Url::from_str(&res.6) .map_err(|_| StoreError::OauthProviderDoesNotExist(*provider))?, created_at: res.7, deleted_at: res.8, }) } async fn write_oauth_validation( &self, v: &OauthValidation, ) -> anyhow::Result { let sqls = get_sqls(&self.sqls_root, WRITE_OAUTH_VALIDATION); let validation_id = v.id.unwrap_or(Uuid::new_v4()); sqlx::query(&sqls[0]) .bind(validation_id) .bind(v.oauth_provider.name.to_string()) .bind(v.oauth_provider.flow.clone()) .bind(v.access_token.clone()) .bind(v.raw_response.clone()) .bind(v.created_at) .bind(v.validated_at) .execute(&self.pool) .await?; if v.identity_id.is_some() || v.revoked_at.is_some() || v.deleted_at.is_some() { sqlx::query(&sqls[1]) .bind(v.identity_id.as_ref()) .bind(validation_id) .bind(v.revoked_at) .bind(v.deleted_at) .execute(&self.pool) .await?; } Ok(validation_id) } async fn read_oauth_validation( &self, validation_id: &ValidationRequestId, ) -> anyhow::Result { let sqls = get_sqls(&self.sqls_root, READ_OAUTH_VALIDATION); let mut es = sqlx::query_as::<_, OauthValidation>(&sqls[0]) .bind(validation_id) .fetch_all(&self.pool) .await?; if es.len() != 1 { bail!(StoreError::OauthValidationDoesNotExist( validation_id.clone() )); } Ok(es.swap_remove(0)) } async fn find_validation_type( &self, validation_id: &ValidationRequestId, ) -> anyhow::Result { let sqls = get_sqls(&self.sqls_root, READ_VALIDATION_TYPE); let mut es = sqlx::query_as::<_, (String,)>(&sqls[0]) .bind(validation_id) .fetch_all(&self.pool) .await .map_err(util::log_err_sqlx)?; match es.len() { 1 => Ok(ValidationType::from_str(&es.swap_remove(0).0)?), _ => bail!(StoreError::Other( "expected a single validation but recieved 0 or multiple validations".into() )), } } } pub struct PgClient { sql: SqlClient, } impl PgClient { pub async fn new(pool: sqlx::Pool) -> Arc { sqlx::migrate!("store/pg/migrations") .run(&pool) .await .expect(ERR_MSG_MIGRATION_FAILED); Arc::new(PgClient { sql: SqlClient { pool, sqls_root: PGSQL.to_string(), }, }) } } #[async_trait::async_trait] impl Store for PgClient { async fn write_email(&self, email_address: &str) -> Result<(), StoreError> { self.sql.write_email(email_address).await } async fn find_email_validation( &self, validation_id: Option<&Uuid>, code: Option<&str>, ) -> Result { self.sql.find_email_validation(validation_id, code).await } async fn write_email_validation(&self, ev: &EmailValidation) -> anyhow::Result { self.sql.write_email_validation(ev).await } async fn find_identity( &self, identity_id: Option<&Uuid>, email: Option<&str>, ) -> anyhow::Result> { self.sql.find_identity(identity_id, email).await } async fn find_identity_by_code(&self, code: &str) -> Result { self.sql.find_identity_by_code(code).await } async fn write_identity(&self, i: &Identity) -> Result<(), StoreError> { self.sql.write_identity(i).await } async fn read_identity(&self, identity_id: &Uuid) -> Result { self.sql.read_identity(identity_id).await } async fn write_session(&self, session: &Session) -> Result<(), StoreError> { self.sql.write_session(session).await } async fn read_session(&self, secret: &SessionSecret) -> Result { self.sql.read_session(secret).await } async fn write_oauth_provider(&self, provider: &OauthProvider) -> Result<(), StoreError> { self.sql.write_oauth_provider(provider).await } async fn read_oauth_provider( &self, provider: &OauthProviderName, flow: Option, ) -> Result { self.sql.read_oauth_provider(provider, flow).await } async fn write_oauth_validation( &self, validation: &OauthValidation, ) -> anyhow::Result { self.sql.write_oauth_validation(validation).await } async fn read_oauth_validation( &self, validation_id: &ValidationRequestId, ) -> anyhow::Result { self.sql.read_oauth_validation(validation_id).await } async fn find_validation_type( &self, validation_id: &ValidationRequestId, ) -> anyhow::Result { self.sql.find_validation_type(validation_id).await } } pub struct SqliteClient { sql: SqlClient, } impl SqliteClient { pub async fn new(pool: sqlx::Pool) -> Arc { sqlx::migrate!("store/sqlite/migrations") .run(&pool) .await .expect(ERR_MSG_MIGRATION_FAILED); sqlx::query("pragma foreign_keys = on") .execute(&pool) .await .expect( "Failed to initialize FK pragma. File a bug at https://www.github.com/secd-lib", ); Arc::new(SqliteClient { sql: SqlClient { pool, sqls_root: SQLITE.to_string(), }, }) } } #[async_trait::async_trait] impl Store for SqliteClient { async fn write_email(&self, email_address: &str) -> Result<(), StoreError> { self.sql.write_email(email_address).await } async fn find_email_validation( &self, validation_id: Option<&Uuid>, code: Option<&str>, ) -> Result { self.sql.find_email_validation(validation_id, code).await } async fn write_email_validation(&self, ev: &EmailValidation) -> anyhow::Result { self.sql.write_email_validation(ev).await } async fn find_identity( &self, identity_id: Option<&Uuid>, email: Option<&str>, ) -> anyhow::Result> { self.sql.find_identity(identity_id, email).await } async fn find_identity_by_code(&self, code: &str) -> Result { self.sql.find_identity_by_code(code).await } async fn write_identity(&self, i: &Identity) -> Result<(), StoreError> { self.sql.write_identity(i).await } async fn read_identity(&self, identity_id: &Uuid) -> Result { self.sql.read_identity(identity_id).await } async fn write_session(&self, session: &Session) -> Result<(), StoreError> { self.sql.write_session(session).await } async fn read_session(&self, secret: &SessionSecret) -> Result { self.sql.read_session(secret).await } async fn write_oauth_provider(&self, provider: &OauthProvider) -> Result<(), StoreError> { self.sql.write_oauth_provider(provider).await } async fn read_oauth_provider( &self, provider: &OauthProviderName, flow: Option, ) -> Result { self.sql.read_oauth_provider(provider, flow).await } async fn write_oauth_validation( &self, validation: &OauthValidation, ) -> anyhow::Result { self.sql.write_oauth_validation(validation).await } async fn read_oauth_validation( &self, validation_id: &ValidationRequestId, ) -> anyhow::Result { self.sql.read_oauth_validation(validation_id).await } async fn find_validation_type( &self, validation_id: &ValidationRequestId, ) -> anyhow::Result { self.sql.find_validation_type(validation_id).await } }