use std::sync::Arc; use super::{ EmailValidation, Identity, Session, SessionSecret, Store, StoreError, ERR_MSG_MIGRATION_FAILED, FIND_EMAIL_VALIDATION, FIND_IDENTITY, FIND_IDENTITY_BY_CODE, PGSQL, READ_EMAIL_RAW_ID, READ_IDENTITY_RAW_ID, READ_SESSION, SQLITE, SQLS, WRITE_EMAIL, WRITE_EMAIL_VALIDATION, WRITE_IDENTITY, WRITE_SESSION, }; use crate::util; use log::error; use openssl::sha::Sha256; use sqlx::{ self, database::HasArguments, ColumnIndex, Database, Decode, Encode, Executor, IntoArguments, Pool, Postgres, Sqlite, Transaction, Type, }; use time::OffsetDateTime; use uuid::Uuid; fn get_sqls(root: &str, file: &str) -> Vec { SQLS.get(root) .unwrap() .get(file) .unwrap() .split("--") .map(|p| p.to_string()) .collect() } fn hash_secret(secret: &str) -> Vec { let mut hasher = Sha256::new(); hasher.update(secret.as_bytes()); hasher.finish().to_vec() } struct SqlClient where D: sqlx::Database, { pool: sqlx::Pool, sqls_root: String, } impl SqlClient where D: sqlx::Database, for<'c> >::Arguments: IntoArguments<'c, D>, for<'c> i64: Decode<'c, D> + Type, for<'c> &'c str: Decode<'c, D> + Type, for<'c> &'c str: Encode<'c, D> + Type, for<'c> usize: ColumnIndex<::Row>, for<'c> Uuid: Decode<'c, D> + Type, for<'c> Uuid: Encode<'c, D> + Type, for<'c> &'c Pool: Executor<'c, Database = D>, { async fn read_identity_raw_id(&self, id: &Uuid) -> Result { let sqls = get_sqls(&self.sqls_root, READ_IDENTITY_RAW_ID); Ok(sqlx::query_as::<_, (i64,)>(&sqls[0]) .bind(id) .fetch_one(&self.pool) .await .map_err(util::log_err_sqlx)? .0) } async fn read_email_raw_id(&self, address: &str) -> Result { let sqls = get_sqls(&self.sqls_root, READ_EMAIL_RAW_ID); Ok(sqlx::query_as::<_, (i64,)>(&sqls[0]) .bind(address) .fetch_one(&self.pool) .await .map_err(util::log_err_sqlx)? .0) } } #[async_trait::async_trait] impl Store for SqlClient where D: sqlx::Database, for<'c> >::Arguments: IntoArguments<'c, D>, for<'c> bool: Decode<'c, D> + Type, for<'c> bool: Encode<'c, D> + Type, for<'c> i64: Decode<'c, D> + Type, for<'c> i64: Encode<'c, D> + Type, for<'c> i32: Decode<'c, D> + Type, for<'c> i32: Encode<'c, D> + Type, for<'c> OffsetDateTime: Decode<'c, D> + Type, for<'c> OffsetDateTime: Encode<'c, D> + Type, for<'c> &'c str: ColumnIndex<::Row>, for<'c> &'c str: Decode<'c, D> + Type, for<'c> &'c str: Encode<'c, D> + Type, for<'c> Option<&'c str>: Decode<'c, D> + Type, for<'c> Option<&'c str>: Encode<'c, D> + Type, for<'c> String: Decode<'c, D> + Type, for<'c> String: Encode<'c, D> + Type, for<'c> Option: Decode<'c, D> + Type, for<'c> Option: Encode<'c, D> + Type, for<'c> usize: ColumnIndex<::Row>, for<'c> Uuid: Decode<'c, D> + Type, for<'c> Uuid: Encode<'c, D> + Type, for<'c> &'c [u8]: Encode<'c, D> + Type, for<'c> Option<&'c Uuid>: Encode<'c, D> + Type, for<'c> Option<&'c Vec>: Encode<'c, D> + Type, for<'c> Option: Decode<'c, D> + Type, for<'c> Option: Encode<'c, D> + Type, for<'c> &'c Pool: Executor<'c, Database = D>, for<'c> &'c mut Transaction<'c, D>: Executor<'c, Database = D>, { async fn write_email(&self, identity_id: Uuid, email_address: &str) -> Result<(), StoreError> { let sqls = get_sqls(&self.sqls_root, WRITE_EMAIL); let identity_id = self.read_identity_raw_id(&identity_id).await?; let email_id: (i64,) = match sqlx::query_as(&sqls[0]) .bind(email_address) .fetch_one(&self.pool) .await { Ok(i) => i, Err(sqlx::Error::RowNotFound) => sqlx::query_as::<_, (i64,)>(&sqls[1]) .bind(email_address) .fetch_one(&self.pool) .await .map_err(util::log_err_sqlx)?, Err(e) => return Err(StoreError::SqlxError(e)), }; sqlx::query(&sqls[2]) .bind(identity_id) .bind(email_id.0) .bind(OffsetDateTime::now_utc()) .execute(&self.pool) .await .map_err(util::log_err_sqlx)?; Ok(()) } async fn find_email_validation( &self, validation_id: Option<&Uuid>, code: Option<&str>, ) -> Result { let sqls = get_sqls(&self.sqls_root, FIND_EMAIL_VALIDATION); let mut rows = sqlx::query_as::<_, EmailValidation>(&sqls[0]) .bind(validation_id) .bind(code) .fetch_all(&self.pool) .await .map_err(util::log_err_sqlx)?; match rows.len() { 0 => Err(StoreError::NoEmailValidationFound), 1 => Ok(rows.swap_remove(0)), _ => Err(StoreError::TooManyEmailValidations), } } async fn write_email_validation(&self, ev: &EmailValidation) -> Result { let sqls = get_sqls(&self.sqls_root, WRITE_EMAIL_VALIDATION); let identity_id = self .read_identity_raw_id( &ev.identity_id .ok_or(StoreError::IdentityIdMustExistInvariant)?, ) .await?; let email_id = self.read_email_raw_id(&ev.email_address).await?; let new_id = Uuid::new_v4(); sqlx::query(&sqls[0]) .bind(ev.id.unwrap_or(new_id)) .bind(identity_id) .bind(email_id) .bind(ev.attempts) .bind(&ev.code) .bind(ev.is_validated) .bind(ev.created_at) .bind(ev.expires_at) .execute(&self.pool) .await .map_err(util::log_err_sqlx)?; Ok(new_id) } async fn find_identity( &self, id: Option<&Uuid>, email: Option<&str>, ) -> Result, StoreError> { let sqls = get_sqls(&self.sqls_root, FIND_IDENTITY); Ok( match sqlx::query_as::<_, Identity>(&sqls[0]) .bind(id) .bind(email) .fetch_one(&self.pool) .await { Ok(i) => Some(i), Err(sqlx::Error::RowNotFound) => None, Err(e) => return Err(StoreError::SqlxError(e)), }, ) } async fn find_identity_by_code(&self, code: &str) -> Result { let sqls = get_sqls(&self.sqls_root, FIND_IDENTITY_BY_CODE); let rows = sqlx::query_as::<_, (i32,)>(&sqls[0]) .bind(code) .fetch_all(&self.pool) .await .map_err(util::log_err_sqlx)?; if rows.len() == 0 { return Err(StoreError::CodeDoesNotExist(code.to_string())); } if rows.len() != 1 { return Err(StoreError::CodeAppearsMoreThanOnce); } let identity_email_id = rows.get(0).unwrap().0; // TODO: IF we expand beyond email codes, then we'll need to join against a bunch of identity tables. // but since a single code was found, only one of them should pop... Ok(sqlx::query_as::<_, Identity>(&sqls[1]) .bind(identity_email_id) .fetch_one(&self.pool) .await .map_err(util::log_err_sqlx)?) } async fn write_identity(&self, i: &Identity) -> Result<(), StoreError> { let sqls = get_sqls(&self.sqls_root, WRITE_IDENTITY); sqlx::query(&sqls[0]) .bind(i.id) .bind(i.data.clone()) .bind(i.created_at) .execute(&self.pool) .await .map_err(|e| { error!("write_identity_failure"); error!("{:?}", e); e })?; Ok(()) } async fn read_identity(&self, id: &Uuid) -> Result { Ok(sqlx::query_as::<_, Identity>( " select identity_public_id, data, created_at from identity where identity_public_id = ?", ) .bind(id) .fetch_one(&self.pool) .await .map_err(util::log_err_sqlx)?) } async fn write_session(&self, session: &Session) -> Result<(), StoreError> { let sqls = get_sqls(&self.sqls_root, WRITE_SESSION); let secret_hash = session.secret.as_ref().map(|s| hash_secret(s)); sqlx::query(&sqls[0]) .bind(&session.identity_id) .bind(secret_hash.as_ref()) .bind(session.created_at) .bind(OffsetDateTime::now_utc()) .bind(session.expires_at) .bind(session.revoked_at) .execute(&self.pool) .await .map_err(util::log_err_sqlx)?; Ok(()) } async fn read_session(&self, secret: &SessionSecret) -> Result { let sqls = get_sqls(&self.sqls_root, READ_SESSION); let secret_hash = hash_secret(secret); let mut session = sqlx::query_as::<_, Session>(&sqls[0]) .bind(&secret_hash[..]) .fetch_one(&self.pool) .await .map_err(util::log_err_sqlx)?; // This should do nothing other than updated touched_at, and then // clear the plaintext secret session.secret = Some(secret.to_string()); self.write_session(&session).await?; session.secret = None; Ok(session) } } pub struct PgClient { sql: SqlClient, } impl PgClient { pub async fn new(pool: sqlx::Pool) -> Arc { sqlx::migrate!("store/pg/migrations") .run(&pool) .await .expect(ERR_MSG_MIGRATION_FAILED); Arc::new(PgClient { sql: SqlClient { pool, sqls_root: PGSQL.to_string(), }, }) } } #[async_trait::async_trait] impl Store for PgClient { async fn write_email(&self, identity_id: Uuid, email_address: &str) -> Result<(), StoreError> { self.sql.write_email(identity_id, email_address).await } async fn find_email_validation( &self, validation_id: Option<&Uuid>, code: Option<&str>, ) -> Result { self.sql.find_email_validation(validation_id, code).await } async fn write_email_validation(&self, ev: &EmailValidation) -> Result { self.sql.write_email_validation(ev).await } async fn find_identity( &self, identity_id: Option<&Uuid>, email: Option<&str>, ) -> Result, StoreError> { self.sql.find_identity(identity_id, email).await } async fn find_identity_by_code(&self, code: &str) -> Result { self.sql.find_identity_by_code(code).await } async fn write_identity(&self, i: &Identity) -> Result<(), StoreError> { self.sql.write_identity(i).await } async fn read_identity(&self, identity_id: &Uuid) -> Result { self.sql.read_identity(identity_id).await } async fn write_session(&self, session: &Session) -> Result<(), StoreError> { self.sql.write_session(session).await } async fn read_session(&self, secret: &SessionSecret) -> Result { self.sql.read_session(secret).await } } pub struct SqliteClient { sql: SqlClient, } impl SqliteClient { pub async fn new(pool: sqlx::Pool) -> Arc { sqlx::migrate!("store/sqlite/migrations") .run(&pool) .await .expect(ERR_MSG_MIGRATION_FAILED); sqlx::query("pragma foreign_keys = on") .execute(&pool) .await .expect( "Failed to initialize FK pragma. File a bug at https://www.github.com/secd-lib", ); Arc::new(SqliteClient { sql: SqlClient { pool, sqls_root: SQLITE.to_string(), }, }) } } #[async_trait::async_trait] impl Store for SqliteClient { async fn write_email(&self, identity_id: Uuid, email_address: &str) -> Result<(), StoreError> { self.sql.write_email(identity_id, email_address).await } async fn find_email_validation( &self, validation_id: Option<&Uuid>, code: Option<&str>, ) -> Result { self.sql.find_email_validation(validation_id, code).await } async fn write_email_validation(&self, ev: &EmailValidation) -> Result { self.sql.write_email_validation(ev).await } async fn find_identity( &self, identity_id: Option<&Uuid>, email: Option<&str>, ) -> Result, StoreError> { self.sql.find_identity(identity_id, email).await } async fn find_identity_by_code(&self, code: &str) -> Result { self.sql.find_identity_by_code(code).await } async fn write_identity(&self, i: &Identity) -> Result<(), StoreError> { self.sql.write_identity(i).await } async fn read_identity(&self, identity_id: &Uuid) -> Result { self.sql.read_identity(identity_id).await } async fn write_session(&self, session: &Session) -> Result<(), StoreError> { self.sql.write_session(session).await } async fn read_session(&self, secret: &SessionSecret) -> Result { self.sql.read_session(secret).await } }