aboutsummaryrefslogtreecommitdiff
path: root/crates/secd/src/client/sqldb.rs
diff options
context:
space:
mode:
authorbenj <benj@rse8.com>2022-12-24 00:43:38 -0800
committerbenj <benj@rse8.com>2022-12-24 00:43:38 -0800
commitc2268c285648ef02ece04de0d9df0813c6d70ff8 (patch)
treef84ec7ee42f97d78245f26d0c5a0c559cd35e89d /crates/secd/src/client/sqldb.rs
parentde6339da72af1d61ca5908b780977e2b037ce014 (diff)
downloadsecdiam-c2268c285648ef02ece04de0d9df0813c6d70ff8.tar
secdiam-c2268c285648ef02ece04de0d9df0813c6d70ff8.tar.gz
secdiam-c2268c285648ef02ece04de0d9df0813c6d70ff8.tar.bz2
secdiam-c2268c285648ef02ece04de0d9df0813c6d70ff8.tar.lz
secdiam-c2268c285648ef02ece04de0d9df0813c6d70ff8.tar.xz
secdiam-c2268c285648ef02ece04de0d9df0813c6d70ff8.tar.zst
secdiam-c2268c285648ef02ece04de0d9df0813c6d70ff8.zip
refactor everything with more abstraction and a nicer interface
Diffstat (limited to '')
-rw-r--r--crates/secd/src/client/sqldb.rs632
1 files changed, 0 insertions, 632 deletions
diff --git a/crates/secd/src/client/sqldb.rs b/crates/secd/src/client/sqldb.rs
deleted file mode 100644
index 6751ef6..0000000
--- a/crates/secd/src/client/sqldb.rs
+++ /dev/null
@@ -1,632 +0,0 @@
-use std::{str::FromStr, sync::Arc};
-
-use super::{
- EmailValidation, Identity, OauthProvider, OauthProviderName, OauthResponseType, Session,
- SessionSecret, Store, StoreError, ERR_MSG_MIGRATION_FAILED, FIND_EMAIL_VALIDATION,
- FIND_IDENTITY, FIND_IDENTITY_BY_CODE, PGSQL, READ_EMAIL_RAW_ID, READ_IDENTITY_RAW_ID,
- READ_OAUTH_PROVIDER, READ_OAUTH_VALIDATION, READ_SESSION, READ_VALIDATION_TYPE, SQLITE, SQLS,
- WRITE_EMAIL, WRITE_EMAIL_VALIDATION, WRITE_IDENTITY, WRITE_OAUTH_PROVIDER,
- WRITE_OAUTH_VALIDATION, WRITE_SESSION,
-};
-use crate::{util, OauthValidation, ValidationRequestId, ValidationType};
-use anyhow::bail;
-use log::{debug, error};
-use openssl::sha::Sha256;
-use sqlx::{
- self, database::HasArguments, ColumnIndex, Database, Decode, Encode, Executor, IntoArguments,
- Pool, Postgres, Sqlite, Transaction, Type,
-};
-use time::OffsetDateTime;
-use url::Url;
-use uuid::Uuid;
-
-fn get_sqls(root: &str, file: &str) -> Vec<String> {
- SQLS.get(root)
- .unwrap()
- .get(file)
- .unwrap()
- .split("--")
- .map(|p| p.to_string())
- .collect()
-}
-
-fn hash_secret(secret: &str) -> Vec<u8> {
- let mut hasher = Sha256::new();
- hasher.update(secret.as_bytes());
- hasher.finish().to_vec()
-}
-
-struct SqlClient<D>
-where
- D: sqlx::Database,
-{
- pool: sqlx::Pool<D>,
- sqls_root: String,
-}
-
-impl<D> SqlClient<D>
-where
- D: sqlx::Database,
- for<'c> <D as HasArguments<'c>>::Arguments: IntoArguments<'c, D>,
- for<'c> i64: Decode<'c, D> + Type<D>,
- for<'c> &'c str: Decode<'c, D> + Type<D>,
- for<'c> &'c str: Encode<'c, D> + Type<D>,
- for<'c> usize: ColumnIndex<<D as Database>::Row>,
- for<'c> Uuid: Decode<'c, D> + Type<D>,
- for<'c> Uuid: Encode<'c, D> + Type<D>,
- for<'c> &'c Pool<D>: Executor<'c, Database = D>,
-{
- async fn read_identity_raw_id(&self, id: &Uuid) -> Result<i64, StoreError> {
- let sqls = get_sqls(&self.sqls_root, READ_IDENTITY_RAW_ID);
-
- Ok(sqlx::query_as::<_, (i64,)>(&sqls[0])
- .bind(id)
- .fetch_one(&self.pool)
- .await
- .map_err(util::log_err_sqlx)?
- .0)
- }
-
- async fn read_email_raw_id(&self, address: &str) -> Result<i64, StoreError> {
- let sqls = get_sqls(&self.sqls_root, READ_EMAIL_RAW_ID);
-
- Ok(sqlx::query_as::<_, (i64,)>(&sqls[0])
- .bind(address)
- .fetch_one(&self.pool)
- .await
- .map_err(util::log_err_sqlx)?
- .0)
- }
-}
-
-#[async_trait::async_trait]
-impl<D> Store for SqlClient<D>
-where
- D: sqlx::Database,
- for<'c> <D as HasArguments<'c>>::Arguments: IntoArguments<'c, D>,
- for<'c> bool: Decode<'c, D> + Type<D>,
- for<'c> bool: Encode<'c, D> + Type<D>,
- for<'c> i64: Decode<'c, D> + Type<D>,
- for<'c> i64: Encode<'c, D> + Type<D>,
- for<'c> i32: Decode<'c, D> + Type<D>,
- for<'c> i32: Encode<'c, D> + Type<D>,
- for<'c> OffsetDateTime: Decode<'c, D> + Type<D>,
- for<'c> OffsetDateTime: Encode<'c, D> + Type<D>,
- for<'c> &'c str: ColumnIndex<<D as Database>::Row>,
- for<'c> &'c str: Decode<'c, D> + Type<D>,
- for<'c> &'c str: Encode<'c, D> + Type<D>,
- for<'c> Option<&'c str>: Decode<'c, D> + Type<D>,
- for<'c> Option<&'c str>: Encode<'c, D> + Type<D>,
- for<'c> String: Decode<'c, D> + Type<D>,
- for<'c> String: Encode<'c, D> + Type<D>,
- for<'c> Option<String>: Decode<'c, D> + Type<D>,
- for<'c> Option<String>: Encode<'c, D> + Type<D>,
- for<'c> OauthProviderName: Decode<'c, D> + Type<D>,
- for<'c> OauthResponseType: Decode<'c, D> + Type<D>,
- for<'c> usize: ColumnIndex<<D as Database>::Row>,
- for<'c> Uuid: Decode<'c, D> + Type<D>,
- for<'c> Uuid: Encode<'c, D> + Type<D>,
- for<'c> &'c [u8]: Encode<'c, D> + Type<D>,
- for<'c> Option<&'c Uuid>: Encode<'c, D> + Type<D>,
- for<'c> Option<&'c Vec<u8>>: Encode<'c, D> + Type<D>,
- for<'c> Option<OffsetDateTime>: Decode<'c, D> + Type<D>,
- for<'c> Option<OffsetDateTime>: Encode<'c, D> + Type<D>,
- for<'c> &'c Pool<D>: Executor<'c, Database = D>,
- for<'c> &'c mut Transaction<'c, D>: Executor<'c, Database = D>,
-{
- async fn write_email(&self, email_address: &str) -> Result<(), StoreError> {
- let sqls = get_sqls(&self.sqls_root, WRITE_EMAIL);
-
- sqlx::query(&sqls[0])
- .bind(email_address)
- .execute(&self.pool)
- .await
- .map_err(util::log_err_sqlx)?;
-
- Ok(())
- }
-
- async fn find_email_validation(
- &self,
- validation_id: Option<&Uuid>,
- code: Option<&str>,
- ) -> Result<EmailValidation, StoreError> {
- let sqls = get_sqls(&self.sqls_root, FIND_EMAIL_VALIDATION);
- let mut rows = sqlx::query_as::<_, EmailValidation>(&sqls[0])
- .bind(validation_id)
- .bind(code)
- .fetch_all(&self.pool)
- .await
- .map_err(util::log_err_sqlx)?;
-
- match rows.len() {
- 0 => Err(StoreError::NoEmailValidationFound),
- 1 => Ok(rows.swap_remove(0)),
- _ => Err(StoreError::TooManyValidations),
- }
- }
-
- async fn write_email_validation(&self, ev: &EmailValidation) -> anyhow::Result<Uuid> {
- let sqls = get_sqls(&self.sqls_root, WRITE_EMAIL_VALIDATION);
-
- let email_id = self.read_email_raw_id(&ev.email_address).await?;
- let validation_id = ev.id.unwrap_or(Uuid::new_v4());
- sqlx::query(&sqls[0])
- .bind(validation_id)
- .bind(email_id)
- .bind(&ev.code)
- .bind(ev.is_oauth_derived)
- .bind(ev.created_at)
- .bind(ev.validated_at)
- .bind(ev.expired_at)
- .execute(&self.pool)
- .await
- .map_err(util::log_err_sqlx)?;
-
- if ev.identity_id.is_some() || ev.revoked_at.is_some() || ev.deleted_at.is_some() {
- sqlx::query(&sqls[1])
- .bind(ev.identity_id.as_ref())
- .bind(validation_id)
- .bind(ev.revoked_at)
- .bind(ev.deleted_at)
- .execute(&self.pool)
- .await
- .map_err(util::log_err_sqlx)?;
- }
-
- Ok(validation_id)
- }
-
- async fn find_identity(
- &self,
- id: Option<&Uuid>,
- email: Option<&str>,
- ) -> anyhow::Result<Option<Identity>> {
- let sqls = get_sqls(&self.sqls_root, FIND_IDENTITY);
- Ok(
- match sqlx::query_as::<_, Identity>(&sqls[0])
- .bind(id)
- .bind(email)
- .fetch_all(&self.pool)
- .await
- {
- Ok(mut is) => match is.len() {
- // if only 1 found, then that's fine
- // if multiple are fond, then if they all have the same id, that's okay
- 1 => {
- let i = is.swap_remove(0);
- match i.deleted_at {
- Some(t) if t > OffsetDateTime::now_utc() => Some(i),
- None => Some(i),
- _ => None,
- }
- }
- 0 => None,
- _ => {
- match is
- .iter()
- .filter(|&i| i.id != is[0].id)
- .collect::<Vec<&Identity>>()
- .len()
- {
- 0 => Some(is.swap_remove(0)),
- _ => bail!(StoreError::TooManyIdentitiesFound),
- }
- }
- },
- Err(sqlx::Error::RowNotFound) => None,
- Err(e) => bail!(StoreError::SqlxError(e)),
- },
- )
- }
-
- async fn find_identity_by_code(&self, code: &str) -> Result<Identity, StoreError> {
- let sqls = get_sqls(&self.sqls_root, FIND_IDENTITY_BY_CODE);
-
- let rows = sqlx::query_as::<_, (i32,)>(&sqls[0])
- .bind(code)
- .fetch_all(&self.pool)
- .await
- .map_err(util::log_err_sqlx)?;
-
- if rows.len() == 0 {
- return Err(StoreError::CodeDoesNotExist(code.to_string()));
- }
-
- if rows.len() != 1 {
- return Err(StoreError::CodeAppearsMoreThanOnce);
- }
-
- let identity_email_id = rows.get(0).unwrap().0;
-
- // TODO: IF we expand beyond email codes, then we'll need to join against a bunch of identity tables.
- // but since a single code was found, only one of them should pop...
- Ok(sqlx::query_as::<_, Identity>(&sqls[1])
- .bind(identity_email_id)
- .fetch_one(&self.pool)
- .await
- .map_err(util::log_err_sqlx)?)
- }
-
- async fn write_identity(&self, i: &Identity) -> Result<(), StoreError> {
- let sqls = get_sqls(&self.sqls_root, WRITE_IDENTITY);
- sqlx::query(&sqls[0])
- .bind(i.id)
- .bind(i.data.clone())
- .bind(i.created_at)
- .execute(&self.pool)
- .await
- .map_err(|e| {
- error!("write_identity_failure");
- error!("{:?}", e);
- e
- })?;
-
- Ok(())
- }
- async fn read_identity(&self, id: &Uuid) -> Result<Identity, StoreError> {
- let identity = sqlx::query_as::<_, Identity>(
- "
-select identity_public_id, data, created_at from identity where identity_public_id = ?",
- )
- .bind(id)
- .fetch_one(&self.pool)
- .await
- .map_err(util::log_err_sqlx)?;
-
- Ok(identity)
- }
-
- async fn write_session(&self, session: &Session) -> Result<(), StoreError> {
- let sqls = get_sqls(&self.sqls_root, WRITE_SESSION);
-
- let secret_hash = session.secret.as_ref().map(|s| hash_secret(s));
-
- sqlx::query(&sqls[0])
- .bind(&session.identity_id)
- .bind(secret_hash.as_ref())
- .bind(session.created_at)
- .bind(session.expired_at)
- .bind(session.revoked_at)
- .execute(&self.pool)
- .await
- .map_err(util::log_err_sqlx)?;
-
- Ok(())
- }
- async fn read_session(&self, secret: &SessionSecret) -> Result<Session, StoreError> {
- let sqls = get_sqls(&self.sqls_root, READ_SESSION);
-
- let secret_hash = hash_secret(secret);
- let mut session = sqlx::query_as::<_, Session>(&sqls[0])
- .bind(&secret_hash[..])
- .fetch_one(&self.pool)
- .await
- .map_err(util::log_err_sqlx)?;
-
- // This should do nothing other than updated touched_at, and then
- // clear the plaintext secret
- session.secret = Some(secret.to_string());
- self.write_session(&session).await?;
- session.secret = None;
-
- Ok(session)
- }
-
- async fn write_oauth_provider(&self, provider: &OauthProvider) -> Result<(), StoreError> {
- let sqls = get_sqls(&self.sqls_root, WRITE_OAUTH_PROVIDER);
- sqlx::query(&sqls[0])
- .bind(&provider.name.to_string())
- .bind(&provider.flow)
- .bind(&provider.base_url.to_string())
- .bind(&provider.response.to_string())
- .bind(&provider.default_scope)
- .bind(&provider.client_id)
- // TODO: encrypt secret before writing
- .bind(&provider.client_secret)
- .bind(&provider.redirect_url.to_string())
- .bind(provider.created_at)
- .bind(provider.deleted_at)
- .execute(&self.pool)
- .await
- .map_err(util::log_err_sqlx)?;
- Ok(())
- }
-
- async fn read_oauth_provider(
- &self,
- provider: &OauthProviderName,
- flow: Option<String>,
- ) -> Result<OauthProvider, StoreError> {
- let sqls = get_sqls(&self.sqls_root, READ_OAUTH_PROVIDER);
- let flow = flow.unwrap_or("default".into());
- debug!("provider: {:?}, flow: {:?}", provider, flow);
- // TODO: Write the generic FromRow impl for OauthProvider...
- let res = sqlx::query_as::<
- _,
- (
- String,
- String,
- String,
- String,
- String,
- String,
- String,
- OffsetDateTime,
- Option<OffsetDateTime>,
- ),
- >(&sqls[0])
- .bind(&provider.to_string())
- .bind(&flow)
- .fetch_one(&self.pool)
- .await
- .map_err(util::log_err_sqlx)?;
-
- debug!("res: {:?}", res);
-
- Ok(OauthProvider {
- name: provider.clone(),
- flow: Some(res.0),
- base_url: Url::from_str(&res.1)
- .map_err(|_| StoreError::OauthProviderDoesNotExist(*provider))?,
- response: OauthResponseType::from_str(&res.2)
- .map_err(|_| StoreError::OauthProviderDoesNotExist(*provider))?,
- default_scope: res.3,
- client_id: res.4,
- client_secret: res.5,
- redirect_url: Url::from_str(&res.6)
- .map_err(|_| StoreError::OauthProviderDoesNotExist(*provider))?,
- created_at: res.7,
- deleted_at: res.8,
- })
- }
- async fn write_oauth_validation(
- &self,
- v: &OauthValidation,
- ) -> anyhow::Result<ValidationRequestId> {
- let sqls = get_sqls(&self.sqls_root, WRITE_OAUTH_VALIDATION);
-
- let validation_id = v.id.unwrap_or(Uuid::new_v4());
- sqlx::query(&sqls[0])
- .bind(validation_id)
- .bind(v.oauth_provider.name.to_string())
- .bind(v.oauth_provider.flow.clone())
- .bind(v.access_token.clone())
- .bind(v.raw_response.clone())
- .bind(v.created_at)
- .bind(v.validated_at)
- .execute(&self.pool)
- .await?;
-
- if v.identity_id.is_some() || v.revoked_at.is_some() || v.deleted_at.is_some() {
- sqlx::query(&sqls[1])
- .bind(v.identity_id.as_ref())
- .bind(validation_id)
- .bind(v.revoked_at)
- .bind(v.deleted_at)
- .execute(&self.pool)
- .await?;
- }
-
- Ok(validation_id)
- }
- async fn read_oauth_validation(
- &self,
- validation_id: &ValidationRequestId,
- ) -> anyhow::Result<OauthValidation> {
- let sqls = get_sqls(&self.sqls_root, READ_OAUTH_VALIDATION);
-
- let mut es = sqlx::query_as::<_, OauthValidation>(&sqls[0])
- .bind(validation_id)
- .fetch_all(&self.pool)
- .await?;
-
- if es.len() != 1 {
- bail!(StoreError::OauthValidationDoesNotExist(
- validation_id.clone()
- ));
- }
-
- Ok(es.swap_remove(0))
- }
- async fn find_validation_type(
- &self,
- validation_id: &ValidationRequestId,
- ) -> anyhow::Result<ValidationType> {
- let sqls = get_sqls(&self.sqls_root, READ_VALIDATION_TYPE);
-
- let mut es = sqlx::query_as::<_, (String,)>(&sqls[0])
- .bind(validation_id)
- .fetch_all(&self.pool)
- .await
- .map_err(util::log_err_sqlx)?;
-
- match es.len() {
- 1 => Ok(ValidationType::from_str(&es.swap_remove(0).0)?),
- _ => bail!(StoreError::Other(
- "expected a single validation but recieved 0 or multiple validations".into()
- )),
- }
- }
-}
-
-pub struct PgClient {
- sql: SqlClient<Postgres>,
-}
-
-impl PgClient {
- pub async fn new(pool: sqlx::Pool<Postgres>) -> Arc<dyn Store + Send + Sync + 'static> {
- sqlx::migrate!("store/pg/migrations")
- .run(&pool)
- .await
- .expect(ERR_MSG_MIGRATION_FAILED);
-
- Arc::new(PgClient {
- sql: SqlClient {
- pool,
- sqls_root: PGSQL.to_string(),
- },
- })
- }
-}
-
-#[async_trait::async_trait]
-impl Store for PgClient {
- async fn write_email(&self, email_address: &str) -> Result<(), StoreError> {
- self.sql.write_email(email_address).await
- }
- async fn find_email_validation(
- &self,
- validation_id: Option<&Uuid>,
- code: Option<&str>,
- ) -> Result<EmailValidation, StoreError> {
- self.sql.find_email_validation(validation_id, code).await
- }
- async fn write_email_validation(&self, ev: &EmailValidation) -> anyhow::Result<Uuid> {
- self.sql.write_email_validation(ev).await
- }
- async fn find_identity(
- &self,
- identity_id: Option<&Uuid>,
- email: Option<&str>,
- ) -> anyhow::Result<Option<Identity>> {
- self.sql.find_identity(identity_id, email).await
- }
- async fn find_identity_by_code(&self, code: &str) -> Result<Identity, StoreError> {
- self.sql.find_identity_by_code(code).await
- }
- async fn write_identity(&self, i: &Identity) -> Result<(), StoreError> {
- self.sql.write_identity(i).await
- }
- async fn read_identity(&self, identity_id: &Uuid) -> Result<Identity, StoreError> {
- self.sql.read_identity(identity_id).await
- }
- async fn write_session(&self, session: &Session) -> Result<(), StoreError> {
- self.sql.write_session(session).await
- }
- async fn read_session(&self, secret: &SessionSecret) -> Result<Session, StoreError> {
- self.sql.read_session(secret).await
- }
- async fn write_oauth_provider(&self, provider: &OauthProvider) -> Result<(), StoreError> {
- self.sql.write_oauth_provider(provider).await
- }
- async fn read_oauth_provider(
- &self,
- provider: &OauthProviderName,
- flow: Option<String>,
- ) -> Result<OauthProvider, StoreError> {
- self.sql.read_oauth_provider(provider, flow).await
- }
- async fn write_oauth_validation(
- &self,
- validation: &OauthValidation,
- ) -> anyhow::Result<ValidationRequestId> {
- self.sql.write_oauth_validation(validation).await
- }
- async fn read_oauth_validation(
- &self,
- validation_id: &ValidationRequestId,
- ) -> anyhow::Result<OauthValidation> {
- self.sql.read_oauth_validation(validation_id).await
- }
- async fn find_validation_type(
- &self,
- validation_id: &ValidationRequestId,
- ) -> anyhow::Result<ValidationType> {
- self.sql.find_validation_type(validation_id).await
- }
-}
-
-pub struct SqliteClient {
- sql: SqlClient<Sqlite>,
-}
-
-impl SqliteClient {
- pub async fn new(pool: sqlx::Pool<Sqlite>) -> Arc<dyn Store + Send + Sync + 'static> {
- sqlx::migrate!("store/sqlite/migrations")
- .run(&pool)
- .await
- .expect(ERR_MSG_MIGRATION_FAILED);
-
- sqlx::query("pragma foreign_keys = on")
- .execute(&pool)
- .await
- .expect(
- "Failed to initialize FK pragma. File a bug at https://www.github.com/secd-lib",
- );
-
- Arc::new(SqliteClient {
- sql: SqlClient {
- pool,
- sqls_root: SQLITE.to_string(),
- },
- })
- }
-}
-
-#[async_trait::async_trait]
-impl Store for SqliteClient {
- async fn write_email(&self, email_address: &str) -> Result<(), StoreError> {
- self.sql.write_email(email_address).await
- }
- async fn find_email_validation(
- &self,
- validation_id: Option<&Uuid>,
- code: Option<&str>,
- ) -> Result<EmailValidation, StoreError> {
- self.sql.find_email_validation(validation_id, code).await
- }
- async fn write_email_validation(&self, ev: &EmailValidation) -> anyhow::Result<Uuid> {
- self.sql.write_email_validation(ev).await
- }
- async fn find_identity(
- &self,
- identity_id: Option<&Uuid>,
- email: Option<&str>,
- ) -> anyhow::Result<Option<Identity>> {
- self.sql.find_identity(identity_id, email).await
- }
- async fn find_identity_by_code(&self, code: &str) -> Result<Identity, StoreError> {
- self.sql.find_identity_by_code(code).await
- }
- async fn write_identity(&self, i: &Identity) -> Result<(), StoreError> {
- self.sql.write_identity(i).await
- }
- async fn read_identity(&self, identity_id: &Uuid) -> Result<Identity, StoreError> {
- self.sql.read_identity(identity_id).await
- }
- async fn write_session(&self, session: &Session) -> Result<(), StoreError> {
- self.sql.write_session(session).await
- }
- async fn read_session(&self, secret: &SessionSecret) -> Result<Session, StoreError> {
- self.sql.read_session(secret).await
- }
- async fn write_oauth_provider(&self, provider: &OauthProvider) -> Result<(), StoreError> {
- self.sql.write_oauth_provider(provider).await
- }
- async fn read_oauth_provider(
- &self,
- provider: &OauthProviderName,
- flow: Option<String>,
- ) -> Result<OauthProvider, StoreError> {
- self.sql.read_oauth_provider(provider, flow).await
- }
- async fn write_oauth_validation(
- &self,
- validation: &OauthValidation,
- ) -> anyhow::Result<ValidationRequestId> {
- self.sql.write_oauth_validation(validation).await
- }
- async fn read_oauth_validation(
- &self,
- validation_id: &ValidationRequestId,
- ) -> anyhow::Result<OauthValidation> {
- self.sql.read_oauth_validation(validation_id).await
- }
- async fn find_validation_type(
- &self,
- validation_id: &ValidationRequestId,
- ) -> anyhow::Result<ValidationType> {
- self.sql.find_validation_type(validation_id).await
- }
-}