aboutsummaryrefslogtreecommitdiff
path: root/crates/secd/src/client
diff options
context:
space:
mode:
Diffstat (limited to 'crates/secd/src/client')
-rw-r--r--crates/secd/src/client/mod.rs233
-rw-r--r--crates/secd/src/client/sqldb.rs324
-rw-r--r--crates/secd/src/client/types.rs3
3 files changed, 492 insertions, 68 deletions
diff --git a/crates/secd/src/client/mod.rs b/crates/secd/src/client/mod.rs
index 3925657..38426ef 100644
--- a/crates/secd/src/client/mod.rs
+++ b/crates/secd/src/client/mod.rs
@@ -1,13 +1,24 @@
-pub mod email;
-pub mod sqldb;
+pub(crate) mod email;
+pub(crate) mod sqldb;
+pub(crate) mod types;
-use std::collections::HashMap;
+use std::{collections::HashMap, str::FromStr};
use super::Identity;
-use crate::{EmailValidation, Session, SessionSecret};
+use crate::{
+ EmailValidation, OauthProvider, OauthProviderName, OauthResponseType, OauthValidation, Session,
+ SessionSecret, ValidationRequestId, ValidationType,
+};
+use email_address::EmailAddress;
use lazy_static::lazy_static;
+use sqlx::{
+ database::HasValueRef, sqlite::SqliteRow, ColumnIndex, Database, Decode, FromRow, Row, Sqlite,
+ Type,
+};
use thiserror::Error;
+use time::OffsetDateTime;
+use url::Url;
use uuid::Uuid;
pub enum EmailType {
@@ -36,13 +47,15 @@ pub trait EmailMessenger {
#[derive(Error, Debug, derive_more::Display)]
pub enum StoreError {
SqlxError(#[from] sqlx::Error),
- EmailAlreadyExists,
CodeAppearsMoreThanOnce,
CodeDoesNotExist(String),
IdentityIdMustExistInvariant,
- TooManyEmailValidations,
+ TooManyValidations,
+ TooManyIdentitiesFound,
NoEmailValidationFound,
- Unknown,
+ OauthProviderDoesNotExist(OauthProviderName),
+ OauthValidationDoesNotExist(ValidationRequestId),
+ Other(String),
}
const EMAIL_TEMPLATE_DEFAULT_LOGIN: &str = "You requested a login link. Please click the following link %secd_code% to login as %secd_email_address%";
@@ -56,6 +69,7 @@ const PGSQL: &str = "pgsql";
const WRITE_IDENTITY: &str = "write_identity";
const WRITE_EMAIL_VALIDATION: &str = "write_email_validation";
const FIND_EMAIL_VALIDATION: &str = "find_email_validation";
+const READ_VALIDATION_TYPE: &str = "read_validation_type";
const WRITE_EMAIL: &str = "write_email";
@@ -69,6 +83,11 @@ const READ_EMAIL_RAW_ID: &str = "read_email_raw_id";
const WRITE_SESSION: &str = "write_session";
const READ_SESSION: &str = "read_session";
+const WRITE_OAUTH_PROVIDER: &str = "write_oauth_provider";
+const READ_OAUTH_PROVIDER: &str = "read_oauth_provider";
+const WRITE_OAUTH_VALIDATION: &str = "write_oauth_validation";
+const READ_OAUTH_VALIDATION: &str = "read_oauth_validation";
+
lazy_static! {
static ref SQLS: HashMap<&'static str, HashMap<&'static str, &'static str>> = {
let sqlite_sqls: HashMap<&'static str, &'static str> = [
@@ -116,6 +135,26 @@ lazy_static! {
FIND_EMAIL_VALIDATION,
include_str!("../../store/sqlite/sql/find_email_validation.sql"),
),
+ (
+ WRITE_OAUTH_PROVIDER,
+ include_str!("../../store/sqlite/sql/write_oauth_provider.sql"),
+ ),
+ (
+ READ_OAUTH_PROVIDER,
+ include_str!("../../store/sqlite/sql/read_oauth_provider.sql"),
+ ),
+ (
+ READ_OAUTH_VALIDATION,
+ include_str!("../../store/sqlite/sql/read_oauth_validation.sql"),
+ ),
+ (
+ WRITE_OAUTH_VALIDATION,
+ include_str!("../../store/sqlite/sql/write_oauth_validation.sql"),
+ ),
+ (
+ READ_VALIDATION_TYPE,
+ include_str!("../../store/sqlite/sql/read_validation_type.sql"),
+ ),
]
.iter()
.cloned()
@@ -166,6 +205,26 @@ lazy_static! {
FIND_EMAIL_VALIDATION,
include_str!("../../store/pg/sql/find_email_validation.sql"),
),
+ (
+ WRITE_OAUTH_PROVIDER,
+ include_str!("../../store/pg/sql/write_oauth_provider.sql"),
+ ),
+ (
+ READ_OAUTH_PROVIDER,
+ include_str!("../../store/pg/sql/read_oauth_provider.sql"),
+ ),
+ (
+ READ_OAUTH_VALIDATION,
+ include_str!("../../store/pg/sql/read_oauth_validation.sql"),
+ ),
+ (
+ WRITE_OAUTH_VALIDATION,
+ include_str!("../../store/pg/sql/write_oauth_validation.sql"),
+ ),
+ (
+ READ_VALIDATION_TYPE,
+ include_str!("../../store/pg/sql/read_validation_type.sql"),
+ ),
]
.iter()
.cloned()
@@ -180,9 +239,143 @@ lazy_static! {
};
}
+impl<'a, R: Row> FromRow<'a, R> for OauthValidation
+where
+ &'a str: ColumnIndex<R>,
+ OauthProviderName: Decode<'a, R::Database> + Type<R::Database>,
+ OauthResponseType: Decode<'a, R::Database> + Type<R::Database>,
+ OffsetDateTime: Decode<'a, R::Database> + Type<R::Database>,
+ String: Decode<'a, R::Database> + Type<R::Database>,
+ Uuid: Decode<'a, R::Database> + Type<R::Database>,
+{
+ fn from_row(row: &'a R) -> Result<Self, sqlx::Error> {
+ let id: Option<Uuid> = row.try_get("oauth_validation_public_id")?;
+ let identity_id: Option<Uuid> = row.try_get("identity_public_id")?;
+ let access_token: Option<String> = row.try_get("access_token")?;
+ let raw_response: Option<String> = row.try_get("raw_response")?;
+ let created_at: Option<OffsetDateTime> = row.try_get("created_at")?;
+ let validated_at: Option<OffsetDateTime> = row.try_get("validated_at")?;
+ let revoked_at: Option<OffsetDateTime> = row.try_get("revoked_at")?;
+ let deleted_at: Option<OffsetDateTime> = row.try_get("deleted_at")?;
+
+ let op_name: Option<OauthProviderName> = row.try_get("oauth_provider_name")?;
+ let op_flow: Option<String> = row.try_get("oauth_provider_flow")?;
+ let op_base_url: Option<String> = row.try_get("oauth_provider_base_url")?;
+ let op_response_type: Option<OauthResponseType> =
+ row.try_get("oauth_provider_response_type")?;
+ let op_default_scope: Option<String> = row.try_get("oauth_provider_default_scope")?;
+ let op_client_id: Option<String> = row.try_get("oauth_provider_client_id")?;
+ let op_client_secret: Option<String> = row.try_get("oauth_provider_client_secret")?;
+ let op_redirect_url: Option<String> = row.try_get("oauth_provider_redirect_url")?;
+ let op_created_at: Option<OffsetDateTime> = row.try_get("oauth_provider_created_at")?;
+ let op_deleted_at: Option<OffsetDateTime> = row.try_get("oauth_provider_deleted_at")?;
+
+ let op_base_url = op_base_url
+ .map(|s| Url::from_str(&s).ok())
+ .flatten()
+ .ok_or(sqlx::Error::ColumnDecode {
+ index: "oauth_provider_base_url".into(),
+ source: "secd".into(),
+ })?;
+
+ let op_redirect_url = op_redirect_url
+ .map(|s| Url::from_str(&s).ok())
+ .flatten()
+ .ok_or(sqlx::Error::ColumnDecode {
+ index: "oauth_provider_redirect_url".into(),
+ source: "secd".into(),
+ })?;
+
+ Ok(OauthValidation {
+ id,
+ identity_id,
+ access_token,
+ raw_response,
+ created_at: created_at.ok_or(sqlx::Error::ColumnDecode {
+ index: "created_at".into(),
+ source: "secd".into(),
+ })?,
+ validated_at,
+ revoked_at,
+ deleted_at,
+ oauth_provider: OauthProvider {
+ name: op_name.unwrap(),
+ flow: op_flow,
+ base_url: op_base_url,
+ response: op_response_type.ok_or(sqlx::Error::ColumnDecode {
+ index: "oauth_provider_response_type".into(),
+ source: "secd".into(),
+ })?,
+ default_scope: op_default_scope.ok_or(sqlx::Error::ColumnDecode {
+ index: "oauth_provider_default_scope".into(),
+ source: "secd".into(),
+ })?,
+ client_id: op_client_id.ok_or(sqlx::Error::ColumnDecode {
+ index: "oauth_provider_client_id".into(),
+ source: "secd".into(),
+ })?,
+ client_secret: op_client_secret.ok_or(sqlx::Error::ColumnDecode {
+ index: "oauth_provider_client_secret".into(),
+ source: "secd".into(),
+ })?,
+ redirect_url: op_redirect_url,
+ created_at: op_created_at.ok_or(sqlx::Error::ColumnDecode {
+ index: "oauth_provider_created_at".into(),
+ source: "secd".into(),
+ })?,
+ deleted_at: op_deleted_at,
+ },
+ })
+ }
+}
+
+impl<'a, D: Database> Decode<'a, D> for OauthProviderName
+where
+ &'a str: Decode<'a, D>,
+{
+ fn decode(
+ value: <D as HasValueRef<'a>>::ValueRef,
+ ) -> Result<Self, Box<dyn ::std::error::Error + 'static + Send + Sync>> {
+ let v = <&str as Decode<D>>::decode(value)?;
+ <OauthProviderName as clap::ValueEnum>::from_str(v, true)
+ .map_err(|_| "OauthProviderName should exist and decode to a program value.".into())
+ }
+}
+
+impl<D: Database> Type<D> for OauthProviderName
+where
+ str: Type<D>,
+{
+ fn type_info() -> D::TypeInfo {
+ <&str as Type<D>>::type_info()
+ }
+}
+
+impl<'a, D: Database> Decode<'a, D> for OauthResponseType
+where
+ &'a str: Decode<'a, D>,
+{
+ fn decode(
+ value: <D as HasValueRef<'a>>::ValueRef,
+ ) -> Result<Self, Box<dyn ::std::error::Error + 'static + Send + Sync>> {
+ let v = <&str as Decode<D>>::decode(value)?;
+ <OauthResponseType as clap::ValueEnum>::from_str(v, true)
+ .map_err(|_| "OauthResponseType should exist and decode to a program value.".into())
+ }
+}
+
+impl<D: Database> Type<D> for OauthResponseType
+where
+ str: Type<D>,
+{
+ fn type_info() -> D::TypeInfo {
+ <&str as Type<D>>::type_info()
+ }
+}
+
#[async_trait::async_trait]
pub trait Store {
- async fn write_email(&self, identity_id: Uuid, email_address: &str) -> Result<(), StoreError>;
+ async fn write_email(&self, email_address: &str) -> Result<(), StoreError>;
async fn find_email_validation(
&self,
@@ -193,17 +386,37 @@ pub trait Store {
&self,
ev: &EmailValidation,
// TODO: Make this write an EmailValidation
- ) -> Result<Uuid, StoreError>;
+ ) -> anyhow::Result<Uuid>;
async fn find_identity(
&self,
identity_id: Option<&Uuid>,
email: Option<&str>,
- ) -> Result<Option<Identity>, StoreError>;
+ ) -> anyhow::Result<Option<Identity>>;
async fn find_identity_by_code(&self, code: &str) -> Result<Identity, StoreError>;
async fn write_identity(&self, i: &Identity) -> Result<(), StoreError>;
async fn read_identity(&self, identity_id: &Uuid) -> Result<Identity, StoreError>;
async fn write_session(&self, session: &Session) -> Result<(), StoreError>;
async fn read_session(&self, secret: &SessionSecret) -> Result<Session, StoreError>;
+
+ async fn write_oauth_provider(&self, provider: &OauthProvider) -> Result<(), StoreError>;
+ async fn read_oauth_provider(
+ &self,
+ provider: &OauthProviderName,
+ flow: Option<String>,
+ ) -> Result<OauthProvider, StoreError>;
+ async fn write_oauth_validation(
+ &self,
+ validation: &OauthValidation,
+ ) -> anyhow::Result<ValidationRequestId>;
+ async fn read_oauth_validation(
+ &self,
+ validation_id: &ValidationRequestId,
+ ) -> anyhow::Result<OauthValidation>;
+
+ async fn find_validation_type(
+ &self,
+ validation_id: &ValidationRequestId,
+ ) -> anyhow::Result<ValidationType>;
}
diff --git a/crates/secd/src/client/sqldb.rs b/crates/secd/src/client/sqldb.rs
index 6048c48..15cc4b5 100644
--- a/crates/secd/src/client/sqldb.rs
+++ b/crates/secd/src/client/sqldb.rs
@@ -1,19 +1,23 @@
-use std::sync::Arc;
+use std::{str::FromStr, sync::Arc};
use super::{
- EmailValidation, Identity, Session, SessionSecret, Store, StoreError, ERR_MSG_MIGRATION_FAILED,
- FIND_EMAIL_VALIDATION, FIND_IDENTITY, FIND_IDENTITY_BY_CODE, PGSQL, READ_EMAIL_RAW_ID,
- READ_IDENTITY_RAW_ID, READ_SESSION, SQLITE, SQLS, WRITE_EMAIL, WRITE_EMAIL_VALIDATION,
- WRITE_IDENTITY, WRITE_SESSION,
+ EmailValidation, Identity, OauthProvider, OauthProviderName, OauthResponseType, Session,
+ SessionSecret, Store, StoreError, ERR_MSG_MIGRATION_FAILED, FIND_EMAIL_VALIDATION,
+ FIND_IDENTITY, FIND_IDENTITY_BY_CODE, PGSQL, READ_EMAIL_RAW_ID, READ_IDENTITY_RAW_ID,
+ READ_OAUTH_PROVIDER, READ_OAUTH_VALIDATION, READ_SESSION, READ_VALIDATION_TYPE, SQLITE, SQLS,
+ WRITE_EMAIL, WRITE_EMAIL_VALIDATION, WRITE_IDENTITY, WRITE_OAUTH_PROVIDER,
+ WRITE_OAUTH_VALIDATION, WRITE_SESSION,
};
-use crate::util;
-use log::error;
+use crate::{util, OauthValidation, ValidationRequestId, ValidationType};
+use anyhow::bail;
+use log::{debug, error};
use openssl::sha::Sha256;
use sqlx::{
self, database::HasArguments, ColumnIndex, Database, Decode, Encode, Executor, IntoArguments,
Pool, Postgres, Sqlite, Transaction, Type,
};
use time::OffsetDateTime;
+use url::Url;
use uuid::Uuid;
fn get_sqls(root: &str, file: &str) -> Vec<String> {
@@ -97,6 +101,8 @@ where
for<'c> String: Encode<'c, D> + Type<D>,
for<'c> Option<String>: Decode<'c, D> + Type<D>,
for<'c> Option<String>: Encode<'c, D> + Type<D>,
+ for<'c> OauthProviderName: Decode<'c, D> + Type<D>,
+ for<'c> OauthResponseType: Decode<'c, D> + Type<D>,
for<'c> usize: ColumnIndex<<D as Database>::Row>,
for<'c> Uuid: Decode<'c, D> + Type<D>,
for<'c> Uuid: Encode<'c, D> + Type<D>,
@@ -108,29 +114,11 @@ where
for<'c> &'c Pool<D>: Executor<'c, Database = D>,
for<'c> &'c mut Transaction<'c, D>: Executor<'c, Database = D>,
{
- async fn write_email(&self, identity_id: Uuid, email_address: &str) -> Result<(), StoreError> {
+ async fn write_email(&self, email_address: &str) -> Result<(), StoreError> {
let sqls = get_sqls(&self.sqls_root, WRITE_EMAIL);
- let identity_id = self.read_identity_raw_id(&identity_id).await?;
-
- let email_id: (i64,) = match sqlx::query_as(&sqls[0])
+ sqlx::query(&sqls[0])
.bind(email_address)
- .fetch_one(&self.pool)
- .await
- {
- Ok(i) => i,
- Err(sqlx::Error::RowNotFound) => sqlx::query_as::<_, (i64,)>(&sqls[1])
- .bind(email_address)
- .fetch_one(&self.pool)
- .await
- .map_err(util::log_err_sqlx)?,
- Err(e) => return Err(StoreError::SqlxError(e)),
- };
-
- sqlx::query(&sqls[2])
- .bind(identity_id)
- .bind(email_id.0)
- .bind(OffsetDateTime::now_utc())
.execute(&self.pool)
.await
.map_err(util::log_err_sqlx)?;
@@ -154,57 +142,84 @@ where
match rows.len() {
0 => Err(StoreError::NoEmailValidationFound),
1 => Ok(rows.swap_remove(0)),
- _ => Err(StoreError::TooManyEmailValidations),
+ _ => Err(StoreError::TooManyValidations),
}
}
- async fn write_email_validation(&self, ev: &EmailValidation) -> Result<Uuid, StoreError> {
+ async fn write_email_validation(&self, ev: &EmailValidation) -> anyhow::Result<Uuid> {
let sqls = get_sqls(&self.sqls_root, WRITE_EMAIL_VALIDATION);
- let identity_id = self
- .read_identity_raw_id(
- &ev.identity_id
- .ok_or(StoreError::IdentityIdMustExistInvariant)?,
- )
- .await?;
let email_id = self.read_email_raw_id(&ev.email_address).await?;
-
- let new_id = Uuid::new_v4();
+ let validation_id = ev.id.unwrap_or(Uuid::new_v4());
sqlx::query(&sqls[0])
- .bind(ev.id.unwrap_or(new_id))
- .bind(identity_id)
+ .bind(validation_id)
.bind(email_id)
- .bind(ev.attempts)
.bind(&ev.code)
- .bind(ev.is_validated)
+ .bind(ev.is_oauth_derived)
.bind(ev.created_at)
- .bind(ev.expires_at)
+ .bind(ev.validated_at)
+ .bind(ev.expired_at)
.execute(&self.pool)
.await
.map_err(util::log_err_sqlx)?;
- Ok(new_id)
+ if ev.identity_id.is_some() || ev.revoked_at.is_some() || ev.deleted_at.is_some() {
+ sqlx::query(&sqls[1])
+ .bind(ev.identity_id.as_ref())
+ .bind(validation_id)
+ .bind(ev.revoked_at)
+ .bind(ev.deleted_at)
+ .execute(&self.pool)
+ .await
+ .map_err(util::log_err_sqlx)?;
+ }
+
+ Ok(validation_id)
}
async fn find_identity(
&self,
id: Option<&Uuid>,
email: Option<&str>,
- ) -> Result<Option<Identity>, StoreError> {
+ ) -> anyhow::Result<Option<Identity>> {
let sqls = get_sqls(&self.sqls_root, FIND_IDENTITY);
Ok(
match sqlx::query_as::<_, Identity>(&sqls[0])
.bind(id)
.bind(email)
- .fetch_one(&self.pool)
+ .fetch_all(&self.pool)
.await
{
- Ok(i) => Some(i),
+ Ok(mut is) => match is.len() {
+ // if only 1 found, then that's fine
+ // if multiple are fond, then if they all have the same id, that's okay
+ 1 => {
+ let i = is.swap_remove(0);
+ match i.deleted_at {
+ Some(t) if t > OffsetDateTime::now_utc() => Some(i),
+ None => Some(i),
+ _ => None,
+ }
+ }
+ 0 => None,
+ _ => {
+ match is
+ .iter()
+ .filter(|&i| i.id != is[0].id)
+ .collect::<Vec<&Identity>>()
+ .len()
+ {
+ 0 => Some(is.swap_remove(0)),
+ _ => bail!(StoreError::TooManyIdentitiesFound),
+ }
+ }
+ },
Err(sqlx::Error::RowNotFound) => None,
- Err(e) => return Err(StoreError::SqlxError(e)),
+ Err(e) => bail!(StoreError::SqlxError(e)),
},
)
}
+
async fn find_identity_by_code(&self, code: &str) -> Result<Identity, StoreError> {
let sqls = get_sqls(&self.sqls_root, FIND_IDENTITY_BY_CODE);
@@ -250,14 +265,16 @@ where
Ok(())
}
async fn read_identity(&self, id: &Uuid) -> Result<Identity, StoreError> {
- Ok(sqlx::query_as::<_, Identity>(
+ let identity = sqlx::query_as::<_, Identity>(
"
select identity_public_id, data, created_at from identity where identity_public_id = ?",
)
.bind(id)
.fetch_one(&self.pool)
.await
- .map_err(util::log_err_sqlx)?)
+ .map_err(util::log_err_sqlx)?;
+
+ Ok(identity)
}
async fn write_session(&self, session: &Session) -> Result<(), StoreError> {
@@ -269,7 +286,6 @@ select identity_public_id, data, created_at from identity where identity_public_
.bind(&session.identity_id)
.bind(secret_hash.as_ref())
.bind(session.created_at)
- .bind(OffsetDateTime::now_utc())
.bind(session.expires_at)
.bind(session.revoked_at)
.execute(&self.pool)
@@ -296,6 +312,142 @@ select identity_public_id, data, created_at from identity where identity_public_
Ok(session)
}
+
+ async fn write_oauth_provider(&self, provider: &OauthProvider) -> Result<(), StoreError> {
+ let sqls = get_sqls(&self.sqls_root, WRITE_OAUTH_PROVIDER);
+ sqlx::query(&sqls[0])
+ .bind(&provider.name.to_string())
+ .bind(&provider.flow)
+ .bind(&provider.base_url.to_string())
+ .bind(&provider.response.to_string())
+ .bind(&provider.default_scope)
+ .bind(&provider.client_id)
+ // TODO: encrypt secret before writing
+ .bind(&provider.client_secret)
+ .bind(&provider.redirect_url.to_string())
+ .bind(provider.created_at)
+ .bind(provider.deleted_at)
+ .execute(&self.pool)
+ .await
+ .map_err(util::log_err_sqlx)?;
+ Ok(())
+ }
+
+ async fn read_oauth_provider(
+ &self,
+ provider: &OauthProviderName,
+ flow: Option<String>,
+ ) -> Result<OauthProvider, StoreError> {
+ let sqls = get_sqls(&self.sqls_root, READ_OAUTH_PROVIDER);
+ let flow = flow.unwrap_or("default".into());
+ debug!("provider: {:?}, flow: {:?}", provider, flow);
+ // TODO: Write the generic FromRow impl for OauthProvider...
+ let res = sqlx::query_as::<
+ _,
+ (
+ String,
+ String,
+ String,
+ String,
+ String,
+ String,
+ String,
+ OffsetDateTime,
+ Option<OffsetDateTime>,
+ ),
+ >(&sqls[0])
+ .bind(&provider.to_string())
+ .bind(&flow)
+ .fetch_one(&self.pool)
+ .await
+ .map_err(util::log_err_sqlx)?;
+
+ debug!("res: {:?}", res);
+
+ Ok(OauthProvider {
+ name: provider.clone(),
+ flow: Some(res.0),
+ base_url: Url::from_str(&res.1)
+ .map_err(|_| StoreError::OauthProviderDoesNotExist(*provider))?,
+ response: OauthResponseType::from_str(&res.2)
+ .map_err(|_| StoreError::OauthProviderDoesNotExist(*provider))?,
+ default_scope: res.3,
+ client_id: res.4,
+ client_secret: res.5,
+ redirect_url: Url::from_str(&res.6)
+ .map_err(|_| StoreError::OauthProviderDoesNotExist(*provider))?,
+ created_at: res.7,
+ deleted_at: res.8,
+ })
+ }
+ async fn write_oauth_validation(
+ &self,
+ v: &OauthValidation,
+ ) -> anyhow::Result<ValidationRequestId> {
+ let sqls = get_sqls(&self.sqls_root, WRITE_OAUTH_VALIDATION);
+
+ let validation_id = v.id.unwrap_or(Uuid::new_v4());
+ sqlx::query(&sqls[0])
+ .bind(validation_id)
+ .bind(v.oauth_provider.name.to_string())
+ .bind(v.oauth_provider.flow.clone())
+ .bind(v.access_token.clone())
+ .bind(v.raw_response.clone())
+ .bind(v.created_at)
+ .bind(v.validated_at)
+ .execute(&self.pool)
+ .await?;
+
+ if v.identity_id.is_some() || v.revoked_at.is_some() || v.deleted_at.is_some() {
+ sqlx::query(&sqls[1])
+ .bind(v.identity_id.as_ref())
+ .bind(validation_id)
+ .bind(v.revoked_at)
+ .bind(v.deleted_at)
+ .execute(&self.pool)
+ .await?;
+ }
+
+ Ok(validation_id)
+ }
+ async fn read_oauth_validation(
+ &self,
+ validation_id: &ValidationRequestId,
+ ) -> anyhow::Result<OauthValidation> {
+ let sqls = get_sqls(&self.sqls_root, READ_OAUTH_VALIDATION);
+
+ let mut es = sqlx::query_as::<_, OauthValidation>(&sqls[0])
+ .bind(validation_id)
+ .fetch_all(&self.pool)
+ .await?;
+
+ if es.len() != 1 {
+ bail!(StoreError::OauthValidationDoesNotExist(
+ validation_id.clone()
+ ));
+ }
+
+ Ok(es.swap_remove(0))
+ }
+ async fn find_validation_type(
+ &self,
+ validation_id: &ValidationRequestId,
+ ) -> anyhow::Result<ValidationType> {
+ let sqls = get_sqls(&self.sqls_root, READ_VALIDATION_TYPE);
+
+ let mut es = sqlx::query_as::<_, (String,)>(&sqls[0])
+ .bind(validation_id)
+ .fetch_all(&self.pool)
+ .await
+ .map_err(util::log_err_sqlx)?;
+
+ match es.len() {
+ 1 => Ok(ValidationType::from_str(&es.swap_remove(0).0)?),
+ _ => bail!(StoreError::Other(
+ "expected a single validation but recieved 0 or multiple validations".into()
+ )),
+ }
+ }
}
pub struct PgClient {
@@ -320,8 +472,8 @@ impl PgClient {
#[async_trait::async_trait]
impl Store for PgClient {
- async fn write_email(&self, identity_id: Uuid, email_address: &str) -> Result<(), StoreError> {
- self.sql.write_email(identity_id, email_address).await
+ async fn write_email(&self, email_address: &str) -> Result<(), StoreError> {
+ self.sql.write_email(email_address).await
}
async fn find_email_validation(
&self,
@@ -330,14 +482,14 @@ impl Store for PgClient {
) -> Result<EmailValidation, StoreError> {
self.sql.find_email_validation(validation_id, code).await
}
- async fn write_email_validation(&self, ev: &EmailValidation) -> Result<Uuid, StoreError> {
+ async fn write_email_validation(&self, ev: &EmailValidation) -> anyhow::Result<Uuid> {
self.sql.write_email_validation(ev).await
}
async fn find_identity(
&self,
identity_id: Option<&Uuid>,
email: Option<&str>,
- ) -> Result<Option<Identity>, StoreError> {
+ ) -> anyhow::Result<Option<Identity>> {
self.sql.find_identity(identity_id, email).await
}
async fn find_identity_by_code(&self, code: &str) -> Result<Identity, StoreError> {
@@ -355,6 +507,34 @@ impl Store for PgClient {
async fn read_session(&self, secret: &SessionSecret) -> Result<Session, StoreError> {
self.sql.read_session(secret).await
}
+ async fn write_oauth_provider(&self, provider: &OauthProvider) -> Result<(), StoreError> {
+ self.sql.write_oauth_provider(provider).await
+ }
+ async fn read_oauth_provider(
+ &self,
+ provider: &OauthProviderName,
+ flow: Option<String>,
+ ) -> Result<OauthProvider, StoreError> {
+ self.sql.read_oauth_provider(provider, flow).await
+ }
+ async fn write_oauth_validation(
+ &self,
+ validation: &OauthValidation,
+ ) -> anyhow::Result<ValidationRequestId> {
+ self.sql.write_oauth_validation(validation).await
+ }
+ async fn read_oauth_validation(
+ &self,
+ validation_id: &ValidationRequestId,
+ ) -> anyhow::Result<OauthValidation> {
+ self.sql.read_oauth_validation(validation_id).await
+ }
+ async fn find_validation_type(
+ &self,
+ validation_id: &ValidationRequestId,
+ ) -> anyhow::Result<ValidationType> {
+ self.sql.find_validation_type(validation_id).await
+ }
}
pub struct SqliteClient {
@@ -386,8 +566,8 @@ impl SqliteClient {
#[async_trait::async_trait]
impl Store for SqliteClient {
- async fn write_email(&self, identity_id: Uuid, email_address: &str) -> Result<(), StoreError> {
- self.sql.write_email(identity_id, email_address).await
+ async fn write_email(&self, email_address: &str) -> Result<(), StoreError> {
+ self.sql.write_email(email_address).await
}
async fn find_email_validation(
&self,
@@ -396,14 +576,14 @@ impl Store for SqliteClient {
) -> Result<EmailValidation, StoreError> {
self.sql.find_email_validation(validation_id, code).await
}
- async fn write_email_validation(&self, ev: &EmailValidation) -> Result<Uuid, StoreError> {
+ async fn write_email_validation(&self, ev: &EmailValidation) -> anyhow::Result<Uuid> {
self.sql.write_email_validation(ev).await
}
async fn find_identity(
&self,
identity_id: Option<&Uuid>,
email: Option<&str>,
- ) -> Result<Option<Identity>, StoreError> {
+ ) -> anyhow::Result<Option<Identity>> {
self.sql.find_identity(identity_id, email).await
}
async fn find_identity_by_code(&self, code: &str) -> Result<Identity, StoreError> {
@@ -421,4 +601,32 @@ impl Store for SqliteClient {
async fn read_session(&self, secret: &SessionSecret) -> Result<Session, StoreError> {
self.sql.read_session(secret).await
}
+ async fn write_oauth_provider(&self, provider: &OauthProvider) -> Result<(), StoreError> {
+ self.sql.write_oauth_provider(provider).await
+ }
+ async fn read_oauth_provider(
+ &self,
+ provider: &OauthProviderName,
+ flow: Option<String>,
+ ) -> Result<OauthProvider, StoreError> {
+ self.sql.read_oauth_provider(provider, flow).await
+ }
+ async fn write_oauth_validation(
+ &self,
+ validation: &OauthValidation,
+ ) -> anyhow::Result<ValidationRequestId> {
+ self.sql.write_oauth_validation(validation).await
+ }
+ async fn read_oauth_validation(
+ &self,
+ validation_id: &ValidationRequestId,
+ ) -> anyhow::Result<OauthValidation> {
+ self.sql.read_oauth_validation(validation_id).await
+ }
+ async fn find_validation_type(
+ &self,
+ validation_id: &ValidationRequestId,
+ ) -> anyhow::Result<ValidationType> {
+ self.sql.find_validation_type(validation_id).await
+ }
}
diff --git a/crates/secd/src/client/types.rs b/crates/secd/src/client/types.rs
new file mode 100644
index 0000000..bacade4
--- /dev/null
+++ b/crates/secd/src/client/types.rs
@@ -0,0 +1,3 @@
+pub(crate) struct Email {
+ address: String,
+}