Skip to content

Commit

Permalink
refactor main.rs to be more high-level by abstracing away its processes
Browse files Browse the repository at this point in the history
  • Loading branch information
ivinjabraham committed Jan 17, 2025
1 parent 3ba685d commit 44de20c
Show file tree
Hide file tree
Showing 4 changed files with 127 additions and 79 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ axum = "0.7.3"
chrono = "0.4.38"
serde = { version = "1.0.188", features = ["derive"] }
sqlx = { version = "0.7.1", features = ["chrono", "postgres", "runtime-tokio"] }
tokio = { version = "1.28.2", features = ["macros", "rt-multi-thread"] } # For async tests
tokio = { version = "1.28.2", features = ["default", "macros", "rt-multi-thread"] } # For async tests
hmac = "0.12.1"
sha2 = "0.10.8"
hex = "0.4.3"
Expand Down
12 changes: 12 additions & 0 deletions src/graphql/mod.rs
Original file line number Diff line number Diff line change
@@ -1,2 +1,14 @@
use async_graphql::MergedObject;
use mutations::{AttendanceMutations, MemberMutations, StreakMutations};
use queries::{AttendanceQueries, MemberQueries, StreakQueries};

pub mod mutations;
pub mod queries;

// This is our main query or QueryRoot. It is made up of structs representing sub-queries, one for each table in the DB. The fields of a relation are exposed via the [`async_graphql::SimpleObject`] directive on the [`models`] themselves. Specific queries, such as getting a member by ID or getting the streak of a member is defined as methods of the sub-query struct. Complex queries, such as those getting related data from multiple tables like querying all members and the streaks of each member, are defined via the [`async_graphql::ComplexObject`] directive on the [`models`] and can be found in the corresponding sub-query module.
#[derive(MergedObject, Default)]
pub struct Query(MemberQueries, AttendanceQueries, StreakQueries);

// Mutations work the same as Queries, sub-modules for each relation in the DB. However, all methods are directly defined on these sub-module structs. But they use slightly modified versions of the [`models`], marked by the Input in the name, to get input.
#[derive(MergedObject, Default)]
pub struct Mutation(MemberMutations, AttendanceMutations, StreakMutations);
162 changes: 86 additions & 76 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,5 @@
use async_graphql::{EmptySubscription, MergedObject};
use async_graphql_axum::GraphQL;
use axum::{
http::{HeaderValue, Method},
routing::get,
Router,
};
use async_graphql::EmptySubscription;
use axum::http::{HeaderValue, Method};
use chrono_tz::Asia::Kolkata;
use sqlx::PgPool;
use std::sync::Arc;
Expand All @@ -14,10 +9,8 @@ use tracing::info;
use tracing_subscriber::{fmt, layer::SubscriberExt, util::SubscriberInitExt, EnvFilter};

use daily_task::execute_daily_task;
use graphql::{
mutations::{AttendanceMutations, MemberMutations, StreakMutations},
queries::{AttendanceQueries, MemberQueries, StreakQueries},
};
use graphql::{Mutation, Query};
use routes::setup_router;

/// Daily task contains the function that is executed daily at midnight, using the thread spawned in main().
pub mod daily_task;
Expand All @@ -28,13 +21,30 @@ pub mod models;
/// Since we really only need one route for a GraphQL server, this just holds a function returning the GraphiQL playground. Probably can clean this up later.
pub mod routes;

// This is our main query or QueryRoot. It is made up of structs representing sub-queries, one for each table in the DB. The fields of a relation are exposed via the [`async_graphql::SimpleObject`] directive on the [`models`] themselves. Specific queries, such as getting a member by ID or getting the streak of a member is defined as methods of the sub-query struct. Complex queries, such as those getting related data from multiple tables like querying all members and the streaks of each member, are defined via the [`async_graphql::ComplexObject`] directive on the [`models`] and can be found in the corresponding sub-query module.
#[derive(MergedObject, Default)]
struct Query(MemberQueries, AttendanceQueries, StreakQueries);
/// Handles all over environment variables in one place.
struct Config {
env: String,
secret_key: String,
database_url: String,
bind_address: String,
}

