aboutsummaryrefslogtreecommitdiff
path: root/crates/secd/src/client/store
diff options
context:
space:
mode:
Diffstat (limited to 'crates/secd/src/client/store')
-rw-r--r--crates/secd/src/client/store/mod.rs78
-rw-r--r--crates/secd/src/client/store/sql_db.rs132
2 files changed, 177 insertions, 33 deletions
diff --git a/crates/secd/src/client/store/mod.rs b/crates/secd/src/client/store/mod.rs
index 8a076c4..7bf01d5 100644
--- a/crates/secd/src/client/store/mod.rs
+++ b/crates/secd/src/client/store/mod.rs
@@ -1,22 +1,28 @@
pub(crate) mod sql_db;
+use async_trait::async_trait;
use sqlx::{Postgres, Sqlite};
use std::sync::Arc;
use uuid::Uuid;
-use crate::{util, Address, AddressType, AddressValidation, Identity, IdentityId, Session};
+use crate::{
+ util, Address, AddressType, AddressValidation, Credential, CredentialId, CredentialType,
+ Identity, IdentityId, Session,
+};
use self::sql_db::SqlClient;
#[derive(Debug, thiserror::Error, derive_more::Display)]
pub enum StoreError {
SqlClientError(#[from] sqlx::Error),
+ SerdeError(#[from] serde_json::Error),
+ ParseError(#[from] strum::ParseError),
StoreValueCannotBeParsedInvariant,
IdempotentCheckAlreadyExists,
}
-#[async_trait::async_trait(?Send)]
-pub trait Store {
+#[async_trait]
+pub trait Store: Send + Sync {
fn get_type(&self) -> StoreType;
}
@@ -25,7 +31,7 @@ pub enum StoreType {
Sqlite { c: Arc<SqlClient<Sqlite>> },
}
-#[async_trait::async_trait(?Send)]
+#[async_trait]
pub(crate) trait Storable<'a> {
type Item;
type Lens;
@@ -64,7 +70,15 @@ pub(crate) struct SessionLens<'a> {
}
impl<'a> Lens for SessionLens<'a> {}
-#[async_trait::async_trait(?Send)]
+pub(crate) struct CredentialLens<'a> {
+ pub id: Option<CredentialId>,
+ pub identity_id: Option<IdentityId>,
+ pub t: Option<&'a CredentialType>,
+ pub restrict_by_key: Option<bool>,
+}
+impl<'a> Lens for CredentialLens<'a> {}
+
+#[async_trait]
impl<'a> Storable<'a> for Address {
type Item = Address;
type Lens = AddressLens<'a>;
@@ -93,7 +107,7 @@ impl<'a> Storable<'a> for Address {
}
}
-#[async_trait::async_trait(?Send)]
+#[async_trait]
impl<'a> Storable<'a> for AddressValidation {
type Item = AddressValidation;
type Lens = AddressValidationLens<'a>;
@@ -116,7 +130,7 @@ impl<'a> Storable<'a> for AddressValidation {
}
}
-#[async_trait::async_trait(?Send)]
+#[async_trait]
impl<'a> Storable<'a> for Identity {
type Item = Identity;
type Lens = IdentityLens<'a>;
@@ -158,7 +172,7 @@ impl<'a> Storable<'a> for Identity {
}
}
-#[async_trait::async_trait(?Send)]
+#[async_trait]
impl<'a> Storable<'a> for Session {
type Item = Session;
type Lens = SessionLens<'a>;
@@ -183,3 +197,51 @@ impl<'a> Storable<'a> for Session {
})
}
}
+
+#[async_trait]
+impl<'a> Storable<'a> for Credential {
+ type Item = Credential;
+ type Lens = CredentialLens<'a>;
+
+ async fn write(&self, store: Arc<dyn Store>) -> Result<(), StoreError> {
+ match store.get_type() {
+ StoreType::Postgres { c } => c.write_credential(self).await?,
+ StoreType::Sqlite { c } => c.write_credential(self).await?,
+ }
+ Ok(())
+ }
+
+ async fn find(
+ store: Arc<dyn Store>,
+ lens: &'a Self::Lens,
+ ) -> Result<Vec<Self::Item>, StoreError> {
+ Ok(match store.get_type() {
+ StoreType::Postgres { c } => {
+ c.find_credential(
+ lens.id,
+ lens.identity_id,
+ lens.t,
+ if let Some(true) = lens.restrict_by_key {
+ true
+ } else {
+ false
+ },
+ )
+ .await?
+ }
+ StoreType::Sqlite { c } => {
+ c.find_credential(
+ lens.id,
+ lens.identity_id,
+ lens.t,
+ if let Some(true) = lens.restrict_by_key {
+ true
+ } else {
+ false
+ },
+ )
+ .await?
+ }
+ })
+ }
+}
diff --git a/crates/secd/src/client/store/sql_db.rs b/crates/secd/src/client/store/sql_db.rs
index ecb13be..3e72fe8 100644
--- a/crates/secd/src/client/store/sql_db.rs
+++ b/crates/secd/src/client/store/sql_db.rs
@@ -1,27 +1,19 @@
-use std::{str::FromStr, sync::Arc};
-
-use email_address::EmailAddress;
-use serde_json::value::RawValue;
-use sqlx::{
- database::HasArguments, types::Json, ColumnIndex, Database, Decode, Encode, Executor,
- IntoArguments, Pool, Transaction, Type,
-};
-use time::OffsetDateTime;
-use uuid::Uuid;
-
+use super::{Store, StoreError, StoreType};
use crate::{
- Address, AddressType, AddressValidation, AddressValidationMethod, Identity, Session,
- SessionToken,
+ Address, AddressType, AddressValidation, AddressValidationMethod, Credential, CredentialId,
+ CredentialType, Identity, IdentityId, Session,
};
-
+use email_address::EmailAddress;
use lazy_static::lazy_static;
+use sqlx::{
+ database::HasArguments, ColumnIndex, Database, Decode, Encode, Executor, IntoArguments, Pool,
+ Transaction, Type,
+};
use sqlx::{Postgres, Sqlite};
use std::collections::HashMap;
-
-use super::{
- AddressLens, AddressValidationLens, IdentityLens, SessionLens, Storable, Store, StoreError,
- StoreType,
-};
+use std::{str::FromStr, sync::Arc};
+use time::OffsetDateTime;
+use uuid::Uuid;
const SQLITE: &str = "sqlite";
const PGSQL: &str = "pgsql";
@@ -30,6 +22,8 @@ const WRITE_ADDRESS: &str = "write_address";
const FIND_ADDRESS: &str = "find_address";
const WRITE_ADDRESS_VALIDATION: &str = "write_address_validation";
const FIND_ADDRESS_VALIDATION: &str = "find_address_validation";
+const WRITE_CREDENTIAL: &str = "write_credential";
+const FIND_CREDENTIAL: &str = "find_credential";
const WRITE_IDENTITY: &str = "write_identity";
const FIND_IDENTITY: &str = "find_identity";
const WRITE_SESSION: &str = "write_session";
@@ -72,6 +66,14 @@ lazy_static! {
FIND_SESSION,
include_str!("../../../store/sqlite/sql/find_session.sql"),
),
+ (
+ WRITE_CREDENTIAL,
+ include_str!("../../../store/sqlite/sql/write_credential.sql"),
+ ),
+ (
+ FIND_CREDENTIAL,
+ include_str!("../../../store/sqlite/sql/find_credential.sql"),
+ ),
]
.iter()
.cloned()
@@ -110,6 +112,14 @@ lazy_static! {
FIND_SESSION,
include_str!("../../../store/pg/sql/find_session.sql"),
),
+ (
+ WRITE_CREDENTIAL,
+ include_str!("../../../store/pg/sql/write_credential.sql"),
+ ),
+ (
+ FIND_CREDENTIAL,
+ include_str!("../../../store/pg/sql/find_credential.sql"),
+ ),
]
.iter()
.cloned()
@@ -131,7 +141,7 @@ pub trait SqlxResultExt<T> {
impl<T> SqlxResultExt<T> for Result<T, sqlx::Error> {
fn extend_err(self) -> Result<T, StoreError> {
if let Err(sqlx::Error::Database(dbe)) = &self {
- if dbe.code() == Some("23505".into()) {
+ if dbe.code() == Some("23505".into()) || dbe.code() == Some("2067".into()) {
return Err(StoreError::IdempotentCheckAlreadyExists);
}
}
@@ -160,7 +170,7 @@ impl Store for PgClient {
impl PgClient {
pub async fn new(pool: sqlx::Pool<Postgres>) -> Arc<dyn Store + Send + Sync + 'static> {
- sqlx::migrate!("store/pg/migrations")
+ sqlx::migrate!("store/pg/migrations", "secd")
.run(&pool)
.await
.expect(ERR_MSG_MIGRATION_FAILED);
@@ -187,7 +197,7 @@ impl Store for SqliteClient {
impl SqliteClient {
pub async fn new(pool: sqlx::Pool<Sqlite>) -> Arc<dyn Store + Send + Sync + 'static> {
- sqlx::migrate!("store/sqlite/migrations")
+ sqlx::migrate!("store/sqlite/migrations", "secd")
.run(&pool)
.await
.expect(ERR_MSG_MIGRATION_FAILED);
@@ -410,8 +420,7 @@ where
let sqls = get_sqls(&self.sqls_root, WRITE_IDENTITY);
sqlx::query(&sqls[0])
.bind(i.id)
- // TODO: validate this is actually Json somewhere way up the chain (when being deserialized)
- .bind(i.metadata.clone().unwrap_or("{}".into()))
+ .bind(i.metadata.clone())
.bind(i.created_at)
.bind(OffsetDateTime::now_utc())
.bind(i.deleted_at)
@@ -449,7 +458,7 @@ where
.extend_err()?;
let mut res = vec![];
- for (id, metadata, created_at, updated_at, deleted_at) in rs.into_iter() {
+ for (id, metadata, created_at, _, deleted_at) in rs.into_iter() {
res.push(Identity {
id,
address_validations: vec![],
@@ -509,6 +518,79 @@ where
}
Ok(res)
}
+
+ pub async fn write_credential(&self, c: &Credential) -> Result<(), StoreError> {
+ let sqls = get_sqls(&self.sqls_root, WRITE_CREDENTIAL);
+ let partial_key = match &c.t {
+ crate::CredentialType::Passphrase { key, value: _ } => Some(key.clone()),
+ _ => None,
+ };
+
+ sqlx::query(&sqls[0])
+ .bind(c.id)
+ .bind(c.identity_id)
+ .bind(partial_key)
+ .bind(c.t.to_string())
+ .bind(serde_json::to_string(&c.t)?)
+ .bind(c.created_at)
+ .bind(c.revoked_at)
+ .bind(c.deleted_at)
+ .execute(&self.pool)
+ .await
+ .extend_err()?;
+ Ok(())
+ }
+ pub async fn find_credential(
+ &self,
+ id: Option<Uuid>,
+ identity_id: Option<Uuid>,
+ t: Option<&CredentialType>,
+ restrict_by_key: bool,
+ ) -> Result<Vec<Credential>, StoreError> {
+ let sqls = get_sqls(&self.sqls_root, FIND_CREDENTIAL);
+ let key = restrict_by_key
+ .then(|| {
+ t.map(|i| match i {
+ CredentialType::Passphrase { key, value: _ } => key.clone(),
+ _ => todo!(),
+ })
+ })
+ .flatten();
+
+ let rs = sqlx::query_as::<
+ _,
+ (
+ CredentialId,
+ IdentityId,
+ String,
+ OffsetDateTime,
+ Option<OffsetDateTime>,
+ Option<OffsetDateTime>,
+ ),
+ >(&sqls[0])
+ .bind(id.as_ref())
+ .bind(identity_id.as_ref())
+ .bind(t.map(|i| i.to_string()))
+ .bind(key)
+ .fetch_all(&self.pool)
+ .await
+ .extend_err()?;
+
+ let mut res = vec![];
+ for (id, identity_id, data, created_at, revoked_at, deleted_at) in rs.into_iter() {
+ let t: CredentialType = serde_json::from_str(&data)?;
+ res.push(Credential {
+ id,
+ identity_id,
+ t,
+ created_at,
+ revoked_at,
+ deleted_at,
+ })
+ }
+
+ Ok(res)
+ }
}
fn get_sqls(root: &str, file: &str) -> Vec<String> {