Skip to content

Commit

Permalink
Merge pull request #40 from ECDAR-AAU-SW-P5/39-implement-create-session
Browse files Browse the repository at this point in the history
39 implement create session
  • Loading branch information
sabotack authored Nov 14, 2023
2 parents 947ad5f + 0161a91 commit 9adcea0
Show file tree
Hide file tree
Showing 4 changed files with 142 additions and 40 deletions.
4 changes: 2 additions & 2 deletions src/api/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ pub fn get_token_from_request<T>(req: &Request<T>) -> Result<String, Status> {
};

if token.is_ok() {
Ok(token.unwrap().to_string())
Ok(token.unwrap().trim_start_matches("Bearer ").to_string())
} else {
Err(Status::unauthenticated(
"Could not read token from metadata",
Expand All @@ -107,7 +107,7 @@ pub fn validate_token(token: String, is_refresh_token: bool) -> Result<TokenData
validation.validate_exp = true;

match decode::<Claims>(
token.trim_start_matches("Bearer "),
&token,
&DecodingKey::from_secret(secret.as_bytes()),
&validation,
) {
Expand Down
119 changes: 102 additions & 17 deletions src/api/ecdar_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ use std::sync::Arc;

use crate::api::ecdar_api::helpers::helpers::{setup_db_with_entities, AnyEntity};
use crate::api::server::server::get_auth_token_request::user_credentials;
use crate::entities::access;

Check warning on line 6 in src/api/ecdar_api.rs

View workflow job for this annotation

GitHub Actions / Clippy lint and check

unused import: `crate::entities::access`

warning: unused import: `crate::entities::access` --> src/api/ecdar_api.rs:6:5 | 6 | use crate::entities::access; | ^^^^^^^^^^^^^^^^^^^^^^^ | = note: `#[warn(unused_imports)]` on by default

Check warning on line 6 in src/api/ecdar_api.rs

View workflow job for this annotation

GitHub Actions / Clippy lint and check

unused import: `crate::entities::access`

warning: unused import: `crate::entities::access` --> src/api/ecdar_api.rs:6:5 | 6 | use crate::entities::access; | ^^^^^^^^^^^^^^^^^^^^^^^ | = note: `#[warn(unused_imports)]` on by default
use crate::entities::session::Model;
use chrono::Local;
use regex::Regex;
use sea_orm::SqlErr;
use tonic::{Code, Request, Response, Status};
Expand Down Expand Up @@ -48,6 +51,58 @@ pub struct ConcreteEcdarApi {
in_use_context: Arc<dyn InUseContextTrait>,
}

/// Updates or creates a session in the database for a given user.
///
///
/// # Errors
/// This function will return an error if the database context returns an error
/// or if a session is not found when trying to update an existing one.
async fn handle_session(
session_context: Arc<dyn SessionContextTrait>,
request: &Request<GetAuthTokenRequest>,
is_new_session: bool,
access_token: String,
refresh_token: String,
uid: String,
) -> Result<(), Status> {
if is_new_session {
session_context
.create(Model {
id: Default::default(),
access_token: access_token.clone(),
refresh_token: refresh_token.clone(),
updated_at: Local::now().naive_local(),
user_id: uid.parse().unwrap(),
})
.await
.unwrap();
} else {
let mut session = match session_context
.get_by_refresh_token(auth::get_token_from_request(request)?)
.await
{
Ok(Some(session)) => session,
Ok(None) => {
return Err(Status::new(
Code::Unauthenticated,
"No session found with given refresh token",
))
}
Err(err) => return Err(Status::new(Code::Internal, err.to_string())),
};

session.access_token = access_token.clone();
session.refresh_token = refresh_token.clone();
session.updated_at = Local::now().naive_local();

match session_context.update(session).await {
Ok(_) => (),
Err(err) => return Err(Status::new(Code::Internal, err.to_string())),
};
}
Ok(())
}

fn get_uid_from_request<T>(request: &Request<T>) -> Result<i32, Status> {
let uid = match request.metadata().get("uid").unwrap().to_str() {
Ok(uid) => uid,
Expand All @@ -62,6 +117,18 @@ fn get_uid_from_request<T>(request: &Request<T>) -> Result<i32, Status> {
Ok(uid.parse().unwrap())
}

fn is_valid_email(email: &str) -> bool {
Regex::new(r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$")
.unwrap()
.is_match(email)
}

fn is_valid_username(username: &str) -> bool {
Regex::new(r"^[a-zA-Z0-9_]{3,32}$")
.unwrap()
.is_match(username)
}

impl ConcreteEcdarApi {
pub async fn new(
model_context: Arc<dyn ModelContextTrait>,
Expand Down Expand Up @@ -191,31 +258,28 @@ impl EcdarApi for ConcreteEcdarApi {
}
}

fn is_valid_email(email: &str) -> bool {
Regex::new(r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$")
.unwrap()
.is_match(email)
}

fn is_valid_username(username: &str) -> bool {
Regex::new(r"^[a-zA-Z0-9_]{3,32}$")
.unwrap()
.is_match(username)
}

#[tonic::async_trait]
impl EcdarApiAuth for ConcreteEcdarApi {
/// This method is used to get a new access and refresh token for a user.
///
/// # Errors
/// This function will return an error if the user does not exist in the database,
/// if the password in the request does not match the user's password,
/// or if no user is provided in the request.
async fn get_auth_token(
&self,
request: Request<GetAuthTokenRequest>,
) -> Result<Response<GetAuthTokenResponse>, Status> {
let message = request.get_ref().clone();
let uid: String;
let user_from_db: User;
let is_new_session: bool;

// Get user from credentials
if let Some(user_credentials) = message.user_credentials {
if let Some(user) = user_credentials.user {
user_from_db = match user {
// Get user from database by username given in request
user_credentials::User::Username(username) => {
match self.user_context.get_by_username(username).await {
Ok(Some(user)) => user,
Expand All @@ -228,6 +292,7 @@ impl EcdarApiAuth for ConcreteEcdarApi {
Err(err) => return Err(Status::new(Code::Internal, err.to_string())),
}
}
// Get user from database by email given in request
user_credentials::User::Email(email) => {
match self.user_context.get_by_email(email).await {
Ok(Some(user)) => user,
Expand All @@ -241,29 +306,49 @@ impl EcdarApiAuth for ConcreteEcdarApi {
}
}
};

uid = user_from_db.id.to_string();

// Check if password in request matches users password
if user_credentials.password != user_from_db.password {
return Err(Status::new(Code::Unauthenticated, "Wrong password"));
}

uid = user_from_db.id.to_string();

// Since the user does not have a refresh_token, a new session has to be made
is_new_session = true;
} else {
return Err(Status::new(Code::Internal, "No user provided"));
}
// Get user from refresh_token
} else {
let refresh_token = auth::get_token_from_request(&request)?;
let token_data = auth::validate_token(refresh_token, true)?;
uid = token_data.claims.sub;

// Since the user does have a refresh_token, a session already exists
is_new_session = false;
}

// Create new access and refresh token with user id
let access_token = match auth::create_access_token(&uid) {
Ok(token) => token,
Ok(token) => token.to_owned(),
Err(e) => return Err(Status::new(Code::Internal, e.to_string())),
};
let refresh_token = match auth::create_refresh_token(&uid) {
Ok(token) => token,
Ok(token) => token.to_owned(),
Err(e) => return Err(Status::new(Code::Internal, e.to_string())),
};

// Update or create session in database
handle_session(
self.session_context.clone(),
&request,
is_new_session,
access_token.clone(),
refresh_token.clone(),
uid,
)
.await?;

Ok(Response::new(GetAuthTokenResponse {
access_token,
refresh_token,
Expand Down
23 changes: 20 additions & 3 deletions src/database/session_context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,38 @@ use std::fmt::Debug;

use sea_orm::prelude::async_trait::async_trait;
use sea_orm::ActiveValue::{Set, Unchanged};
use sea_orm::{ActiveModelTrait, DbErr, EntityTrait};
use sea_orm::{ActiveModelTrait, ColumnTrait, DbErr, EntityTrait, QueryFilter};

use crate::database::database_context::DatabaseContextTrait;
use crate::database::entity_context::EntityContextTrait;
use crate::entities::prelude::Session as SessionEntity;
use crate::entities::session::Column as SessionColumn;
use crate::entities::session::{ActiveModel, Model as Session};

#[derive(Debug)]
pub struct SessionContext {
db_context: Box<dyn DatabaseContextTrait>,
}

pub trait SessionContextTrait: EntityContextTrait<Session> {}
#[async_trait]
pub trait SessionContextTrait: EntityContextTrait<Session> {
async fn get_by_refresh_token(&self, refresh_token: String) -> Result<Option<Session>, DbErr>;
}

impl SessionContextTrait for SessionContext {}
#[async_trait]
impl SessionContextTrait for SessionContext {
/// Returns a session by searching for its refresh_token.
/// # Example
/// ```rust
/// let session: Result<Option<Model>, DbErr> = session_context.get_by_refresh_token(refresh_token).await;
/// ```
async fn get_by_refresh_token(&self, refresh_token: String) -> Result<Option<Session>, DbErr> {
SessionEntity::find()
.filter(SessionColumn::RefreshToken.eq(refresh_token))
.one(&self.db_context.get_connection())
.await
}
}

#[async_trait]
impl EntityContextTrait<Session> for SessionContext {
Expand Down
36 changes: 18 additions & 18 deletions src/tests/database/session_context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ mod database_tests {

let new_session = Model {
id: 1,
refresh_token: Uuid::parse_str("4473240f-2acb-422f-bd1a-5214554ed0e0").unwrap(),
access_token: Uuid::parse_str("4473240f-2acb-422f-bd1a-5214554ed0e0").unwrap(),
refresh_token: "test_refresh_token".to_string(),
access_token: "test_token".to_string(),
updated_at: Local::now().naive_utc(),
user_id,
};
Expand All @@ -92,8 +92,8 @@ mod database_tests {

let new_session = Model {
id: 1,
refresh_token: Uuid::parse_str("4473240f-2acb-422f-bd1a-5214554ed0e0").unwrap(),
access_token: Uuid::parse_str("4473240f-2acb-422f-bd1a-5214554ed0e0").unwrap(),
refresh_token: "test_refresh_token".to_string(),
access_token: "test_token".to_string(),
updated_at: Local::now().naive_utc(),
user_id,
};
Expand Down Expand Up @@ -121,16 +121,16 @@ mod database_tests {
// Create the sessions structs
let session1 = Model {
id: 1,
refresh_token: Uuid::parse_str("4473240f-2acb-422f-bd1a-5214554ed0e0").unwrap(),
access_token: Uuid::parse_str("4473240f-2acb-422f-bd1a-5214554ed0e0").unwrap(),
refresh_token: "test_refresh_token".to_string(),
access_token: "test_token".to_string(),
updated_at: Local::now().naive_utc(),
user_id,
};

let session2 = Model {
id: 2,
refresh_token: Uuid::parse_str("4473240f-2acb-422f-bd1a-5214554ed0e1").unwrap(),
access_token: Uuid::parse_str("75ecdf25-538c-4fe0-872d-525570c96b91").unwrap(),
refresh_token: "test_refresh_token_2".to_string(),
access_token: "test_token_2".to_string(),
updated_at: Local::now().naive_utc(),
user_id,
};
Expand Down Expand Up @@ -159,17 +159,17 @@ mod database_tests {

let original_session = Model {
id: 1,
refresh_token: Uuid::parse_str("4473240f-2acb-422f-bd1a-5214554ed0e0").unwrap(),
access_token: Uuid::parse_str("5c5e9172-9dff-4f35-afde-029a6f99652c").unwrap(),
refresh_token: "test_refresh_token".to_string(),
access_token: "test_token".to_string(),
updated_at: Local::now().naive_utc(),
user_id,
};

let original_session = session_context.create(original_session).await.unwrap();

let altered_session = Model {
refresh_token: Uuid::parse_str("4473240f-2acb-422f-bd1a-5214554ed0e1").unwrap(),
access_token: Uuid::parse_str("ddd9b7a3-98ff-43b0-b5b5-aa2abaea9d96").unwrap(),
refresh_token: "test_refresh_token".to_string(),
access_token: "test_token".to_string(),
..original_session
};

Expand All @@ -190,8 +190,8 @@ mod database_tests {

let original_session = Model {
id: 1,
refresh_token: Uuid::parse_str("4473240f-2acb-422f-bd1a-5214554ed0e1").unwrap(),
access_token: Uuid::parse_str("5c5e9172-9dff-4f35-afde-029a6f99652c").unwrap(),
refresh_token: "test_refresh_token".to_string(),
access_token: "test_token".to_string(),
updated_at: Local::now().naive_utc(),
user_id,
};
Expand All @@ -207,8 +207,8 @@ mod database_tests {

let original_session = Model {
id: 1,
refresh_token: Uuid::parse_str("4473240f-2acb-422f-bd1a-5214554ed0e1").unwrap(),
access_token: Uuid::parse_str("5c5e9172-9dff-4f35-afde-029a6f99652c").unwrap(),
refresh_token: "test_refresh_token".to_string(),
access_token: "test_token".to_string(),
updated_at: Local::now().naive_utc(),
user_id,
};
Expand All @@ -231,8 +231,8 @@ mod database_tests {

let original_session = Model {
id: 1,
refresh_token: Uuid::parse_str("4473240f-2acb-422f-bd1a-5214554ed0e1").unwrap(),
access_token: Uuid::parse_str("5c5e9172-9dff-4f35-afde-029a6f99652c").unwrap(),
refresh_token: "test_refresh_token".to_string(),
access_token: "test_token".to_string(),
updated_at: Local::now().naive_utc(),
user_id,
};
Expand Down

0 comments on commit 9adcea0

Please sign in to comment.