Skip to content

Commit

Permalink
feat: add support to multitenancy (#37)
Browse files Browse the repository at this point in the history
  • Loading branch information
dracarys18 authored Jan 6, 2025
1 parent 1ef277b commit e4dc7b3
Show file tree
Hide file tree
Showing 29 changed files with 418 additions and 111 deletions.
36 changes: 22 additions & 14 deletions benches/encryption_bench.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
use std::sync::Arc;

use cripta::{
app::AppState,
config,
core::{crypto::custodian::Custodian, datakey::create::generate_and_create_data_key},
multitenancy::TenantId,
types::{
core::{DecryptedData, DecryptedDataGroup, Identifier},
method::EncryptionType,
Expand All @@ -18,6 +17,7 @@ use tokio::runtime::Runtime;
// Note: modify this to run for different size inputs
const SINGLE_BENCH_ITERATION: u32 = 10;
const BATCH_BENCH_ITERATION: u32 = 10;
const PUBLIC_TENANT_ID: &str = "public";

criterion_main!(benches);
criterion_group!(
Expand All @@ -36,18 +36,20 @@ pub fn criterion_data_encryption_decryption(c: &mut Criterion) {
let config = config::Config::with_config_path(config::Environment::Dev, None);
let state = rt.block_on(async { AppState::from_config(config).await });
// create a DataKey in data_key_store
let config2 = config::Config::with_config_path(config::Environment::Dev, None);
let identifier = Identifier::User(String::from("bench_user"));
let key_create_req: CreateDataKeyRequest = CreateDataKeyRequest {
identifier: identifier.clone(),
};
let key_create_state = rt.block_on(async {
let state = AppState::from_config(config2).await;
<AppState as Into<Arc<AppState>>>::into(state)
});
let tenant_state = state
.tenant_states
.get(&TenantId::new(PUBLIC_TENANT_ID.to_string()))
.cloned()
.expect("Invalid tenant");

rt.block_on(async {
let _ =
generate_and_create_data_key(key_create_state, custodian.clone(), key_create_req).await;
generate_and_create_data_key(tenant_state.clone(), custodian.clone(), key_create_req)
.await;
});

{
Expand All @@ -68,7 +70,7 @@ pub fn criterion_data_encryption_decryption(c: &mut Criterion) {
black_box(rt.block_on(async {
bench_input
.clone()
.encrypt(&state, &identifier.clone(), custodian.clone())
.encrypt(&tenant_state, &identifier.clone(), custodian.clone())
.await
.expect("Failed while encrypting")
}))
Expand All @@ -88,7 +90,7 @@ pub fn criterion_data_encryption_decryption(c: &mut Criterion) {
let bench_input = EncryptionType::Single(DecryptedData::from_data(value.into()));
let encrypted_data = rt.block_on(async {
bench_input
.encrypt(&state, &identifier, custodian.clone())
.encrypt(&tenant_state, &identifier, custodian.clone())
.await
.expect("Failed while encrypting")
});
Expand All @@ -102,7 +104,7 @@ pub fn criterion_data_encryption_decryption(c: &mut Criterion) {
black_box(rt.block_on(async {
encrypted_data
.clone()
.decrypt(&state, &identifier.clone(), custodian.clone())
.decrypt(&tenant_state, &identifier.clone(), custodian.clone())
.await
.expect("Failed while decrypting")
}))
Expand Down Expand Up @@ -138,6 +140,12 @@ pub fn criterion_batch_data_encryption_decryption(c: &mut Criterion) {
let custodian = Custodian::new(Some(("key".to_string(), "value".to_string())));
let config = config::Config::with_config_path(config::Environment::Dev, None);
let state = rt.block_on(async { AppState::from_config(config).await });
let tenant_state = state
.tenant_states
.get(&TenantId::new(PUBLIC_TENANT_ID.to_string()))
.cloned()
.expect("Invalid tenant");

let identifier = Identifier::User(String::from("bench_user"));
{
let mut group = c.benchmark_group("data-encryption-batch");
Expand All @@ -153,7 +161,7 @@ pub fn criterion_batch_data_encryption_decryption(c: &mut Criterion) {
black_box(rt.block_on(async {
bench_input
.clone()
.encrypt(&state, &identifier.clone(), custodian.clone())
.encrypt(&tenant_state, &identifier.clone(), custodian.clone())
.await
.expect("Failed while encrypting")
}))
Expand All @@ -169,7 +177,7 @@ pub fn criterion_batch_data_encryption_decryption(c: &mut Criterion) {
let decrypted_input = EncryptionType::Batch(generate_batch_data(input_size));
let encrypted_bench_input = rt.block_on(async {
decrypted_input
.encrypt(&state, &identifier, custodian.clone())
.encrypt(&tenant_state, &identifier, custodian.clone())
.await
.expect("Failed while encrypting")
});
Expand All @@ -182,7 +190,7 @@ pub fn criterion_batch_data_encryption_decryption(c: &mut Criterion) {
black_box(rt.block_on(async {
encrypted_bench_input
.clone()
.decrypt(&state, &identifier.clone(), custodian.clone())
.decrypt(&tenant_state, &identifier.clone(), custodian.clone())
.await
.expect("Failed while decrypting")
}))
Expand Down
8 changes: 8 additions & 0 deletions config/development.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,14 @@ pool_size = 5
min_idle = 2
enable_ssl = false

[multitenancy.tenants.public]
cache_prefix = "public"
schema = "public"

[multitenancy.tenants.global]
cache_prefix = "global"
schema = "global"

[log]
log_level = "debug"
log_format = "console"
Expand Down
56 changes: 45 additions & 11 deletions src/app.rs
Original file line number Diff line number Diff line change
@@ -1,44 +1,74 @@
#[cfg(feature = "mtls")]
pub mod tls;

use crate::{config::Config, crypto::blake3::Blake3, crypto::KeyManagerClient, storage::DbState};

use crate::storage::adapter;
use crate::{
config::{Config, TenantConfig},
crypto::blake3::Blake3,
crypto::KeyManagerClient,
multitenancy::{MultiTenant, TenantId, TenantState},
storage::DbState,
};
use std::sync::Arc;

#[cfg(not(feature = "cassandra"))]
use diesel_async::pooled_connection::bb8::Pool;
#[cfg(not(feature = "cassandra"))]
use diesel_async::AsyncPgConnection;

use rayon::{ThreadPool, ThreadPoolBuilder};
use rustc_hash::FxHashMap;

#[cfg(not(feature = "cassandra"))]
type StorageState = DbState<Pool<AsyncPgConnection>, adapter::PostgreSQL>;
pub(crate) type StorageState = DbState<Pool<AsyncPgConnection>, adapter::PostgreSQL>;

#[cfg(feature = "cassandra")]
type StorageState = DbState<scylla::CachingSession, adapter::Cassandra>;
pub(crate) type StorageState = DbState<scylla::CachingSession, adapter::Cassandra>;

pub struct AppState {
pub conf: Config,
pub db_pool: StorageState,
pub keymanager_client: KeyManagerClient,
pub tenant_states: MultiTenant<TenantState>,
}

impl AppState {
pub async fn from_config(config: Config) -> Self {
let mut tenants = FxHashMap::default();

for (tenant_id, tenant) in &config.multitenancy.tenants.0 {
tenants.insert(
TenantId::new(tenant_id.clone()),
TenantState::new(Arc::new(SessionState::from_config(&config, tenant).await)),
);
}

Self {
conf: config,
tenant_states: tenants,
}
}
}

pub struct SessionState {
pub cache_prefix: String,
pub thread_pool: ThreadPool,
pub keymanager_client: KeyManagerClient,
db_pool: StorageState,
pub hash_client: Blake3,
}

impl AppState {
impl SessionState {
/// # Panics
///
/// Panics if failed to build thread pool
#[allow(clippy::expect_used)]
pub async fn from_config(config: Config) -> Self {
pub async fn from_config(config: &Config, tenant_config: &TenantConfig) -> Self {
let secrets = config.secrets.clone();
let db_pool = StorageState::from_config(&config).await;
let db_pool = StorageState::from_config(config, &tenant_config.schema).await;
let num_threads = config.pool_config.pool;
let hash_client = Blake3::from_config(&config).await;
let hash_client = Blake3::from_config(config).await;

Self {
conf: config,
cache_prefix: tenant_config.cache_prefix.clone(),
keymanager_client: secrets.create_keymanager_client().await,
db_pool,
hash_client,
Expand All @@ -48,4 +78,8 @@ impl AppState {
.expect("Failed to create a threadpool"),
}
}

pub fn db_pool(&self) -> &StorageState {
&self.db_pool
}
}
16 changes: 12 additions & 4 deletions src/bin/cripta.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,12 @@ use tower::ServiceBuilder;
use tower_http::{trace::TraceLayer, ServiceBuilderExt};

use cripta::{
app::AppState, config, env::observability, env::observability as logger, request_id::MakeUlid,
app::AppState,
config,
consts::{TENANT_HEADER, X_REQUEST_ID},
env::observability,
env::observability as logger,
request_id::MakeUlid,
routes::*,
};
use std::sync::Arc;
Expand All @@ -33,10 +38,13 @@ async fn main() {
.propagate_x_request_id()
.layer(
TraceLayer::new_for_http().make_span_with(|request: &Request<Body>| {
let request_id = request.headers().get("x-request-id").and_then(|r| r.to_str().ok()).unwrap_or("unknown_id");
let tenant_id = request.headers().get(TENANT_HEADER).and_then(|r| r.to_str().ok()).unwrap_or("invalid_tenant");
let request_id = request.headers().get(X_REQUEST_ID).and_then(|r| r.to_str().ok()).unwrap_or("unknown_id");

tracing::debug_span!("request",request_id = %request_id,method = %request.method(), uri=%request.uri())
}),
tracing::debug_span!("request",request_id = %request_id,method = %request.method(), uri=%request.uri(), tenant_id=%tenant_id)
})
.on_request(logger::OnRequest::with_level(logger::LogLevel::Info))
.on_response(logger::OnResponse::with_level(logger::LogLevel::Info))
);

let app = Router::new()
Expand Down
33 changes: 31 additions & 2 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use crate::{
use std::num::NonZeroUsize;

use config::File;
use rustc_hash::FxHashMap;
use serde::Deserialize;
use std::sync::Arc;

Expand Down Expand Up @@ -139,15 +140,29 @@ pub struct Config {
#[serde(default)]
pub cassandra: Cassandra,
pub log: LogConfig,
pub multitenancy: MultiTenancy,
pub pool_config: PoolConfig,
#[cfg(feature = "mtls")]
pub certs: Certs,
}

#[derive(Deserialize, Debug)]
pub struct MultiTenancy {
pub tenants: TenantsConfig,
}

#[derive(Deserialize, Debug)]
pub struct TenantsConfig(pub FxHashMap<String, TenantConfig>);

#[derive(Deserialize, Debug)]
pub struct TenantConfig {
pub schema: String,
pub cache_prefix: String,
}

#[derive(Deserialize, Debug, Eq, PartialEq)]
pub struct Cassandra {
pub known_nodes: Vec<String>,
pub keyspace: String,
pub timeout: u32,
pub pool_size: NonZeroUsize,
pub cache_size: usize,
Expand Down Expand Up @@ -216,7 +231,6 @@ impl Default for Cassandra {
fn default() -> Self {
Self {
known_nodes: Vec::new(),
keyspace: String::new(),
timeout: 0,
cache_size: 0,
pool_size: NonZeroUsize::new(1).expect("The provided number is non zero"),
Expand All @@ -240,6 +254,17 @@ impl Cassandra {
}
}

impl MultiTenancy {
fn validate(&self) -> CustomResult<(), errors::ParsingError> {
error_stack::ensure!(
!self.tenants.0.is_empty(),
errors::ParsingError::DecodingFailed("Failed to validate multitenancy configuration. You need to configure atleast one tenant".to_string()
)
);
Ok(())
}
}

impl Config {
pub fn config_path(environment: Environment, explicit_config_path: Option<PathBuf>) -> PathBuf {
let mut config_path = PathBuf::new();
Expand Down Expand Up @@ -287,6 +312,10 @@ impl Config {
self.cassandra
.validate()
.expect("Failed to valdiate cassandra some missing configuration found");

self.multitenancy
.validate()
.expect("Failed to validate multitenancy, some missing configuration found");
}
}

Expand Down
3 changes: 3 additions & 0 deletions src/consts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,6 @@ pub mod base64 {
pub const BASE64_ENGINE: base64::engine::GeneralPurpose =
base64::engine::general_purpose::STANDARD;
}

pub const X_REQUEST_ID: &str = "x-request-id";
pub const TENANT_HEADER: &str = "x-tenant-id";
9 changes: 4 additions & 5 deletions src/core/crypto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,21 @@ mod encryption;
pub use crux::*;

use crate::{
app::AppState,
errors, metrics,
multitenancy::TenantState,
types::{
requests::{DecryptionRequest, EncryptDataRequest},
response::{DecryptionResponse, EncryptionResponse},
},
utils,
};
use axum::extract::{Json, State};
use axum::extract::Json;
use opentelemetry::KeyValue;
use std::sync::Arc;

use self::custodian::Custodian;

pub async fn encrypt_data(
State(state): State<Arc<AppState>>,
state: TenantState,
custodian: Custodian,
Json(req): Json<EncryptDataRequest>,
) -> errors::ApiResponseResult<Json<EncryptionResponse>> {
Expand All @@ -39,7 +38,7 @@ pub async fn encrypt_data(
}

pub async fn decrypt_data(
State(state): State<Arc<AppState>>,
state: TenantState,
custodian: Custodian,
Json(req): Json<DecryptionRequest>,
) -> errors::ApiResponseResult<Json<DecryptionResponse>> {
Expand Down
Loading

0 comments on commit e4dc7b3

Please sign in to comment.