aboutsummaryrefslogtreecommitdiff
path: root/crates/secd/src/client/sqldb.rs
diff options
context:
space:
mode:
authorbenj <benj@rse8.com>2022-12-12 17:06:57 -0800
committerbenj <benj@rse8.com>2022-12-12 17:06:57 -0800
commit0920c4d4f30a3345870d385d5c6f3e0919228b56 (patch)
treef54668d91db469b7304758893a51b590c8f9b0de /crates/secd/src/client/sqldb.rs
parent3a4de13528fc85dcbe6bc9055d97ba5cc87f5712 (diff)
downloadsecdiam-0920c4d4f30a3345870d385d5c6f3e0919228b56.tar
secdiam-0920c4d4f30a3345870d385d5c6f3e0919228b56.tar.gz
secdiam-0920c4d4f30a3345870d385d5c6f3e0919228b56.tar.bz2
secdiam-0920c4d4f30a3345870d385d5c6f3e0919228b56.tar.lz
secdiam-0920c4d4f30a3345870d385d5c6f3e0919228b56.tar.xz
secdiam-0920c4d4f30a3345870d385d5c6f3e0919228b56.tar.zst
secdiam-0920c4d4f30a3345870d385d5c6f3e0919228b56.zip
(oauth2 + email added): a mess that may or may not really work and needs to be refactored...
Diffstat (limited to 'crates/secd/src/client/sqldb.rs')
-rw-r--r--crates/secd/src/client/sqldb.rs324
1 files changed, 266 insertions, 58 deletions
diff --git a/crates/secd/src/client/sqldb.rs b/crates/secd/src/client/sqldb.rs
index 6048c48..15cc4b5 100644
--- a/crates/secd/src/client/sqldb.rs
+++ b/crates/secd/src/client/sqldb.rs
@@ -1,19 +1,23 @@
-use std::sync::Arc;
+use std::{str::FromStr, sync::Arc};
use super::{
- EmailValidation, Identity, Session, SessionSecret, Store, StoreError, ERR_MSG_MIGRATION_FAILED,
- FIND_EMAIL_VALIDATION, FIND_IDENTITY, FIND_IDENTITY_BY_CODE, PGSQL, READ_EMAIL_RAW_ID,
- READ_IDENTITY_RAW_ID, READ_SESSION, SQLITE, SQLS, WRITE_EMAIL, WRITE_EMAIL_VALIDATION,
- WRITE_IDENTITY, WRITE_SESSION,
+ EmailValidation, Identity, OauthProvider, OauthProviderName, OauthResponseType, Session,
+ SessionSecret, Store, StoreError, ERR_MSG_MIGRATION_FAILED, FIND_EMAIL_VALIDATION,
+ FIND_IDENTITY, FIND_IDENTITY_BY_CODE, PGSQL, READ_EMAIL_RAW_ID, READ_IDENTITY_RAW_ID,
+ READ_OAUTH_PROVIDER, READ_OAUTH_VALIDATION, READ_SESSION, READ_VALIDATION_TYPE, SQLITE, SQLS,
+ WRITE_EMAIL, WRITE_EMAIL_VALIDATION, WRITE_IDENTITY, WRITE_OAUTH_PROVIDER,
+ WRITE_OAUTH_VALIDATION, WRITE_SESSION,
};
-use crate::util;
-use log::error;
+use crate::{util, OauthValidation, ValidationRequestId, ValidationType};
+use anyhow::bail;
+use log::{debug, error};
use openssl::sha::Sha256;
use sqlx::{
self, database::HasArguments, ColumnIndex, Database, Decode, Encode, Executor, IntoArguments,
Pool, Postgres, Sqlite, Transaction, Type,
};
use time::OffsetDateTime;
+use url::Url;
use uuid::Uuid;
fn get_sqls(root: &str, file: &str) -> Vec<String> {
@@ -97,6 +101,8 @@ where
for<'c> String: Encode<'c, D> + Type<D>,
for<'c> Option<String>: Decode<'c, D> + Type<D>,
for<'c> Option<String>: Encode<'c, D> + Type<D>,
+ for<'c> OauthProviderName: Decode<'c, D> + Type<D>,
+ for<'c> OauthResponseType: Decode<'c, D> + Type<D>,
for<'c> usize: ColumnIndex<<D as Database>::Row>,
for<'c> Uuid: Decode<'c, D> + Type<D>,
for<'c> Uuid: Encode<'c, D> + Type<D>,
@@ -108,29 +114,11 @@ where
for<'c> &'c Pool<D>: Executor<'c, Database = D>,
for<'c> &'c mut Transaction<'c, D>: Executor<'c, Database = D>,
{
- async fn write_email(&self, identity_id: Uuid, email_address: &str) -> Result<(), StoreError> {
+ async fn write_email(&self, email_address: &str) -> Result<(), StoreError> {
let sqls = get_sqls(&self.sqls_root, WRITE_EMAIL);
- let identity_id = self.read_identity_raw_id(&identity_id).await?;
-
- let email_id: (i64,) = match sqlx::query_as(&sqls[0])
+ sqlx::query(&sqls[0])
.bind(email_address)
- .fetch_one(&self.pool)
- .await
- {
- Ok(i) => i,
- Err(sqlx::Error::RowNotFound) => sqlx::query_as::<_, (i64,)>(&sqls[1])
- .bind(email_address)
- .fetch_one(&self.pool)
- .await
- .map_err(util::log_err_sqlx)?,
- Err(e) => return Err(StoreError::SqlxError(e)),
- };
-
- sqlx::query(&sqls[2])
- .bind(identity_id)
- .bind(email_id.0)
- .bind(OffsetDateTime::now_utc())
.execute(&self.pool)
.await
.map_err(util::log_err_sqlx)?;
@@ -154,57 +142,84 @@ where
match rows.len() {
0 => Err(StoreError::NoEmailValidationFound),
1 => Ok(rows.swap_remove(0)),
- _ => Err(StoreError::TooManyEmailValidations),
+ _ => Err(StoreError::TooManyValidations),
}
}
- async fn write_email_validation(&self, ev: &EmailValidation) -> Result<Uuid, StoreError> {
+ async fn write_email_validation(&self, ev: &EmailValidation) -> anyhow::Result<Uuid> {
let sqls = get_sqls(&self.sqls_root, WRITE_EMAIL_VALIDATION);
- let identity_id = self
- .read_identity_raw_id(
- &ev.identity_id
- .ok_or(StoreError::IdentityIdMustExistInvariant)?,
- )
- .await?;
let email_id = self.read_email_raw_id(&ev.email_address).await?;
-
- let new_id = Uuid::new_v4();
+ let validation_id = ev.id.unwrap_or(Uuid::new_v4());
sqlx::query(&sqls[0])
- .bind(ev.id.unwrap_or(new_id))
- .bind(identity_id)
+ .bind(validation_id)
.bind(email_id)
- .bind(ev.attempts)
.bind(&ev.code)
- .bind(ev.is_validated)
+ .bind(ev.is_oauth_derived)
.bind(ev.created_at)
- .bind(ev.expires_at)
+ .bind(ev.validated_at)
+ .bind(ev.expired_at)
.execute(&self.pool)
.await
.map_err(util::log_err_sqlx)?;
- Ok(new_id)
+ if ev.identity_id.is_some() || ev.revoked_at.is_some() || ev.deleted_at.is_some() {
+ sqlx::query(&sqls[1])
+ .bind(ev.identity_id.as_ref())
+ .bind(validation_id)
+ .bind(ev.revoked_at)
+ .bind(ev.deleted_at)
+ .execute(&self.pool)
+ .await
+ .map_err(util::log_err_sqlx)?;
+ }
+
+ Ok(validation_id)
}
async fn find_identity(
&self,
id: Option<&Uuid>,
email: Option<&str>,
- ) -> Result<Option<Identity>, StoreError> {
+ ) -> anyhow::Result<Option<Identity>> {
let sqls = get_sqls(&self.sqls_root, FIND_IDENTITY);
Ok(
match sqlx::query_as::<_, Identity>(&sqls[0])
.bind(id)
.bind(email)
- .fetch_one(&self.pool)
+ .fetch_all(&self.pool)
.await
{
- Ok(i) => Some(i),
+ Ok(mut is) => match is.len() {
+ // if only 1 found, then that's fine
+ // if multiple are fond, then if they all have the same id, that's okay
+ 1 => {
+ let i = is.swap_remove(0);
+ match i.deleted_at {
+ Some(t) if t > OffsetDateTime::now_utc() => Some(i),
+ None => Some(i),
+ _ => None,
+ }
+ }
+ 0 => None,
+ _ => {
+ match is
+ .iter()
+ .filter(|&i| i.id != is[0].id)
+ .collect::<Vec<&Identity>>()
+ .len()
+ {
+ 0 => Some(is.swap_remove(0)),
+ _ => bail!(StoreError::TooManyIdentitiesFound),
+ }
+ }
+ },
Err(sqlx::Error::RowNotFound) => None,
- Err(e) => return Err(StoreError::SqlxError(e)),
+ Err(e) => bail!(StoreError::SqlxError(e)),
},
)
}
+
async fn find_identity_by_code(&self, code: &str) -> Result<Identity, StoreError> {
let sqls = get_sqls(&self.sqls_root, FIND_IDENTITY_BY_CODE);
@@ -250,14 +265,16 @@ where
Ok(())
}
async fn read_identity(&self, id: &Uuid) -> Result<Identity, StoreError> {
- Ok(sqlx::query_as::<_, Identity>(
+ let identity = sqlx::query_as::<_, Identity>(
"
select identity_public_id, data, created_at from identity where identity_public_id = ?",
)
.bind(id)
.fetch_one(&self.pool)
.await
- .map_err(util::log_err_sqlx)?)
+ .map_err(util::log_err_sqlx)?;
+
+ Ok(identity)
}
async fn write_session(&self, session: &Session) -> Result<(), StoreError> {
@@ -269,7 +286,6 @@ select identity_public_id, data, created_at from identity where identity_public_
.bind(&session.identity_id)
.bind(secret_hash.as_ref())
.bind(session.created_at)
- .bind(OffsetDateTime::now_utc())
.bind(session.expires_at)
.bind(session.revoked_at)
.execute(&self.pool)
@@ -296,6 +312,142 @@ select identity_public_id, data, created_at from identity where identity_public_
Ok(session)
}
+
+ async fn write_oauth_provider(&self, provider: &OauthProvider) -> Result<(), StoreError> {
+ let sqls = get_sqls(&self.sqls_root, WRITE_OAUTH_PROVIDER);
+ sqlx::query(&sqls[0])
+ .bind(&provider.name.to_string())
+ .bind(&provider.flow)
+ .bind(&provider.base_url.to_string())
+ .bind(&provider.response.to_string())
+ .bind(&provider.default_scope)
+ .bind(&provider.client_id)
+ // TODO: encrypt secret before writing
+ .bind(&provider.client_secret)
+ .bind(&provider.redirect_url.to_string())
+ .bind(provider.created_at)
+ .bind(provider.deleted_at)
+ .execute(&self.pool)
+ .await
+ .map_err(util::log_err_sqlx)?;
+ Ok(())
+ }
+
+ async fn read_oauth_provider(
+ &self,
+ provider: &OauthProviderName,
+ flow: Option<String>,
+ ) -> Result<OauthProvider, StoreError> {
+ let sqls = get_sqls(&self.sqls_root, READ_OAUTH_PROVIDER);
+ let flow = flow.unwrap_or("default".into());
+ debug!("provider: {:?}, flow: {:?}", provider, flow);
+ // TODO: Write the generic FromRow impl for OauthProvider...
+ let res = sqlx::query_as::<
+ _,
+ (
+ String,
+ String,
+ String,
+ String,
+ String,
+ String,
+ String,
+ OffsetDateTime,
+ Option<OffsetDateTime>,
+ ),
+ >(&sqls[0])
+ .bind(&provider.to_string())
+ .bind(&flow)
+ .fetch_one(&self.pool)
+ .await
+ .map_err(util::log_err_sqlx)?;
+
+ debug!("res: {:?}", res);
+
+ Ok(OauthProvider {
+ name: provider.clone(),
+ flow: Some(res.0),
+ base_url: Url::from_str(&res.1)
+ .map_err(|_| StoreError::OauthProviderDoesNotExist(*provider))?,
+ response: OauthResponseType::from_str(&res.2)
+ .map_err(|_| StoreError::OauthProviderDoesNotExist(*provider))?,
+ default_scope: res.3,
+ client_id: res.4,
+ client_secret: res.5,
+ redirect_url: Url::from_str(&res.6)
+ .map_err(|_| StoreError::OauthProviderDoesNotExist(*provider))?,
+ created_at: res.7,
+ deleted_at: res.8,
+ })
+ }
+ async fn write_oauth_validation(
+ &self,
+ v: &OauthValidation,
+ ) -> anyhow::Result<ValidationRequestId> {
+ let sqls = get_sqls(&self.sqls_root, WRITE_OAUTH_VALIDATION);
+
+ let validation_id = v.id.unwrap_or(Uuid::new_v4());
+ sqlx::query(&sqls[0])
+ .bind(validation_id)
+ .bind(v.oauth_provider.name.to_string())
+ .bind(v.oauth_provider.flow.clone())
+ .bind(v.access_token.clone())
+ .bind(v.raw_response.clone())
+ .bind(v.created_at)
+ .bind(v.validated_at)
+ .execute(&self.pool)
+ .await?;
+
+ if v.identity_id.is_some() || v.revoked_at.is_some() || v.deleted_at.is_some() {
+ sqlx::query(&sqls[1])
+ .bind(v.identity_id.as_ref())
+ .bind(validation_id)
+ .bind(v.revoked_at)
+ .bind(v.deleted_at)
+ .execute(&self.pool)
+ .await?;
+ }
+
+ Ok(validation_id)
+ }
+ async fn read_oauth_validation(
+ &self,
+ validation_id: &ValidationRequestId,
+ ) -> anyhow::Result<OauthValidation> {
+ let sqls = get_sqls(&self.sqls_root, READ_OAUTH_VALIDATION);
+
+ let mut es = sqlx::query_as::<_, OauthValidation>(&sqls[0])
+ .bind(validation_id)
+ .fetch_all(&self.pool)
+ .await?;
+
+ if es.len() != 1 {
+ bail!(StoreError::OauthValidationDoesNotExist(
+ validation_id.clone()
+ ));
+ }
+
+ Ok(es.swap_remove(0))
+ }
+ async fn find_validation_type(
+ &self,
+ validation_id: &ValidationRequestId,
+ ) -> anyhow::Result<ValidationType> {
+ let sqls = get_sqls(&self.sqls_root, READ_VALIDATION_TYPE);
+
+ let mut es = sqlx::query_as::<_, (String,)>(&sqls[0])
+ .bind(validation_id)
+ .fetch_all(&self.pool)
+ .await
+ .map_err(util::log_err_sqlx)?;
+
+ match es.len() {
+ 1 => Ok(ValidationType::from_str(&es.swap_remove(0).0)?),
+ _ => bail!(StoreError::Other(
+ "expected a single validation but recieved 0 or multiple validations".into()
+ )),
+ }
+ }
}
pub struct PgClient {
@@ -320,8 +472,8 @@ impl PgClient {
#[async_trait::async_trait]
impl Store for PgClient {
- async fn write_email(&self, identity_id: Uuid, email_address: &str) -> Result<(), StoreError> {
- self.sql.write_email(identity_id, email_address).await
+ async fn write_email(&self, email_address: &str) -> Result<(), StoreError> {
+ self.sql.write_email(email_address).await
}
async fn find_email_validation(
&self,
@@ -330,14 +482,14 @@ impl Store for PgClient {
) -> Result<EmailValidation, StoreError> {
self.sql.find_email_validation(validation_id, code).await
}
- async fn write_email_validation(&self, ev: &EmailValidation) -> Result<Uuid, StoreError> {
+ async fn write_email_validation(&self, ev: &EmailValidation) -> anyhow::Result<Uuid> {
self.sql.write_email_validation(ev).await
}
async fn find_identity(
&self,
identity_id: Option<&Uuid>,
email: Option<&str>,
- ) -> Result<Option<Identity>, StoreError> {
+ ) -> anyhow::Result<Option<Identity>> {
self.sql.find_identity(identity_id, email).await
}
async fn find_identity_by_code(&self, code: &str) -> Result<Identity, StoreError> {
@@ -355,6 +507,34 @@ impl Store for PgClient {
async fn read_session(&self, secret: &SessionSecret) -> Result<Session, StoreError> {
self.sql.read_session(secret).await
}
+ async fn write_oauth_provider(&self, provider: &OauthProvider) -> Result<(), StoreError> {
+ self.sql.write_oauth_provider(provider).await
+ }
+ async fn read_oauth_provider(
+ &self,
+ provider: &OauthProviderName,
+ flow: Option<String>,
+ ) -> Result<OauthProvider, StoreError> {
+ self.sql.read_oauth_provider(provider, flow).await
+ }
+ async fn write_oauth_validation(
+ &self,
+ validation: &OauthValidation,
+ ) -> anyhow::Result<ValidationRequestId> {
+ self.sql.write_oauth_validation(validation).await
+ }
+ async fn read_oauth_validation(
+ &self,
+ validation_id: &ValidationRequestId,
+ ) -> anyhow::Result<OauthValidation> {
+ self.sql.read_oauth_validation(validation_id).await
+ }
+ async fn find_validation_type(
+ &self,
+ validation_id: &ValidationRequestId,
+ ) -> anyhow::Result<ValidationType> {
+ self.sql.find_validation_type(validation_id).await
+ }
}
pub struct SqliteClient {
@@ -386,8 +566,8 @@ impl SqliteClient {
#[async_trait::async_trait]
impl Store for SqliteClient {
- async fn write_email(&self, identity_id: Uuid, email_address: &str) -> Result<(), StoreError> {
- self.sql.write_email(identity_id, email_address).await
+ async fn write_email(&self, email_address: &str) -> Result<(), StoreError> {
+ self.sql.write_email(email_address).await
}
async fn find_email_validation(
&self,
@@ -396,14 +576,14 @@ impl Store for SqliteClient {
) -> Result<EmailValidation, StoreError> {
self.sql.find_email_validation(validation_id, code).await
}
- async fn write_email_validation(&self, ev: &EmailValidation) -> Result<Uuid, StoreError> {
+ async fn write_email_validation(&self, ev: &EmailValidation) -> anyhow::Result<Uuid> {
self.sql.write_email_validation(ev).await
}
async fn find_identity(
&self,
identity_id: Option<&Uuid>,
email: Option<&str>,
- ) -> Result<Option<Identity>, StoreError> {
+ ) -> anyhow::Result<Option<Identity>> {
self.sql.find_identity(identity_id, email).await
}
async fn find_identity_by_code(&self, code: &str) -> Result<Identity, StoreError> {
@@ -421,4 +601,32 @@ impl Store for SqliteClient {
async fn read_session(&self, secret: &SessionSecret) -> Result<Session, StoreError> {
self.sql.read_session(secret).await
}
+ async fn write_oauth_provider(&self, provider: &OauthProvider) -> Result<(), StoreError> {
+ self.sql.write_oauth_provider(provider).await
+ }
+ async fn read_oauth_provider(
+ &self,
+ provider: &OauthProviderName,
+ flow: Option<String>,
+ ) -> Result<OauthProvider, StoreError> {
+ self.sql.read_oauth_provider(provider, flow).await
+ }
+ async fn write_oauth_validation(
+ &self,
+ validation: &OauthValidation,
+ ) -> anyhow::Result<ValidationRequestId> {
+ self.sql.write_oauth_validation(validation).await
+ }
+ async fn read_oauth_validation(
+ &self,
+ validation_id: &ValidationRequestId,
+ ) -> anyhow::Result<OauthValidation> {
+ self.sql.read_oauth_validation(validation_id).await
+ }
+ async fn find_validation_type(
+ &self,
+ validation_id: &ValidationRequestId,
+ ) -> anyhow::Result<ValidationType> {
+ self.sql.find_validation_type(validation_id).await
+ }
}