aboutsummaryrefslogtreecommitdiff
path: root/crates/secd/src
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--crates/secd/src/client/mod.rs233
-rw-r--r--crates/secd/src/client/sqldb.rs324
-rw-r--r--crates/secd/src/client/types.rs3
-rw-r--r--crates/secd/src/command/admin.rs57
-rw-r--r--crates/secd/src/command/authn.rs230
-rw-r--r--crates/secd/src/command/mod.rs66
-rw-r--r--crates/secd/src/lib.rs390
-rw-r--r--crates/secd/src/util/mod.rs158
8 files changed, 1184 insertions, 277 deletions
diff --git a/crates/secd/src/client/mod.rs b/crates/secd/src/client/mod.rs
index 3925657..38426ef 100644
--- a/crates/secd/src/client/mod.rs
+++ b/crates/secd/src/client/mod.rs
@@ -1,13 +1,24 @@
-pub mod email;
-pub mod sqldb;
+pub(crate) mod email;
+pub(crate) mod sqldb;
+pub(crate) mod types;
-use std::collections::HashMap;
+use std::{collections::HashMap, str::FromStr};
use super::Identity;
-use crate::{EmailValidation, Session, SessionSecret};
+use crate::{
+ EmailValidation, OauthProvider, OauthProviderName, OauthResponseType, OauthValidation, Session,
+ SessionSecret, ValidationRequestId, ValidationType,
+};
+use email_address::EmailAddress;
use lazy_static::lazy_static;
+use sqlx::{
+ database::HasValueRef, sqlite::SqliteRow, ColumnIndex, Database, Decode, FromRow, Row, Sqlite,
+ Type,
+};
use thiserror::Error;
+use time::OffsetDateTime;
+use url::Url;
use uuid::Uuid;
pub enum EmailType {
@@ -36,13 +47,15 @@ pub trait EmailMessenger {
#[derive(Error, Debug, derive_more::Display)]
pub enum StoreError {
SqlxError(#[from] sqlx::Error),
- EmailAlreadyExists,
CodeAppearsMoreThanOnce,
CodeDoesNotExist(String),
IdentityIdMustExistInvariant,
- TooManyEmailValidations,
+ TooManyValidations,
+ TooManyIdentitiesFound,
NoEmailValidationFound,
- Unknown,
+ OauthProviderDoesNotExist(OauthProviderName),
+ OauthValidationDoesNotExist(ValidationRequestId),
+ Other(String),
}
const EMAIL_TEMPLATE_DEFAULT_LOGIN: &str = "You requested a login link. Please click the following link %secd_code% to login as %secd_email_address%";
@@ -56,6 +69,7 @@ const PGSQL: &str = "pgsql";
const WRITE_IDENTITY: &str = "write_identity";
const WRITE_EMAIL_VALIDATION: &str = "write_email_validation";
const FIND_EMAIL_VALIDATION: &str = "find_email_validation";
+const READ_VALIDATION_TYPE: &str = "read_validation_type";
const WRITE_EMAIL: &str = "write_email";
@@ -69,6 +83,11 @@ const READ_EMAIL_RAW_ID: &str = "read_email_raw_id";
const WRITE_SESSION: &str = "write_session";
const READ_SESSION: &str = "read_session";
+const WRITE_OAUTH_PROVIDER: &str = "write_oauth_provider";
+const READ_OAUTH_PROVIDER: &str = "read_oauth_provider";
+const WRITE_OAUTH_VALIDATION: &str = "write_oauth_validation";
+const READ_OAUTH_VALIDATION: &str = "read_oauth_validation";
+
lazy_static! {
static ref SQLS: HashMap<&'static str, HashMap<&'static str, &'static str>> = {
let sqlite_sqls: HashMap<&'static str, &'static str> = [
@@ -116,6 +135,26 @@ lazy_static! {
FIND_EMAIL_VALIDATION,
include_str!("../../store/sqlite/sql/find_email_validation.sql"),
),
+ (
+ WRITE_OAUTH_PROVIDER,
+ include_str!("../../store/sqlite/sql/write_oauth_provider.sql"),
+ ),
+ (
+ READ_OAUTH_PROVIDER,
+ include_str!("../../store/sqlite/sql/read_oauth_provider.sql"),
+ ),
+ (
+ READ_OAUTH_VALIDATION,
+ include_str!("../../store/sqlite/sql/read_oauth_validation.sql"),
+ ),
+ (
+ WRITE_OAUTH_VALIDATION,
+ include_str!("../../store/sqlite/sql/write_oauth_validation.sql"),
+ ),
+ (
+ READ_VALIDATION_TYPE,
+ include_str!("../../store/sqlite/sql/read_validation_type.sql"),
+ ),
]
.iter()
.cloned()
@@ -166,6 +205,26 @@ lazy_static! {
FIND_EMAIL_VALIDATION,
include_str!("../../store/pg/sql/find_email_validation.sql"),
),
+ (
+ WRITE_OAUTH_PROVIDER,
+ include_str!("../../store/pg/sql/write_oauth_provider.sql"),
+ ),
+ (
+ READ_OAUTH_PROVIDER,
+ include_str!("../../store/pg/sql/read_oauth_provider.sql"),
+ ),
+ (
+ READ_OAUTH_VALIDATION,
+ include_str!("../../store/pg/sql/read_oauth_validation.sql"),
+ ),
+ (
+ WRITE_OAUTH_VALIDATION,
+ include_str!("../../store/pg/sql/write_oauth_validation.sql"),
+ ),
+ (
+ READ_VALIDATION_TYPE,
+ include_str!("../../store/pg/sql/read_validation_type.sql"),
+ ),
]
.iter()
.cloned()
@@ -180,9 +239,143 @@ lazy_static! {
};
}
+impl<'a, R: Row> FromRow<'a, R> for OauthValidation
+where
+ &'a str: ColumnIndex<R>,
+ OauthProviderName: Decode<'a, R::Database> + Type<R::Database>,
+ OauthResponseType: Decode<'a, R::Database> + Type<R::Database>,
+ OffsetDateTime: Decode<'a, R::Database> + Type<R::Database>,
+ String: Decode<'a, R::Database> + Type<R::Database>,
+ Uuid: Decode<'a, R::Database> + Type<R::Database>,
+{
+ fn from_row(row: &'a R) -> Result<Self, sqlx::Error> {
+ let id: Option<Uuid> = row.try_get("oauth_validation_public_id")?;
+ let identity_id: Option<Uuid> = row.try_get("identity_public_id")?;
+ let access_token: Option<String> = row.try_get("access_token")?;
+ let raw_response: Option<String> = row.try_get("raw_response")?;
+ let created_at: Option<OffsetDateTime> = row.try_get("created_at")?;
+ let validated_at: Option<OffsetDateTime> = row.try_get("validated_at")?;
+ let revoked_at: Option<OffsetDateTime> = row.try_get("revoked_at")?;
+ let deleted_at: Option<OffsetDateTime> = row.try_get("deleted_at")?;
+
+ let op_name: Option<OauthProviderName> = row.try_get("oauth_provider_name")?;
+ let op_flow: Option<String> = row.try_get("oauth_provider_flow")?;
+ let op_base_url: Option<String> = row.try_get("oauth_provider_base_url")?;
+ let op_response_type: Option<OauthResponseType> =
+ row.try_get("oauth_provider_response_type")?;
+ let op_default_scope: Option<String> = row.try_get("oauth_provider_default_scope")?;
+ let op_client_id: Option<String> = row.try_get("oauth_provider_client_id")?;
+ let op_client_secret: Option<String> = row.try_get("oauth_provider_client_secret")?;
+ let op_redirect_url: Option<String> = row.try_get("oauth_provider_redirect_url")?;
+ let op_created_at: Option<OffsetDateTime> = row.try_get("oauth_provider_created_at")?;
+ let op_deleted_at: Option<OffsetDateTime> = row.try_get("oauth_provider_deleted_at")?;
+
+ let op_base_url = op_base_url
+ .map(|s| Url::from_str(&s).ok())
+ .flatten()
+ .ok_or(sqlx::Error::ColumnDecode {
+ index: "oauth_provider_base_url".into(),
+ source: "secd".into(),
+ })?;
+
+ let op_redirect_url = op_redirect_url
+ .map(|s| Url::from_str(&s).ok())
+ .flatten()
+ .ok_or(sqlx::Error::ColumnDecode {
+ index: "oauth_provider_redirect_url".into(),
+ source: "secd".into(),
+ })?;
+
+ Ok(OauthValidation {
+ id,
+ identity_id,
+ access_token,
+ raw_response,
+ created_at: created_at.ok_or(sqlx::Error::ColumnDecode {
+ index: "created_at".into(),
+ source: "secd".into(),
+ })?,
+ validated_at,
+ revoked_at,
+ deleted_at,
+ oauth_provider: OauthProvider {
+ name: op_name.unwrap(),
+ flow: op_flow,
+ base_url: op_base_url,
+ response: op_response_type.ok_or(sqlx::Error::ColumnDecode {
+ index: "oauth_provider_response_type".into(),
+ source: "secd".into(),
+ })?,
+ default_scope: op_default_scope.ok_or(sqlx::Error::ColumnDecode {
+ index: "oauth_provider_default_scope".into(),
+ source: "secd".into(),
+ })?,
+ client_id: op_client_id.ok_or(sqlx::Error::ColumnDecode {
+ index: "oauth_provider_client_id".into(),
+ source: "secd".into(),
+ })?,
+ client_secret: op_client_secret.ok_or(sqlx::Error::ColumnDecode {
+ index: "oauth_provider_client_secret".into(),
+ source: "secd".into(),
+ })?,
+ redirect_url: op_redirect_url,
+ created_at: op_created_at.ok_or(sqlx::Error::ColumnDecode {
+ index: "oauth_provider_created_at".into(),
+ source: "secd".into(),
+ })?,
+ deleted_at: op_deleted_at,
+ },
+ })
+ }
+}
+
+impl<'a, D: Database> Decode<'a, D> for OauthProviderName
+where
+ &'a str: Decode<'a, D>,
+{
+ fn decode(
+ value: <D as HasValueRef<'a>>::ValueRef,
+ ) -> Result<Self, Box<dyn ::std::error::Error + 'static + Send + Sync>> {
+ let v = <&str as Decode<D>>::decode(value)?;
+ <OauthProviderName as clap::ValueEnum>::from_str(v, true)
+ .map_err(|_| "OauthProviderName should exist and decode to a program value.".into())
+ }
+}
+
+impl<D: Database> Type<D> for OauthProviderName
+where
+ str: Type<D>,
+{
+ fn type_info() -> D::TypeInfo {
+ <&str as Type<D>>::type_info()
+ }
+}
+
+impl<'a, D: Database> Decode<'a, D> for OauthResponseType
+where
+ &'a str: Decode<'a, D>,
+{
+ fn decode(
+ value: <D as HasValueRef<'a>>::ValueRef,
+ ) -> Result<Self, Box<dyn ::std::error::Error + 'static + Send + Sync>> {
+ let v = <&str as Decode<D>>::decode(value)?;
+ <OauthResponseType as clap::ValueEnum>::from_str(v, true)
+ .map_err(|_| "OauthResponseType should exist and decode to a program value.".into())
+ }
+}
+
+impl<D: Database> Type<D> for OauthResponseType
+where
+ str: Type<D>,
+{
+ fn type_info() -> D::TypeInfo {
+ <&str as Type<D>>::type_info()
+ }
+}
+
#[async_trait::async_trait]
pub trait Store {
- async fn write_email(&self, identity_id: Uuid, email_address: &str) -> Result<(), StoreError>;
+ async fn write_email(&self, email_address: &str) -> Result<(), StoreError>;
async fn find_email_validation(
&self,
@@ -193,17 +386,37 @@ pub trait Store {
&self,
ev: &EmailValidation,
// TODO: Make this write an EmailValidation
- ) -> Result<Uuid, StoreError>;
+ ) -> anyhow::Result<Uuid>;
async fn find_identity(
&self,
identity_id: Option<&Uuid>,
email: Option<&str>,
- ) -> Result<Option<Identity>, StoreError>;
+ ) -> anyhow::Result<Option<Identity>>;
async fn find_identity_by_code(&self, code: &str) -> Result<Identity, StoreError>;
async fn write_identity(&self, i: &Identity) -> Result<(), StoreError>;
async fn read_identity(&self, identity_id: &Uuid) -> Result<Identity, StoreError>;
async fn write_session(&self, session: &Session) -> Result<(), StoreError>;
async fn read_session(&self, secret: &SessionSecret) -> Result<Session, StoreError>;
+
+ async fn write_oauth_provider(&self, provider: &OauthProvider) -> Result<(), StoreError>;
+ async fn read_oauth_provider(
+ &self,
+ provider: &OauthProviderName,
+ flow: Option<String>,
+ ) -> Result<OauthProvider, StoreError>;
+ async fn write_oauth_validation(
+ &self,
+ validation: &OauthValidation,
+ ) -> anyhow::Result<ValidationRequestId>;
+ async fn read_oauth_validation(
+ &self,
+ validation_id: &ValidationRequestId,
+ ) -> anyhow::Result<OauthValidation>;
+
+ async fn find_validation_type(
+ &self,
+ validation_id: &ValidationRequestId,
+ ) -> anyhow::Result<ValidationType>;
}
diff --git a/crates/secd/src/client/sqldb.rs b/crates/secd/src/client/sqldb.rs
index 6048c48..15cc4b5 100644
--- a/crates/secd/src/client/sqldb.rs
+++ b/crates/secd/src/client/sqldb.rs
@@ -1,19 +1,23 @@
-use std::sync::Arc;
+use std::{str::FromStr, sync::Arc};
use super::{
- EmailValidation, Identity, Session, SessionSecret, Store, StoreError, ERR_MSG_MIGRATION_FAILED,
- FIND_EMAIL_VALIDATION, FIND_IDENTITY, FIND_IDENTITY_BY_CODE, PGSQL, READ_EMAIL_RAW_ID,
- READ_IDENTITY_RAW_ID, READ_SESSION, SQLITE, SQLS, WRITE_EMAIL, WRITE_EMAIL_VALIDATION,
- WRITE_IDENTITY, WRITE_SESSION,
+ EmailValidation, Identity, OauthProvider, OauthProviderName, OauthResponseType, Session,
+ SessionSecret, Store, StoreError, ERR_MSG_MIGRATION_FAILED, FIND_EMAIL_VALIDATION,
+ FIND_IDENTITY, FIND_IDENTITY_BY_CODE, PGSQL, READ_EMAIL_RAW_ID, READ_IDENTITY_RAW_ID,
+ READ_OAUTH_PROVIDER, READ_OAUTH_VALIDATION, READ_SESSION, READ_VALIDATION_TYPE, SQLITE, SQLS,
+ WRITE_EMAIL, WRITE_EMAIL_VALIDATION, WRITE_IDENTITY, WRITE_OAUTH_PROVIDER,
+ WRITE_OAUTH_VALIDATION, WRITE_SESSION,
};
-use crate::util;
-use log::error;
+use crate::{util, OauthValidation, ValidationRequestId, ValidationType};
+use anyhow::bail;
+use log::{debug, error};
use openssl::sha::Sha256;
use sqlx::{
self, database::HasArguments, ColumnIndex, Database, Decode, Encode, Executor, IntoArguments,
Pool, Postgres, Sqlite, Transaction, Type,
};
use time::OffsetDateTime;
+use url::Url;
use uuid::Uuid;
fn get_sqls(root: &str, file: &str) -> Vec<String> {
@@ -97,6 +101,8 @@ where
for<'c> String: Encode<'c, D> + Type<D>,
for<'c> Option<String>: Decode<'c, D> + Type<D>,
for<'c> Option<String>: Encode<'c, D> + Type<D>,
+ for<'c> OauthProviderName: Decode<'c, D> + Type<D>,
+ for<'c> OauthResponseType: Decode<'c, D> + Type<D>,
for<'c> usize: ColumnIndex<<D as Database>::Row>,
for<'c> Uuid: Decode<'c, D> + Type<D>,
for<'c> Uuid: Encode<'c, D> + Type<D>,
@@ -108,29 +114,11 @@ where
for<'c> &'c Pool<D>: Executor<'c, Database = D>,
for<'c> &'c mut Transaction<'c, D>: Executor<'c, Database = D>,
{
- async fn write_email(&self, identity_id: Uuid, email_address: &str) -> Result<(), StoreError> {
+ async fn write_email(&self, email_address: &str) -> Result<(), StoreError> {
let sqls = get_sqls(&self.sqls_root, WRITE_EMAIL);
- let identity_id = self.read_identity_raw_id(&identity_id).await?;
-
- let email_id: (i64,) = match sqlx::query_as(&sqls[0])
+ sqlx::query(&sqls[0])
.bind(email_address)
- .fetch_one(&self.pool)
- .await
- {
- Ok(i) => i,
- Err(sqlx::Error::RowNotFound) => sqlx::query_as::<_, (i64,)>(&sqls[1])
- .bind(email_address)
- .fetch_one(&self.pool)
- .await
- .map_err(util::log_err_sqlx)?,
- Err(e) => return Err(StoreError::SqlxError(e)),
- };
-
- sqlx::query(&sqls[2])
- .bind(identity_id)
- .bind(email_id.0)
- .bind(OffsetDateTime::now_utc())
.execute(&self.pool)
.await
.map_err(util::log_err_sqlx)?;
@@ -154,57 +142,84 @@ where
match rows.len() {
0 => Err(StoreError::NoEmailValidationFound),
1 => Ok(rows.swap_remove(0)),
- _ => Err(StoreError::TooManyEmailValidations),
+ _ => Err(StoreError::TooManyValidations),
}
}
- async fn write_email_validation(&self, ev: &EmailValidation) -> Result<Uuid, StoreError> {
+ async fn write_email_validation(&self, ev: &EmailValidation) -> anyhow::Result<Uuid> {
let sqls = get_sqls(&self.sqls_root, WRITE_EMAIL_VALIDATION);
- let identity_id = self
- .read_identity_raw_id(
- &ev.identity_id
- .ok_or(StoreError::IdentityIdMustExistInvariant)?,
- )
- .await?;
let email_id = self.read_email_raw_id(&ev.email_address).await?;
-
- let new_id = Uuid::new_v4();
+ let validation_id = ev.id.unwrap_or(Uuid::new_v4());
sqlx::query(&sqls[0])
- .bind(ev.id.unwrap_or(new_id))
- .bind(identity_id)
+ .bind(validation_id)
.bind(email_id)
- .bind(ev.attempts)
.bind(&ev.code)
- .bind(ev.is_validated)
+ .bind(ev.is_oauth_derived)
.bind(ev.created_at)
- .bind(ev.expires_at)
+ .bind(ev.validated_at)
+ .bind(ev.expired_at)
.execute(&self.pool)
.await
.map_err(util::log_err_sqlx)?;
- Ok(new_id)
+ if ev.identity_id.is_some() || ev.revoked_at.is_some() || ev.deleted_at.is_some() {
+ sqlx::query(&sqls[1])
+ .bind(ev.identity_id.as_ref())
+ .bind(validation_id)
+ .bind(ev.revoked_at)
+ .bind(ev.deleted_at)
+ .execute(&self.pool)
+ .await
+ .map_err(util::log_err_sqlx)?;
+ }
+
+ Ok(validation_id)
}
async fn find_identity(
&self,
id: Option<&Uuid>,
email: Option<&str>,
- ) -> Result<Option<Identity>, StoreError> {
+ ) -> anyhow::Result<Option<Identity>> {
let sqls = get_sqls(&self.sqls_root, FIND_IDENTITY);
Ok(
match sqlx::query_as::<_, Identity>(&sqls[0])
.bind(id)
.bind(email)
- .fetch_one(&self.pool)
+ .fetch_all(&self.pool)
.await
{
- Ok(i) => Some(i),
+ Ok(mut is) => match is.len() {
+ // if only 1 found, then that's fine
+ // if multiple are fond, then if they all have the same id, that's okay
+ 1 => {
+ let i = is.swap_remove(0);
+ match i.deleted_at {
+ Some(t) if t > OffsetDateTime::now_utc() => Some(i),
+ None => Some(i),
+ _ => None,
+ }
+ }
+ 0 => None,
+ _ => {
+ match is
+ .iter()
+ .filter(|&i| i.id != is[0].id)
+ .collect::<Vec<&Identity>>()
+ .len()
+ {
+ 0 => Some(is.swap_remove(0)),
+ _ => bail!(StoreError::TooManyIdentitiesFound),
+ }
+ }
+ },
Err(sqlx::Error::RowNotFound) => None,
- Err(e) => return Err(StoreError::SqlxError(e)),
+ Err(e) => bail!(StoreError::SqlxError(e)),
},
)
}
+
async fn find_identity_by_code(&self, code: &str) -> Result<Identity, StoreError> {
let sqls = get_sqls(&self.sqls_root, FIND_IDENTITY_BY_CODE);
@@ -250,14 +265,16 @@ where
Ok(())
}
async fn read_identity(&self, id: &Uuid) -> Result<Identity, StoreError> {
- Ok(sqlx::query_as::<_, Identity>(
+ let identity = sqlx::query_as::<_, Identity>(
"
select identity_public_id, data, created_at from identity where identity_public_id = ?",
)
.bind(id)
.fetch_one(&self.pool)
.await
- .map_err(util::log_err_sqlx)?)
+ .map_err(util::log_err_sqlx)?;
+
+ Ok(identity)
}
async fn write_session(&self, session: &Session) -> Result<(), StoreError> {
@@ -269,7 +286,6 @@ select identity_public_id, data, created_at from identity where identity_public_
.bind(&session.identity_id)
.bind(secret_hash.as_ref())
.bind(session.created_at)
- .bind(OffsetDateTime::now_utc())
.bind(session.expires_at)
.bind(session.revoked_at)
.execute(&self.pool)
@@ -296,6 +312,142 @@ select identity_public_id, data, created_at from identity where identity_public_
Ok(session)
}
+
+ async fn write_oauth_provider(&self, provider: &OauthProvider) -> Result<(), StoreError> {
+ let sqls = get_sqls(&self.sqls_root, WRITE_OAUTH_PROVIDER);
+ sqlx::query(&sqls[0])
+ .bind(&provider.name.to_string())
+ .bind(&provider.flow)
+ .bind(&provider.base_url.to_string())
+ .bind(&provider.response.to_string())
+ .bind(&provider.default_scope)
+ .bind(&provider.client_id)
+ // TODO: encrypt secret before writing
+ .bind(&provider.client_secret)
+ .bind(&provider.redirect_url.to_string())
+ .bind(provider.created_at)
+ .bind(provider.deleted_at)
+ .execute(&self.pool)
+ .await
+ .map_err(util::log_err_sqlx)?;
+ Ok(())
+ }
+
+ async fn read_oauth_provider(
+ &self,
+ provider: &OauthProviderName,
+ flow: Option<String>,
+ ) -> Result<OauthProvider, StoreError> {
+ let sqls = get_sqls(&self.sqls_root, READ_OAUTH_PROVIDER);
+ let flow = flow.unwrap_or("default".into());
+ debug!("provider: {:?}, flow: {:?}", provider, flow);
+ // TODO: Write the generic FromRow impl for OauthProvider...
+ let res = sqlx::query_as::<
+ _,
+ (
+ String,
+ String,
+ String,
+ String,
+ String,
+ String,
+ String,
+ OffsetDateTime,
+ Option<OffsetDateTime>,
+ ),
+ >(&sqls[0])
+ .bind(&provider.to_string())
+ .bind(&flow)
+ .fetch_one(&self.pool)
+ .await
+ .map_err(util::log_err_sqlx)?;
+
+ debug!("res: {:?}", res);
+
+ Ok(OauthProvider {
+ name: provider.clone(),
+ flow: Some(res.0),
+ base_url: Url::from_str(&res.1)
+ .map_err(|_| StoreError::OauthProviderDoesNotExist(*provider))?,
+ response: OauthResponseType::from_str(&res.2)
+ .map_err(|_| StoreError::OauthProviderDoesNotExist(*provider))?,
+ default_scope: res.3,
+ client_id: res.4,
+ client_secret: res.5,
+ redirect_url: Url::from_str(&res.6)
+ .map_err(|_| StoreError::OauthProviderDoesNotExist(*provider))?,
+ created_at: res.7,
+ deleted_at: res.8,
+ })
+ }
+ async fn write_oauth_validation(
+ &self,
+ v: &OauthValidation,
+ ) -> anyhow::Result<ValidationRequestId> {
+ let sqls = get_sqls(&self.sqls_root, WRITE_OAUTH_VALIDATION);
+
+ let validation_id = v.id.unwrap_or(Uuid::new_v4());
+ sqlx::query(&sqls[0])
+ .bind(validation_id)
+ .bind(v.oauth_provider.name.to_string())
+ .bind(v.oauth_provider.flow.clone())
+ .bind(v.access_token.clone())
+ .bind(v.raw_response.clone())
+ .bind(v.created_at)
+ .bind(v.validated_at)
+ .execute(&self.pool)
+ .await?;
+
+ if v.identity_id.is_some() || v.revoked_at.is_some() || v.deleted_at.is_some() {
+ sqlx::query(&sqls[1])
+ .bind(v.identity_id.as_ref())
+ .bind(validation_id)
+ .bind(v.revoked_at)
+ .bind(v.deleted_at)
+ .execute(&self.pool)
+ .await?;
+ }
+
+ Ok(validation_id)
+ }
+ async fn read_oauth_validation(
+ &self,
+ validation_id: &ValidationRequestId,
+ ) -> anyhow::Result<OauthValidation> {
+ let sqls = get_sqls(&self.sqls_root, READ_OAUTH_VALIDATION);
+
+ let mut es = sqlx::query_as::<_, OauthValidation>(&sqls[0])
+ .bind(validation_id)
+ .fetch_all(&self.pool)
+ .await?;
+
+ if es.len() != 1 {
+ bail!(StoreError::OauthValidationDoesNotExist(
+ validation_id.clone()
+ ));
+ }
+
+ Ok(es.swap_remove(0))
+ }
+ async fn find_validation_type(
+ &self,
+ validation_id: &ValidationRequestId,
+ ) -> anyhow::Result<ValidationType> {
+ let sqls = get_sqls(&self.sqls_root, READ_VALIDATION_TYPE);
+
+ let mut es = sqlx::query_as::<_, (String,)>(&sqls[0])
+ .bind(validation_id)
+ .fetch_all(&self.pool)
+ .await
+ .map_err(util::log_err_sqlx)?;
+
+ match es.len() {
+ 1 => Ok(ValidationType::from_str(&es.swap_remove(0).0)?),
+ _ => bail!(StoreError::Other(
+ "expected a single validation but recieved 0 or multiple validations".into()
+ )),
+ }
+ }
}
pub struct PgClient {
@@ -320,8 +472,8 @@ impl PgClient {
#[async_trait::async_trait]
impl Store for PgClient {
- async fn write_email(&self, identity_id: Uuid, email_address: &str) -> Result<(), StoreError> {
- self.sql.write_email(identity_id, email_address).await
+ async fn write_email(&self, email_address: &str) -> Result<(), StoreError> {
+ self.sql.write_email(email_address).await
}
async fn find_email_validation(
&self,
@@ -330,14 +482,14 @@ impl Store for PgClient {
) -> Result<EmailValidation, StoreError> {
self.sql.find_email_validation(validation_id, code).await
}
- async fn write_email_validation(&self, ev: &EmailValidation) -> Result<Uuid, StoreError> {
+ async fn write_email_validation(&self, ev: &EmailValidation) -> anyhow::Result<Uuid> {
self.sql.write_email_validation(ev).await
}
async fn find_identity(
&self,
identity_id: Option<&Uuid>,
email: Option<&str>,
- ) -> Result<Option<Identity>, StoreError> {
+ ) -> anyhow::Result<Option<Identity>> {
self.sql.find_identity(identity_id, email).await
}
async fn find_identity_by_code(&self, code: &str) -> Result<Identity, StoreError> {
@@ -355,6 +507,34 @@ impl Store for PgClient {
async fn read_session(&self, secret: &SessionSecret) -> Result<Session, StoreError> {
self.sql.read_session(secret).await
}
+ async fn write_oauth_provider(&self, provider: &OauthProvider) -> Result<(), StoreError> {
+ self.sql.write_oauth_provider(provider).await
+ }
+ async fn read_oauth_provider(
+ &self,
+ provider: &OauthProviderName,
+ flow: Option<String>,
+ ) -> Result<OauthProvider, StoreError> {
+ self.sql.read_oauth_provider(provider, flow).await
+ }
+ async fn write_oauth_validation(
+ &self,
+ validation: &OauthValidation,
+ ) -> anyhow::Result<ValidationRequestId> {
+ self.sql.write_oauth_validation(validation).await
+ }
+ async fn read_oauth_validation(
+ &self,
+ validation_id: &ValidationRequestId,
+ ) -> anyhow::Result<OauthValidation> {
+ self.sql.read_oauth_validation(validation_id).await
+ }
+ async fn find_validation_type(
+ &self,
+ validation_id: &ValidationRequestId,
+ ) -> anyhow::Result<ValidationType> {
+ self.sql.find_validation_type(validation_id).await
+ }
}
pub struct SqliteClient {
@@ -386,8 +566,8 @@ impl SqliteClient {
#[async_trait::async_trait]
impl Store for SqliteClient {
- async fn write_email(&self, identity_id: Uuid, email_address: &str) -> Result<(), StoreError> {
- self.sql.write_email(identity_id, email_address).await
+ async fn write_email(&self, email_address: &str) -> Result<(), StoreError> {
+ self.sql.write_email(email_address).await
}
async fn find_email_validation(
&self,
@@ -396,14 +576,14 @@ impl Store for SqliteClient {
) -> Result<EmailValidation, StoreError> {
self.sql.find_email_validation(validation_id, code).await
}
- async fn write_email_validation(&self, ev: &EmailValidation) -> Result<Uuid, StoreError> {
+ async fn write_email_validation(&self, ev: &EmailValidation) -> anyhow::Result<Uuid> {
self.sql.write_email_validation(ev).await
}
async fn find_identity(
&self,
identity_id: Option<&Uuid>,
email: Option<&str>,
- ) -> Result<Option<Identity>, StoreError> {
+ ) -> anyhow::Result<Option<Identity>> {
self.sql.find_identity(identity_id, email).await
}
async fn find_identity_by_code(&self, code: &str) -> Result<Identity, StoreError> {
@@ -421,4 +601,32 @@ impl Store for SqliteClient {
async fn read_session(&self, secret: &SessionSecret) -> Result<Session, StoreError> {
self.sql.read_session(secret).await
}
+ async fn write_oauth_provider(&self, provider: &OauthProvider) -> Result<(), StoreError> {
+ self.sql.write_oauth_provider(provider).await
+ }
+ async fn read_oauth_provider(
+ &self,
+ provider: &OauthProviderName,
+ flow: Option<String>,
+ ) -> Result<OauthProvider, StoreError> {
+ self.sql.read_oauth_provider(provider, flow).await
+ }
+ async fn write_oauth_validation(
+ &self,
+ validation: &OauthValidation,
+ ) -> anyhow::Result<ValidationRequestId> {
+ self.sql.write_oauth_validation(validation).await
+ }
+ async fn read_oauth_validation(
+ &self,
+ validation_id: &ValidationRequestId,
+ ) -> anyhow::Result<OauthValidation> {
+ self.sql.read_oauth_validation(validation_id).await
+ }
+ async fn find_validation_type(
+ &self,
+ validation_id: &ValidationRequestId,
+ ) -> anyhow::Result<ValidationType> {
+ self.sql.find_validation_type(validation_id).await
+ }
}
diff --git a/crates/secd/src/client/types.rs b/crates/secd/src/client/types.rs
new file mode 100644
index 0000000..bacade4
--- /dev/null
+++ b/crates/secd/src/client/types.rs
@@ -0,0 +1,3 @@
+pub(crate) struct Email {
+ address: String,
+}
diff --git a/crates/secd/src/command/admin.rs b/crates/secd/src/command/admin.rs
new file mode 100644
index 0000000..b04dbef
--- /dev/null
+++ b/crates/secd/src/command/admin.rs
@@ -0,0 +1,57 @@
+use std::str::FromStr;
+
+use time::OffsetDateTime;
+use url::Url;
+
+use crate::{OauthProviderName, Secd, SecdError};
+
+impl OauthProviderName {
+ fn base_url(&self) -> Url {
+ match self {
+ OauthProviderName::Google => {
+ Url::from_str("https://accounts.google.com/o/oauth2/v2/auth").unwrap()
+ }
+ OauthProviderName::Microsoft => {
+ Url::from_str("https://login.microsoftonline.com/common/oauth2/v2.0/authorize")
+ .unwrap()
+ }
+ _ => unimplemented!(),
+ }
+ }
+
+ fn default_scope(&self) -> String {
+ match self {
+ OauthProviderName::Google => "openid%20email".into(),
+ OauthProviderName::Microsoft => "openid%20email".into(),
+ _ => unimplemented!(),
+ }
+ }
+}
+
+impl Secd {
+ pub async fn create_oauth_provider(
+ &self,
+ provider: &OauthProviderName,
+ client_id: String,
+ client_secret: String,
+ redirect_url: Url,
+ ) -> Result<(), SecdError> {
+ self.store
+ .write_oauth_provider(&crate::OauthProvider {
+ name: provider.clone(),
+ flow: Some("default".into()),
+ base_url: provider.base_url(),
+ response: crate::OauthResponseType::Code,
+ default_scope: provider.default_scope(),
+ client_id,
+ client_secret,
+ redirect_url,
+ created_at: OffsetDateTime::now_utc(),
+ deleted_at: None,
+ })
+ .await
+ .map_err(|_| SecdError::Todo)?;
+
+ Ok(())
+ }
+}
diff --git a/crates/secd/src/command/authn.rs b/crates/secd/src/command/authn.rs
new file mode 100644
index 0000000..862d921
--- /dev/null
+++ b/crates/secd/src/command/authn.rs
@@ -0,0 +1,230 @@
+use email_address::EmailAddress;
+use log::debug;
+use rand::distributions::{Alphanumeric, DistString};
+use time::Duration;
+use time::OffsetDateTime;
+use uuid::Uuid;
+
+use crate::util::{build_oauth_auth_url, get_oauth_access_token};
+use crate::OauthRedirectAuthUrl;
+use crate::Validation;
+use crate::ValidationType;
+use crate::INTERNAL_ERR_MSG;
+use crate::{
+ client, util, EmailValidation, Identity, OauthProviderName, Secd, SecdError, Session,
+ ValidationRequestId, ValidationSecretCode, EMAIL_VALIDATION_DURATION, SESSION_DURATION,
+ SESSION_SIZE_BYTES, VALIDATION_CODE_SIZE,
+};
+
+impl Secd {
+ /// create_validation_request_oauth
+ ///
+ /// Generate a request to validate with the specified oauth provider.[
+ // TODO: How to handle different oauth "flows"? e.g. web app vs desktop vs mobile...
+ pub async fn create_validation_request_oauth(
+ &self,
+ provider: &OauthProviderName,
+ scope: Option<String>,
+ ) -> Result<OauthRedirectAuthUrl, SecdError> {
+ if scope.is_some() {
+ return Err(SecdError::NotImplemented(
+ "Only default scopes are currently supported.".into(),
+ ));
+ }
+
+ let p = self
+ .store
+ .read_oauth_provider(provider, None)
+ .await
+ .map_err(|_| SecdError::InternalError(INTERNAL_ERR_MSG.to_string()))?;
+
+ let req_id = self
+ .store
+ .write_oauth_validation(&crate::OauthValidation {
+ id: Some(Uuid::new_v4()),
+ identity_id: None,
+ oauth_provider: p.clone(),
+ access_token: None,
+ raw_response: None,
+ created_at: OffsetDateTime::now_utc(),
+ validated_at: None,
+ revoked_at: None,
+ deleted_at: None,
+ })
+ .await
+ .map_err(|e| util::to_secd_err(e, SecdError::OauthValidationRequestError))?;
+
+ build_oauth_auth_url(&p, req_id)
+ }
+ /// create_validation_request_email
+ ///
+ /// Generate a request to validate the provided email.
+ pub async fn create_validation_request_email(
+ &self,
+ email: Option<&str>,
+ ) -> Result<ValidationRequestId, SecdError> {
+ let now = OffsetDateTime::now_utc();
+
+ let email = match email {
+ Some(ea) => {
+ if EmailAddress::is_valid(ea) {
+ ea
+ } else {
+ return Err(SecdError::InvalidEmailAddress);
+ }
+ }
+ None => return Err(SecdError::InvalidEmailAddress),
+ };
+
+ let mut ev = EmailValidation {
+ id: None,
+ identity_id: None,
+ email_address: email.to_string(),
+ code: Some(
+ Alphanumeric
+ .sample_string(&mut rand::thread_rng(), VALIDATION_CODE_SIZE)
+ .to_lowercase(),
+ ),
+ is_oauth_derived: false,
+ created_at: now,
+ expired_at: now
+ .checked_add(Duration::new(EMAIL_VALIDATION_DURATION, 0))
+ .ok_or(SecdError::EmailValidationExpiryOverflow)?,
+ validated_at: None,
+ revoked_at: None,
+ deleted_at: None,
+ };
+
+ let (req_id, mail_type) = match self
+ .store
+ .find_identity(None, Some(email))
+ .await
+ .map_err(|e| util::log_err(e.into(), SecdError::Todo))?
+ {
+ Some(identity) => {
+ let req_id = {
+ ev.identity_id = Some(identity.id);
+ self.store
+ .write_email_validation(&ev)
+ .await
+ .map_err(|e| util::log_err(e.into(), SecdError::Todo))?
+ };
+ (req_id, client::EmailType::Login)
+ }
+ None => {
+ self.store
+ .write_email(email)
+ .await
+ .map_err(|e| util::log_err(e.into(), SecdError::Todo))?;
+
+ let req_id = {
+ self.store
+ .write_email_validation(&ev)
+ .await
+ .map_err(|e| util::log_err(e.into(), SecdError::Todo))?
+ };
+
+ (req_id, client::EmailType::Signup)
+ }
+ };
+
+ self.email_messenger
+ .send_email(email, &req_id.to_string(), &ev.code.unwrap(), mail_type)
+ .await?;
+
+ Ok(req_id)
+ }
+ /// exchange_secret_for_session
+ ///
+ /// Exchanges a secret, which consists of a validation_request_id and secret_code
+ /// for a session which allows authentication on behalf of the associated identity.
+ ///
+ /// Session secrets should be used to return authorization for the associated identity.
+ pub async fn exchange_code_for_session(
+ &self,
+ validation_request_id: ValidationRequestId,
+ code: ValidationSecretCode,
+ ) -> Result<Session, SecdError> {
+ let mut v: Box<dyn Validation> = match self
+ .store
+ .find_validation_type(&validation_request_id)
+ .await
+ .map_err(|e| util::to_secd_err(e, SecdError::Todo))?
+ {
+ ValidationType::Email => Box::new(
+ self.store
+ .find_email_validation(Some(&validation_request_id), Some(&code))
+ .await
+ .map_err(|e| {
+ util::log_err(e.into(), SecdError::EmailValidationExpiryOverflow)
+ })?,
+ ),
+ ValidationType::Oauth => Box::new({
+ let mut t = self
+ .store
+ .read_oauth_validation(&validation_request_id)
+ .await
+ .map_err(|e| util::to_secd_err(e, SecdError::Todo))?;
+
+ let access_token = get_oauth_access_token(&t, &code)
+ .await
+ .map_err(|_| SecdError::Todo)?;
+
+ t.access_token = Some(access_token);
+ t
+ }),
+ };
+
+ if v.expired() || v.is_validated() {
+ return Err(SecdError::InvalidCode);
+ };
+
+ let mut identity = Identity {
+ id: Uuid::new_v4(),
+ data: None,
+ created_at: OffsetDateTime::now_utc(),
+ deleted_at: None,
+ };
+
+ match v
+ .find_associated_identities(self.store.clone())
+ .await
+ .map_err(|e| util::to_secd_err(e, SecdError::IdentityIdShouldExistInvariant))?
+ {
+ Some(i) => identity.id = i.id,
+ _ => self.store.write_identity(&identity).await.map_err(|_| {
+ SecdError::InternalError("failed to write identity during session exchange".into())
+ })?,
+ };
+
+ v.validate(&identity, self.store.clone())
+ .await
+ .map_err(|e| {
+ util::to_secd_err(
+ e,
+ SecdError::InternalError(
+ "failed to update validation during session exchange".into(),
+ ),
+ )
+ })?;
+
+ // TODO: clear previous sessions if they fit the criteria
+ let now = OffsetDateTime::now_utc();
+ let s = Session {
+ identity_id: identity.id,
+ secret: Some(Alphanumeric.sample_string(&mut rand::thread_rng(), SESSION_SIZE_BYTES)),
+ created_at: now,
+ expires_at: now
+ .checked_add(Duration::new(SESSION_DURATION, 0))
+ .ok_or(SecdError::SessionExpiryOverflow)?,
+ revoked_at: None,
+ };
+
+ self.store
+ .write_session(&s)
+ .await
+ .map_err(|e| util::log_err(e.into(), SecdError::Todo))?;
+
+ Ok(s)
+ }
+}
diff --git a/crates/secd/src/command/mod.rs b/crates/secd/src/command/mod.rs
new file mode 100644
index 0000000..cd0d8c3
--- /dev/null
+++ b/crates/secd/src/command/mod.rs
@@ -0,0 +1,66 @@
+pub mod admin;
+pub mod authn;
+
+use crate::client::{
+ email,
+ sqldb::{PgClient, SqliteClient},
+};
+use crate::{AuthEmail, AuthStore, Secd, SecdError};
+use log::error;
+use std::sync::Arc;
+
+impl Secd {
+ /// init
+ ///
+ /// Initialize SecD with the specified configuration, established the necessary
+ /// constraints, persistance stores, and options.
+ pub async fn init(
+ auth_store: AuthStore,
+ conn_string: Option<&str>,
+ email_messenger: AuthEmail,
+ email_template_login: Option<String>,
+ email_template_signup: Option<String>,
+ ) -> Result<Self, SecdError> {
+ let store = match auth_store {
+ AuthStore::Sqlite => {
+ SqliteClient::new(
+ sqlx::sqlite::SqlitePoolOptions::new()
+ .connect(conn_string.unwrap_or("sqlite::memory:".into()))
+ .await
+ .map_err(|e| SecdError::InitializationFailure(e))?,
+ )
+ .await
+ }
+ AuthStore::Postgres => {
+ PgClient::new(
+ sqlx::postgres::PgPoolOptions::new()
+ .connect(conn_string.expect("No postgres connection string provided."))
+ .await
+ .map_err(|e| SecdError::InitializationFailure(e))?,
+ )
+ .await
+ }
+ rest @ _ => {
+ error!(
+ "requested an AuthStore which has not yet been implemented: {:?}",
+ rest
+ );
+ unimplemented!()
+ }
+ };
+
+ let email_sender = match email_messenger {
+ // TODO: initialize email and SMS templates with secd
+ AuthEmail::LocalStub => email::LocalEmailStubber {
+ email_template_login,
+ email_template_signup,
+ },
+ _ => unimplemented!(),
+ };
+
+ Ok(Secd {
+ store,
+ email_messenger: Arc::new(email_sender),
+ })
+ }
+}
diff --git a/crates/secd/src/lib.rs b/crates/secd/src/lib.rs
index 4feda04..faa92ca 100644
--- a/crates/secd/src/lib.rs
+++ b/crates/secd/src/lib.rs
@@ -1,28 +1,28 @@
mod client;
+mod command;
mod util;
use std::sync::Arc;
-use client::{
- email,
- sqldb::{PgClient, SqliteClient},
- EmailMessenger, EmailMessengerError, Store, StoreError,
-};
+use clap::ValueEnum;
+use client::{EmailMessenger, EmailMessengerError, Store};
use derive_more::Display;
use email_address::EmailAddress;
-use log::error;
-use rand::distributions::{Alphanumeric, DistString};
use serde::{Deserialize, Serialize};
+use sqlx::FromRow;
use strum_macros::{EnumString, EnumVariantNames};
-use time::{Duration, OffsetDateTime};
+use time::OffsetDateTime;
+use url::Url;
+use util::get_oauth_identity_data;
use uuid::Uuid;
const SESSION_SIZE_BYTES: usize = 32;
const SESSION_DURATION: i64 = 60 /* seconds*/ * 60 /* minutes */ * 24 /* hours */ * 360 /* days */;
const EMAIL_VALIDATION_DURATION: i64 = 60 /* seconds*/ * 15 /* minutes */;
-const VALIDATION_ATTEMPTS_MAX: i32 = 5;
const VALIDATION_CODE_SIZE: usize = 6;
+const INTERNAL_ERR_MSG: &str = "It seems an invariant was borked or something non-deterministic happened. Please file a bug with secd.";
+
#[derive(sqlx::FromRow, Debug, Serialize)]
pub struct ApiKey {
pub public_key: String,
@@ -38,9 +38,11 @@ pub struct Authorization {
pub struct Identity {
#[sqlx(rename = "identity_public_id")]
id: Uuid,
- created_at: OffsetDateTime,
#[serde(skip_serializing_if = "Option::is_none")]
data: Option<String>,
+ created_at: OffsetDateTime,
+ #[serde(skip_serializing_if = "Option::is_none")]
+ deleted_at: Option<OffsetDateTime>,
}
#[derive(sqlx::FromRow, Debug, Serialize)]
@@ -58,6 +60,121 @@ pub struct Session {
pub revoked_at: Option<OffsetDateTime>,
}
+#[async_trait::async_trait]
+trait Validation {
+ fn expired(&self) -> bool;
+ fn is_validated(&self) -> bool;
+ async fn find_associated_identities(
+ &self,
+ store: Arc<dyn Store + Send + Sync>,
+ ) -> anyhow::Result<Option<Identity>>;
+ async fn validate(
+ &mut self,
+ i: &Identity,
+ store: Arc<dyn Store + Send + Sync>,
+ ) -> anyhow::Result<()>;
+}
+
+#[async_trait::async_trait]
+impl Validation for EmailValidation {
+ fn expired(&self) -> bool {
+ let now = OffsetDateTime::now_utc();
+ self.expired_at < now
+ || self.revoked_at.map(|t| t < now).unwrap_or(false)
+ || self.deleted_at.map(|t| t < now).unwrap_or(false)
+ }
+ fn is_validated(&self) -> bool {
+ self.validated_at
+ .map(|t| t >= OffsetDateTime::now_utc())
+ .unwrap_or(false)
+ }
+ async fn find_associated_identities(
+ &self,
+ store: Arc<dyn Store + Send + Sync>,
+ ) -> anyhow::Result<Option<Identity>> {
+ store.find_identity(None, Some(&self.email_address)).await
+ }
+ async fn validate(
+ &mut self,
+ i: &Identity,
+ store: Arc<dyn Store + Send + Sync>,
+ ) -> anyhow::Result<()> {
+ self.identity_id = Some(i.id);
+ self.validated_at = Some(OffsetDateTime::now_utc());
+ store.write_email_validation(&self).await?;
+ Ok(())
+ }
+}
+
+#[async_trait::async_trait]
+impl Validation for OauthValidation {
+ fn expired(&self) -> bool {
+ let now = OffsetDateTime::now_utc();
+ self.revoked_at.map(|t| t < now).unwrap_or(false)
+ || self.deleted_at.map(|t| t < now).unwrap_or(false)
+ }
+ fn is_validated(&self) -> bool {
+ self.validated_at
+ .map(|t| t >= OffsetDateTime::now_utc())
+ .unwrap_or(false)
+ }
+ async fn find_associated_identities(
+ &self,
+ store: Arc<dyn Store + Send + Sync>,
+ ) -> anyhow::Result<Option<Identity>> {
+ let oauth_identity = get_oauth_identity_data(&self).await?;
+
+ let identity = store
+ .find_identity(None, oauth_identity.email.as_deref())
+ .await?;
+
+ let now = OffsetDateTime::now_utc();
+ if let Some(email) = oauth_identity.email.clone() {
+ let identity = identity.unwrap_or(Identity {
+ id: Uuid::new_v4(),
+ data: None,
+ created_at: OffsetDateTime::now_utc(),
+ deleted_at: None,
+ });
+ store.write_identity(&identity).await?;
+ store.write_email(&email).await?;
+ store
+ .write_email_validation(&EmailValidation {
+ id: Some(Uuid::new_v4()),
+ identity_id: Some(identity.id),
+ email_address: email,
+ code: None,
+ is_oauth_derived: true,
+ created_at: now,
+ expired_at: now,
+ validated_at: Some(now),
+ revoked_at: None,
+ deleted_at: None,
+ })
+ .await?;
+ Ok(Some(identity))
+ } else {
+ Ok(identity)
+ }
+ }
+ async fn validate(
+ &mut self,
+ i: &Identity,
+ store: Arc<dyn Store + Send + Sync>,
+ ) -> anyhow::Result<()> {
+ self.identity_id = Some(i.id);
+ self.validated_at = Some(OffsetDateTime::now_utc());
+ store.write_oauth_validation(&self).await?;
+ Ok(())
+ }
+}
+
+#[derive(Debug, EnumString)]
+pub enum ValidationType {
+ Email,
+ Oauth,
+}
+
#[derive(sqlx::FromRow, Debug)]
pub struct EmailValidation {
#[sqlx(rename = "email_validation_public_id")]
@@ -66,16 +183,53 @@ pub struct EmailValidation {
identity_id: Option<IdentityId>,
#[sqlx(rename = "address")]
email_address: String,
- attempts: i32,
- code: String,
- is_validated: bool,
+ code: Option<String>,
+ is_oauth_derived: bool,
+ created_at: OffsetDateTime,
+ expired_at: OffsetDateTime,
+ validated_at: Option<OffsetDateTime>,
+ revoked_at: Option<OffsetDateTime>,
+ deleted_at: Option<OffsetDateTime>,
+}
+
+#[derive(Debug)]
+pub struct OauthValidation {
+ id: Option<Uuid>,
+ identity_id: Option<IdentityId>,
+ oauth_provider: OauthProvider,
+ access_token: Option<String>,
+ raw_response: Option<String>,
created_at: OffsetDateTime,
- expires_at: OffsetDateTime,
+ validated_at: Option<OffsetDateTime>,
revoked_at: Option<OffsetDateTime>,
+ deleted_at: Option<OffsetDateTime>,
+}
+
+#[derive(Debug, Clone)]
+pub struct OauthProvider {
+ pub name: OauthProviderName,
+ pub flow: Option<String>,
+ pub base_url: Url,
+ pub response: OauthResponseType,
+ pub default_scope: String,
+ pub client_id: String,
+ pub client_secret: String,
+ pub redirect_url: Url,
+ pub created_at: OffsetDateTime,
+ pub deleted_at: Option<OffsetDateTime>,
+}
+
+#[derive(Debug, Display, Clone, Copy, ValueEnum, EnumString)]
+pub enum OauthResponseType {
+ Code,
+ IdToken,
+ None,
+ Token,
}
-#[derive(Copy, Display, Clone, Debug)]
-pub enum OauthProvider {
+// TODO: feature gate ValueEnum since it's only needed for iam builds
+#[derive(Copy, Display, Clone, Debug, ValueEnum, EnumString)]
+pub enum OauthProviderName {
Amazon,
Apple,
Dropbox,
@@ -121,19 +275,24 @@ pub type SessionSecret = String;
pub type SessionSecretHash = String;
pub type ValidationRequestId = Uuid;
pub type ValidationSecretCode = String;
+pub type OauthRedirectAuthUrl = Url;
#[derive(Debug, derive_more::Display, thiserror::Error)]
pub enum SecdError {
- InvalidEmailAddress,
- InvalidCode,
- InitializationFailure(sqlx::Error),
- IdentityIdShouldExistInvariant,
EmailSendError(#[from] EmailMessengerError),
- EmailValidationRequestError,
EmailValidationExpiryOverflow,
+ EmailValidationRequestError,
+ OauthValidationRequestError,
+ IdentityIdShouldExistInvariant,
+ InitializationFailure(sqlx::Error),
+ InvalidCode,
+ InvalidEmailAddress,
+ InputValidation(String),
+ InternalError(String),
+ NotImplemented(String),
SessionExpiryOverflow,
Unauthenticated,
- Unknown,
+ Todo,
}
pub struct Secd {
@@ -142,191 +301,6 @@ pub struct Secd {
}
impl Secd {
- pub async fn init(
- auth_store: AuthStore,
- conn_string: Option<&str>,
- email_messenger: AuthEmail,
- email_template_login: Option<String>,
- email_template_signup: Option<String>,
- ) -> Result<Self, SecdError> {
- let store = match auth_store {
- AuthStore::Sqlite => {
- SqliteClient::new(
- sqlx::sqlite::SqlitePoolOptions::new()
- .connect(conn_string.unwrap_or("sqlite::memory:".into()))
- .await
- .map_err(|e| SecdError::InitializationFailure(e))?,
- )
- .await
- }
- AuthStore::Postgres => {
- PgClient::new(
- sqlx::postgres::PgPoolOptions::new()
- .connect(conn_string.expect("No postgres connection string provided."))
- .await
- .map_err(|e| SecdError::InitializationFailure(e))?,
- )
- .await
- }
- rest @ _ => {
- error!(
- "requested an AuthStore which has not yet been implemented: {:?}",
- rest
- );
- unimplemented!()
- }
- };
-
- let email_sender = match email_messenger {
- // TODO: initialize email and SMS templates with secd
- AuthEmail::LocalStub => email::LocalEmailStubber {
- email_template_login,
- email_template_signup,
- },
- _ => unimplemented!(),
- };
-
- Ok(Secd {
- store,
- email_messenger: Arc::new(email_sender),
- })
- }
- /// create_validation_request
- ///
- /// Generate a request to validate the provided email.
- pub async fn create_validation_request(
- &self,
- email: Option<&str>,
- ) -> Result<ValidationRequestId, SecdError> {
- let now = OffsetDateTime::now_utc();
-
- let email = match email {
- Some(ea) => {
- if EmailAddress::is_valid(ea) {
- ea
- } else {
- return Err(SecdError::InvalidEmailAddress);
- }
- }
- None => return Err(SecdError::InvalidEmailAddress),
- };
-
- let mut ev = EmailValidation {
- id: None,
- identity_id: None,
- email_address: email.to_string(),
- attempts: 0,
- code: Alphanumeric
- .sample_string(&mut rand::thread_rng(), VALIDATION_CODE_SIZE)
- .to_lowercase(),
- is_validated: false,
- created_at: now,
- expires_at: now
- .checked_add(Duration::new(EMAIL_VALIDATION_DURATION, 0))
- .ok_or(SecdError::EmailValidationExpiryOverflow)?,
- revoked_at: None,
- };
-
- let (req_id, mail_type) = match self
- .store
- .find_identity(None, Some(email))
- .await
- .map_err(|e| util::log_err(e.into(), SecdError::Unknown))?
- {
- Some(identity) => {
- let req_id = {
- ev.identity_id = Some(identity.id);
- self.store
- .write_email_validation(&ev)
- .await
- .map_err(|e| util::log_err(e.into(), SecdError::Unknown))?
- };
- (req_id, client::EmailType::Login)
- }
- None => {
- let identity = Identity {
- id: Uuid::new_v4(),
- created_at: OffsetDateTime::now_utc(),
- data: None,
- };
- self.store
- .write_identity(&identity)
- .await
- .map_err(|e| util::log_err(e.into(), SecdError::Unknown))?;
- self.store
- .write_email(identity.id, email)
- .await
- .map_err(|e| util::log_err(e.into(), SecdError::Unknown))?;
-
- let req_id = {
- ev.identity_id = Some(identity.id);
- self.store
- .write_email_validation(&ev)
- .await
- .map_err(|e| util::log_err(e.into(), SecdError::Unknown))?
- };
-
- (req_id, client::EmailType::Signup)
- }
- };
-
- self.email_messenger
- .send_email(email, &req_id.to_string(), &ev.code, mail_type)
- .await?;
-
- Ok(req_id)
- }
- /// exchange_secret_for_session
- ///
- /// Exchanges a secret, which consists of a validation_request_id and secret_code
- /// for a session which allows authentication on behalf of the associated identity.
- ///
- /// Session secrets should be used to return authorization for the associated identity.
- pub async fn exchange_code_for_session(
- &self,
- validation_request_id: ValidationRequestId,
- code: ValidationSecretCode,
- ) -> Result<Session, SecdError> {
- let mut ev = self
- .store
- .find_email_validation(Some(&validation_request_id), Some(&code))
- .await
- .map_err(|e| util::log_err(e.into(), SecdError::EmailValidationExpiryOverflow))?;
-
- if ev.is_validated
- || ev.expires_at < OffsetDateTime::now_utc()
- || ev.attempts >= VALIDATION_ATTEMPTS_MAX
- {
- return Err(SecdError::InvalidCode);
- };
-
- ev.is_validated = true;
- ev.attempts += 1;
- self.store
- .write_email_validation(&ev)
- .await
- .map_err(|e| util::log_err(e.into(), SecdError::Unknown))?;
-
- // TODO: clear previous sessions if they fit the criteria
- let now = OffsetDateTime::now_utc();
- let s = Session {
- identity_id: ev
- .identity_id
- .ok_or(SecdError::IdentityIdShouldExistInvariant)?,
- secret: Some(Alphanumeric.sample_string(&mut rand::thread_rng(), SESSION_SIZE_BYTES)),
- created_at: now,
- expires_at: now
- .checked_add(Duration::new(SESSION_DURATION, 0))
- .ok_or(SecdError::SessionExpiryOverflow)?,
- revoked_at: None,
- };
- self.store
- .write_session(&s)
- .await
- .map_err(|e| util::log_err(e.into(), SecdError::Unknown))?;
-
- Ok(s)
- }
/// get_identity
///
/// Return all information associated with the identity id.
@@ -350,7 +324,7 @@ impl Secd {
Ok(Authorization { session })
}
Ok(_) => Err(SecdError::Unauthenticated),
- Err(_e) => Err(SecdError::Unknown),
+ Err(_e) => Err(SecdError::Todo),
}
}
/// revoke_session
diff --git a/crates/secd/src/util/mod.rs b/crates/secd/src/util/mod.rs
index da16901..bb177cb 100644
--- a/crates/secd/src/util/mod.rs
+++ b/crates/secd/src/util/mod.rs
@@ -1,13 +1,27 @@
+use std::str::FromStr;
+
+use anyhow::{bail, Context};
use log::error;
use rand::distributions::Alphanumeric;
use rand::{thread_rng, Rng};
+use reqwest::header;
+use serde::{Deserialize, Serialize};
+use url::Url;
-use crate::SecdError;
+use crate::{
+ OauthProvider, OauthProviderName, OauthValidation, SecdError, ValidationRequestId,
+ INTERNAL_ERR_MSG,
+};
pub(crate) fn log_err(e: Box<dyn std::error::Error>, new_e: SecdError) -> SecdError {
error!("{:?}", e);
new_e
}
+pub(crate) fn to_secd_err(e: anyhow::Error, new_e: SecdError) -> SecdError {
+ error!("{:?}", e);
+ new_e
+}
+
pub(crate) fn log_err_sqlx(e: sqlx::Error) -> sqlx::Error {
error!("{:?}", e);
e
@@ -19,3 +33,145 @@ pub(crate) fn generate_random_url_safe(n: usize) -> String {
.map(char::from)
.collect()
}
+
+pub(crate) fn remove_trailing_slash(url: &mut Url) -> String {
+ let mut u = url.to_string();
+
+ if u.ends_with('/') {
+ u.pop();
+ }
+
+ u
+}
+
+pub(crate) fn build_oauth_auth_url(
+ p: &OauthProvider,
+ validation_id: ValidationRequestId,
+) -> Result<Url, SecdError> {
+ let redirect_url = remove_trailing_slash(&mut p.redirect_url.clone());
+
+ Ok(Url::from_str(&format!(
+ "{}?client_id={}&response_type={}&redirect_uri={}&scope={}&state={}",
+ p.base_url,
+ p.client_id,
+ p.response.to_string().to_lowercase(),
+ redirect_url,
+ p.default_scope,
+ validation_id.to_string()
+ ))
+ .map_err(|_| SecdError::InternalError(INTERNAL_ERR_MSG.into()))?)
+}
+
+pub(crate) async fn get_oauth_identity_data(
+ validation: &OauthValidation,
+) -> anyhow::Result<OauthAccessIdentity> {
+ let provider = validation.oauth_provider.name;
+ let token = validation
+ .access_token
+ .clone()
+ .ok_or(SecdError::InternalError(
+ "no access token provided with which to build oauth data url".into(),
+ ))?;
+
+ let url = Url::from_str(&format!(
+ "{}{}",
+ match provider {
+ OauthProviderName::Google =>
+ "https://www.googleapis.com/oauth2/v2/userinfo?access_token=",
+ _ => unimplemented!(),
+ },
+ token
+ ))?;
+
+ let resp = reqwest::get(url).await?.json::<serde_json::Value>().await?;
+ let identity = match provider {
+ OauthProviderName::Google => OauthAccessIdentity {
+ email: resp
+ .get("email")
+ .and_then(|v| v.as_str().map(|s| s.to_string())),
+ email_is_verified: resp.get("verified_email").and_then(|v| v.as_bool()),
+ picture_url: resp
+ .get("picture")
+ .and_then(|v| Url::from_str(&v.to_string()).ok()),
+ },
+ _ => unimplemented!(),
+ };
+
+ Ok(identity)
+}
+
+#[derive(Debug, Serialize)]
+pub(crate) struct OauthAccessTokenGoogleRequest {
+ grant_type: String,
+ code: String,
+ client_id: String,
+ client_secret: String,
+ redirect_uri: String,
+}
+
+#[derive(Debug, Deserialize)]
+pub(crate) struct OauthAccessTokenGoogleResponse {
+ access_token: String,
+ expires_in: i32,
+ token_type: String,
+ scope: String,
+ id_token: String,
+}
+
+#[derive(Debug)]
+pub(crate) struct OauthAccessIdentity {
+ pub(crate) email: Option<String>,
+ pub(crate) email_is_verified: Option<bool>,
+ pub(crate) picture_url: Option<Url>,
+}
+
+type AccessTokenRequestData = String;
+
+pub(crate) async fn get_oauth_access_token(
+ validation: &OauthValidation,
+ secret_code: &String,
+) -> anyhow::Result<String> {
+ let provider = validation.oauth_provider.name;
+
+ let url = Url::from_str(match provider {
+ OauthProviderName::Google => "https://accounts.google.com/o/oauth2/token",
+ _ => unimplemented!(),
+ })?;
+
+ let request_data = serde_json::to_string(&match provider {
+ OauthProviderName::Google => OauthAccessTokenGoogleRequest {
+ grant_type: "authorization_code".to_string(),
+ code: secret_code.to_string(),
+ client_id: validation.oauth_provider.client_id.clone(),
+ client_secret: validation.oauth_provider.client_secret.clone(),
+ redirect_uri: remove_trailing_slash(
+ &mut validation.oauth_provider.redirect_url.clone(),
+ ),
+ },
+ _ => unimplemented!(),
+ })?;
+
+ let r = reqwest::Client::new()
+ .post(url)
+ .body(request_data)
+ .header(header::CONTENT_TYPE, "application/json")
+ .send()
+ .await
+ .context(format!(
+ "Failed to successfully POST a new access token for: {}",
+ provider
+ ))?;
+
+ let access_token = match provider {
+ OauthProviderName::Google => {
+ let resp: OauthAccessTokenGoogleResponse = r.json().await.context(format!(
+ "Failed to parse access token response for: {}",
+ provider
+ ))?;
+ resp.access_token
+ }
+ _ => unimplemented!(),
+ };
+
+ Ok(access_token)
+}