// Mutations work the same as Queries, sub-modules for each relation in the DB. However, all methods are directly defined on these sub-module structs. But they use slightly modified versions of the [`models`], marked by the Input in the name, to get input.
#[derive(MergedObject, Default)]
struct Mutations(MemberMutations, AttendanceMutations, StreakMutations);
impl Config {
fn from_env() -> Self {
dotenv::dotenv().ok();
Self {
// RUST_ENV is used to check if it's in production to avoid unnecessary logging and exposing the
// graphiql interface. Make sure to set it to "production" before deployment.
env: std::env::var("RUST_ENV").unwrap_or_else(|_| "development".to_string()),
// ROOT_SECRET is used to cryptographically verify the origin of attendance updation requests.
secret_key: std::env::var("ROOT_SECRET").expect("ROOT_SECRET must be set."),
// DATABASE_URL provides the connection string for the PostgreSQL database.
database_url: std::env::var("DATABASE_URL").expect("DATABASE_URL must be set."),
// BIND_ADDRESS is used to determine the IP address for the server's socket to bind to.
bind_address: std::env::var("BIND_ADDRESS").expect("BIND_ADDRESS must be set."),
}
}
}

#[tokio::main]
async fn main() {
Expand All @@ -45,18 +55,44 @@ async fn main() {
// Currently, we need the DATABASE_URL to be loaded in through the .env.
// In the future, if we use any other configuration (say Github Secrets), we
// can allow dotenv() to err.
dotenv::dotenv().expect("Failed to load .env file.");

// RUST_ENV is used to check if it's in production to avoid unnecessary logging and exposing the
// graphiql interface. Make sure to set it to "production" before deployment.
let env = std::env::var("RUST_ENV").unwrap_or_else(|_| "development".to_string());
// ROOT_SECRET is used to cryptographically verify the origin of attendance updation requests.
let secret_key = std::env::var("ROOT_SECRET").expect("ROOT_SECRET must be set.");
// DATABASE_URL provides the connection string for the PostgreSQL database.
let database_url = std::env::var("DATABASE_URL").expect("DATABASE_URL must be set.");
// BIND_ADDRESS is used to determine the IP address for the server's socket to bind to.
let bind_addr = std::env::var("BIND_ADDRESS").expect("BIND_ADDRESS must be set.");
let config = Config::from_env();
setup_tracing(&config.env);

let pool = setup_database(&config.database_url).await;
let schema = build_graphql_schema(pool.clone(), config.secret_key);

tokio::task::spawn(async {
run_daily_task_at_midnight(pool).await;
});

let cors = setup_cors();
let router = setup_router(schema, cors, config.env == "development");

info!("Starting Root...");
let listener = tokio::net::TcpListener::bind(config.bind_address).await.unwrap();
axum::serve(listener, router).await.unwrap();
}

/// Continuously sleep till midnight, then run the 'execute_daily_task' function.
async fn run_daily_task_at_midnight(pool: Arc<PgPool>) {
loop {
let now = chrono::Local::now().with_timezone(&Kolkata);
let next_midnight = (now + chrono::Duration::days(1))
.date_naive()
.and_hms_opt(0, 0, 0)
.unwrap();

let duration_until_midnight = next_midnight.signed_duration_since(now.naive_local());
let sleep_duration =
tokio::time::Duration::from_secs(duration_until_midnight.num_seconds() as u64);

sleep_until(tokio::time::Instant::now() + sleep_duration).await;
execute_daily_task(pool.clone()).await;
}
}

/// Abstraction over initializing the global subscriber for tracing depending on whether it's in production or dev.
fn setup_tracing(env: &str) {
if env == "production" {
tracing_subscriber::registry()
// In production, no need to write to stdout, write directly to file.
Expand Down Expand Up @@ -85,11 +121,14 @@ async fn main() {
.init();
info!("Running in development mode.");
}
}

/// Abstraction over setting up the database pool.
async fn setup_database(database_url: &str) -> Arc<PgPool> {
let pool = sqlx::postgres::PgPoolOptions::new()
.min_connections(2) // Maintain at least two connections, one for amD and one for Home. It should be
.max_connections(3) // pretty unlikely that amD, Home and the web interface is used simultaneously.
.connect(&database_url)
.min_connections(2)
.max_connections(3)
.connect(database_url)
.await
.expect("Pool must be initialized properly.");

Expand All @@ -98,55 +137,26 @@ async fn main() {
.await
.expect("Failed to run migrations.");

// Wrap pool in an Arc to share across threads.
let pool = Arc::new(pool);

let schema =
async_graphql::Schema::build(Query::default(), Mutations::default(), EmptySubscription)
.data(pool.clone())
.data(secret_key)
.finish();
Arc::new(pool)
}

// This thread will sleep until it's time to run the daily task.
// Also takes ownership of pool.
tokio::task::spawn(async {
run_daily_task_at_midnight(pool).await;
});
/// Abstraction over setting up the GraphQL schema from [`Query`] and [`Mutation`], and adding a reference to [`pool`] and [`secret_key`].
fn build_graphql_schema(
pool: Arc<PgPool>,
secret_key: String,
) -> async_graphql::Schema<Query, Mutation, EmptySubscription> {
async_graphql::Schema::build(Query::default(), Mutation::default(), EmptySubscription)
.data(pool)
.data(secret_key)
.finish()
}

let cors = CorsLayer::new()
/// Abstraction over making the CORSLayer.
fn setup_cors() -> CorsLayer {
CorsLayer::new()
// Home should be the only website that accesses the API, bots and scripts do not trigger CORS AFAIK.
// This lets us restrict who has access to what in the API on the Home frontend.
.allow_origin(HeaderValue::from_static("https://home.amfoss.in"))
.allow_methods([Method::GET, Method::POST, Method::OPTIONS])
.allow_headers(tower_http::cors::Any);

info!("Starting Root...");
// TODO: Avoid exposing the GraphiQL interface in prod.
let router = Router::new()
.route(
"/",
get(routes::graphiql).post_service(GraphQL::new(schema.clone())),
)
.layer(cors);

let listener = tokio::net::TcpListener::bind(bind_addr).await.unwrap();
axum::serve(listener, router).await.unwrap();
}

/// Continuously sleep till midnight, then run the 'execute_daily_task' function.
async fn run_daily_task_at_midnight(pool: Arc<PgPool>) {
loop {
let now = chrono::Local::now().with_timezone(&Kolkata);
let next_midnight = (now + chrono::Duration::days(1))
.date_naive()
.and_hms_opt(0, 0, 0)
.unwrap();

let duration_until_midnight = next_midnight.signed_duration_since(now.naive_local());
let sleep_duration =
tokio::time::Duration::from_secs(duration_until_midnight.num_seconds() as u64);

sleep_until(tokio::time::Instant::now() + sleep_duration).await;
execute_daily_task(pool.clone()).await;
}
.allow_headers(tower_http::cors::Any)
}
30 changes: 28 additions & 2 deletions src/routes.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,31 @@
use async_graphql::http::GraphiQLSource;
use axum::response::{Html, IntoResponse};
use async_graphql::{http::GraphiQLSource, EmptySubscription, Schema};
use async_graphql_axum::GraphQL;
use axum::{
response::{Html, IntoResponse},
routing::get,
Router,
};
use tower_http::cors::CorsLayer;

use crate::graphql::{Mutation, Query};

/// Setups the router with the given Schema and CORSLayer. Additionally, adds the GraphiQL playground if `is_dev` is true.
pub fn setup_router(
schema: Schema<Query, Mutation, EmptySubscription>,
cors: CorsLayer,
is_dev: bool,
) -> Router {
let mut router = Router::new()
.route_service("/graphql", GraphQL::new(schema.clone()))
.layer(cors);

if is_dev {
// Add GraphiQL playground only in development mode
router = router.route("/", get(graphiql));
}

router
}

// TODO: We do not want to expose GraphiQL unless in dev.
/// Returns the built-in GraphQL playground from async_graphql.
Expand Down

0 comments on commit 44de20c

Please sign in to comment.