Skip to content

Commit

Permalink
merge dev -> main (#100)
Browse files Browse the repository at this point in the history
* skip manual config step

* better ux

* eslint

* cleanup

* update protobufs (#88)

* feat: config polling (#86)

* CI: fix re-creating manifests

* chore: log version with git commit hash on startup (#89)

* update protobufs (#90)

* Rework instance config fetching (#91)

* instance config fetching rework

* update protobufs

* add teonite link (#92)

* add link

* noreferrer

* add defguard link

* Basic nix flake without rust

* Flake update

* enable ARMv7 build (#93)

Co-authored-by: Maciej Wójcik <[email protected]>

* Make a pre-release and release docker build workflow (#94)

* split builds

* fix vergen

* add flavor to build-docker workflow

* bump version to 1.0.0 (#95)

* OpenID via Proxy (#97)

* Handle auth info

* Use AuthInfoRequest

* Handle AuthCallback

* Use Url crate for URL option

* add frontend

* translations, id_token -> code

* more translations, cleanup

* cleanup

* move to enterprise folder

---------

Co-authored-by: Aleksander <[email protected]>

* Change nonce and csrf cookie names (#99)

* change cookies name

* bump version

* fix cargo lock

---------

Co-authored-by: Robert Olejnik <[email protected]>
Co-authored-by: Jacek Chmielewski <[email protected]>
Co-authored-by: Adam Ciarciński <[email protected]>
Co-authored-by: Maciek <[email protected]>
Co-authored-by: Maciej Wójcik <[email protected]>
  • Loading branch information
6 people authored Nov 19, 2024
1 parent 44d968f commit 00c6b41
Show file tree
Hide file tree
Showing 32 changed files with 1,773 additions and 454 deletions.
926 changes: 568 additions & 358 deletions Cargo.lock

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "defguard-proxy"
version = "1.0.0"
version = "1.1.0"
edition = "2021"
license = "Apache-2.0"
homepage = "https://github.com/DefGuard/proxy"
Expand Down Expand Up @@ -38,7 +38,7 @@ anyhow = "1.0"
clap = { version = "4.5", features = ["derive", "env", "cargo"] }
# other utils
dotenvy = "0.15"
url = "2.5"
url = { version = "2.5", features = ["serde"] }
tower_governor = "0.4"
# UI embedding
rust-embed = { version = "8.5", features = ["include-exclude"] }
Expand Down
2 changes: 1 addition & 1 deletion build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
config.protoc_arg("--experimental_allow_proto3_optional");
// Make all messages serde-serializable
config.type_attribute(".", "#[derive(serde::Serialize,serde::Deserialize)]");
tonic_build::configure().compile_with_config(
tonic_build::configure().compile_protos_with_config(
config,
&["proto/core/proxy.proto"],
&["proto/core"],
Expand Down
2 changes: 1 addition & 1 deletion proto
25 changes: 17 additions & 8 deletions src/config.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
use std::{fs, io::Error as IoError};
use std::{fs::read_to_string, path::PathBuf};

use clap::Parser;
use log::LevelFilter;
use serde::Deserialize;
use url::Url;

#[derive(Parser, Debug, Deserialize)]
#[command(version)]
Expand Down Expand Up @@ -35,16 +36,24 @@ pub struct Config {
#[arg(long, env = "DEFGUARD_PROXY_RATELIMIT_BURST", default_value_t = 0)]
pub rate_limit_burst: u32,

#[arg(
long,
env = "DEFGUARD_PROXY_URL",
value_parser = Url::parse,
default_value = "http://localhost:8080"
)]
pub url: Url,

/// Configuration file path
#[arg(long = "config", short)]
#[serde(skip)]
config_path: Option<std::path::PathBuf>,
config_path: Option<PathBuf>,
}

#[derive(thiserror::Error, Debug)]
pub enum ConfigError {
#[error("Failed to read config file")]
IoError(#[from] IoError),
IoError(#[from] std::io::Error),
#[error("Failed to parse config file")]
ParseError(#[from] toml::de::Error),
}
Expand All @@ -55,11 +64,11 @@ pub fn get_config() -> Result<Config, ConfigError> {

// load config from file if one was specified
if let Some(config_path) = cli_config.config_path {
info!("Reading configuration from config file: {config_path:?}");
let config_toml = fs::read_to_string(config_path)?;
info!("Reading configuration from file: {config_path:?}");
let config_toml = read_to_string(config_path)?;
let file_config: Config = toml::from_str(&config_toml)?;
return Ok(file_config);
Ok(file_config)
} else {
Ok(cli_config)
}

Ok(cli_config)
}
1 change: 1 addition & 0 deletions src/enterprise/handlers/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pub mod openid_login;
147 changes: 147 additions & 0 deletions src/enterprise/handlers/openid_login.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
use axum::{
extract::State,
routing::{get, post},
Json, Router,
};
use axum_extra::extract::{
cookie::{Cookie, SameSite},
PrivateCookieJar,
};
use serde::{Deserialize, Serialize};
use time::Duration;

use crate::{
error::ApiError,
handlers::get_core_response,
http::AppState,
proto::{
core_request, core_response, AuthCallbackRequest, AuthCallbackResponse, AuthInfoRequest,
},
};

const COOKIE_MAX_AGE: Duration = Duration::days(1);
static CSRF_COOKIE_NAME: &str = "csrf_proxy";
static NONCE_COOKIE_NAME: &str = "nonce_proxy";

pub(crate) fn router() -> Router<AppState> {
Router::new()
.route("/auth_info", get(auth_info))
.route("/callback", post(auth_callback))
}

#[derive(Serialize)]
struct AuthInfo {
url: String,
button_display_name: Option<String>,
}

impl AuthInfo {
#[must_use]
fn new(url: String, button_display_name: Option<String>) -> Self {
Self {
url,
button_display_name,
}
}
}

/// Request external OAuth2/OpenID provider details from Defguard Core.
#[instrument(level = "debug", skip(state))]
async fn auth_info(
State(state): State<AppState>,
private_cookies: PrivateCookieJar,
) -> Result<(PrivateCookieJar, Json<AuthInfo>), ApiError> {
debug!("Getting auth info for OAuth2/OpenID login");

let request = AuthInfoRequest {
redirect_url: state.callback_url().to_string(),
};

let rx = state
.grpc_server
.send(Some(core_request::Payload::AuthInfo(request)), None)?;
let payload = get_core_response(rx).await?;
if let core_response::Payload::AuthInfo(response) = payload {
debug!("Received auth info {response:?}");

let nonce_cookie = Cookie::build((NONCE_COOKIE_NAME, response.nonce))
// .domain(cookie_domain)
.path("/api/v1/openid/callback")
.http_only(true)
.same_site(SameSite::Strict)
.secure(true)
.max_age(COOKIE_MAX_AGE)
.build();
let csrf_cookie = Cookie::build((CSRF_COOKIE_NAME, response.csrf_token))
// .domain(cookie_domain)
.path("/api/v1/openid/callback")
.http_only(true)
.same_site(SameSite::Strict)
.secure(true)
.max_age(COOKIE_MAX_AGE)
.build();
let private_cookies = private_cookies.add(nonce_cookie).add(csrf_cookie);

let auth_info = AuthInfo::new(response.url, response.button_display_name);
Ok((private_cookies, Json(auth_info)))
} else {
error!("Received invalid gRPC response type: {payload:#?}");
Err(ApiError::InvalidResponseType)
}
}

#[derive(Debug, Deserialize)]
pub struct AuthenticationResponse {
code: String,
state: String,
}

#[derive(Serialize)]
struct CallbackResponseData {
url: String,
token: String,
}

#[instrument(level = "debug", skip(state))]
async fn auth_callback(
State(state): State<AppState>,
mut private_cookies: PrivateCookieJar,
Json(payload): Json<AuthenticationResponse>,
) -> Result<(PrivateCookieJar, Json<CallbackResponseData>), ApiError> {
let nonce = private_cookies
.get(NONCE_COOKIE_NAME)
.ok_or(ApiError::Unauthorized("Nonce cookie not found".into()))?
.value_trimmed()
.to_string();
let csrf = private_cookies
.get(CSRF_COOKIE_NAME)
.ok_or(ApiError::Unauthorized("CSRF cookie not found".into()))?
.value_trimmed()
.to_string();

if payload.state != csrf {
return Err(ApiError::Unauthorized("CSRF token mismatch".into()));
}

private_cookies = private_cookies
.remove(Cookie::from(NONCE_COOKIE_NAME))
.remove(Cookie::from(CSRF_COOKIE_NAME));

let request = AuthCallbackRequest {
code: payload.code,
nonce,
callback_url: state.callback_url().to_string(),
};

let rx = state
.grpc_server
.send(Some(core_request::Payload::AuthCallback(request)), None)?;
let payload = get_core_response(rx).await?;
if let core_response::Payload::AuthCallback(AuthCallbackResponse { url, token }) = payload {
debug!("Received auth callback response {url:?} {token:?}");
Ok((private_cookies, Json(CallbackResponseData { url, token })))
} else {
error!("Received invalid gRPC response type during handling the OpenID authentication callback: {payload:#?}");
Err(ApiError::InvalidResponseType)
}
}
1 change: 1 addition & 0 deletions src/enterprise/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pub mod handlers;
12 changes: 7 additions & 5 deletions src/grpc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ pub(crate) struct ProxyServer {
impl ProxyServer {
#[must_use]
/// Create new `ProxyServer`.
pub fn new() -> Self {
pub(crate) fn new() -> Self {
Self {
current_id: Arc::new(AtomicU64::new(1)),
clients: Arc::new(Mutex::new(HashMap::new())),
Expand All @@ -42,7 +42,7 @@ impl ProxyServer {
/// Sends message to the other side of RPC, with given `payload` and optional `device_info`.
/// Returns `tokio::sync::oneshot::Reveicer` to let the caller await reply.
#[instrument(name = "send_grpc_message", level = "debug", skip(self))]
pub fn send(
pub(crate) fn send(
&self,
payload: Option<core_request::Payload>,
device_info: Option<DeviceInfo>,
Expand All @@ -64,9 +64,11 @@ impl ProxyServer {
self.connected.store(true, Ordering::Relaxed);
Ok(rx)
} else {
error!("Defguard core is disconnected");
error!("Defguard Core is not connected");
self.connected.store(false, Ordering::Relaxed);
Err(ApiError::Unexpected("Defguard core is disconnected".into()))
Err(ApiError::Unexpected(
"Defguard Core is not connected".into(),
))
}
}
}
Expand Down Expand Up @@ -96,7 +98,7 @@ impl proxy_server::Proxy for ProxyServer {
error!("Failed to determine client address for request: {request:?}");
return Err(Status::internal("Failed to determine client address"));
};
info!("Defguard core RPC client connected from: {address}");
info!("Defguard Core gRPC client connected from: {address}");

let (tx, rx) = mpsc::unbounded_channel();
self.clients.lock().unwrap().insert(address, tx);
Expand Down
10 changes: 5 additions & 5 deletions src/handlers/enrollment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use crate::{
},
};

pub fn router() -> Router<AppState> {
pub(crate) fn router() -> Router<AppState> {
Router::new()
.route("/start", post(start_enrollment_process))
.route("/activate_user", post(activate_user))
Expand All @@ -21,7 +21,7 @@ pub fn router() -> Router<AppState> {
}

#[instrument(level = "debug", skip(state))]
pub async fn start_enrollment_process(
async fn start_enrollment_process(
State(state): State<AppState>,
mut private_cookies: PrivateCookieJar,
Json(req): Json<EnrollmentStartRequest>,
Expand Down Expand Up @@ -60,7 +60,7 @@ pub async fn start_enrollment_process(
}

#[instrument(level = "debug", skip(state))]
pub async fn activate_user(
async fn activate_user(
State(state): State<AppState>,
device_info: Option<DeviceInfo>,
mut private_cookies: PrivateCookieJar,
Expand Down Expand Up @@ -95,7 +95,7 @@ pub async fn activate_user(
}

#[instrument(level = "debug", skip(state))]
pub async fn create_device(
async fn create_device(
State(state): State<AppState>,
device_info: Option<DeviceInfo>,
private_cookies: PrivateCookieJar,
Expand Down Expand Up @@ -123,7 +123,7 @@ pub async fn create_device(
}

#[instrument(level = "debug", skip(state))]
pub async fn get_network_info(
async fn get_network_info(
State(state): State<AppState>,
private_cookies: PrivateCookieJar,
Json(mut req): Json<ExistingDevice>,
Expand Down
17 changes: 12 additions & 5 deletions src/handlers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use axum::{extract::FromRequestParts, http::request::Parts};
use axum_client_ip::{InsecureClientIp, LeftmostXForwardedFor};
use axum_extra::{headers::UserAgent, TypedHeader};
use tokio::{sync::oneshot::Receiver, time::timeout};
use tonic::Code;

use super::proto::DeviceInfo;
use crate::{error::ApiError, proto::core_response::Payload};
Expand All @@ -13,8 +14,8 @@ pub(crate) mod enrollment;
pub(crate) mod password_reset;
pub(crate) mod polling;

// timeout in seconds for awaiting core response
const CORE_RESPONSE_TIMEOUT: u64 = 5;
// Timeout for awaiting response from Defguard Core.
const CORE_RESPONSE_TIMEOUT: Duration = Duration::from_secs(5);

#[tonic::async_trait]
impl<S> FromRequestParts<S> for DeviceInfo
Expand Down Expand Up @@ -47,11 +48,17 @@ where
/// Helper which awaits core response
///
/// Waits for core response with a given timeout and returns the response payload.
async fn get_core_response(rx: Receiver<Payload>) -> Result<Payload, ApiError> {
pub(crate) async fn get_core_response(rx: Receiver<Payload>) -> Result<Payload, ApiError> {
debug!("Fetching core response...");
if let Ok(core_response) = timeout(Duration::from_secs(CORE_RESPONSE_TIMEOUT), rx).await {
if let Ok(core_response) = timeout(CORE_RESPONSE_TIMEOUT, rx).await {
debug!("Got gRPC response from Defguard core: {core_response:?}");
if let Ok(Payload::CoreError(core_error)) = core_response {
if core_error.status_code == Code::FailedPrecondition as i32
&& core_error.message == "no valid license"
{
debug!("Tried to get core response related to an enterprise feature but the enterprise is not enabled, ignoring it...");
return Err(ApiError::EnterpriseNotEnabled);
}
error!(
"Received an error response from the core service. | status code: {} message: {}",
core_error.status_code, core_error.message
Expand All @@ -61,7 +68,7 @@ async fn get_core_response(rx: Receiver<Payload>) -> Result<Payload, ApiError> {
core_response
.map_err(|err| ApiError::Unexpected(format!("Failed to receive core response: {err}")))
} else {
error!("Did not receive core response within {CORE_RESPONSE_TIMEOUT} seconds");
error!("Did not receive core response within {CORE_RESPONSE_TIMEOUT:?}");
Err(ApiError::CoreTimeout)
}
}
Loading

0 comments on commit 00c6b41

Please sign in to comment.