diff options
Diffstat (limited to '')
| -rw-r--r-- | crates/secd/src/client/sqldb.rs | 632 |
1 files changed, 0 insertions, 632 deletions
diff --git a/crates/secd/src/client/sqldb.rs b/crates/secd/src/client/sqldb.rs deleted file mode 100644 index 6751ef6..0000000 --- a/crates/secd/src/client/sqldb.rs +++ /dev/null @@ -1,632 +0,0 @@ -use std::{str::FromStr, sync::Arc}; - -use super::{ - EmailValidation, Identity, OauthProvider, OauthProviderName, OauthResponseType, Session, - SessionSecret, Store, StoreError, ERR_MSG_MIGRATION_FAILED, FIND_EMAIL_VALIDATION, - FIND_IDENTITY, FIND_IDENTITY_BY_CODE, PGSQL, READ_EMAIL_RAW_ID, READ_IDENTITY_RAW_ID, - READ_OAUTH_PROVIDER, READ_OAUTH_VALIDATION, READ_SESSION, READ_VALIDATION_TYPE, SQLITE, SQLS, - WRITE_EMAIL, WRITE_EMAIL_VALIDATION, WRITE_IDENTITY, WRITE_OAUTH_PROVIDER, - WRITE_OAUTH_VALIDATION, WRITE_SESSION, -}; -use crate::{util, OauthValidation, ValidationRequestId, ValidationType}; -use anyhow::bail; -use log::{debug, error}; -use openssl::sha::Sha256; -use sqlx::{ - self, database::HasArguments, ColumnIndex, Database, Decode, Encode, Executor, IntoArguments, - Pool, Postgres, Sqlite, Transaction, Type, -}; -use time::OffsetDateTime; -use url::Url; -use uuid::Uuid; - -fn get_sqls(root: &str, file: &str) -> Vec<String> { - SQLS.get(root) - .unwrap() - .get(file) - .unwrap() - .split("--") - .map(|p| p.to_string()) - .collect() -} - -fn hash_secret(secret: &str) -> Vec<u8> { - let mut hasher = Sha256::new(); - hasher.update(secret.as_bytes()); - hasher.finish().to_vec() -} - -struct SqlClient<D> -where - D: sqlx::Database, -{ - pool: sqlx::Pool<D>, - sqls_root: String, -} - -impl<D> SqlClient<D> -where - D: sqlx::Database, - for<'c> <D as HasArguments<'c>>::Arguments: IntoArguments<'c, D>, - for<'c> i64: Decode<'c, D> + Type<D>, - for<'c> &'c str: Decode<'c, D> + Type<D>, - for<'c> &'c str: Encode<'c, D> + Type<D>, - for<'c> usize: ColumnIndex<<D as Database>::Row>, - for<'c> Uuid: Decode<'c, D> + Type<D>, - for<'c> Uuid: Encode<'c, D> + Type<D>, - for<'c> &'c Pool<D>: Executor<'c, Database = D>, -{ - async fn read_identity_raw_id(&self, id: &Uuid) -> Result<i64, StoreError> { - let sqls = get_sqls(&self.sqls_root, READ_IDENTITY_RAW_ID); - - Ok(sqlx::query_as::<_, (i64,)>(&sqls[0]) - .bind(id) - .fetch_one(&self.pool) - .await - .map_err(util::log_err_sqlx)? - .0) - } - - async fn read_email_raw_id(&self, address: &str) -> Result<i64, StoreError> { - let sqls = get_sqls(&self.sqls_root, READ_EMAIL_RAW_ID); - - Ok(sqlx::query_as::<_, (i64,)>(&sqls[0]) - .bind(address) - .fetch_one(&self.pool) - .await - .map_err(util::log_err_sqlx)? - .0) - } -} - -#[async_trait::async_trait] -impl<D> Store for SqlClient<D> -where - D: sqlx::Database, - for<'c> <D as HasArguments<'c>>::Arguments: IntoArguments<'c, D>, - for<'c> bool: Decode<'c, D> + Type<D>, - for<'c> bool: Encode<'c, D> + Type<D>, - for<'c> i64: Decode<'c, D> + Type<D>, - for<'c> i64: Encode<'c, D> + Type<D>, - for<'c> i32: Decode<'c, D> + Type<D>, - for<'c> i32: Encode<'c, D> + Type<D>, - for<'c> OffsetDateTime: Decode<'c, D> + Type<D>, - for<'c> OffsetDateTime: Encode<'c, D> + Type<D>, - for<'c> &'c str: ColumnIndex<<D as Database>::Row>, - for<'c> &'c str: Decode<'c, D> + Type<D>, - for<'c> &'c str: Encode<'c, D> + Type<D>, - for<'c> Option<&'c str>: Decode<'c, D> + Type<D>, - for<'c> Option<&'c str>: Encode<'c, D> + Type<D>, - for<'c> String: Decode<'c, D> + Type<D>, - for<'c> String: Encode<'c, D> + Type<D>, - for<'c> Option<String>: Decode<'c, D> + Type<D>, - for<'c> Option<String>: Encode<'c, D> + Type<D>, - for<'c> OauthProviderName: Decode<'c, D> + Type<D>, - for<'c> OauthResponseType: Decode<'c, D> + Type<D>, - for<'c> usize: ColumnIndex<<D as Database>::Row>, - for<'c> Uuid: Decode<'c, D> + Type<D>, - for<'c> Uuid: Encode<'c, D> + Type<D>, - for<'c> &'c [u8]: Encode<'c, D> + Type<D>, - for<'c> Option<&'c Uuid>: Encode<'c, D> + Type<D>, - for<'c> Option<&'c Vec<u8>>: Encode<'c, D> + Type<D>, - for<'c> Option<OffsetDateTime>: Decode<'c, D> + Type<D>, - for<'c> Option<OffsetDateTime>: Encode<'c, D> + Type<D>, - for<'c> &'c Pool<D>: Executor<'c, Database = D>, - for<'c> &'c mut Transaction<'c, D>: Executor<'c, Database = D>, -{ - async fn write_email(&self, email_address: &str) -> Result<(), StoreError> { - let sqls = get_sqls(&self.sqls_root, WRITE_EMAIL); - - sqlx::query(&sqls[0]) - .bind(email_address) - .execute(&self.pool) - .await - .map_err(util::log_err_sqlx)?; - - Ok(()) - } - - async fn find_email_validation( - &self, - validation_id: Option<&Uuid>, - code: Option<&str>, - ) -> Result<EmailValidation, StoreError> { - let sqls = get_sqls(&self.sqls_root, FIND_EMAIL_VALIDATION); - let mut rows = sqlx::query_as::<_, EmailValidation>(&sqls[0]) - .bind(validation_id) - .bind(code) - .fetch_all(&self.pool) - .await - .map_err(util::log_err_sqlx)?; - - match rows.len() { - 0 => Err(StoreError::NoEmailValidationFound), - 1 => Ok(rows.swap_remove(0)), - _ => Err(StoreError::TooManyValidations), - } - } - - async fn write_email_validation(&self, ev: &EmailValidation) -> anyhow::Result<Uuid> { - let sqls = get_sqls(&self.sqls_root, WRITE_EMAIL_VALIDATION); - - let email_id = self.read_email_raw_id(&ev.email_address).await?; - let validation_id = ev.id.unwrap_or(Uuid::new_v4()); - sqlx::query(&sqls[0]) - .bind(validation_id) - .bind(email_id) - .bind(&ev.code) - .bind(ev.is_oauth_derived) - .bind(ev.created_at) - .bind(ev.validated_at) - .bind(ev.expired_at) - .execute(&self.pool) - .await - .map_err(util::log_err_sqlx)?; - - if ev.identity_id.is_some() || ev.revoked_at.is_some() || ev.deleted_at.is_some() { - sqlx::query(&sqls[1]) - .bind(ev.identity_id.as_ref()) - .bind(validation_id) - .bind(ev.revoked_at) - .bind(ev.deleted_at) - .execute(&self.pool) - .await - .map_err(util::log_err_sqlx)?; - } - - Ok(validation_id) - } - - async fn find_identity( - &self, - id: Option<&Uuid>, - email: Option<&str>, - ) -> anyhow::Result<Option<Identity>> { - let sqls = get_sqls(&self.sqls_root, FIND_IDENTITY); - Ok( - match sqlx::query_as::<_, Identity>(&sqls[0]) - .bind(id) - .bind(email) - .fetch_all(&self.pool) - .await - { - Ok(mut is) => match is.len() { - // if only 1 found, then that's fine - // if multiple are fond, then if they all have the same id, that's okay - 1 => { - let i = is.swap_remove(0); - match i.deleted_at { - Some(t) if t > OffsetDateTime::now_utc() => Some(i), - None => Some(i), - _ => None, - } - } - 0 => None, - _ => { - match is - .iter() - .filter(|&i| i.id != is[0].id) - .collect::<Vec<&Identity>>() - .len() - { - 0 => Some(is.swap_remove(0)), - _ => bail!(StoreError::TooManyIdentitiesFound), - } - } - }, - Err(sqlx::Error::RowNotFound) => None, - Err(e) => bail!(StoreError::SqlxError(e)), - }, - ) - } - - async fn find_identity_by_code(&self, code: &str) -> Result<Identity, StoreError> { - let sqls = get_sqls(&self.sqls_root, FIND_IDENTITY_BY_CODE); - - let rows = sqlx::query_as::<_, (i32,)>(&sqls[0]) - .bind(code) - .fetch_all(&self.pool) - .await - .map_err(util::log_err_sqlx)?; - - if rows.len() == 0 { - return Err(StoreError::CodeDoesNotExist(code.to_string())); - } - - if rows.len() != 1 { - return Err(StoreError::CodeAppearsMoreThanOnce); - } - - let identity_email_id = rows.get(0).unwrap().0; - - // TODO: IF we expand beyond email codes, then we'll need to join against a bunch of identity tables. - // but since a single code was found, only one of them should pop... - Ok(sqlx::query_as::<_, Identity>(&sqls[1]) - .bind(identity_email_id) - .fetch_one(&self.pool) - .await - .map_err(util::log_err_sqlx)?) - } - - async fn write_identity(&self, i: &Identity) -> Result<(), StoreError> { - let sqls = get_sqls(&self.sqls_root, WRITE_IDENTITY); - sqlx::query(&sqls[0]) - .bind(i.id) - .bind(i.data.clone()) - .bind(i.created_at) - .execute(&self.pool) - .await - .map_err(|e| { - error!("write_identity_failure"); - error!("{:?}", e); - e - })?; - - Ok(()) - } - async fn read_identity(&self, id: &Uuid) -> Result<Identity, StoreError> { - let identity = sqlx::query_as::<_, Identity>( - " -select identity_public_id, data, created_at from identity where identity_public_id = ?", - ) - .bind(id) - .fetch_one(&self.pool) - .await - .map_err(util::log_err_sqlx)?; - - Ok(identity) - } - - async fn write_session(&self, session: &Session) -> Result<(), StoreError> { - let sqls = get_sqls(&self.sqls_root, WRITE_SESSION); - - let secret_hash = session.secret.as_ref().map(|s| hash_secret(s)); - - sqlx::query(&sqls[0]) - .bind(&session.identity_id) - .bind(secret_hash.as_ref()) - .bind(session.created_at) - .bind(session.expired_at) - .bind(session.revoked_at) - .execute(&self.pool) - .await - .map_err(util::log_err_sqlx)?; - - Ok(()) - } - async fn read_session(&self, secret: &SessionSecret) -> Result<Session, StoreError> { - let sqls = get_sqls(&self.sqls_root, READ_SESSION); - - let secret_hash = hash_secret(secret); - let mut session = sqlx::query_as::<_, Session>(&sqls[0]) - .bind(&secret_hash[..]) - .fetch_one(&self.pool) - .await - .map_err(util::log_err_sqlx)?; - - // This should do nothing other than updated touched_at, and then - // clear the plaintext secret - session.secret = Some(secret.to_string()); - self.write_session(&session).await?; - session.secret = None; - - Ok(session) - } - - async fn write_oauth_provider(&self, provider: &OauthProvider) -> Result<(), StoreError> { - let sqls = get_sqls(&self.sqls_root, WRITE_OAUTH_PROVIDER); - sqlx::query(&sqls[0]) - .bind(&provider.name.to_string()) - .bind(&provider.flow) - .bind(&provider.base_url.to_string()) - .bind(&provider.response.to_string()) - .bind(&provider.default_scope) - .bind(&provider.client_id) - // TODO: encrypt secret before writing - .bind(&provider.client_secret) - .bind(&provider.redirect_url.to_string()) - .bind(provider.created_at) - .bind(provider.deleted_at) - .execute(&self.pool) - .await - .map_err(util::log_err_sqlx)?; - Ok(()) - } - - async fn read_oauth_provider( - &self, - provider: &OauthProviderName, - flow: Option<String>, - ) -> Result<OauthProvider, StoreError> { - let sqls = get_sqls(&self.sqls_root, READ_OAUTH_PROVIDER); - let flow = flow.unwrap_or("default".into()); - debug!("provider: {:?}, flow: {:?}", provider, flow); - // TODO: Write the generic FromRow impl for OauthProvider... - let res = sqlx::query_as::< - _, - ( - String, - String, - String, - String, - String, - String, - String, - OffsetDateTime, - Option<OffsetDateTime>, - ), - >(&sqls[0]) - .bind(&provider.to_string()) - .bind(&flow) - .fetch_one(&self.pool) - .await - .map_err(util::log_err_sqlx)?; - - debug!("res: {:?}", res); - - Ok(OauthProvider { - name: provider.clone(), - flow: Some(res.0), - base_url: Url::from_str(&res.1) - .map_err(|_| StoreError::OauthProviderDoesNotExist(*provider))?, - response: OauthResponseType::from_str(&res.2) - .map_err(|_| StoreError::OauthProviderDoesNotExist(*provider))?, - default_scope: res.3, - client_id: res.4, - client_secret: res.5, - redirect_url: Url::from_str(&res.6) - .map_err(|_| StoreError::OauthProviderDoesNotExist(*provider))?, - created_at: res.7, - deleted_at: res.8, - }) - } - async fn write_oauth_validation( - &self, - v: &OauthValidation, - ) -> anyhow::Result<ValidationRequestId> { - let sqls = get_sqls(&self.sqls_root, WRITE_OAUTH_VALIDATION); - - let validation_id = v.id.unwrap_or(Uuid::new_v4()); - sqlx::query(&sqls[0]) - .bind(validation_id) - .bind(v.oauth_provider.name.to_string()) - .bind(v.oauth_provider.flow.clone()) - .bind(v.access_token.clone()) - .bind(v.raw_response.clone()) - .bind(v.created_at) - .bind(v.validated_at) - .execute(&self.pool) - .await?; - - if v.identity_id.is_some() || v.revoked_at.is_some() || v.deleted_at.is_some() { - sqlx::query(&sqls[1]) - .bind(v.identity_id.as_ref()) - .bind(validation_id) - .bind(v.revoked_at) - .bind(v.deleted_at) - .execute(&self.pool) - .await?; - } - - Ok(validation_id) - } - async fn read_oauth_validation( - &self, - validation_id: &ValidationRequestId, - ) -> anyhow::Result<OauthValidation> { - let sqls = get_sqls(&self.sqls_root, READ_OAUTH_VALIDATION); - - let mut es = sqlx::query_as::<_, OauthValidation>(&sqls[0]) - .bind(validation_id) - .fetch_all(&self.pool) - .await?; - - if es.len() != 1 { - bail!(StoreError::OauthValidationDoesNotExist( - validation_id.clone() - )); - } - - Ok(es.swap_remove(0)) - } - async fn find_validation_type( - &self, - validation_id: &ValidationRequestId, - ) -> anyhow::Result<ValidationType> { - let sqls = get_sqls(&self.sqls_root, READ_VALIDATION_TYPE); - - let mut es = sqlx::query_as::<_, (String,)>(&sqls[0]) - .bind(validation_id) - .fetch_all(&self.pool) - .await - .map_err(util::log_err_sqlx)?; - - match es.len() { - 1 => Ok(ValidationType::from_str(&es.swap_remove(0).0)?), - _ => bail!(StoreError::Other( - "expected a single validation but recieved 0 or multiple validations".into() - )), - } - } -} - -pub struct PgClient { - sql: SqlClient<Postgres>, -} - -impl PgClient { - pub async fn new(pool: sqlx::Pool<Postgres>) -> Arc<dyn Store + Send + Sync + 'static> { - sqlx::migrate!("store/pg/migrations") - .run(&pool) - .await - .expect(ERR_MSG_MIGRATION_FAILED); - - Arc::new(PgClient { - sql: SqlClient { - pool, - sqls_root: PGSQL.to_string(), - }, - }) - } -} - -#[async_trait::async_trait] -impl Store for PgClient { - async fn write_email(&self, email_address: &str) -> Result<(), StoreError> { - self.sql.write_email(email_address).await - } - async fn find_email_validation( - &self, - validation_id: Option<&Uuid>, - code: Option<&str>, - ) -> Result<EmailValidation, StoreError> { - self.sql.find_email_validation(validation_id, code).await - } - async fn write_email_validation(&self, ev: &EmailValidation) -> anyhow::Result<Uuid> { - self.sql.write_email_validation(ev).await - } - async fn find_identity( - &self, - identity_id: Option<&Uuid>, - email: Option<&str>, - ) -> anyhow::Result<Option<Identity>> { - self.sql.find_identity(identity_id, email).await - } - async fn find_identity_by_code(&self, code: &str) -> Result<Identity, StoreError> { - self.sql.find_identity_by_code(code).await - } - async fn write_identity(&self, i: &Identity) -> Result<(), StoreError> { - self.sql.write_identity(i).await - } - async fn read_identity(&self, identity_id: &Uuid) -> Result<Identity, StoreError> { - self.sql.read_identity(identity_id).await - } - async fn write_session(&self, session: &Session) -> Result<(), StoreError> { - self.sql.write_session(session).await - } - async fn read_session(&self, secret: &SessionSecret) -> Result<Session, StoreError> { - self.sql.read_session(secret).await - } - async fn write_oauth_provider(&self, provider: &OauthProvider) -> Result<(), StoreError> { - self.sql.write_oauth_provider(provider).await - } - async fn read_oauth_provider( - &self, - provider: &OauthProviderName, - flow: Option<String>, - ) -> Result<OauthProvider, StoreError> { - self.sql.read_oauth_provider(provider, flow).await - } - async fn write_oauth_validation( - &self, - validation: &OauthValidation, - ) -> anyhow::Result<ValidationRequestId> { - self.sql.write_oauth_validation(validation).await - } - async fn read_oauth_validation( - &self, - validation_id: &ValidationRequestId, - ) -> anyhow::Result<OauthValidation> { - self.sql.read_oauth_validation(validation_id).await - } - async fn find_validation_type( - &self, - validation_id: &ValidationRequestId, - ) -> anyhow::Result<ValidationType> { - self.sql.find_validation_type(validation_id).await - } -} - -pub struct SqliteClient { - sql: SqlClient<Sqlite>, -} - -impl SqliteClient { - pub async fn new(pool: sqlx::Pool<Sqlite>) -> Arc<dyn Store + Send + Sync + 'static> { - sqlx::migrate!("store/sqlite/migrations") - .run(&pool) - .await - .expect(ERR_MSG_MIGRATION_FAILED); - - sqlx::query("pragma foreign_keys = on") - .execute(&pool) - .await - .expect( - "Failed to initialize FK pragma. File a bug at https://www.github.com/secd-lib", - ); - - Arc::new(SqliteClient { - sql: SqlClient { - pool, - sqls_root: SQLITE.to_string(), - }, - }) - } -} - -#[async_trait::async_trait] -impl Store for SqliteClient { - async fn write_email(&self, email_address: &str) -> Result<(), StoreError> { - self.sql.write_email(email_address).await - } - async fn find_email_validation( - &self, - validation_id: Option<&Uuid>, - code: Option<&str>, - ) -> Result<EmailValidation, StoreError> { - self.sql.find_email_validation(validation_id, code).await - } - async fn write_email_validation(&self, ev: &EmailValidation) -> anyhow::Result<Uuid> { - self.sql.write_email_validation(ev).await - } - async fn find_identity( - &self, - identity_id: Option<&Uuid>, - email: Option<&str>, - ) -> anyhow::Result<Option<Identity>> { - self.sql.find_identity(identity_id, email).await - } - async fn find_identity_by_code(&self, code: &str) -> Result<Identity, StoreError> { - self.sql.find_identity_by_code(code).await - } - async fn write_identity(&self, i: &Identity) -> Result<(), StoreError> { - self.sql.write_identity(i).await - } - async fn read_identity(&self, identity_id: &Uuid) -> Result<Identity, StoreError> { - self.sql.read_identity(identity_id).await - } - async fn write_session(&self, session: &Session) -> Result<(), StoreError> { - self.sql.write_session(session).await - } - async fn read_session(&self, secret: &SessionSecret) -> Result<Session, StoreError> { - self.sql.read_session(secret).await - } - async fn write_oauth_provider(&self, provider: &OauthProvider) -> Result<(), StoreError> { - self.sql.write_oauth_provider(provider).await - } - async fn read_oauth_provider( - &self, - provider: &OauthProviderName, - flow: Option<String>, - ) -> Result<OauthProvider, StoreError> { - self.sql.read_oauth_provider(provider, flow).await - } - async fn write_oauth_validation( - &self, - validation: &OauthValidation, - ) -> anyhow::Result<ValidationRequestId> { - self.sql.write_oauth_validation(validation).await - } - async fn read_oauth_validation( - &self, - validation_id: &ValidationRequestId, - ) -> anyhow::Result<OauthValidation> { - self.sql.read_oauth_validation(validation_id).await - } - async fn find_validation_type( - &self, - validation_id: &ValidationRequestId, - ) -> anyhow::Result<ValidationType> { - self.sql.find_validation_type(validation_id).await - } -} |
