aboutsummaryrefslogtreecommitdiff
path: root/crates/secd/src/client/mod.rs
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--crates/secd/src/client/mod.rs233
1 files changed, 223 insertions, 10 deletions
diff --git a/crates/secd/src/client/mod.rs b/crates/secd/src/client/mod.rs
index 3925657..38426ef 100644
--- a/crates/secd/src/client/mod.rs
+++ b/crates/secd/src/client/mod.rs
@@ -1,13 +1,24 @@
-pub mod email;
-pub mod sqldb;
+pub(crate) mod email;
+pub(crate) mod sqldb;
+pub(crate) mod types;
-use std::collections::HashMap;
+use std::{collections::HashMap, str::FromStr};
use super::Identity;
-use crate::{EmailValidation, Session, SessionSecret};
+use crate::{
+ EmailValidation, OauthProvider, OauthProviderName, OauthResponseType, OauthValidation, Session,
+ SessionSecret, ValidationRequestId, ValidationType,
+};
+use email_address::EmailAddress;
use lazy_static::lazy_static;
+use sqlx::{
+ database::HasValueRef, sqlite::SqliteRow, ColumnIndex, Database, Decode, FromRow, Row, Sqlite,
+ Type,
+};
use thiserror::Error;
+use time::OffsetDateTime;
+use url::Url;
use uuid::Uuid;
pub enum EmailType {
@@ -36,13 +47,15 @@ pub trait EmailMessenger {
#[derive(Error, Debug, derive_more::Display)]
pub enum StoreError {
SqlxError(#[from] sqlx::Error),
- EmailAlreadyExists,
CodeAppearsMoreThanOnce,
CodeDoesNotExist(String),
IdentityIdMustExistInvariant,
- TooManyEmailValidations,
+ TooManyValidations,
+ TooManyIdentitiesFound,
NoEmailValidationFound,
- Unknown,
+ OauthProviderDoesNotExist(OauthProviderName),
+ OauthValidationDoesNotExist(ValidationRequestId),
+ Other(String),
}
const EMAIL_TEMPLATE_DEFAULT_LOGIN: &str = "You requested a login link. Please click the following link %secd_code% to login as %secd_email_address%";
@@ -56,6 +69,7 @@ const PGSQL: &str = "pgsql";
const WRITE_IDENTITY: &str = "write_identity";
const WRITE_EMAIL_VALIDATION: &str = "write_email_validation";
const FIND_EMAIL_VALIDATION: &str = "find_email_validation";
+const READ_VALIDATION_TYPE: &str = "read_validation_type";
const WRITE_EMAIL: &str = "write_email";
@@ -69,6 +83,11 @@ const READ_EMAIL_RAW_ID: &str = "read_email_raw_id";
const WRITE_SESSION: &str = "write_session";
const READ_SESSION: &str = "read_session";
+const WRITE_OAUTH_PROVIDER: &str = "write_oauth_provider";
+const READ_OAUTH_PROVIDER: &str = "read_oauth_provider";
+const WRITE_OAUTH_VALIDATION: &str = "write_oauth_validation";
+const READ_OAUTH_VALIDATION: &str = "read_oauth_validation";
+
lazy_static! {
static ref SQLS: HashMap<&'static str, HashMap<&'static str, &'static str>> = {
let sqlite_sqls: HashMap<&'static str, &'static str> = [
@@ -116,6 +135,26 @@ lazy_static! {
FIND_EMAIL_VALIDATION,
include_str!("../../store/sqlite/sql/find_email_validation.sql"),
),
+ (
+ WRITE_OAUTH_PROVIDER,
+ include_str!("../../store/sqlite/sql/write_oauth_provider.sql"),
+ ),
+ (
+ READ_OAUTH_PROVIDER,
+ include_str!("../../store/sqlite/sql/read_oauth_provider.sql"),
+ ),
+ (
+ READ_OAUTH_VALIDATION,
+ include_str!("../../store/sqlite/sql/read_oauth_validation.sql"),
+ ),
+ (
+ WRITE_OAUTH_VALIDATION,
+ include_str!("../../store/sqlite/sql/write_oauth_validation.sql"),
+ ),
+ (
+ READ_VALIDATION_TYPE,
+ include_str!("../../store/sqlite/sql/read_validation_type.sql"),
+ ),
]
.iter()
.cloned()
@@ -166,6 +205,26 @@ lazy_static! {
FIND_EMAIL_VALIDATION,
include_str!("../../store/pg/sql/find_email_validation.sql"),
),
+ (
+ WRITE_OAUTH_PROVIDER,
+ include_str!("../../store/pg/sql/write_oauth_provider.sql"),
+ ),
+ (
+ READ_OAUTH_PROVIDER,
+ include_str!("../../store/pg/sql/read_oauth_provider.sql"),
+ ),
+ (
+ READ_OAUTH_VALIDATION,
+ include_str!("../../store/pg/sql/read_oauth_validation.sql"),
+ ),
+ (
+ WRITE_OAUTH_VALIDATION,
+ include_str!("../../store/pg/sql/write_oauth_validation.sql"),
+ ),
+ (
+ READ_VALIDATION_TYPE,
+ include_str!("../../store/pg/sql/read_validation_type.sql"),
+ ),
]
.iter()
.cloned()
@@ -180,9 +239,143 @@ lazy_static! {
};
}
+impl<'a, R: Row> FromRow<'a, R> for OauthValidation
+where
+ &'a str: ColumnIndex<R>,
+ OauthProviderName: Decode<'a, R::Database> + Type<R::Database>,
+ OauthResponseType: Decode<'a, R::Database> + Type<R::Database>,
+ OffsetDateTime: Decode<'a, R::Database> + Type<R::Database>,
+ String: Decode<'a, R::Database> + Type<R::Database>,
+ Uuid: Decode<'a, R::Database> + Type<R::Database>,
+{
+ fn from_row(row: &'a R) -> Result<Self, sqlx::Error> {
+ let id: Option<Uuid> = row.try_get("oauth_validation_public_id")?;
+ let identity_id: Option<Uuid> = row.try_get("identity_public_id")?;
+ let access_token: Option<String> = row.try_get("access_token")?;
+ let raw_response: Option<String> = row.try_get("raw_response")?;
+ let created_at: Option<OffsetDateTime> = row.try_get("created_at")?;
+ let validated_at: Option<OffsetDateTime> = row.try_get("validated_at")?;
+ let revoked_at: Option<OffsetDateTime> = row.try_get("revoked_at")?;
+ let deleted_at: Option<OffsetDateTime> = row.try_get("deleted_at")?;
+
+ let op_name: Option<OauthProviderName> = row.try_get("oauth_provider_name")?;
+ let op_flow: Option<String> = row.try_get("oauth_provider_flow")?;
+ let op_base_url: Option<String> = row.try_get("oauth_provider_base_url")?;
+ let op_response_type: Option<OauthResponseType> =
+ row.try_get("oauth_provider_response_type")?;
+ let op_default_scope: Option<String> = row.try_get("oauth_provider_default_scope")?;
+ let op_client_id: Option<String> = row.try_get("oauth_provider_client_id")?;
+ let op_client_secret: Option<String> = row.try_get("oauth_provider_client_secret")?;
+ let op_redirect_url: Option<String> = row.try_get("oauth_provider_redirect_url")?;
+ let op_created_at: Option<OffsetDateTime> = row.try_get("oauth_provider_created_at")?;
+ let op_deleted_at: Option<OffsetDateTime> = row.try_get("oauth_provider_deleted_at")?;
+
+ let op_base_url = op_base_url
+ .map(|s| Url::from_str(&s).ok())
+ .flatten()
+ .ok_or(sqlx::Error::ColumnDecode {
+ index: "oauth_provider_base_url".into(),
+ source: "secd".into(),
+ })?;
+
+ let op_redirect_url = op_redirect_url
+ .map(|s| Url::from_str(&s).ok())
+ .flatten()
+ .ok_or(sqlx::Error::ColumnDecode {
+ index: "oauth_provider_redirect_url".into(),
+ source: "secd".into(),
+ })?;
+
+ Ok(OauthValidation {
+ id,
+ identity_id,
+ access_token,
+ raw_response,
+ created_at: created_at.ok_or(sqlx::Error::ColumnDecode {
+ index: "created_at".into(),
+ source: "secd".into(),
+ })?,
+ validated_at,
+ revoked_at,
+ deleted_at,
+ oauth_provider: OauthProvider {
+ name: op_name.unwrap(),
+ flow: op_flow,
+ base_url: op_base_url,
+ response: op_response_type.ok_or(sqlx::Error::ColumnDecode {
+ index: "oauth_provider_response_type".into(),
+ source: "secd".into(),
+ })?,
+ default_scope: op_default_scope.ok_or(sqlx::Error::ColumnDecode {
+ index: "oauth_provider_default_scope".into(),
+ source: "secd".into(),
+ })?,
+ client_id: op_client_id.ok_or(sqlx::Error::ColumnDecode {
+ index: "oauth_provider_client_id".into(),
+ source: "secd".into(),
+ })?,
+ client_secret: op_client_secret.ok_or(sqlx::Error::ColumnDecode {
+ index: "oauth_provider_client_secret".into(),
+ source: "secd".into(),
+ })?,
+ redirect_url: op_redirect_url,
+ created_at: op_created_at.ok_or(sqlx::Error::ColumnDecode {
+ index: "oauth_provider_created_at".into(),
+ source: "secd".into(),
+ })?,
+ deleted_at: op_deleted_at,
+ },
+ })
+ }
+}
+
+impl<'a, D: Database> Decode<'a, D> for OauthProviderName
+where
+ &'a str: Decode<'a, D>,
+{
+ fn decode(
+ value: <D as HasValueRef<'a>>::ValueRef,
+ ) -> Result<Self, Box<dyn ::std::error::Error + 'static + Send + Sync>> {
+ let v = <&str as Decode<D>>::decode(value)?;
+ <OauthProviderName as clap::ValueEnum>::from_str(v, true)
+ .map_err(|_| "OauthProviderName should exist and decode to a program value.".into())
+ }
+}
+
+impl<D: Database> Type<D> for OauthProviderName
+where
+ str: Type<D>,
+{
+ fn type_info() -> D::TypeInfo {
+ <&str as Type<D>>::type_info()
+ }
+}
+
+impl<'a, D: Database> Decode<'a, D> for OauthResponseType
+where
+ &'a str: Decode<'a, D>,
+{
+ fn decode(
+ value: <D as HasValueRef<'a>>::ValueRef,
+ ) -> Result<Self, Box<dyn ::std::error::Error + 'static + Send + Sync>> {
+ let v = <&str as Decode<D>>::decode(value)?;
+ <OauthResponseType as clap::ValueEnum>::from_str(v, true)
+ .map_err(|_| "OauthResponseType should exist and decode to a program value.".into())
+ }
+}
+
+impl<D: Database> Type<D> for OauthResponseType
+where
+ str: Type<D>,
+{
+ fn type_info() -> D::TypeInfo {
+ <&str as Type<D>>::type_info()
+ }
+}
+
#[async_trait::async_trait]
pub trait Store {
- async fn write_email(&self, identity_id: Uuid, email_address: &str) -> Result<(), StoreError>;
+ async fn write_email(&self, email_address: &str) -> Result<(), StoreError>;
async fn find_email_validation(
&self,
@@ -193,17 +386,37 @@ pub trait Store {
&self,
ev: &EmailValidation,
// TODO: Make this write an EmailValidation
- ) -> Result<Uuid, StoreError>;
+ ) -> anyhow::Result<Uuid>;
async fn find_identity(
&self,
identity_id: Option<&Uuid>,
email: Option<&str>,
- ) -> Result<Option<Identity>, StoreError>;
+ ) -> anyhow::Result<Option<Identity>>;
async fn find_identity_by_code(&self, code: &str) -> Result<Identity, StoreError>;
async fn write_identity(&self, i: &Identity) -> Result<(), StoreError>;
async fn read_identity(&self, identity_id: &Uuid) -> Result<Identity, StoreError>;
async fn write_session(&self, session: &Session) -> Result<(), StoreError>;
async fn read_session(&self, secret: &SessionSecret) -> Result<Session, StoreError>;
+
+ async fn write_oauth_provider(&self, provider: &OauthProvider) -> Result<(), StoreError>;
+ async fn read_oauth_provider(
+ &self,
+ provider: &OauthProviderName,
+ flow: Option<String>,
+ ) -> Result<OauthProvider, StoreError>;
+ async fn write_oauth_validation(
+ &self,
+ validation: &OauthValidation,
+ ) -> anyhow::Result<ValidationRequestId>;
+ async fn read_oauth_validation(
+ &self,
+ validation_id: &ValidationRequestId,
+ ) -> anyhow::Result<OauthValidation>;
+
+ async fn find_validation_type(
+ &self,
+ validation_id: &ValidationRequestId,
+ ) -> anyhow::Result<ValidationType>;
}