aboutsummaryrefslogtreecommitdiff
path: root/crates/secd/src/util/mod.rs
diff options
context:
space:
mode:
Diffstat (limited to 'crates/secd/src/util/mod.rs')
-rw-r--r--crates/secd/src/util/mod.rs158
1 files changed, 157 insertions, 1 deletions
diff --git a/crates/secd/src/util/mod.rs b/crates/secd/src/util/mod.rs
index da16901..bb177cb 100644
--- a/crates/secd/src/util/mod.rs
+++ b/crates/secd/src/util/mod.rs
@@ -1,13 +1,27 @@
+use std::str::FromStr;
+
+use anyhow::{bail, Context};
use log::error;
use rand::distributions::Alphanumeric;
use rand::{thread_rng, Rng};
+use reqwest::header;
+use serde::{Deserialize, Serialize};
+use url::Url;
-use crate::SecdError;
+use crate::{
+ OauthProvider, OauthProviderName, OauthValidation, SecdError, ValidationRequestId,
+ INTERNAL_ERR_MSG,
+};
pub(crate) fn log_err(e: Box<dyn std::error::Error>, new_e: SecdError) -> SecdError {
error!("{:?}", e);
new_e
}
+pub(crate) fn to_secd_err(e: anyhow::Error, new_e: SecdError) -> SecdError {
+ error!("{:?}", e);
+ new_e
+}
+
pub(crate) fn log_err_sqlx(e: sqlx::Error) -> sqlx::Error {
error!("{:?}", e);
e
@@ -19,3 +33,145 @@ pub(crate) fn generate_random_url_safe(n: usize) -> String {
.map(char::from)
.collect()
}
+
+pub(crate) fn remove_trailing_slash(url: &mut Url) -> String {
+ let mut u = url.to_string();
+
+ if u.ends_with('/') {
+ u.pop();
+ }
+
+ u
+}
+
+pub(crate) fn build_oauth_auth_url(
+ p: &OauthProvider,
+ validation_id: ValidationRequestId,
+) -> Result<Url, SecdError> {
+ let redirect_url = remove_trailing_slash(&mut p.redirect_url.clone());
+
+ Ok(Url::from_str(&format!(
+ "{}?client_id={}&response_type={}&redirect_uri={}&scope={}&state={}",
+ p.base_url,
+ p.client_id,
+ p.response.to_string().to_lowercase(),
+ redirect_url,
+ p.default_scope,
+ validation_id.to_string()
+ ))
+ .map_err(|_| SecdError::InternalError(INTERNAL_ERR_MSG.into()))?)
+}
+
+pub(crate) async fn get_oauth_identity_data(
+ validation: &OauthValidation,
+) -> anyhow::Result<OauthAccessIdentity> {
+ let provider = validation.oauth_provider.name;
+ let token = validation
+ .access_token
+ .clone()
+ .ok_or(SecdError::InternalError(
+ "no access token provided with which to build oauth data url".into(),
+ ))?;
+
+ let url = Url::from_str(&format!(
+ "{}{}",
+ match provider {
+ OauthProviderName::Google =>
+ "https://www.googleapis.com/oauth2/v2/userinfo?access_token=",
+ _ => unimplemented!(),
+ },
+ token
+ ))?;
+
+ let resp = reqwest::get(url).await?.json::<serde_json::Value>().await?;
+ let identity = match provider {
+ OauthProviderName::Google => OauthAccessIdentity {
+ email: resp
+ .get("email")
+ .and_then(|v| v.as_str().map(|s| s.to_string())),
+ email_is_verified: resp.get("verified_email").and_then(|v| v.as_bool()),
+ picture_url: resp
+ .get("picture")
+ .and_then(|v| Url::from_str(&v.to_string()).ok()),
+ },
+ _ => unimplemented!(),
+ };
+
+ Ok(identity)
+}
+
+#[derive(Debug, Serialize)]
+pub(crate) struct OauthAccessTokenGoogleRequest {
+ grant_type: String,
+ code: String,
+ client_id: String,
+ client_secret: String,
+ redirect_uri: String,
+}
+
+#[derive(Debug, Deserialize)]
+pub(crate) struct OauthAccessTokenGoogleResponse {
+ access_token: String,
+ expires_in: i32,
+ token_type: String,
+ scope: String,
+ id_token: String,
+}
+
+#[derive(Debug)]
+pub(crate) struct OauthAccessIdentity {
+ pub(crate) email: Option<String>,
+ pub(crate) email_is_verified: Option<bool>,
+ pub(crate) picture_url: Option<Url>,
+}
+
+type AccessTokenRequestData = String;
+
+pub(crate) async fn get_oauth_access_token(
+ validation: &OauthValidation,
+ secret_code: &String,
+) -> anyhow::Result<String> {
+ let provider = validation.oauth_provider.name;
+
+ let url = Url::from_str(match provider {
+ OauthProviderName::Google => "https://accounts.google.com/o/oauth2/token",
+ _ => unimplemented!(),
+ })?;
+
+ let request_data = serde_json::to_string(&match provider {
+ OauthProviderName::Google => OauthAccessTokenGoogleRequest {
+ grant_type: "authorization_code".to_string(),
+ code: secret_code.to_string(),
+ client_id: validation.oauth_provider.client_id.clone(),
+ client_secret: validation.oauth_provider.client_secret.clone(),
+ redirect_uri: remove_trailing_slash(
+ &mut validation.oauth_provider.redirect_url.clone(),
+ ),
+ },
+ _ => unimplemented!(),
+ })?;
+
+ let r = reqwest::Client::new()
+ .post(url)
+ .body(request_data)
+ .header(header::CONTENT_TYPE, "application/json")
+ .send()
+ .await
+ .context(format!(
+ "Failed to successfully POST a new access token for: {}",
+ provider
+ ))?;
+
+ let access_token = match provider {
+ OauthProviderName::Google => {
+ let resp: OauthAccessTokenGoogleResponse = r.json().await.context(format!(
+ "Failed to parse access token response for: {}",
+ provider
+ ))?;
+ resp.access_token
+ }
+ _ => unimplemented!(),
+ };
+
+ Ok(access_token)
+}