From 2c4eb2d311919ad9fb70738199ecf99bf20c9fce Mon Sep 17 00:00:00 2001 From: benj Date: Thu, 1 Dec 2022 10:30:34 -0800 Subject: - basic functionality with psql and sqlite - cli helper tool --- crates/secd/src/client/sqldb.rs | 424 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 424 insertions(+) create mode 100644 crates/secd/src/client/sqldb.rs (limited to 'crates/secd/src/client/sqldb.rs') diff --git a/crates/secd/src/client/sqldb.rs b/crates/secd/src/client/sqldb.rs new file mode 100644 index 0000000..6048c48 --- /dev/null +++ b/crates/secd/src/client/sqldb.rs @@ -0,0 +1,424 @@ +use std::sync::Arc; + +use super::{ + EmailValidation, Identity, Session, SessionSecret, Store, StoreError, ERR_MSG_MIGRATION_FAILED, + FIND_EMAIL_VALIDATION, FIND_IDENTITY, FIND_IDENTITY_BY_CODE, PGSQL, READ_EMAIL_RAW_ID, + READ_IDENTITY_RAW_ID, READ_SESSION, SQLITE, SQLS, WRITE_EMAIL, WRITE_EMAIL_VALIDATION, + WRITE_IDENTITY, WRITE_SESSION, +}; +use crate::util; +use log::error; +use openssl::sha::Sha256; +use sqlx::{ + self, database::HasArguments, ColumnIndex, Database, Decode, Encode, Executor, IntoArguments, + Pool, Postgres, Sqlite, Transaction, Type, +}; +use time::OffsetDateTime; +use uuid::Uuid; + +fn get_sqls(root: &str, file: &str) -> Vec { + SQLS.get(root) + .unwrap() + .get(file) + .unwrap() + .split("--") + .map(|p| p.to_string()) + .collect() +} + +fn hash_secret(secret: &str) -> Vec { + let mut hasher = Sha256::new(); + hasher.update(secret.as_bytes()); + hasher.finish().to_vec() +} + +struct SqlClient +where + D: sqlx::Database, +{ + pool: sqlx::Pool, + sqls_root: String, +} + +impl SqlClient +where + D: sqlx::Database, + for<'c> >::Arguments: IntoArguments<'c, D>, + for<'c> i64: Decode<'c, D> + Type, + for<'c> &'c str: Decode<'c, D> + Type, + for<'c> &'c str: Encode<'c, D> + Type, + for<'c> usize: ColumnIndex<::Row>, + for<'c> Uuid: Decode<'c, D> + Type, + for<'c> Uuid: Encode<'c, D> + Type, + for<'c> &'c Pool: Executor<'c, Database = D>, +{ + async fn read_identity_raw_id(&self, id: &Uuid) -> Result { + let sqls = get_sqls(&self.sqls_root, READ_IDENTITY_RAW_ID); + + Ok(sqlx::query_as::<_, (i64,)>(&sqls[0]) + .bind(id) + .fetch_one(&self.pool) + .await + .map_err(util::log_err_sqlx)? + .0) + } + + async fn read_email_raw_id(&self, address: &str) -> Result { + let sqls = get_sqls(&self.sqls_root, READ_EMAIL_RAW_ID); + + Ok(sqlx::query_as::<_, (i64,)>(&sqls[0]) + .bind(address) + .fetch_one(&self.pool) + .await + .map_err(util::log_err_sqlx)? + .0) + } +} + +#[async_trait::async_trait] +impl Store for SqlClient +where + D: sqlx::Database, + for<'c> >::Arguments: IntoArguments<'c, D>, + for<'c> bool: Decode<'c, D> + Type, + for<'c> bool: Encode<'c, D> + Type, + for<'c> i64: Decode<'c, D> + Type, + for<'c> i64: Encode<'c, D> + Type, + for<'c> i32: Decode<'c, D> + Type, + for<'c> i32: Encode<'c, D> + Type, + for<'c> OffsetDateTime: Decode<'c, D> + Type, + for<'c> OffsetDateTime: Encode<'c, D> + Type, + for<'c> &'c str: ColumnIndex<::Row>, + for<'c> &'c str: Decode<'c, D> + Type, + for<'c> &'c str: Encode<'c, D> + Type, + for<'c> Option<&'c str>: Decode<'c, D> + Type, + for<'c> Option<&'c str>: Encode<'c, D> + Type, + for<'c> String: Decode<'c, D> + Type, + for<'c> String: Encode<'c, D> + Type, + for<'c> Option: Decode<'c, D> + Type, + for<'c> Option: Encode<'c, D> + Type, + for<'c> usize: ColumnIndex<::Row>, + for<'c> Uuid: Decode<'c, D> + Type, + for<'c> Uuid: Encode<'c, D> + Type, + for<'c> &'c [u8]: Encode<'c, D> + Type, + for<'c> Option<&'c Uuid>: Encode<'c, D> + Type, + for<'c> Option<&'c Vec>: Encode<'c, D> + Type, + for<'c> Option: Decode<'c, D> + Type, + for<'c> Option: Encode<'c, D> + Type, + for<'c> &'c Pool: Executor<'c, Database = D>, + for<'c> &'c mut Transaction<'c, D>: Executor<'c, Database = D>, +{ + async fn write_email(&self, identity_id: Uuid, email_address: &str) -> Result<(), StoreError> { + let sqls = get_sqls(&self.sqls_root, WRITE_EMAIL); + + let identity_id = self.read_identity_raw_id(&identity_id).await?; + + let email_id: (i64,) = match sqlx::query_as(&sqls[0]) + .bind(email_address) + .fetch_one(&self.pool) + .await + { + Ok(i) => i, + Err(sqlx::Error::RowNotFound) => sqlx::query_as::<_, (i64,)>(&sqls[1]) + .bind(email_address) + .fetch_one(&self.pool) + .await + .map_err(util::log_err_sqlx)?, + Err(e) => return Err(StoreError::SqlxError(e)), + }; + + sqlx::query(&sqls[2]) + .bind(identity_id) + .bind(email_id.0) + .bind(OffsetDateTime::now_utc()) + .execute(&self.pool) + .await + .map_err(util::log_err_sqlx)?; + + Ok(()) + } + + async fn find_email_validation( + &self, + validation_id: Option<&Uuid>, + code: Option<&str>, + ) -> Result { + let sqls = get_sqls(&self.sqls_root, FIND_EMAIL_VALIDATION); + let mut rows = sqlx::query_as::<_, EmailValidation>(&sqls[0]) + .bind(validation_id) + .bind(code) + .fetch_all(&self.pool) + .await + .map_err(util::log_err_sqlx)?; + + match rows.len() { + 0 => Err(StoreError::NoEmailValidationFound), + 1 => Ok(rows.swap_remove(0)), + _ => Err(StoreError::TooManyEmailValidations), + } + } + + async fn write_email_validation(&self, ev: &EmailValidation) -> Result { + let sqls = get_sqls(&self.sqls_root, WRITE_EMAIL_VALIDATION); + + let identity_id = self + .read_identity_raw_id( + &ev.identity_id + .ok_or(StoreError::IdentityIdMustExistInvariant)?, + ) + .await?; + let email_id = self.read_email_raw_id(&ev.email_address).await?; + + let new_id = Uuid::new_v4(); + sqlx::query(&sqls[0]) + .bind(ev.id.unwrap_or(new_id)) + .bind(identity_id) + .bind(email_id) + .bind(ev.attempts) + .bind(&ev.code) + .bind(ev.is_validated) + .bind(ev.created_at) + .bind(ev.expires_at) + .execute(&self.pool) + .await + .map_err(util::log_err_sqlx)?; + + Ok(new_id) + } + + async fn find_identity( + &self, + id: Option<&Uuid>, + email: Option<&str>, + ) -> Result, StoreError> { + let sqls = get_sqls(&self.sqls_root, FIND_IDENTITY); + Ok( + match sqlx::query_as::<_, Identity>(&sqls[0]) + .bind(id) + .bind(email) + .fetch_one(&self.pool) + .await + { + Ok(i) => Some(i), + Err(sqlx::Error::RowNotFound) => None, + Err(e) => return Err(StoreError::SqlxError(e)), + }, + ) + } + async fn find_identity_by_code(&self, code: &str) -> Result { + let sqls = get_sqls(&self.sqls_root, FIND_IDENTITY_BY_CODE); + + let rows = sqlx::query_as::<_, (i32,)>(&sqls[0]) + .bind(code) + .fetch_all(&self.pool) + .await + .map_err(util::log_err_sqlx)?; + + if rows.len() == 0 { + return Err(StoreError::CodeDoesNotExist(code.to_string())); + } + + if rows.len() != 1 { + return Err(StoreError::CodeAppearsMoreThanOnce); + } + + let identity_email_id = rows.get(0).unwrap().0; + + // TODO: IF we expand beyond email codes, then we'll need to join against a bunch of identity tables. + // but since a single code was found, only one of them should pop... + Ok(sqlx::query_as::<_, Identity>(&sqls[1]) + .bind(identity_email_id) + .fetch_one(&self.pool) + .await + .map_err(util::log_err_sqlx)?) + } + + async fn write_identity(&self, i: &Identity) -> Result<(), StoreError> { + let sqls = get_sqls(&self.sqls_root, WRITE_IDENTITY); + sqlx::query(&sqls[0]) + .bind(i.id) + .bind(i.data.clone()) + .bind(i.created_at) + .execute(&self.pool) + .await + .map_err(|e| { + error!("write_identity_failure"); + error!("{:?}", e); + e + })?; + + Ok(()) + } + async fn read_identity(&self, id: &Uuid) -> Result { + Ok(sqlx::query_as::<_, Identity>( + " +select identity_public_id, data, created_at from identity where identity_public_id = ?", + ) + .bind(id) + .fetch_one(&self.pool) + .await + .map_err(util::log_err_sqlx)?) + } + + async fn write_session(&self, session: &Session) -> Result<(), StoreError> { + let sqls = get_sqls(&self.sqls_root, WRITE_SESSION); + + let secret_hash = session.secret.as_ref().map(|s| hash_secret(s)); + + sqlx::query(&sqls[0]) + .bind(&session.identity_id) + .bind(secret_hash.as_ref()) + .bind(session.created_at) + .bind(OffsetDateTime::now_utc()) + .bind(session.expires_at) + .bind(session.revoked_at) + .execute(&self.pool) + .await + .map_err(util::log_err_sqlx)?; + + Ok(()) + } + async fn read_session(&self, secret: &SessionSecret) -> Result { + let sqls = get_sqls(&self.sqls_root, READ_SESSION); + + let secret_hash = hash_secret(secret); + let mut session = sqlx::query_as::<_, Session>(&sqls[0]) + .bind(&secret_hash[..]) + .fetch_one(&self.pool) + .await + .map_err(util::log_err_sqlx)?; + + // This should do nothing other than updated touched_at, and then + // clear the plaintext secret + session.secret = Some(secret.to_string()); + self.write_session(&session).await?; + session.secret = None; + + Ok(session) + } +} + +pub struct PgClient { + sql: SqlClient, +} + +impl PgClient { + pub async fn new(pool: sqlx::Pool) -> Arc { + sqlx::migrate!("store/pg/migrations") + .run(&pool) + .await + .expect(ERR_MSG_MIGRATION_FAILED); + + Arc::new(PgClient { + sql: SqlClient { + pool, + sqls_root: PGSQL.to_string(), + }, + }) + } +} + +#[async_trait::async_trait] +impl Store for PgClient { + async fn write_email(&self, identity_id: Uuid, email_address: &str) -> Result<(), StoreError> { + self.sql.write_email(identity_id, email_address).await + } + async fn find_email_validation( + &self, + validation_id: Option<&Uuid>, + code: Option<&str>, + ) -> Result { + self.sql.find_email_validation(validation_id, code).await + } + async fn write_email_validation(&self, ev: &EmailValidation) -> Result { + self.sql.write_email_validation(ev).await + } + async fn find_identity( + &self, + identity_id: Option<&Uuid>, + email: Option<&str>, + ) -> Result, StoreError> { + self.sql.find_identity(identity_id, email).await + } + async fn find_identity_by_code(&self, code: &str) -> Result { + self.sql.find_identity_by_code(code).await + } + async fn write_identity(&self, i: &Identity) -> Result<(), StoreError> { + self.sql.write_identity(i).await + } + async fn read_identity(&self, identity_id: &Uuid) -> Result { + self.sql.read_identity(identity_id).await + } + async fn write_session(&self, session: &Session) -> Result<(), StoreError> { + self.sql.write_session(session).await + } + async fn read_session(&self, secret: &SessionSecret) -> Result { + self.sql.read_session(secret).await + } +} + +pub struct SqliteClient { + sql: SqlClient, +} + +impl SqliteClient { + pub async fn new(pool: sqlx::Pool) -> Arc { + sqlx::migrate!("store/sqlite/migrations") + .run(&pool) + .await + .expect(ERR_MSG_MIGRATION_FAILED); + + sqlx::query("pragma foreign_keys = on") + .execute(&pool) + .await + .expect( + "Failed to initialize FK pragma. File a bug at https://www.github.com/secd-lib", + ); + + Arc::new(SqliteClient { + sql: SqlClient { + pool, + sqls_root: SQLITE.to_string(), + }, + }) + } +} + +#[async_trait::async_trait] +impl Store for SqliteClient { + async fn write_email(&self, identity_id: Uuid, email_address: &str) -> Result<(), StoreError> { + self.sql.write_email(identity_id, email_address).await + } + async fn find_email_validation( + &self, + validation_id: Option<&Uuid>, + code: Option<&str>, + ) -> Result { + self.sql.find_email_validation(validation_id, code).await + } + async fn write_email_validation(&self, ev: &EmailValidation) -> Result { + self.sql.write_email_validation(ev).await + } + async fn find_identity( + &self, + identity_id: Option<&Uuid>, + email: Option<&str>, + ) -> Result, StoreError> { + self.sql.find_identity(identity_id, email).await + } + async fn find_identity_by_code(&self, code: &str) -> Result { + self.sql.find_identity_by_code(code).await + } + async fn write_identity(&self, i: &Identity) -> Result<(), StoreError> { + self.sql.write_identity(i).await + } + async fn read_identity(&self, identity_id: &Uuid) -> Result { + self.sql.read_identity(identity_id).await + } + async fn write_session(&self, session: &Session) -> Result<(), StoreError> { + self.sql.write_session(session).await + } + async fn read_session(&self, secret: &SessionSecret) -> Result { + self.sql.read_session(secret).await + } +} -- cgit v1.2.3