diff --git a/src/api/auth.rs b/src/api/auth.rs index 4160efd..5fcab01 100644 --- a/src/api/auth.rs +++ b/src/api/auth.rs @@ -85,7 +85,7 @@ pub fn get_token_from_request(req: &Request) -> Result { }; if token.is_ok() { - Ok(token.unwrap().to_string()) + Ok(token.unwrap().trim_start_matches("Bearer ").to_string()) } else { Err(Status::unauthenticated( "Could not read token from metadata", @@ -107,7 +107,7 @@ pub fn validate_token(token: String, is_refresh_token: bool) -> Result( - token.trim_start_matches("Bearer "), + &token, &DecodingKey::from_secret(secret.as_bytes()), &validation, ) { diff --git a/src/api/ecdar_api.rs b/src/api/ecdar_api.rs index 271c9fe..63cf863 100644 --- a/src/api/ecdar_api.rs +++ b/src/api/ecdar_api.rs @@ -3,6 +3,9 @@ use std::sync::Arc; use crate::api::ecdar_api::helpers::helpers::{setup_db_with_entities, AnyEntity}; use crate::api::server::server::get_auth_token_request::user_credentials; +use crate::entities::access; +use crate::entities::session::Model; +use chrono::Local; use regex::Regex; use sea_orm::SqlErr; use tonic::{Code, Request, Response, Status}; @@ -48,6 +51,58 @@ pub struct ConcreteEcdarApi { in_use_context: Arc, } +/// Updates or creates a session in the database for a given user. +/// +/// +/// # Errors +/// This function will return an error if the database context returns an error +/// or if a session is not found when trying to update an existing one. +async fn handle_session( + session_context: Arc, + request: &Request, + is_new_session: bool, + access_token: String, + refresh_token: String, + uid: String, +) -> Result<(), Status> { + if is_new_session { + session_context + .create(Model { + id: Default::default(), + access_token: access_token.clone(), + refresh_token: refresh_token.clone(), + updated_at: Local::now().naive_local(), + user_id: uid.parse().unwrap(), + }) + .await + .unwrap(); + } else { + let mut session = match session_context + .get_by_refresh_token(auth::get_token_from_request(request)?) + .await + { + Ok(Some(session)) => session, + Ok(None) => { + return Err(Status::new( + Code::Unauthenticated, + "No session found with given refresh token", + )) + } + Err(err) => return Err(Status::new(Code::Internal, err.to_string())), + }; + + session.access_token = access_token.clone(); + session.refresh_token = refresh_token.clone(); + session.updated_at = Local::now().naive_local(); + + match session_context.update(session).await { + Ok(_) => (), + Err(err) => return Err(Status::new(Code::Internal, err.to_string())), + }; + } + Ok(()) +} + fn get_uid_from_request(request: &Request) -> Result { let uid = match request.metadata().get("uid").unwrap().to_str() { Ok(uid) => uid, @@ -62,6 +117,18 @@ fn get_uid_from_request(request: &Request) -> Result { Ok(uid.parse().unwrap()) } +fn is_valid_email(email: &str) -> bool { + Regex::new(r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$") + .unwrap() + .is_match(email) +} + +fn is_valid_username(username: &str) -> bool { + Regex::new(r"^[a-zA-Z0-9_]{3,32}$") + .unwrap() + .is_match(username) +} + impl ConcreteEcdarApi { pub async fn new( model_context: Arc, @@ -191,20 +258,14 @@ impl EcdarApi for ConcreteEcdarApi { } } -fn is_valid_email(email: &str) -> bool { - Regex::new(r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$") - .unwrap() - .is_match(email) -} - -fn is_valid_username(username: &str) -> bool { - Regex::new(r"^[a-zA-Z0-9_]{3,32}$") - .unwrap() - .is_match(username) -} - #[tonic::async_trait] impl EcdarApiAuth for ConcreteEcdarApi { + /// This method is used to get a new access and refresh token for a user. + /// + /// # Errors + /// This function will return an error if the user does not exist in the database, + /// if the password in the request does not match the user's password, + /// or if no user is provided in the request. async fn get_auth_token( &self, request: Request, @@ -212,10 +273,13 @@ impl EcdarApiAuth for ConcreteEcdarApi { let message = request.get_ref().clone(); let uid: String; let user_from_db: User; + let is_new_session: bool; + // Get user from credentials if let Some(user_credentials) = message.user_credentials { if let Some(user) = user_credentials.user { user_from_db = match user { + // Get user from database by username given in request user_credentials::User::Username(username) => { match self.user_context.get_by_username(username).await { Ok(Some(user)) => user, @@ -228,6 +292,7 @@ impl EcdarApiAuth for ConcreteEcdarApi { Err(err) => return Err(Status::new(Code::Internal, err.to_string())), } } + // Get user from database by email given in request user_credentials::User::Email(email) => { match self.user_context.get_by_email(email).await { Ok(Some(user)) => user, @@ -241,29 +306,49 @@ impl EcdarApiAuth for ConcreteEcdarApi { } } }; - - uid = user_from_db.id.to_string(); - + // Check if password in request matches users password if user_credentials.password != user_from_db.password { return Err(Status::new(Code::Unauthenticated, "Wrong password")); } + + uid = user_from_db.id.to_string(); + + // Since the user does not have a refresh_token, a new session has to be made + is_new_session = true; } else { return Err(Status::new(Code::Internal, "No user provided")); } + // Get user from refresh_token } else { let refresh_token = auth::get_token_from_request(&request)?; let token_data = auth::validate_token(refresh_token, true)?; uid = token_data.claims.sub; + + // Since the user does have a refresh_token, a session already exists + is_new_session = false; } + // Create new access and refresh token with user id let access_token = match auth::create_access_token(&uid) { - Ok(token) => token, + Ok(token) => token.to_owned(), Err(e) => return Err(Status::new(Code::Internal, e.to_string())), }; let refresh_token = match auth::create_refresh_token(&uid) { - Ok(token) => token, + Ok(token) => token.to_owned(), Err(e) => return Err(Status::new(Code::Internal, e.to_string())), }; + + // Update or create session in database + handle_session( + self.session_context.clone(), + &request, + is_new_session, + access_token.clone(), + refresh_token.clone(), + uid, + ) + .await?; + Ok(Response::new(GetAuthTokenResponse { access_token, refresh_token, diff --git a/src/database/session_context.rs b/src/database/session_context.rs index bd2d69b..c063e35 100644 --- a/src/database/session_context.rs +++ b/src/database/session_context.rs @@ -2,11 +2,12 @@ use std::fmt::Debug; use sea_orm::prelude::async_trait::async_trait; use sea_orm::ActiveValue::{Set, Unchanged}; -use sea_orm::{ActiveModelTrait, DbErr, EntityTrait}; +use sea_orm::{ActiveModelTrait, ColumnTrait, DbErr, EntityTrait, QueryFilter}; use crate::database::database_context::DatabaseContextTrait; use crate::database::entity_context::EntityContextTrait; use crate::entities::prelude::Session as SessionEntity; +use crate::entities::session::Column as SessionColumn; use crate::entities::session::{ActiveModel, Model as Session}; #[derive(Debug)] @@ -14,9 +15,25 @@ pub struct SessionContext { db_context: Box, } -pub trait SessionContextTrait: EntityContextTrait {} +#[async_trait] +pub trait SessionContextTrait: EntityContextTrait { + async fn get_by_refresh_token(&self, refresh_token: String) -> Result, DbErr>; +} -impl SessionContextTrait for SessionContext {} +#[async_trait] +impl SessionContextTrait for SessionContext { + /// Returns a session by searching for its refresh_token. + /// # Example + /// ```rust + /// let session: Result, DbErr> = session_context.get_by_refresh_token(refresh_token).await; + /// ``` + async fn get_by_refresh_token(&self, refresh_token: String) -> Result, DbErr> { + SessionEntity::find() + .filter(SessionColumn::RefreshToken.eq(refresh_token)) + .one(&self.db_context.get_connection()) + .await + } +} #[async_trait] impl EntityContextTrait for SessionContext { diff --git a/src/tests/database/session_context.rs b/src/tests/database/session_context.rs index da11458..993fada 100644 --- a/src/tests/database/session_context.rs +++ b/src/tests/database/session_context.rs @@ -67,8 +67,8 @@ mod database_tests { let new_session = Model { id: 1, - refresh_token: Uuid::parse_str("4473240f-2acb-422f-bd1a-5214554ed0e0").unwrap(), - access_token: Uuid::parse_str("4473240f-2acb-422f-bd1a-5214554ed0e0").unwrap(), + refresh_token: "test_refresh_token".to_string(), + access_token: "test_token".to_string(), updated_at: Local::now().naive_utc(), user_id, }; @@ -92,8 +92,8 @@ mod database_tests { let new_session = Model { id: 1, - refresh_token: Uuid::parse_str("4473240f-2acb-422f-bd1a-5214554ed0e0").unwrap(), - access_token: Uuid::parse_str("4473240f-2acb-422f-bd1a-5214554ed0e0").unwrap(), + refresh_token: "test_refresh_token".to_string(), + access_token: "test_token".to_string(), updated_at: Local::now().naive_utc(), user_id, }; @@ -121,16 +121,16 @@ mod database_tests { // Create the sessions structs let session1 = Model { id: 1, - refresh_token: Uuid::parse_str("4473240f-2acb-422f-bd1a-5214554ed0e0").unwrap(), - access_token: Uuid::parse_str("4473240f-2acb-422f-bd1a-5214554ed0e0").unwrap(), + refresh_token: "test_refresh_token".to_string(), + access_token: "test_token".to_string(), updated_at: Local::now().naive_utc(), user_id, }; let session2 = Model { id: 2, - refresh_token: Uuid::parse_str("4473240f-2acb-422f-bd1a-5214554ed0e1").unwrap(), - access_token: Uuid::parse_str("75ecdf25-538c-4fe0-872d-525570c96b91").unwrap(), + refresh_token: "test_refresh_token_2".to_string(), + access_token: "test_token_2".to_string(), updated_at: Local::now().naive_utc(), user_id, }; @@ -159,8 +159,8 @@ mod database_tests { let original_session = Model { id: 1, - refresh_token: Uuid::parse_str("4473240f-2acb-422f-bd1a-5214554ed0e0").unwrap(), - access_token: Uuid::parse_str("5c5e9172-9dff-4f35-afde-029a6f99652c").unwrap(), + refresh_token: "test_refresh_token".to_string(), + access_token: "test_token".to_string(), updated_at: Local::now().naive_utc(), user_id, }; @@ -168,8 +168,8 @@ mod database_tests { let original_session = session_context.create(original_session).await.unwrap(); let altered_session = Model { - refresh_token: Uuid::parse_str("4473240f-2acb-422f-bd1a-5214554ed0e1").unwrap(), - access_token: Uuid::parse_str("ddd9b7a3-98ff-43b0-b5b5-aa2abaea9d96").unwrap(), + refresh_token: "test_refresh_token".to_string(), + access_token: "test_token".to_string(), ..original_session }; @@ -190,8 +190,8 @@ mod database_tests { let original_session = Model { id: 1, - refresh_token: Uuid::parse_str("4473240f-2acb-422f-bd1a-5214554ed0e1").unwrap(), - access_token: Uuid::parse_str("5c5e9172-9dff-4f35-afde-029a6f99652c").unwrap(), + refresh_token: "test_refresh_token".to_string(), + access_token: "test_token".to_string(), updated_at: Local::now().naive_utc(), user_id, }; @@ -207,8 +207,8 @@ mod database_tests { let original_session = Model { id: 1, - refresh_token: Uuid::parse_str("4473240f-2acb-422f-bd1a-5214554ed0e1").unwrap(), - access_token: Uuid::parse_str("5c5e9172-9dff-4f35-afde-029a6f99652c").unwrap(), + refresh_token: "test_refresh_token".to_string(), + access_token: "test_token".to_string(), updated_at: Local::now().naive_utc(), user_id, }; @@ -231,8 +231,8 @@ mod database_tests { let original_session = Model { id: 1, - refresh_token: Uuid::parse_str("4473240f-2acb-422f-bd1a-5214554ed0e1").unwrap(), - access_token: Uuid::parse_str("5c5e9172-9dff-4f35-afde-029a6f99652c").unwrap(), + refresh_token: "test_refresh_token".to_string(), + access_token: "test_token".to_string(), updated_at: Local::now().naive_utc(), user_id, };