aboutsummaryrefslogtreecommitdiff
path: root/crates/secd/src/client/sqldb.rs
diff options
context:
space:
mode:
Diffstat (limited to 'crates/secd/src/client/sqldb.rs')
-rw-r--r--crates/secd/src/client/sqldb.rs424
1 files changed, 424 insertions, 0 deletions
diff --git a/crates/secd/src/client/sqldb.rs b/crates/secd/src/client/sqldb.rs
new file mode 100644
index 0000000..6048c48
--- /dev/null
+++ b/crates/secd/src/client/sqldb.rs
@@ -0,0 +1,424 @@
+use std::sync::Arc;
+
+use super::{
+ EmailValidation, Identity, Session, SessionSecret, Store, StoreError, ERR_MSG_MIGRATION_FAILED,
+ FIND_EMAIL_VALIDATION, FIND_IDENTITY, FIND_IDENTITY_BY_CODE, PGSQL, READ_EMAIL_RAW_ID,
+ READ_IDENTITY_RAW_ID, READ_SESSION, SQLITE, SQLS, WRITE_EMAIL, WRITE_EMAIL_VALIDATION,
+ WRITE_IDENTITY, WRITE_SESSION,
+};
+use crate::util;
+use log::error;
+use openssl::sha::Sha256;
+use sqlx::{
+ self, database::HasArguments, ColumnIndex, Database, Decode, Encode, Executor, IntoArguments,
+ Pool, Postgres, Sqlite, Transaction, Type,
+};
+use time::OffsetDateTime;
+use uuid::Uuid;
+
+fn get_sqls(root: &str, file: &str) -> Vec<String> {
+ SQLS.get(root)
+ .unwrap()
+ .get(file)
+ .unwrap()
+ .split("--")
+ .map(|p| p.to_string())
+ .collect()
+}
+
+fn hash_secret(secret: &str) -> Vec<u8> {
+ let mut hasher = Sha256::new();
+ hasher.update(secret.as_bytes());
+ hasher.finish().to_vec()
+}
+
+struct SqlClient<D>
+where
+ D: sqlx::Database,
+{
+ pool: sqlx::Pool<D>,
+ sqls_root: String,
+}
+
+impl<D> SqlClient<D>
+where
+ D: sqlx::Database,
+ for<'c> <D as HasArguments<'c>>::Arguments: IntoArguments<'c, D>,
+ for<'c> i64: Decode<'c, D> + Type<D>,
+ for<'c> &'c str: Decode<'c, D> + Type<D>,
+ for<'c> &'c str: Encode<'c, D> + Type<D>,
+ for<'c> usize: ColumnIndex<<D as Database>::Row>,
+ for<'c> Uuid: Decode<'c, D> + Type<D>,
+ for<'c> Uuid: Encode<'c, D> + Type<D>,
+ for<'c> &'c Pool<D>: Executor<'c, Database = D>,
+{
+ async fn read_identity_raw_id(&self, id: &Uuid) -> Result<i64, StoreError> {
+ let sqls = get_sqls(&self.sqls_root, READ_IDENTITY_RAW_ID);
+
+ Ok(sqlx::query_as::<_, (i64,)>(&sqls[0])
+ .bind(id)
+ .fetch_one(&self.pool)
+ .await
+ .map_err(util::log_err_sqlx)?
+ .0)
+ }
+
+ async fn read_email_raw_id(&self, address: &str) -> Result<i64, StoreError> {
+ let sqls = get_sqls(&self.sqls_root, READ_EMAIL_RAW_ID);
+
+ Ok(sqlx::query_as::<_, (i64,)>(&sqls[0])
+ .bind(address)
+ .fetch_one(&self.pool)
+ .await
+ .map_err(util::log_err_sqlx)?
+ .0)
+ }
+}
+
+#[async_trait::async_trait]
+impl<D> Store for SqlClient<D>
+where
+ D: sqlx::Database,
+ for<'c> <D as HasArguments<'c>>::Arguments: IntoArguments<'c, D>,
+ for<'c> bool: Decode<'c, D> + Type<D>,
+ for<'c> bool: Encode<'c, D> + Type<D>,
+ for<'c> i64: Decode<'c, D> + Type<D>,
+ for<'c> i64: Encode<'c, D> + Type<D>,
+ for<'c> i32: Decode<'c, D> + Type<D>,
+ for<'c> i32: Encode<'c, D> + Type<D>,
+ for<'c> OffsetDateTime: Decode<'c, D> + Type<D>,
+ for<'c> OffsetDateTime: Encode<'c, D> + Type<D>,
+ for<'c> &'c str: ColumnIndex<<D as Database>::Row>,
+ for<'c> &'c str: Decode<'c, D> + Type<D>,
+ for<'c> &'c str: Encode<'c, D> + Type<D>,
+ for<'c> Option<&'c str>: Decode<'c, D> + Type<D>,
+ for<'c> Option<&'c str>: Encode<'c, D> + Type<D>,
+ for<'c> String: Decode<'c, D> + Type<D>,
+ for<'c> String: Encode<'c, D> + Type<D>,
+ for<'c> Option<String>: Decode<'c, D> + Type<D>,
+ for<'c> Option<String>: Encode<'c, D> + Type<D>,
+ for<'c> usize: ColumnIndex<<D as Database>::Row>,
+ for<'c> Uuid: Decode<'c, D> + Type<D>,
+ for<'c> Uuid: Encode<'c, D> + Type<D>,
+ for<'c> &'c [u8]: Encode<'c, D> + Type<D>,
+ for<'c> Option<&'c Uuid>: Encode<'c, D> + Type<D>,
+ for<'c> Option<&'c Vec<u8>>: Encode<'c, D> + Type<D>,
+ for<'c> Option<OffsetDateTime>: Decode<'c, D> + Type<D>,
+ for<'c> Option<OffsetDateTime>: Encode<'c, D> + Type<D>,
+ for<'c> &'c Pool<D>: Executor<'c, Database = D>,
+ for<'c> &'c mut Transaction<'c, D>: Executor<'c, Database = D>,
+{
+ async fn write_email(&self, identity_id: Uuid, email_address: &str) -> Result<(), StoreError> {
+ let sqls = get_sqls(&self.sqls_root, WRITE_EMAIL);
+
+ let identity_id = self.read_identity_raw_id(&identity_id).await?;
+
+ let email_id: (i64,) = match sqlx::query_as(&sqls[0])
+ .bind(email_address)
+ .fetch_one(&self.pool)
+ .await
+ {
+ Ok(i) => i,
+ Err(sqlx::Error::RowNotFound) => sqlx::query_as::<_, (i64,)>(&sqls[1])
+ .bind(email_address)
+ .fetch_one(&self.pool)
+ .await
+ .map_err(util::log_err_sqlx)?,
+ Err(e) => return Err(StoreError::SqlxError(e)),
+ };
+
+ sqlx::query(&sqls[2])
+ .bind(identity_id)
+ .bind(email_id.0)
+ .bind(OffsetDateTime::now_utc())
+ .execute(&self.pool)
+ .await
+ .map_err(util::log_err_sqlx)?;
+
+ Ok(())
+ }
+
+ async fn find_email_validation(
+ &self,
+ validation_id: Option<&Uuid>,
+ code: Option<&str>,
+ ) -> Result<EmailValidation, StoreError> {
+ let sqls = get_sqls(&self.sqls_root, FIND_EMAIL_VALIDATION);
+ let mut rows = sqlx::query_as::<_, EmailValidation>(&sqls[0])
+ .bind(validation_id)
+ .bind(code)
+ .fetch_all(&self.pool)
+ .await
+ .map_err(util::log_err_sqlx)?;
+
+ match rows.len() {
+ 0 => Err(StoreError::NoEmailValidationFound),
+ 1 => Ok(rows.swap_remove(0)),
+ _ => Err(StoreError::TooManyEmailValidations),
+ }
+ }
+
+ async fn write_email_validation(&self, ev: &EmailValidation) -> Result<Uuid, StoreError> {
+ let sqls = get_sqls(&self.sqls_root, WRITE_EMAIL_VALIDATION);
+
+ let identity_id = self
+ .read_identity_raw_id(
+ &ev.identity_id
+ .ok_or(StoreError::IdentityIdMustExistInvariant)?,
+ )
+ .await?;
+ let email_id = self.read_email_raw_id(&ev.email_address).await?;
+
+ let new_id = Uuid::new_v4();
+ sqlx::query(&sqls[0])
+ .bind(ev.id.unwrap_or(new_id))
+ .bind(identity_id)
+ .bind(email_id)
+ .bind(ev.attempts)
+ .bind(&ev.code)
+ .bind(ev.is_validated)
+ .bind(ev.created_at)
+ .bind(ev.expires_at)
+ .execute(&self.pool)
+ .await
+ .map_err(util::log_err_sqlx)?;
+
+ Ok(new_id)
+ }
+
+ async fn find_identity(
+ &self,
+ id: Option<&Uuid>,
+ email: Option<&str>,
+ ) -> Result<Option<Identity>, StoreError> {
+ let sqls = get_sqls(&self.sqls_root, FIND_IDENTITY);
+ Ok(
+ match sqlx::query_as::<_, Identity>(&sqls[0])
+ .bind(id)
+ .bind(email)
+ .fetch_one(&self.pool)
+ .await
+ {
+ Ok(i) => Some(i),
+ Err(sqlx::Error::RowNotFound) => None,
+ Err(e) => return Err(StoreError::SqlxError(e)),
+ },
+ )
+ }
+ async fn find_identity_by_code(&self, code: &str) -> Result<Identity, StoreError> {
+ let sqls = get_sqls(&self.sqls_root, FIND_IDENTITY_BY_CODE);
+
+ let rows = sqlx::query_as::<_, (i32,)>(&sqls[0])
+ .bind(code)
+ .fetch_all(&self.pool)
+ .await
+ .map_err(util::log_err_sqlx)?;
+
+ if rows.len() == 0 {
+ return Err(StoreError::CodeDoesNotExist(code.to_string()));
+ }
+
+ if rows.len() != 1 {
+ return Err(StoreError::CodeAppearsMoreThanOnce);
+ }
+
+ let identity_email_id = rows.get(0).unwrap().0;
+
+ // TODO: IF we expand beyond email codes, then we'll need to join against a bunch of identity tables.
+ // but since a single code was found, only one of them should pop...
+ Ok(sqlx::query_as::<_, Identity>(&sqls[1])
+ .bind(identity_email_id)
+ .fetch_one(&self.pool)
+ .await
+ .map_err(util::log_err_sqlx)?)
+ }
+
+ async fn write_identity(&self, i: &Identity) -> Result<(), StoreError> {
+ let sqls = get_sqls(&self.sqls_root, WRITE_IDENTITY);
+ sqlx::query(&sqls[0])
+ .bind(i.id)
+ .bind(i.data.clone())
+ .bind(i.created_at)
+ .execute(&self.pool)
+ .await
+ .map_err(|e| {
+ error!("write_identity_failure");
+ error!("{:?}", e);
+ e
+ })?;
+
+ Ok(())
+ }
+ async fn read_identity(&self, id: &Uuid) -> Result<Identity, StoreError> {
+ Ok(sqlx::query_as::<_, Identity>(
+ "
+select identity_public_id, data, created_at from identity where identity_public_id = ?",
+ )
+ .bind(id)
+ .fetch_one(&self.pool)
+ .await
+ .map_err(util::log_err_sqlx)?)
+ }
+
+ async fn write_session(&self, session: &Session) -> Result<(), StoreError> {
+ let sqls = get_sqls(&self.sqls_root, WRITE_SESSION);
+
+ let secret_hash = session.secret.as_ref().map(|s| hash_secret(s));
+
+ sqlx::query(&sqls[0])
+ .bind(&session.identity_id)
+ .bind(secret_hash.as_ref())
+ .bind(session.created_at)
+ .bind(OffsetDateTime::now_utc())
+ .bind(session.expires_at)
+ .bind(session.revoked_at)
+ .execute(&self.pool)
+ .await
+ .map_err(util::log_err_sqlx)?;
+
+ Ok(())
+ }
+ async fn read_session(&self, secret: &SessionSecret) -> Result<Session, StoreError> {
+ let sqls = get_sqls(&self.sqls_root, READ_SESSION);
+
+ let secret_hash = hash_secret(secret);
+ let mut session = sqlx::query_as::<_, Session>(&sqls[0])
+ .bind(&secret_hash[..])
+ .fetch_one(&self.pool)
+ .await
+ .map_err(util::log_err_sqlx)?;
+
+ // This should do nothing other than updated touched_at, and then
+ // clear the plaintext secret
+ session.secret = Some(secret.to_string());
+ self.write_session(&session).await?;
+ session.secret = None;
+
+ Ok(session)
+ }
+}
+
+pub struct PgClient {
+ sql: SqlClient<Postgres>,
+}
+
+impl PgClient {
+ pub async fn new(pool: sqlx::Pool<Postgres>) -> Arc<dyn Store + Send + Sync + 'static> {
+ sqlx::migrate!("store/pg/migrations")
+ .run(&pool)
+ .await
+ .expect(ERR_MSG_MIGRATION_FAILED);
+
+ Arc::new(PgClient {
+ sql: SqlClient {
+ pool,
+ sqls_root: PGSQL.to_string(),
+ },
+ })
+ }
+}
+
+#[async_trait::async_trait]
+impl Store for PgClient {
+ async fn write_email(&self, identity_id: Uuid, email_address: &str) -> Result<(), StoreError> {
+ self.sql.write_email(identity_id, email_address).await
+ }
+ async fn find_email_validation(
+ &self,
+ validation_id: Option<&Uuid>,
+ code: Option<&str>,
+ ) -> Result<EmailValidation, StoreError> {
+ self.sql.find_email_validation(validation_id, code).await
+ }
+ async fn write_email_validation(&self, ev: &EmailValidation) -> Result<Uuid, StoreError> {
+ self.sql.write_email_validation(ev).await
+ }
+ async fn find_identity(
+ &self,
+ identity_id: Option<&Uuid>,
+ email: Option<&str>,
+ ) -> Result<Option<Identity>, StoreError> {
+ self.sql.find_identity(identity_id, email).await
+ }
+ async fn find_identity_by_code(&self, code: &str) -> Result<Identity, StoreError> {
+ self.sql.find_identity_by_code(code).await
+ }
+ async fn write_identity(&self, i: &Identity) -> Result<(), StoreError> {
+ self.sql.write_identity(i).await
+ }
+ async fn read_identity(&self, identity_id: &Uuid) -> Result<Identity, StoreError> {
+ self.sql.read_identity(identity_id).await
+ }
+ async fn write_session(&self, session: &Session) -> Result<(), StoreError> {
+ self.sql.write_session(session).await
+ }
+ async fn read_session(&self, secret: &SessionSecret) -> Result<Session, StoreError> {
+ self.sql.read_session(secret).await
+ }
+}
+
+pub struct SqliteClient {
+ sql: SqlClient<Sqlite>,
+}
+
+impl SqliteClient {
+ pub async fn new(pool: sqlx::Pool<Sqlite>) -> Arc<dyn Store + Send + Sync + 'static> {
+ sqlx::migrate!("store/sqlite/migrations")
+ .run(&pool)
+ .await
+ .expect(ERR_MSG_MIGRATION_FAILED);
+
+ sqlx::query("pragma foreign_keys = on")
+ .execute(&pool)
+ .await
+ .expect(
+ "Failed to initialize FK pragma. File a bug at https://www.github.com/secd-lib",
+ );
+
+ Arc::new(SqliteClient {
+ sql: SqlClient {
+ pool,
+ sqls_root: SQLITE.to_string(),
+ },
+ })
+ }
+}
+
+#[async_trait::async_trait]
+impl Store for SqliteClient {
+ async fn write_email(&self, identity_id: Uuid, email_address: &str) -> Result<(), StoreError> {
+ self.sql.write_email(identity_id, email_address).await
+ }
+ async fn find_email_validation(
+ &self,
+ validation_id: Option<&Uuid>,
+ code: Option<&str>,
+ ) -> Result<EmailValidation, StoreError> {
+ self.sql.find_email_validation(validation_id, code).await
+ }
+ async fn write_email_validation(&self, ev: &EmailValidation) -> Result<Uuid, StoreError> {
+ self.sql.write_email_validation(ev).await
+ }
+ async fn find_identity(
+ &self,
+ identity_id: Option<&Uuid>,
+ email: Option<&str>,
+ ) -> Result<Option<Identity>, StoreError> {
+ self.sql.find_identity(identity_id, email).await
+ }
+ async fn find_identity_by_code(&self, code: &str) -> Result<Identity, StoreError> {
+ self.sql.find_identity_by_code(code).await
+ }
+ async fn write_identity(&self, i: &Identity) -> Result<(), StoreError> {
+ self.sql.write_identity(i).await
+ }
+ async fn read_identity(&self, identity_id: &Uuid) -> Result<Identity, StoreError> {
+ self.sql.read_identity(identity_id).await
+ }
+ async fn write_session(&self, session: &Session) -> Result<(), StoreError> {
+ self.sql.write_session(session).await
+ }
+ async fn read_session(&self, secret: &SessionSecret) -> Result<Session, StoreError> {
+ self.sql.read_session(secret).await
+ }
+}