From 0920c4d4f30a3345870d385d5c6f3e0919228b56 Mon Sep 17 00:00:00 2001 From: benj Date: Mon, 12 Dec 2022 17:06:57 -0800 Subject: (oauth2 + email added): a mess that may or may not really work and needs to be refactored... --- crates/secd/src/client/mod.rs | 233 ++++++++++++++++++++++- crates/secd/src/client/sqldb.rs | 324 ++++++++++++++++++++++++++------ crates/secd/src/client/types.rs | 3 + crates/secd/src/command/admin.rs | 57 ++++++ crates/secd/src/command/authn.rs | 230 +++++++++++++++++++++++ crates/secd/src/command/mod.rs | 66 +++++++ crates/secd/src/lib.rs | 390 ++++++++++++++++++--------------------- crates/secd/src/util/mod.rs | 158 +++++++++++++++- 8 files changed, 1184 insertions(+), 277 deletions(-) create mode 100644 crates/secd/src/client/types.rs create mode 100644 crates/secd/src/command/admin.rs create mode 100644 crates/secd/src/command/authn.rs create mode 100644 crates/secd/src/command/mod.rs (limited to 'crates/secd/src') diff --git a/crates/secd/src/client/mod.rs b/crates/secd/src/client/mod.rs index 3925657..38426ef 100644 --- a/crates/secd/src/client/mod.rs +++ b/crates/secd/src/client/mod.rs @@ -1,13 +1,24 @@ -pub mod email; -pub mod sqldb; +pub(crate) mod email; +pub(crate) mod sqldb; +pub(crate) mod types; -use std::collections::HashMap; +use std::{collections::HashMap, str::FromStr}; use super::Identity; -use crate::{EmailValidation, Session, SessionSecret}; +use crate::{ + EmailValidation, OauthProvider, OauthProviderName, OauthResponseType, OauthValidation, Session, + SessionSecret, ValidationRequestId, ValidationType, +}; +use email_address::EmailAddress; use lazy_static::lazy_static; +use sqlx::{ + database::HasValueRef, sqlite::SqliteRow, ColumnIndex, Database, Decode, FromRow, Row, Sqlite, + Type, +}; use thiserror::Error; +use time::OffsetDateTime; +use url::Url; use uuid::Uuid; pub enum EmailType { @@ -36,13 +47,15 @@ pub trait EmailMessenger { #[derive(Error, Debug, derive_more::Display)] pub enum StoreError { SqlxError(#[from] sqlx::Error), - EmailAlreadyExists, CodeAppearsMoreThanOnce, CodeDoesNotExist(String), IdentityIdMustExistInvariant, - TooManyEmailValidations, + TooManyValidations, + TooManyIdentitiesFound, NoEmailValidationFound, - Unknown, + OauthProviderDoesNotExist(OauthProviderName), + OauthValidationDoesNotExist(ValidationRequestId), + Other(String), } const EMAIL_TEMPLATE_DEFAULT_LOGIN: &str = "You requested a login link. Please click the following link %secd_code% to login as %secd_email_address%"; @@ -56,6 +69,7 @@ const PGSQL: &str = "pgsql"; const WRITE_IDENTITY: &str = "write_identity"; const WRITE_EMAIL_VALIDATION: &str = "write_email_validation"; const FIND_EMAIL_VALIDATION: &str = "find_email_validation"; +const READ_VALIDATION_TYPE: &str = "read_validation_type"; const WRITE_EMAIL: &str = "write_email"; @@ -69,6 +83,11 @@ const READ_EMAIL_RAW_ID: &str = "read_email_raw_id"; const WRITE_SESSION: &str = "write_session"; const READ_SESSION: &str = "read_session"; +const WRITE_OAUTH_PROVIDER: &str = "write_oauth_provider"; +const READ_OAUTH_PROVIDER: &str = "read_oauth_provider"; +const WRITE_OAUTH_VALIDATION: &str = "write_oauth_validation"; +const READ_OAUTH_VALIDATION: &str = "read_oauth_validation"; + lazy_static! { static ref SQLS: HashMap<&'static str, HashMap<&'static str, &'static str>> = { let sqlite_sqls: HashMap<&'static str, &'static str> = [ @@ -116,6 +135,26 @@ lazy_static! { FIND_EMAIL_VALIDATION, include_str!("../../store/sqlite/sql/find_email_validation.sql"), ), + ( + WRITE_OAUTH_PROVIDER, + include_str!("../../store/sqlite/sql/write_oauth_provider.sql"), + ), + ( + READ_OAUTH_PROVIDER, + include_str!("../../store/sqlite/sql/read_oauth_provider.sql"), + ), + ( + READ_OAUTH_VALIDATION, + include_str!("../../store/sqlite/sql/read_oauth_validation.sql"), + ), + ( + WRITE_OAUTH_VALIDATION, + include_str!("../../store/sqlite/sql/write_oauth_validation.sql"), + ), + ( + READ_VALIDATION_TYPE, + include_str!("../../store/sqlite/sql/read_validation_type.sql"), + ), ] .iter() .cloned() @@ -166,6 +205,26 @@ lazy_static! { FIND_EMAIL_VALIDATION, include_str!("../../store/pg/sql/find_email_validation.sql"), ), + ( + WRITE_OAUTH_PROVIDER, + include_str!("../../store/pg/sql/write_oauth_provider.sql"), + ), + ( + READ_OAUTH_PROVIDER, + include_str!("../../store/pg/sql/read_oauth_provider.sql"), + ), + ( + READ_OAUTH_VALIDATION, + include_str!("../../store/pg/sql/read_oauth_validation.sql"), + ), + ( + WRITE_OAUTH_VALIDATION, + include_str!("../../store/pg/sql/write_oauth_validation.sql"), + ), + ( + READ_VALIDATION_TYPE, + include_str!("../../store/pg/sql/read_validation_type.sql"), + ), ] .iter() .cloned() @@ -180,9 +239,143 @@ lazy_static! { }; } +impl<'a, R: Row> FromRow<'a, R> for OauthValidation +where + &'a str: ColumnIndex, + OauthProviderName: Decode<'a, R::Database> + Type, + OauthResponseType: Decode<'a, R::Database> + Type, + OffsetDateTime: Decode<'a, R::Database> + Type, + String: Decode<'a, R::Database> + Type, + Uuid: Decode<'a, R::Database> + Type, +{ + fn from_row(row: &'a R) -> Result { + let id: Option = row.try_get("oauth_validation_public_id")?; + let identity_id: Option = row.try_get("identity_public_id")?; + let access_token: Option = row.try_get("access_token")?; + let raw_response: Option = row.try_get("raw_response")?; + let created_at: Option = row.try_get("created_at")?; + let validated_at: Option = row.try_get("validated_at")?; + let revoked_at: Option = row.try_get("revoked_at")?; + let deleted_at: Option = row.try_get("deleted_at")?; + + let op_name: Option = row.try_get("oauth_provider_name")?; + let op_flow: Option = row.try_get("oauth_provider_flow")?; + let op_base_url: Option = row.try_get("oauth_provider_base_url")?; + let op_response_type: Option = + row.try_get("oauth_provider_response_type")?; + let op_default_scope: Option = row.try_get("oauth_provider_default_scope")?; + let op_client_id: Option = row.try_get("oauth_provider_client_id")?; + let op_client_secret: Option = row.try_get("oauth_provider_client_secret")?; + let op_redirect_url: Option = row.try_get("oauth_provider_redirect_url")?; + let op_created_at: Option = row.try_get("oauth_provider_created_at")?; + let op_deleted_at: Option = row.try_get("oauth_provider_deleted_at")?; + + let op_base_url = op_base_url + .map(|s| Url::from_str(&s).ok()) + .flatten() + .ok_or(sqlx::Error::ColumnDecode { + index: "oauth_provider_base_url".into(), + source: "secd".into(), + })?; + + let op_redirect_url = op_redirect_url + .map(|s| Url::from_str(&s).ok()) + .flatten() + .ok_or(sqlx::Error::ColumnDecode { + index: "oauth_provider_redirect_url".into(), + source: "secd".into(), + })?; + + Ok(OauthValidation { + id, + identity_id, + access_token, + raw_response, + created_at: created_at.ok_or(sqlx::Error::ColumnDecode { + index: "created_at".into(), + source: "secd".into(), + })?, + validated_at, + revoked_at, + deleted_at, + oauth_provider: OauthProvider { + name: op_name.unwrap(), + flow: op_flow, + base_url: op_base_url, + response: op_response_type.ok_or(sqlx::Error::ColumnDecode { + index: "oauth_provider_response_type".into(), + source: "secd".into(), + })?, + default_scope: op_default_scope.ok_or(sqlx::Error::ColumnDecode { + index: "oauth_provider_default_scope".into(), + source: "secd".into(), + })?, + client_id: op_client_id.ok_or(sqlx::Error::ColumnDecode { + index: "oauth_provider_client_id".into(), + source: "secd".into(), + })?, + client_secret: op_client_secret.ok_or(sqlx::Error::ColumnDecode { + index: "oauth_provider_client_secret".into(), + source: "secd".into(), + })?, + redirect_url: op_redirect_url, + created_at: op_created_at.ok_or(sqlx::Error::ColumnDecode { + index: "oauth_provider_created_at".into(), + source: "secd".into(), + })?, + deleted_at: op_deleted_at, + }, + }) + } +} + +impl<'a, D: Database> Decode<'a, D> for OauthProviderName +where + &'a str: Decode<'a, D>, +{ + fn decode( + value: >::ValueRef, + ) -> Result> { + let v = <&str as Decode>::decode(value)?; + ::from_str(v, true) + .map_err(|_| "OauthProviderName should exist and decode to a program value.".into()) + } +} + +impl Type for OauthProviderName +where + str: Type, +{ + fn type_info() -> D::TypeInfo { + <&str as Type>::type_info() + } +} + +impl<'a, D: Database> Decode<'a, D> for OauthResponseType +where + &'a str: Decode<'a, D>, +{ + fn decode( + value: >::ValueRef, + ) -> Result> { + let v = <&str as Decode>::decode(value)?; + ::from_str(v, true) + .map_err(|_| "OauthResponseType should exist and decode to a program value.".into()) + } +} + +impl Type for OauthResponseType +where + str: Type, +{ + fn type_info() -> D::TypeInfo { + <&str as Type>::type_info() + } +} + #[async_trait::async_trait] pub trait Store { - async fn write_email(&self, identity_id: Uuid, email_address: &str) -> Result<(), StoreError>; + async fn write_email(&self, email_address: &str) -> Result<(), StoreError>; async fn find_email_validation( &self, @@ -193,17 +386,37 @@ pub trait Store { &self, ev: &EmailValidation, // TODO: Make this write an EmailValidation - ) -> Result; + ) -> anyhow::Result; async fn find_identity( &self, identity_id: Option<&Uuid>, email: Option<&str>, - ) -> Result, StoreError>; + ) -> anyhow::Result>; async fn find_identity_by_code(&self, code: &str) -> Result; async fn write_identity(&self, i: &Identity) -> Result<(), StoreError>; async fn read_identity(&self, identity_id: &Uuid) -> Result; async fn write_session(&self, session: &Session) -> Result<(), StoreError>; async fn read_session(&self, secret: &SessionSecret) -> Result; + + async fn write_oauth_provider(&self, provider: &OauthProvider) -> Result<(), StoreError>; + async fn read_oauth_provider( + &self, + provider: &OauthProviderName, + flow: Option, + ) -> Result; + async fn write_oauth_validation( + &self, + validation: &OauthValidation, + ) -> anyhow::Result; + async fn read_oauth_validation( + &self, + validation_id: &ValidationRequestId, + ) -> anyhow::Result; + + async fn find_validation_type( + &self, + validation_id: &ValidationRequestId, + ) -> anyhow::Result; } diff --git a/crates/secd/src/client/sqldb.rs b/crates/secd/src/client/sqldb.rs index 6048c48..15cc4b5 100644 --- a/crates/secd/src/client/sqldb.rs +++ b/crates/secd/src/client/sqldb.rs @@ -1,19 +1,23 @@ -use std::sync::Arc; +use std::{str::FromStr, sync::Arc}; use super::{ - EmailValidation, Identity, Session, SessionSecret, Store, StoreError, ERR_MSG_MIGRATION_FAILED, - FIND_EMAIL_VALIDATION, FIND_IDENTITY, FIND_IDENTITY_BY_CODE, PGSQL, READ_EMAIL_RAW_ID, - READ_IDENTITY_RAW_ID, READ_SESSION, SQLITE, SQLS, WRITE_EMAIL, WRITE_EMAIL_VALIDATION, - WRITE_IDENTITY, WRITE_SESSION, + EmailValidation, Identity, OauthProvider, OauthProviderName, OauthResponseType, Session, + SessionSecret, Store, StoreError, ERR_MSG_MIGRATION_FAILED, FIND_EMAIL_VALIDATION, + FIND_IDENTITY, FIND_IDENTITY_BY_CODE, PGSQL, READ_EMAIL_RAW_ID, READ_IDENTITY_RAW_ID, + READ_OAUTH_PROVIDER, READ_OAUTH_VALIDATION, READ_SESSION, READ_VALIDATION_TYPE, SQLITE, SQLS, + WRITE_EMAIL, WRITE_EMAIL_VALIDATION, WRITE_IDENTITY, WRITE_OAUTH_PROVIDER, + WRITE_OAUTH_VALIDATION, WRITE_SESSION, }; -use crate::util; -use log::error; +use crate::{util, OauthValidation, ValidationRequestId, ValidationType}; +use anyhow::bail; +use log::{debug, error}; use openssl::sha::Sha256; use sqlx::{ self, database::HasArguments, ColumnIndex, Database, Decode, Encode, Executor, IntoArguments, Pool, Postgres, Sqlite, Transaction, Type, }; use time::OffsetDateTime; +use url::Url; use uuid::Uuid; fn get_sqls(root: &str, file: &str) -> Vec { @@ -97,6 +101,8 @@ where for<'c> String: Encode<'c, D> + Type, for<'c> Option: Decode<'c, D> + Type, for<'c> Option: Encode<'c, D> + Type, + for<'c> OauthProviderName: Decode<'c, D> + Type, + for<'c> OauthResponseType: Decode<'c, D> + Type, for<'c> usize: ColumnIndex<::Row>, for<'c> Uuid: Decode<'c, D> + Type, for<'c> Uuid: Encode<'c, D> + Type, @@ -108,29 +114,11 @@ where for<'c> &'c Pool: Executor<'c, Database = D>, for<'c> &'c mut Transaction<'c, D>: Executor<'c, Database = D>, { - async fn write_email(&self, identity_id: Uuid, email_address: &str) -> Result<(), StoreError> { + async fn write_email(&self, email_address: &str) -> Result<(), StoreError> { let sqls = get_sqls(&self.sqls_root, WRITE_EMAIL); - let identity_id = self.read_identity_raw_id(&identity_id).await?; - - let email_id: (i64,) = match sqlx::query_as(&sqls[0]) + sqlx::query(&sqls[0]) .bind(email_address) - .fetch_one(&self.pool) - .await - { - Ok(i) => i, - Err(sqlx::Error::RowNotFound) => sqlx::query_as::<_, (i64,)>(&sqls[1]) - .bind(email_address) - .fetch_one(&self.pool) - .await - .map_err(util::log_err_sqlx)?, - Err(e) => return Err(StoreError::SqlxError(e)), - }; - - sqlx::query(&sqls[2]) - .bind(identity_id) - .bind(email_id.0) - .bind(OffsetDateTime::now_utc()) .execute(&self.pool) .await .map_err(util::log_err_sqlx)?; @@ -154,57 +142,84 @@ where match rows.len() { 0 => Err(StoreError::NoEmailValidationFound), 1 => Ok(rows.swap_remove(0)), - _ => Err(StoreError::TooManyEmailValidations), + _ => Err(StoreError::TooManyValidations), } } - async fn write_email_validation(&self, ev: &EmailValidation) -> Result { + async fn write_email_validation(&self, ev: &EmailValidation) -> anyhow::Result { let sqls = get_sqls(&self.sqls_root, WRITE_EMAIL_VALIDATION); - let identity_id = self - .read_identity_raw_id( - &ev.identity_id - .ok_or(StoreError::IdentityIdMustExistInvariant)?, - ) - .await?; let email_id = self.read_email_raw_id(&ev.email_address).await?; - - let new_id = Uuid::new_v4(); + let validation_id = ev.id.unwrap_or(Uuid::new_v4()); sqlx::query(&sqls[0]) - .bind(ev.id.unwrap_or(new_id)) - .bind(identity_id) + .bind(validation_id) .bind(email_id) - .bind(ev.attempts) .bind(&ev.code) - .bind(ev.is_validated) + .bind(ev.is_oauth_derived) .bind(ev.created_at) - .bind(ev.expires_at) + .bind(ev.validated_at) + .bind(ev.expired_at) .execute(&self.pool) .await .map_err(util::log_err_sqlx)?; - Ok(new_id) + if ev.identity_id.is_some() || ev.revoked_at.is_some() || ev.deleted_at.is_some() { + sqlx::query(&sqls[1]) + .bind(ev.identity_id.as_ref()) + .bind(validation_id) + .bind(ev.revoked_at) + .bind(ev.deleted_at) + .execute(&self.pool) + .await + .map_err(util::log_err_sqlx)?; + } + + Ok(validation_id) } async fn find_identity( &self, id: Option<&Uuid>, email: Option<&str>, - ) -> Result, StoreError> { + ) -> anyhow::Result> { let sqls = get_sqls(&self.sqls_root, FIND_IDENTITY); Ok( match sqlx::query_as::<_, Identity>(&sqls[0]) .bind(id) .bind(email) - .fetch_one(&self.pool) + .fetch_all(&self.pool) .await { - Ok(i) => Some(i), + Ok(mut is) => match is.len() { + // if only 1 found, then that's fine + // if multiple are fond, then if they all have the same id, that's okay + 1 => { + let i = is.swap_remove(0); + match i.deleted_at { + Some(t) if t > OffsetDateTime::now_utc() => Some(i), + None => Some(i), + _ => None, + } + } + 0 => None, + _ => { + match is + .iter() + .filter(|&i| i.id != is[0].id) + .collect::>() + .len() + { + 0 => Some(is.swap_remove(0)), + _ => bail!(StoreError::TooManyIdentitiesFound), + } + } + }, Err(sqlx::Error::RowNotFound) => None, - Err(e) => return Err(StoreError::SqlxError(e)), + Err(e) => bail!(StoreError::SqlxError(e)), }, ) } + async fn find_identity_by_code(&self, code: &str) -> Result { let sqls = get_sqls(&self.sqls_root, FIND_IDENTITY_BY_CODE); @@ -250,14 +265,16 @@ where Ok(()) } async fn read_identity(&self, id: &Uuid) -> Result { - Ok(sqlx::query_as::<_, Identity>( + let identity = sqlx::query_as::<_, Identity>( " select identity_public_id, data, created_at from identity where identity_public_id = ?", ) .bind(id) .fetch_one(&self.pool) .await - .map_err(util::log_err_sqlx)?) + .map_err(util::log_err_sqlx)?; + + Ok(identity) } async fn write_session(&self, session: &Session) -> Result<(), StoreError> { @@ -269,7 +286,6 @@ select identity_public_id, data, created_at from identity where identity_public_ .bind(&session.identity_id) .bind(secret_hash.as_ref()) .bind(session.created_at) - .bind(OffsetDateTime::now_utc()) .bind(session.expires_at) .bind(session.revoked_at) .execute(&self.pool) @@ -296,6 +312,142 @@ select identity_public_id, data, created_at from identity where identity_public_ Ok(session) } + + async fn write_oauth_provider(&self, provider: &OauthProvider) -> Result<(), StoreError> { + let sqls = get_sqls(&self.sqls_root, WRITE_OAUTH_PROVIDER); + sqlx::query(&sqls[0]) + .bind(&provider.name.to_string()) + .bind(&provider.flow) + .bind(&provider.base_url.to_string()) + .bind(&provider.response.to_string()) + .bind(&provider.default_scope) + .bind(&provider.client_id) + // TODO: encrypt secret before writing + .bind(&provider.client_secret) + .bind(&provider.redirect_url.to_string()) + .bind(provider.created_at) + .bind(provider.deleted_at) + .execute(&self.pool) + .await + .map_err(util::log_err_sqlx)?; + Ok(()) + } + + async fn read_oauth_provider( + &self, + provider: &OauthProviderName, + flow: Option, + ) -> Result { + let sqls = get_sqls(&self.sqls_root, READ_OAUTH_PROVIDER); + let flow = flow.unwrap_or("default".into()); + debug!("provider: {:?}, flow: {:?}", provider, flow); + // TODO: Write the generic FromRow impl for OauthProvider... + let res = sqlx::query_as::< + _, + ( + String, + String, + String, + String, + String, + String, + String, + OffsetDateTime, + Option, + ), + >(&sqls[0]) + .bind(&provider.to_string()) + .bind(&flow) + .fetch_one(&self.pool) + .await + .map_err(util::log_err_sqlx)?; + + debug!("res: {:?}", res); + + Ok(OauthProvider { + name: provider.clone(), + flow: Some(res.0), + base_url: Url::from_str(&res.1) + .map_err(|_| StoreError::OauthProviderDoesNotExist(*provider))?, + response: OauthResponseType::from_str(&res.2) + .map_err(|_| StoreError::OauthProviderDoesNotExist(*provider))?, + default_scope: res.3, + client_id: res.4, + client_secret: res.5, + redirect_url: Url::from_str(&res.6) + .map_err(|_| StoreError::OauthProviderDoesNotExist(*provider))?, + created_at: res.7, + deleted_at: res.8, + }) + } + async fn write_oauth_validation( + &self, + v: &OauthValidation, + ) -> anyhow::Result { + let sqls = get_sqls(&self.sqls_root, WRITE_OAUTH_VALIDATION); + + let validation_id = v.id.unwrap_or(Uuid::new_v4()); + sqlx::query(&sqls[0]) + .bind(validation_id) + .bind(v.oauth_provider.name.to_string()) + .bind(v.oauth_provider.flow.clone()) + .bind(v.access_token.clone()) + .bind(v.raw_response.clone()) + .bind(v.created_at) + .bind(v.validated_at) + .execute(&self.pool) + .await?; + + if v.identity_id.is_some() || v.revoked_at.is_some() || v.deleted_at.is_some() { + sqlx::query(&sqls[1]) + .bind(v.identity_id.as_ref()) + .bind(validation_id) + .bind(v.revoked_at) + .bind(v.deleted_at) + .execute(&self.pool) + .await?; + } + + Ok(validation_id) + } + async fn read_oauth_validation( + &self, + validation_id: &ValidationRequestId, + ) -> anyhow::Result { + let sqls = get_sqls(&self.sqls_root, READ_OAUTH_VALIDATION); + + let mut es = sqlx::query_as::<_, OauthValidation>(&sqls[0]) + .bind(validation_id) + .fetch_all(&self.pool) + .await?; + + if es.len() != 1 { + bail!(StoreError::OauthValidationDoesNotExist( + validation_id.clone() + )); + } + + Ok(es.swap_remove(0)) + } + async fn find_validation_type( + &self, + validation_id: &ValidationRequestId, + ) -> anyhow::Result { + let sqls = get_sqls(&self.sqls_root, READ_VALIDATION_TYPE); + + let mut es = sqlx::query_as::<_, (String,)>(&sqls[0]) + .bind(validation_id) + .fetch_all(&self.pool) + .await + .map_err(util::log_err_sqlx)?; + + match es.len() { + 1 => Ok(ValidationType::from_str(&es.swap_remove(0).0)?), + _ => bail!(StoreError::Other( + "expected a single validation but recieved 0 or multiple validations".into() + )), + } + } } pub struct PgClient { @@ -320,8 +472,8 @@ impl PgClient { #[async_trait::async_trait] impl Store for PgClient { - async fn write_email(&self, identity_id: Uuid, email_address: &str) -> Result<(), StoreError> { - self.sql.write_email(identity_id, email_address).await + async fn write_email(&self, email_address: &str) -> Result<(), StoreError> { + self.sql.write_email(email_address).await } async fn find_email_validation( &self, @@ -330,14 +482,14 @@ impl Store for PgClient { ) -> Result { self.sql.find_email_validation(validation_id, code).await } - async fn write_email_validation(&self, ev: &EmailValidation) -> Result { + async fn write_email_validation(&self, ev: &EmailValidation) -> anyhow::Result { self.sql.write_email_validation(ev).await } async fn find_identity( &self, identity_id: Option<&Uuid>, email: Option<&str>, - ) -> Result, StoreError> { + ) -> anyhow::Result> { self.sql.find_identity(identity_id, email).await } async fn find_identity_by_code(&self, code: &str) -> Result { @@ -355,6 +507,34 @@ impl Store for PgClient { async fn read_session(&self, secret: &SessionSecret) -> Result { self.sql.read_session(secret).await } + async fn write_oauth_provider(&self, provider: &OauthProvider) -> Result<(), StoreError> { + self.sql.write_oauth_provider(provider).await + } + async fn read_oauth_provider( + &self, + provider: &OauthProviderName, + flow: Option, + ) -> Result { + self.sql.read_oauth_provider(provider, flow).await + } + async fn write_oauth_validation( + &self, + validation: &OauthValidation, + ) -> anyhow::Result { + self.sql.write_oauth_validation(validation).await + } + async fn read_oauth_validation( + &self, + validation_id: &ValidationRequestId, + ) -> anyhow::Result { + self.sql.read_oauth_validation(validation_id).await + } + async fn find_validation_type( + &self, + validation_id: &ValidationRequestId, + ) -> anyhow::Result { + self.sql.find_validation_type(validation_id).await + } } pub struct SqliteClient { @@ -386,8 +566,8 @@ impl SqliteClient { #[async_trait::async_trait] impl Store for SqliteClient { - async fn write_email(&self, identity_id: Uuid, email_address: &str) -> Result<(), StoreError> { - self.sql.write_email(identity_id, email_address).await + async fn write_email(&self, email_address: &str) -> Result<(), StoreError> { + self.sql.write_email(email_address).await } async fn find_email_validation( &self, @@ -396,14 +576,14 @@ impl Store for SqliteClient { ) -> Result { self.sql.find_email_validation(validation_id, code).await } - async fn write_email_validation(&self, ev: &EmailValidation) -> Result { + async fn write_email_validation(&self, ev: &EmailValidation) -> anyhow::Result { self.sql.write_email_validation(ev).await } async fn find_identity( &self, identity_id: Option<&Uuid>, email: Option<&str>, - ) -> Result, StoreError> { + ) -> anyhow::Result> { self.sql.find_identity(identity_id, email).await } async fn find_identity_by_code(&self, code: &str) -> Result { @@ -421,4 +601,32 @@ impl Store for SqliteClient { async fn read_session(&self, secret: &SessionSecret) -> Result { self.sql.read_session(secret).await } + async fn write_oauth_provider(&self, provider: &OauthProvider) -> Result<(), StoreError> { + self.sql.write_oauth_provider(provider).await + } + async fn read_oauth_provider( + &self, + provider: &OauthProviderName, + flow: Option, + ) -> Result { + self.sql.read_oauth_provider(provider, flow).await + } + async fn write_oauth_validation( + &self, + validation: &OauthValidation, + ) -> anyhow::Result { + self.sql.write_oauth_validation(validation).await + } + async fn read_oauth_validation( + &self, + validation_id: &ValidationRequestId, + ) -> anyhow::Result { + self.sql.read_oauth_validation(validation_id).await + } + async fn find_validation_type( + &self, + validation_id: &ValidationRequestId, + ) -> anyhow::Result { + self.sql.find_validation_type(validation_id).await + } } diff --git a/crates/secd/src/client/types.rs b/crates/secd/src/client/types.rs new file mode 100644 index 0000000..bacade4 --- /dev/null +++ b/crates/secd/src/client/types.rs @@ -0,0 +1,3 @@ +pub(crate) struct Email { + address: String, +} diff --git a/crates/secd/src/command/admin.rs b/crates/secd/src/command/admin.rs new file mode 100644 index 0000000..b04dbef --- /dev/null +++ b/crates/secd/src/command/admin.rs @@ -0,0 +1,57 @@ +use std::str::FromStr; + +use time::OffsetDateTime; +use url::Url; + +use crate::{OauthProviderName, Secd, SecdError}; + +impl OauthProviderName { + fn base_url(&self) -> Url { + match self { + OauthProviderName::Google => { + Url::from_str("https://accounts.google.com/o/oauth2/v2/auth").unwrap() + } + OauthProviderName::Microsoft => { + Url::from_str("https://login.microsoftonline.com/common/oauth2/v2.0/authorize") + .unwrap() + } + _ => unimplemented!(), + } + } + + fn default_scope(&self) -> String { + match self { + OauthProviderName::Google => "openid%20email".into(), + OauthProviderName::Microsoft => "openid%20email".into(), + _ => unimplemented!(), + } + } +} + +impl Secd { + pub async fn create_oauth_provider( + &self, + provider: &OauthProviderName, + client_id: String, + client_secret: String, + redirect_url: Url, + ) -> Result<(), SecdError> { + self.store + .write_oauth_provider(&crate::OauthProvider { + name: provider.clone(), + flow: Some("default".into()), + base_url: provider.base_url(), + response: crate::OauthResponseType::Code, + default_scope: provider.default_scope(), + client_id, + client_secret, + redirect_url, + created_at: OffsetDateTime::now_utc(), + deleted_at: None, + }) + .await + .map_err(|_| SecdError::Todo)?; + + Ok(()) + } +} diff --git a/crates/secd/src/command/authn.rs b/crates/secd/src/command/authn.rs new file mode 100644 index 0000000..862d921 --- /dev/null +++ b/crates/secd/src/command/authn.rs @@ -0,0 +1,230 @@ +use email_address::EmailAddress; +use log::debug; +use rand::distributions::{Alphanumeric, DistString}; +use time::Duration; +use time::OffsetDateTime; +use uuid::Uuid; + +use crate::util::{build_oauth_auth_url, get_oauth_access_token}; +use crate::OauthRedirectAuthUrl; +use crate::Validation; +use crate::ValidationType; +use crate::INTERNAL_ERR_MSG; +use crate::{ + client, util, EmailValidation, Identity, OauthProviderName, Secd, SecdError, Session, + ValidationRequestId, ValidationSecretCode, EMAIL_VALIDATION_DURATION, SESSION_DURATION, + SESSION_SIZE_BYTES, VALIDATION_CODE_SIZE, +}; + +impl Secd { + /// create_validation_request_oauth + /// + /// Generate a request to validate with the specified oauth provider.[ + // TODO: How to handle different oauth "flows"? e.g. web app vs desktop vs mobile... + pub async fn create_validation_request_oauth( + &self, + provider: &OauthProviderName, + scope: Option, + ) -> Result { + if scope.is_some() { + return Err(SecdError::NotImplemented( + "Only default scopes are currently supported.".into(), + )); + } + + let p = self + .store + .read_oauth_provider(provider, None) + .await + .map_err(|_| SecdError::InternalError(INTERNAL_ERR_MSG.to_string()))?; + + let req_id = self + .store + .write_oauth_validation(&crate::OauthValidation { + id: Some(Uuid::new_v4()), + identity_id: None, + oauth_provider: p.clone(), + access_token: None, + raw_response: None, + created_at: OffsetDateTime::now_utc(), + validated_at: None, + revoked_at: None, + deleted_at: None, + }) + .await + .map_err(|e| util::to_secd_err(e, SecdError::OauthValidationRequestError))?; + + build_oauth_auth_url(&p, req_id) + } + /// create_validation_request_email + /// + /// Generate a request to validate the provided email. + pub async fn create_validation_request_email( + &self, + email: Option<&str>, + ) -> Result { + let now = OffsetDateTime::now_utc(); + + let email = match email { + Some(ea) => { + if EmailAddress::is_valid(ea) { + ea + } else { + return Err(SecdError::InvalidEmailAddress); + } + } + None => return Err(SecdError::InvalidEmailAddress), + }; + + let mut ev = EmailValidation { + id: None, + identity_id: None, + email_address: email.to_string(), + code: Some( + Alphanumeric + .sample_string(&mut rand::thread_rng(), VALIDATION_CODE_SIZE) + .to_lowercase(), + ), + is_oauth_derived: false, + created_at: now, + expired_at: now + .checked_add(Duration::new(EMAIL_VALIDATION_DURATION, 0)) + .ok_or(SecdError::EmailValidationExpiryOverflow)?, + validated_at: None, + revoked_at: None, + deleted_at: None, + }; + + let (req_id, mail_type) = match self + .store + .find_identity(None, Some(email)) + .await + .map_err(|e| util::log_err(e.into(), SecdError::Todo))? + { + Some(identity) => { + let req_id = { + ev.identity_id = Some(identity.id); + self.store + .write_email_validation(&ev) + .await + .map_err(|e| util::log_err(e.into(), SecdError::Todo))? + }; + (req_id, client::EmailType::Login) + } + None => { + self.store + .write_email(email) + .await + .map_err(|e| util::log_err(e.into(), SecdError::Todo))?; + + let req_id = { + self.store + .write_email_validation(&ev) + .await + .map_err(|e| util::log_err(e.into(), SecdError::Todo))? + }; + + (req_id, client::EmailType::Signup) + } + }; + + self.email_messenger + .send_email(email, &req_id.to_string(), &ev.code.unwrap(), mail_type) + .await?; + + Ok(req_id) + } + /// exchange_secret_for_session + /// + /// Exchanges a secret, which consists of a validation_request_id and secret_code + /// for a session which allows authentication on behalf of the associated identity. + /// + /// Session secrets should be used to return authorization for the associated identity. + pub async fn exchange_code_for_session( + &self, + validation_request_id: ValidationRequestId, + code: ValidationSecretCode, + ) -> Result { + let mut v: Box = match self + .store + .find_validation_type(&validation_request_id) + .await + .map_err(|e| util::to_secd_err(e, SecdError::Todo))? + { + ValidationType::Email => Box::new( + self.store + .find_email_validation(Some(&validation_request_id), Some(&code)) + .await + .map_err(|e| { + util::log_err(e.into(), SecdError::EmailValidationExpiryOverflow) + })?, + ), + ValidationType::Oauth => Box::new({ + let mut t = self + .store + .read_oauth_validation(&validation_request_id) + .await + .map_err(|e| util::to_secd_err(e, SecdError::Todo))?; + + let access_token = get_oauth_access_token(&t, &code) + .await + .map_err(|_| SecdError::Todo)?; + + t.access_token = Some(access_token); + t + }), + }; + + if v.expired() || v.is_validated() { + return Err(SecdError::InvalidCode); + }; + + let mut identity = Identity { + id: Uuid::new_v4(), + data: None, + created_at: OffsetDateTime::now_utc(), + deleted_at: None, + }; + + match v + .find_associated_identities(self.store.clone()) + .await + .map_err(|e| util::to_secd_err(e, SecdError::IdentityIdShouldExistInvariant))? + { + Some(i) => identity.id = i.id, + _ => self.store.write_identity(&identity).await.map_err(|_| { + SecdError::InternalError("failed to write identity during session exchange".into()) + })?, + }; + + v.validate(&identity, self.store.clone()) + .await + .map_err(|e| { + util::to_secd_err( + e, + SecdError::InternalError( + "failed to update validation during session exchange".into(), + ), + ) + })?; + + // TODO: clear previous sessions if they fit the criteria + let now = OffsetDateTime::now_utc(); + let s = Session { + identity_id: identity.id, + secret: Some(Alphanumeric.sample_string(&mut rand::thread_rng(), SESSION_SIZE_BYTES)), + created_at: now, + expires_at: now + .checked_add(Duration::new(SESSION_DURATION, 0)) + .ok_or(SecdError::SessionExpiryOverflow)?, + revoked_at: None, + }; + + self.store + .write_session(&s) + .await + .map_err(|e| util::log_err(e.into(), SecdError::Todo))?; + + Ok(s) + } +} diff --git a/crates/secd/src/command/mod.rs b/crates/secd/src/command/mod.rs new file mode 100644 index 0000000..cd0d8c3 --- /dev/null +++ b/crates/secd/src/command/mod.rs @@ -0,0 +1,66 @@ +pub mod admin; +pub mod authn; + +use crate::client::{ + email, + sqldb::{PgClient, SqliteClient}, +}; +use crate::{AuthEmail, AuthStore, Secd, SecdError}; +use log::error; +use std::sync::Arc; + +impl Secd { + /// init + /// + /// Initialize SecD with the specified configuration, established the necessary + /// constraints, persistance stores, and options. + pub async fn init( + auth_store: AuthStore, + conn_string: Option<&str>, + email_messenger: AuthEmail, + email_template_login: Option, + email_template_signup: Option, + ) -> Result { + let store = match auth_store { + AuthStore::Sqlite => { + SqliteClient::new( + sqlx::sqlite::SqlitePoolOptions::new() + .connect(conn_string.unwrap_or("sqlite::memory:".into())) + .await + .map_err(|e| SecdError::InitializationFailure(e))?, + ) + .await + } + AuthStore::Postgres => { + PgClient::new( + sqlx::postgres::PgPoolOptions::new() + .connect(conn_string.expect("No postgres connection string provided.")) + .await + .map_err(|e| SecdError::InitializationFailure(e))?, + ) + .await + } + rest @ _ => { + error!( + "requested an AuthStore which has not yet been implemented: {:?}", + rest + ); + unimplemented!() + } + }; + + let email_sender = match email_messenger { + // TODO: initialize email and SMS templates with secd + AuthEmail::LocalStub => email::LocalEmailStubber { + email_template_login, + email_template_signup, + }, + _ => unimplemented!(), + }; + + Ok(Secd { + store, + email_messenger: Arc::new(email_sender), + }) + } +} diff --git a/crates/secd/src/lib.rs b/crates/secd/src/lib.rs index 4feda04..faa92ca 100644 --- a/crates/secd/src/lib.rs +++ b/crates/secd/src/lib.rs @@ -1,28 +1,28 @@ mod client; +mod command; mod util; use std::sync::Arc; -use client::{ - email, - sqldb::{PgClient, SqliteClient}, - EmailMessenger, EmailMessengerError, Store, StoreError, -}; +use clap::ValueEnum; +use client::{EmailMessenger, EmailMessengerError, Store}; use derive_more::Display; use email_address::EmailAddress; -use log::error; -use rand::distributions::{Alphanumeric, DistString}; use serde::{Deserialize, Serialize}; +use sqlx::FromRow; use strum_macros::{EnumString, EnumVariantNames}; -use time::{Duration, OffsetDateTime}; +use time::OffsetDateTime; +use url::Url; +use util::get_oauth_identity_data; use uuid::Uuid; const SESSION_SIZE_BYTES: usize = 32; const SESSION_DURATION: i64 = 60 /* seconds*/ * 60 /* minutes */ * 24 /* hours */ * 360 /* days */; const EMAIL_VALIDATION_DURATION: i64 = 60 /* seconds*/ * 15 /* minutes */; -const VALIDATION_ATTEMPTS_MAX: i32 = 5; const VALIDATION_CODE_SIZE: usize = 6; +const INTERNAL_ERR_MSG: &str = "It seems an invariant was borked or something non-deterministic happened. Please file a bug with secd."; + #[derive(sqlx::FromRow, Debug, Serialize)] pub struct ApiKey { pub public_key: String, @@ -38,9 +38,11 @@ pub struct Authorization { pub struct Identity { #[sqlx(rename = "identity_public_id")] id: Uuid, - created_at: OffsetDateTime, #[serde(skip_serializing_if = "Option::is_none")] data: Option, + created_at: OffsetDateTime, + #[serde(skip_serializing_if = "Option::is_none")] + deleted_at: Option, } #[derive(sqlx::FromRow, Debug, Serialize)] @@ -58,6 +60,121 @@ pub struct Session { pub revoked_at: Option, } +#[async_trait::async_trait] +trait Validation { + fn expired(&self) -> bool; + fn is_validated(&self) -> bool; + async fn find_associated_identities( + &self, + store: Arc, + ) -> anyhow::Result>; + async fn validate( + &mut self, + i: &Identity, + store: Arc, + ) -> anyhow::Result<()>; +} + +#[async_trait::async_trait] +impl Validation for EmailValidation { + fn expired(&self) -> bool { + let now = OffsetDateTime::now_utc(); + self.expired_at < now + || self.revoked_at.map(|t| t < now).unwrap_or(false) + || self.deleted_at.map(|t| t < now).unwrap_or(false) + } + fn is_validated(&self) -> bool { + self.validated_at + .map(|t| t >= OffsetDateTime::now_utc()) + .unwrap_or(false) + } + async fn find_associated_identities( + &self, + store: Arc, + ) -> anyhow::Result> { + store.find_identity(None, Some(&self.email_address)).await + } + async fn validate( + &mut self, + i: &Identity, + store: Arc, + ) -> anyhow::Result<()> { + self.identity_id = Some(i.id); + self.validated_at = Some(OffsetDateTime::now_utc()); + store.write_email_validation(&self).await?; + Ok(()) + } +} + +#[async_trait::async_trait] +impl Validation for OauthValidation { + fn expired(&self) -> bool { + let now = OffsetDateTime::now_utc(); + self.revoked_at.map(|t| t < now).unwrap_or(false) + || self.deleted_at.map(|t| t < now).unwrap_or(false) + } + fn is_validated(&self) -> bool { + self.validated_at + .map(|t| t >= OffsetDateTime::now_utc()) + .unwrap_or(false) + } + async fn find_associated_identities( + &self, + store: Arc, + ) -> anyhow::Result> { + let oauth_identity = get_oauth_identity_data(&self).await?; + + let identity = store + .find_identity(None, oauth_identity.email.as_deref()) + .await?; + + let now = OffsetDateTime::now_utc(); + if let Some(email) = oauth_identity.email.clone() { + let identity = identity.unwrap_or(Identity { + id: Uuid::new_v4(), + data: None, + created_at: OffsetDateTime::now_utc(), + deleted_at: None, + }); + store.write_identity(&identity).await?; + store.write_email(&email).await?; + store + .write_email_validation(&EmailValidation { + id: Some(Uuid::new_v4()), + identity_id: Some(identity.id), + email_address: email, + code: None, + is_oauth_derived: true, + created_at: now, + expired_at: now, + validated_at: Some(now), + revoked_at: None, + deleted_at: None, + }) + .await?; + Ok(Some(identity)) + } else { + Ok(identity) + } + } + async fn validate( + &mut self, + i: &Identity, + store: Arc, + ) -> anyhow::Result<()> { + self.identity_id = Some(i.id); + self.validated_at = Some(OffsetDateTime::now_utc()); + store.write_oauth_validation(&self).await?; + Ok(()) + } +} + +#[derive(Debug, EnumString)] +pub enum ValidationType { + Email, + Oauth, +} + #[derive(sqlx::FromRow, Debug)] pub struct EmailValidation { #[sqlx(rename = "email_validation_public_id")] @@ -66,16 +183,53 @@ pub struct EmailValidation { identity_id: Option, #[sqlx(rename = "address")] email_address: String, - attempts: i32, - code: String, - is_validated: bool, + code: Option, + is_oauth_derived: bool, + created_at: OffsetDateTime, + expired_at: OffsetDateTime, + validated_at: Option, + revoked_at: Option, + deleted_at: Option, +} + +#[derive(Debug)] +pub struct OauthValidation { + id: Option, + identity_id: Option, + oauth_provider: OauthProvider, + access_token: Option, + raw_response: Option, created_at: OffsetDateTime, - expires_at: OffsetDateTime, + validated_at: Option, revoked_at: Option, + deleted_at: Option, +} + +#[derive(Debug, Clone)] +pub struct OauthProvider { + pub name: OauthProviderName, + pub flow: Option, + pub base_url: Url, + pub response: OauthResponseType, + pub default_scope: String, + pub client_id: String, + pub client_secret: String, + pub redirect_url: Url, + pub created_at: OffsetDateTime, + pub deleted_at: Option, +} + +#[derive(Debug, Display, Clone, Copy, ValueEnum, EnumString)] +pub enum OauthResponseType { + Code, + IdToken, + None, + Token, } -#[derive(Copy, Display, Clone, Debug)] -pub enum OauthProvider { +// TODO: feature gate ValueEnum since it's only needed for iam builds +#[derive(Copy, Display, Clone, Debug, ValueEnum, EnumString)] +pub enum OauthProviderName { Amazon, Apple, Dropbox, @@ -121,19 +275,24 @@ pub type SessionSecret = String; pub type SessionSecretHash = String; pub type ValidationRequestId = Uuid; pub type ValidationSecretCode = String; +pub type OauthRedirectAuthUrl = Url; #[derive(Debug, derive_more::Display, thiserror::Error)] pub enum SecdError { - InvalidEmailAddress, - InvalidCode, - InitializationFailure(sqlx::Error), - IdentityIdShouldExistInvariant, EmailSendError(#[from] EmailMessengerError), - EmailValidationRequestError, EmailValidationExpiryOverflow, + EmailValidationRequestError, + OauthValidationRequestError, + IdentityIdShouldExistInvariant, + InitializationFailure(sqlx::Error), + InvalidCode, + InvalidEmailAddress, + InputValidation(String), + InternalError(String), + NotImplemented(String), SessionExpiryOverflow, Unauthenticated, - Unknown, + Todo, } pub struct Secd { @@ -142,191 +301,6 @@ pub struct Secd { } impl Secd { - pub async fn init( - auth_store: AuthStore, - conn_string: Option<&str>, - email_messenger: AuthEmail, - email_template_login: Option, - email_template_signup: Option, - ) -> Result { - let store = match auth_store { - AuthStore::Sqlite => { - SqliteClient::new( - sqlx::sqlite::SqlitePoolOptions::new() - .connect(conn_string.unwrap_or("sqlite::memory:".into())) - .await - .map_err(|e| SecdError::InitializationFailure(e))?, - ) - .await - } - AuthStore::Postgres => { - PgClient::new( - sqlx::postgres::PgPoolOptions::new() - .connect(conn_string.expect("No postgres connection string provided.")) - .await - .map_err(|e| SecdError::InitializationFailure(e))?, - ) - .await - } - rest @ _ => { - error!( - "requested an AuthStore which has not yet been implemented: {:?}", - rest - ); - unimplemented!() - } - }; - - let email_sender = match email_messenger { - // TODO: initialize email and SMS templates with secd - AuthEmail::LocalStub => email::LocalEmailStubber { - email_template_login, - email_template_signup, - }, - _ => unimplemented!(), - }; - - Ok(Secd { - store, - email_messenger: Arc::new(email_sender), - }) - } - /// create_validation_request - /// - /// Generate a request to validate the provided email. - pub async fn create_validation_request( - &self, - email: Option<&str>, - ) -> Result { - let now = OffsetDateTime::now_utc(); - - let email = match email { - Some(ea) => { - if EmailAddress::is_valid(ea) { - ea - } else { - return Err(SecdError::InvalidEmailAddress); - } - } - None => return Err(SecdError::InvalidEmailAddress), - }; - - let mut ev = EmailValidation { - id: None, - identity_id: None, - email_address: email.to_string(), - attempts: 0, - code: Alphanumeric - .sample_string(&mut rand::thread_rng(), VALIDATION_CODE_SIZE) - .to_lowercase(), - is_validated: false, - created_at: now, - expires_at: now - .checked_add(Duration::new(EMAIL_VALIDATION_DURATION, 0)) - .ok_or(SecdError::EmailValidationExpiryOverflow)?, - revoked_at: None, - }; - - let (req_id, mail_type) = match self - .store - .find_identity(None, Some(email)) - .await - .map_err(|e| util::log_err(e.into(), SecdError::Unknown))? - { - Some(identity) => { - let req_id = { - ev.identity_id = Some(identity.id); - self.store - .write_email_validation(&ev) - .await - .map_err(|e| util::log_err(e.into(), SecdError::Unknown))? - }; - (req_id, client::EmailType::Login) - } - None => { - let identity = Identity { - id: Uuid::new_v4(), - created_at: OffsetDateTime::now_utc(), - data: None, - }; - self.store - .write_identity(&identity) - .await - .map_err(|e| util::log_err(e.into(), SecdError::Unknown))?; - self.store - .write_email(identity.id, email) - .await - .map_err(|e| util::log_err(e.into(), SecdError::Unknown))?; - - let req_id = { - ev.identity_id = Some(identity.id); - self.store - .write_email_validation(&ev) - .await - .map_err(|e| util::log_err(e.into(), SecdError::Unknown))? - }; - - (req_id, client::EmailType::Signup) - } - }; - - self.email_messenger - .send_email(email, &req_id.to_string(), &ev.code, mail_type) - .await?; - - Ok(req_id) - } - /// exchange_secret_for_session - /// - /// Exchanges a secret, which consists of a validation_request_id and secret_code - /// for a session which allows authentication on behalf of the associated identity. - /// - /// Session secrets should be used to return authorization for the associated identity. - pub async fn exchange_code_for_session( - &self, - validation_request_id: ValidationRequestId, - code: ValidationSecretCode, - ) -> Result { - let mut ev = self - .store - .find_email_validation(Some(&validation_request_id), Some(&code)) - .await - .map_err(|e| util::log_err(e.into(), SecdError::EmailValidationExpiryOverflow))?; - - if ev.is_validated - || ev.expires_at < OffsetDateTime::now_utc() - || ev.attempts >= VALIDATION_ATTEMPTS_MAX - { - return Err(SecdError::InvalidCode); - }; - - ev.is_validated = true; - ev.attempts += 1; - self.store - .write_email_validation(&ev) - .await - .map_err(|e| util::log_err(e.into(), SecdError::Unknown))?; - - // TODO: clear previous sessions if they fit the criteria - let now = OffsetDateTime::now_utc(); - let s = Session { - identity_id: ev - .identity_id - .ok_or(SecdError::IdentityIdShouldExistInvariant)?, - secret: Some(Alphanumeric.sample_string(&mut rand::thread_rng(), SESSION_SIZE_BYTES)), - created_at: now, - expires_at: now - .checked_add(Duration::new(SESSION_DURATION, 0)) - .ok_or(SecdError::SessionExpiryOverflow)?, - revoked_at: None, - }; - self.store - .write_session(&s) - .await - .map_err(|e| util::log_err(e.into(), SecdError::Unknown))?; - - Ok(s) - } /// get_identity /// /// Return all information associated with the identity id. @@ -350,7 +324,7 @@ impl Secd { Ok(Authorization { session }) } Ok(_) => Err(SecdError::Unauthenticated), - Err(_e) => Err(SecdError::Unknown), + Err(_e) => Err(SecdError::Todo), } } /// revoke_session diff --git a/crates/secd/src/util/mod.rs b/crates/secd/src/util/mod.rs index da16901..bb177cb 100644 --- a/crates/secd/src/util/mod.rs +++ b/crates/secd/src/util/mod.rs @@ -1,13 +1,27 @@ +use std::str::FromStr; + +use anyhow::{bail, Context}; use log::error; use rand::distributions::Alphanumeric; use rand::{thread_rng, Rng}; +use reqwest::header; +use serde::{Deserialize, Serialize}; +use url::Url; -use crate::SecdError; +use crate::{ + OauthProvider, OauthProviderName, OauthValidation, SecdError, ValidationRequestId, + INTERNAL_ERR_MSG, +}; pub(crate) fn log_err(e: Box, new_e: SecdError) -> SecdError { error!("{:?}", e); new_e } +pub(crate) fn to_secd_err(e: anyhow::Error, new_e: SecdError) -> SecdError { + error!("{:?}", e); + new_e +} + pub(crate) fn log_err_sqlx(e: sqlx::Error) -> sqlx::Error { error!("{:?}", e); e @@ -19,3 +33,145 @@ pub(crate) fn generate_random_url_safe(n: usize) -> String { .map(char::from) .collect() } + +pub(crate) fn remove_trailing_slash(url: &mut Url) -> String { + let mut u = url.to_string(); + + if u.ends_with('/') { + u.pop(); + } + + u +} + +pub(crate) fn build_oauth_auth_url( + p: &OauthProvider, + validation_id: ValidationRequestId, +) -> Result { + let redirect_url = remove_trailing_slash(&mut p.redirect_url.clone()); + + Ok(Url::from_str(&format!( + "{}?client_id={}&response_type={}&redirect_uri={}&scope={}&state={}", + p.base_url, + p.client_id, + p.response.to_string().to_lowercase(), + redirect_url, + p.default_scope, + validation_id.to_string() + )) + .map_err(|_| SecdError::InternalError(INTERNAL_ERR_MSG.into()))?) +} + +pub(crate) async fn get_oauth_identity_data( + validation: &OauthValidation, +) -> anyhow::Result { + let provider = validation.oauth_provider.name; + let token = validation + .access_token + .clone() + .ok_or(SecdError::InternalError( + "no access token provided with which to build oauth data url".into(), + ))?; + + let url = Url::from_str(&format!( + "{}{}", + match provider { + OauthProviderName::Google => + "https://www.googleapis.com/oauth2/v2/userinfo?access_token=", + _ => unimplemented!(), + }, + token + ))?; + + let resp = reqwest::get(url).await?.json::().await?; + let identity = match provider { + OauthProviderName::Google => OauthAccessIdentity { + email: resp + .get("email") + .and_then(|v| v.as_str().map(|s| s.to_string())), + email_is_verified: resp.get("verified_email").and_then(|v| v.as_bool()), + picture_url: resp + .get("picture") + .and_then(|v| Url::from_str(&v.to_string()).ok()), + }, + _ => unimplemented!(), + }; + + Ok(identity) +} + +#[derive(Debug, Serialize)] +pub(crate) struct OauthAccessTokenGoogleRequest { + grant_type: String, + code: String, + client_id: String, + client_secret: String, + redirect_uri: String, +} + +#[derive(Debug, Deserialize)] +pub(crate) struct OauthAccessTokenGoogleResponse { + access_token: String, + expires_in: i32, + token_type: String, + scope: String, + id_token: String, +} + +#[derive(Debug)] +pub(crate) struct OauthAccessIdentity { + pub(crate) email: Option, + pub(crate) email_is_verified: Option, + pub(crate) picture_url: Option, +} + +type AccessTokenRequestData = String; + +pub(crate) async fn get_oauth_access_token( + validation: &OauthValidation, + secret_code: &String, +) -> anyhow::Result { + let provider = validation.oauth_provider.name; + + let url = Url::from_str(match provider { + OauthProviderName::Google => "https://accounts.google.com/o/oauth2/token", + _ => unimplemented!(), + })?; + + let request_data = serde_json::to_string(&match provider { + OauthProviderName::Google => OauthAccessTokenGoogleRequest { + grant_type: "authorization_code".to_string(), + code: secret_code.to_string(), + client_id: validation.oauth_provider.client_id.clone(), + client_secret: validation.oauth_provider.client_secret.clone(), + redirect_uri: remove_trailing_slash( + &mut validation.oauth_provider.redirect_url.clone(), + ), + }, + _ => unimplemented!(), + })?; + + let r = reqwest::Client::new() + .post(url) + .body(request_data) + .header(header::CONTENT_TYPE, "application/json") + .send() + .await + .context(format!( + "Failed to successfully POST a new access token for: {}", + provider + ))?; + + let access_token = match provider { + OauthProviderName::Google => { + let resp: OauthAccessTokenGoogleResponse = r.json().await.context(format!( + "Failed to parse access token response for: {}", + provider + ))?; + resp.access_token + } + _ => unimplemented!(), + }; + + Ok(access_token) +} -- cgit v1.2.3