pub(crate) mod email; pub(crate) mod sqldb; pub(crate) mod types; use std::{collections::HashMap, str::FromStr}; use super::Identity; use crate::{ EmailValidation, OauthProvider, OauthProviderName, OauthResponseType, OauthValidation, Session, SessionSecret, ValidationRequestId, ValidationType, }; use email_address::EmailAddress; use lazy_static::lazy_static; use sqlx::{ database::HasValueRef, sqlite::SqliteRow, ColumnIndex, Database, Decode, FromRow, Row, Sqlite, Type, }; use thiserror::Error; use time::OffsetDateTime; use url::Url; use uuid::Uuid; pub enum EmailType { Login, Signup, } #[derive(Error, Debug, derive_more::Display)] pub enum EmailMessengerError { InvalidEmailAddress, FailedToSendEmail, Unknown, } #[async_trait::async_trait] pub trait EmailMessenger { async fn send_email( &self, email_address: &str, validation_id: &str, secret_code: &str, t: EmailType, ) -> Result<(), EmailMessengerError>; } #[derive(Error, Debug, derive_more::Display)] pub enum StoreError { SqlxError(#[from] sqlx::Error), CodeAppearsMoreThanOnce, CodeDoesNotExist(String), IdentityIdMustExistInvariant, TooManyValidations, TooManyIdentitiesFound, NoEmailValidationFound, OauthProviderDoesNotExist(OauthProviderName), OauthValidationDoesNotExist(ValidationRequestId), Other(String), } const EMAIL_TEMPLATE_DEFAULT_LOGIN: &str = "You requested a login link. Please click the following link %secd_code% to login as %secd_email_address%"; const EMAIL_TEMPLATE_DEFAULT_SIGNUP: &str = "You requested a sign up. Please click the following link %secd_code% to complete your sign up and validate %secd_email_address%"; const ERR_MSG_MIGRATION_FAILED: &str = "Failed to execute migrations. This appears to be a secd issue. File a bug at https://www.github.com/secd-lib"; const SQLITE: &str = "sqlite"; const PGSQL: &str = "pgsql"; const WRITE_IDENTITY: &str = "write_identity"; const WRITE_EMAIL_VALIDATION: &str = "write_email_validation"; const FIND_EMAIL_VALIDATION: &str = "find_email_validation"; const READ_VALIDATION_TYPE: &str = "read_validation_type"; const WRITE_EMAIL: &str = "write_email"; const READ_IDENTITY: &str = "read_identity"; const FIND_IDENTITY: &str = "find_identity"; const FIND_IDENTITY_BY_CODE: &str = "find_identity_by_code"; const READ_IDENTITY_RAW_ID: &str = "read_identity_raw_id"; const READ_EMAIL_RAW_ID: &str = "read_email_raw_id"; const WRITE_SESSION: &str = "write_session"; const READ_SESSION: &str = "read_session"; const WRITE_OAUTH_PROVIDER: &str = "write_oauth_provider"; const READ_OAUTH_PROVIDER: &str = "read_oauth_provider"; const WRITE_OAUTH_VALIDATION: &str = "write_oauth_validation"; const READ_OAUTH_VALIDATION: &str = "read_oauth_validation"; lazy_static! { static ref SQLS: HashMap<&'static str, HashMap<&'static str, &'static str>> = { let sqlite_sqls: HashMap<&'static str, &'static str> = [ ( WRITE_IDENTITY, include_str!("../../store/sqlite/sql/write_identity.sql"), ), ( WRITE_EMAIL_VALIDATION, include_str!("../../store/sqlite/sql/write_email_validation.sql"), ), ( WRITE_EMAIL, include_str!("../../store/sqlite/sql/write_email.sql"), ), ( READ_IDENTITY, include_str!("../../store/sqlite/sql/read_identity.sql"), ), ( FIND_IDENTITY, include_str!("../../store/sqlite/sql/find_identity.sql"), ), ( FIND_IDENTITY_BY_CODE, include_str!("../../store/sqlite/sql/find_identity_by_code.sql"), ), ( READ_IDENTITY_RAW_ID, include_str!("../../store/sqlite/sql/read_identity_raw_id.sql"), ), ( READ_EMAIL_RAW_ID, include_str!("../../store/sqlite/sql/read_email_raw_id.sql"), ), ( WRITE_SESSION, include_str!("../../store/sqlite/sql/write_session.sql"), ), ( READ_SESSION, include_str!("../../store/sqlite/sql/read_session.sql"), ), ( FIND_EMAIL_VALIDATION, include_str!("../../store/sqlite/sql/find_email_validation.sql"), ), ( WRITE_OAUTH_PROVIDER, include_str!("../../store/sqlite/sql/write_oauth_provider.sql"), ), ( READ_OAUTH_PROVIDER, include_str!("../../store/sqlite/sql/read_oauth_provider.sql"), ), ( READ_OAUTH_VALIDATION, include_str!("../../store/sqlite/sql/read_oauth_validation.sql"), ), ( WRITE_OAUTH_VALIDATION, include_str!("../../store/sqlite/sql/write_oauth_validation.sql"), ), ( READ_VALIDATION_TYPE, include_str!("../../store/sqlite/sql/read_validation_type.sql"), ), ] .iter() .cloned() .collect(); let pg_sqls: HashMap<&'static str, &'static str> = [ ( WRITE_IDENTITY, include_str!("../../store/pg/sql/write_identity.sql"), ), ( WRITE_EMAIL_VALIDATION, include_str!("../../store/pg/sql/write_email_validation.sql"), ), ( WRITE_EMAIL, include_str!("../../store/pg/sql/write_email.sql"), ), ( READ_IDENTITY, include_str!("../../store/pg/sql/read_identity.sql"), ), ( FIND_IDENTITY, include_str!("../../store/pg/sql/find_identity.sql"), ), ( FIND_IDENTITY_BY_CODE, include_str!("../../store/pg/sql/find_identity_by_code.sql"), ), ( READ_IDENTITY_RAW_ID, include_str!("../../store/pg/sql/read_identity_raw_id.sql"), ), ( READ_EMAIL_RAW_ID, include_str!("../../store/pg/sql/read_email_raw_id.sql"), ), ( WRITE_SESSION, include_str!("../../store/pg/sql/write_session.sql"), ), ( READ_SESSION, include_str!("../../store/pg/sql/read_session.sql"), ), ( FIND_EMAIL_VALIDATION, include_str!("../../store/pg/sql/find_email_validation.sql"), ), ( WRITE_OAUTH_PROVIDER, include_str!("../../store/pg/sql/write_oauth_provider.sql"), ), ( READ_OAUTH_PROVIDER, include_str!("../../store/pg/sql/read_oauth_provider.sql"), ), ( READ_OAUTH_VALIDATION, include_str!("../../store/pg/sql/read_oauth_validation.sql"), ), ( WRITE_OAUTH_VALIDATION, include_str!("../../store/pg/sql/write_oauth_validation.sql"), ), ( READ_VALIDATION_TYPE, include_str!("../../store/pg/sql/read_validation_type.sql"), ), ] .iter() .cloned() .collect(); let sqls: HashMap<&'static str, HashMap<&'static str, &'static str>> = [(SQLITE, sqlite_sqls), (PGSQL, pg_sqls)] .iter() .cloned() .collect(); sqls }; } impl<'a, R: Row> FromRow<'a, R> for OauthValidation where &'a str: ColumnIndex, OauthProviderName: Decode<'a, R::Database> + Type, OauthResponseType: Decode<'a, R::Database> + Type, OffsetDateTime: Decode<'a, R::Database> + Type, String: Decode<'a, R::Database> + Type, Uuid: Decode<'a, R::Database> + Type, { fn from_row(row: &'a R) -> Result { let id: Option = row.try_get("oauth_validation_public_id")?; let identity_id: Option = row.try_get("identity_public_id")?; let access_token: Option = row.try_get("access_token")?; let raw_response: Option = row.try_get("raw_response")?; let created_at: Option = row.try_get("created_at")?; let validated_at: Option = row.try_get("validated_at")?; let revoked_at: Option = row.try_get("revoked_at")?; let deleted_at: Option = row.try_get("deleted_at")?; let op_name: Option = row.try_get("oauth_provider_name")?; let op_flow: Option = row.try_get("oauth_provider_flow")?; let op_base_url: Option = row.try_get("oauth_provider_base_url")?; let op_response_type: Option = row.try_get("oauth_provider_response_type")?; let op_default_scope: Option = row.try_get("oauth_provider_default_scope")?; let op_client_id: Option = row.try_get("oauth_provider_client_id")?; let op_client_secret: Option = row.try_get("oauth_provider_client_secret")?; let op_redirect_url: Option = row.try_get("oauth_provider_redirect_url")?; let op_created_at: Option = row.try_get("oauth_provider_created_at")?; let op_deleted_at: Option = row.try_get("oauth_provider_deleted_at")?; let op_base_url = op_base_url .map(|s| Url::from_str(&s).ok()) .flatten() .ok_or(sqlx::Error::ColumnDecode { index: "oauth_provider_base_url".into(), source: "secd".into(), })?; let op_redirect_url = op_redirect_url .map(|s| Url::from_str(&s).ok()) .flatten() .ok_or(sqlx::Error::ColumnDecode { index: "oauth_provider_redirect_url".into(), source: "secd".into(), })?; Ok(OauthValidation { id, identity_id, access_token, raw_response, created_at: created_at.ok_or(sqlx::Error::ColumnDecode { index: "created_at".into(), source: "secd".into(), })?, validated_at, revoked_at, deleted_at, oauth_provider: OauthProvider { name: op_name.unwrap(), flow: op_flow, base_url: op_base_url, response: op_response_type.ok_or(sqlx::Error::ColumnDecode { index: "oauth_provider_response_type".into(), source: "secd".into(), })?, default_scope: op_default_scope.ok_or(sqlx::Error::ColumnDecode { index: "oauth_provider_default_scope".into(), source: "secd".into(), })?, client_id: op_client_id.ok_or(sqlx::Error::ColumnDecode { index: "oauth_provider_client_id".into(), source: "secd".into(), })?, client_secret: op_client_secret.ok_or(sqlx::Error::ColumnDecode { index: "oauth_provider_client_secret".into(), source: "secd".into(), })?, redirect_url: op_redirect_url, created_at: op_created_at.ok_or(sqlx::Error::ColumnDecode { index: "oauth_provider_created_at".into(), source: "secd".into(), })?, deleted_at: op_deleted_at, }, }) } } impl<'a, D: Database> Decode<'a, D> for OauthProviderName where &'a str: Decode<'a, D>, { fn decode( value: >::ValueRef, ) -> Result> { let v = <&str as Decode>::decode(value)?; ::from_str(v, true) .map_err(|_| "OauthProviderName should exist and decode to a program value.".into()) } } impl Type for OauthProviderName where str: Type, { fn type_info() -> D::TypeInfo { <&str as Type>::type_info() } } impl<'a, D: Database> Decode<'a, D> for OauthResponseType where &'a str: Decode<'a, D>, { fn decode( value: >::ValueRef, ) -> Result> { let v = <&str as Decode>::decode(value)?; ::from_str(v, true) .map_err(|_| "OauthResponseType should exist and decode to a program value.".into()) } } impl Type for OauthResponseType where str: Type, { fn type_info() -> D::TypeInfo { <&str as Type>::type_info() } } #[async_trait::async_trait] pub trait Store { async fn write_email(&self, email_address: &str) -> Result<(), StoreError>; async fn find_email_validation( &self, validation_id: Option<&Uuid>, code: Option<&str>, ) -> Result; async fn write_email_validation( &self, ev: &EmailValidation, // TODO: Make this write an EmailValidation ) -> anyhow::Result; async fn find_identity( &self, identity_id: Option<&Uuid>, email: Option<&str>, ) -> anyhow::Result>; async fn find_identity_by_code(&self, code: &str) -> Result; async fn write_identity(&self, i: &Identity) -> Result<(), StoreError>; async fn read_identity(&self, identity_id: &Uuid) -> Result; async fn write_session(&self, session: &Session) -> Result<(), StoreError>; async fn read_session(&self, secret: &SessionSecret) -> Result; async fn write_oauth_provider(&self, provider: &OauthProvider) -> Result<(), StoreError>; async fn read_oauth_provider( &self, provider: &OauthProviderName, flow: Option, ) -> Result; async fn write_oauth_validation( &self, validation: &OauthValidation, ) -> anyhow::Result; async fn read_oauth_validation( &self, validation_id: &ValidationRequestId, ) -> anyhow::Result; async fn find_validation_type( &self, validation_id: &ValidationRequestId, ) -> anyhow::Result; }