From c955a89c433419a1ea4529732982cc802a8ec8eb Mon Sep 17 00:00:00 2001 From: ByronHsu Date: Sun, 8 Dec 2024 23:38:53 +0000 Subject: [PATCH] wip --- rust/Cargo.lock | 2 + rust/Cargo.toml | 2 +- rust/py_src/sglang_router/launch_server.py | 7 -- rust/src/router.rs | 110 ++++++++++++++++++--- rust/src/server.rs | 23 +++-- 5 files changed, 114 insertions(+), 30 deletions(-) diff --git a/rust/Cargo.lock b/rust/Cargo.lock index 8e7f306589f..dc9c46a7146 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -851,6 +851,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2dff15bf788c671c1934e366d07e30c1814a8ef514e1af724a602e8a2fbe1b10" dependencies = [ "futures-core", + "futures-sink", ] [[package]] @@ -1986,6 +1987,7 @@ dependencies = [ "base64 0.22.1", "bytes", "encoding_rs", + "futures-channel", "futures-core", "futures-util", "h2 0.4.6", diff --git a/rust/Cargo.toml b/rust/Cargo.toml index d49af81cf56..d20a381ee7b 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -19,7 +19,7 @@ serde = { version = "1.0", features = ["derive"] } clap = { version = "4.4", features = ["derive"] } bytes = "1.8.0" rand = "0.8.5" -reqwest = { version = "0.12.8", features = ["stream"] } +reqwest = { version = "0.12.8", features = ["stream", "blocking"] } futures-util = "0.3" serde_json = "1.0" pyo3 = { version = "0.22.5", features = ["extension-module"] } diff --git a/rust/py_src/sglang_router/launch_server.py b/rust/py_src/sglang_router/launch_server.py index 9c482e48986..6bb07cc0c96 100644 --- a/rust/py_src/sglang_router/launch_server.py +++ b/rust/py_src/sglang_router/launch_server.py @@ -137,13 +137,6 @@ def main(): signal.SIGQUIT, lambda sig, frame: cleanup_processes(server_processes) ) - for port in worker_ports: - if not wait_for_server_health(server_args.host, port): - logger.error(f"Server on port {port} failed to become healthy") - break - - logger.info("All servers are healthy. Starting router...") - # Update router args with worker URLs router_args.worker_urls = [ f"http://{server_args.host}:{port}" for port in worker_ports diff --git a/rust/src/router.rs b/rust/src/router.rs index acba974972c..ceb9814f02a 100644 --- a/rust/src/router.rs +++ b/rust/src/router.rs @@ -93,7 +93,7 @@ pub enum Router { }, } -#[derive(Debug)] +#[derive(Debug, Clone)] pub enum PolicyConfig { RandomConfig, RoundRobinConfig, @@ -127,9 +127,14 @@ fn get_text_from_request(body: &Bytes, route: &str) -> String { return "".to_string(); } + impl Router { - pub fn new(worker_urls: Vec, policy_config: PolicyConfig) -> Self { - match policy_config { + pub fn new(worker_urls: Vec, policy_config: PolicyConfig) -> Result { + // Wait until all workers are healthy + Self::wait_for_healthy_workers(&worker_urls, 300, 10)?; + + // Create router based on policy... + Ok(match policy_config { PolicyConfig::RandomConfig => Router::Random { worker_urls: Arc::new(RwLock::new(worker_urls)), }, @@ -196,7 +201,7 @@ impl Router { _eviction_thread: Some(eviction_thread), } } - } + }) } pub fn get_first(&self) -> Option { @@ -213,6 +218,61 @@ impl Router { } } + fn wait_for_healthy_workers( + worker_urls: &[String], + timeout_secs: u64, + interval_secs: u64, + ) -> Result<(), String> { + let start_time = std::time::Instant::now(); + let sync_client = reqwest::blocking::Client::new(); + + loop { + if start_time.elapsed() > Duration::from_secs(timeout_secs) { + return Err(format!( + "Timeout {}s waiting for workers to become healthy", + timeout_secs + )); + } + + let mut all_healthy = true; + let mut unhealthy_workers = Vec::new(); + + for url in worker_urls { + match sync_client.get(&format!("{}/health", url)).send() { + Ok(res) => { + if !res.status().is_success() { + info!( + "Worker {} health check is pending with status: {}.", + url, res.status() + ); + all_healthy = false; + unhealthy_workers.push((url, format!("Status: {}", res.status()))); + } + } + Err(e) => { + info!( + "Worker {} health check is pending with error: {}", + url, e + ); + all_healthy = false; + unhealthy_workers.push((url, format!("Error: {}", e))); + } + } + } + + if all_healthy { + info!("All workers are healthy"); + return Ok(()); + } else { + info!("Unhealthy workers:"); + for (url, reason) in &unhealthy_workers { + info!(" {} - {}", url, reason); + } + thread::sleep(Duration::from_secs(interval_secs)); + } + } + } + pub async fn dispatch( &self, client: &reqwest::Client, @@ -386,7 +446,7 @@ impl Router { } } - pub async fn add_worker(&self, worker_url: String) -> HttpResponse { + pub async fn add_worker(&self, worker_url: String) -> Result { let interval_secs = 10; // check every 10 seconds let timeout_secs = 300; // 5 minutes @@ -395,7 +455,7 @@ impl Router { loop { if start_time.elapsed() > Duration::from_secs(timeout_secs) { - return HttpResponse::InternalServerError().body(format!( + return Err(format!( "Timeout {}s waiting for worker {} to become healthy", timeout_secs, worker_url )); @@ -411,18 +471,38 @@ impl Router { info!("Worker {} health check passed", worker_url); let mut urls = worker_urls.write().unwrap(); if urls.contains(&worker_url) { - return HttpResponse::BadRequest() - .body(format!("Worker {} already exists", worker_url)); + return Err(format!("Worker {} already exists", worker_url)); } info!("Added worker: {}", worker_url); urls.push(worker_url.clone()); } } - return HttpResponse::Ok() - .body(format!("Successfully added worker: {}", worker_url)); + + // If cache aware, initialize the queues for the new worker + if let Router::CacheAware { + running_queue, + processed_queue, + tree, + .. + } = self + { + // Add worker to running queue with initial count of 0 + running_queue.lock().unwrap().insert(worker_url.clone(), 0); + + // Add worker to processed queue with initial count of 0 + processed_queue + .lock() + .unwrap() + .insert(worker_url.clone(), 0); + + // Add worker to tree + tree.lock().unwrap().insert(&"".to_string(), &worker_url); + } + + return Ok(format!("Successfully added worker: {}", worker_url)); } else { info!( - "Worker {} health check failed with status: {}. The worker might still be starting up.", + "Worker {} health check is pending with status: {}.", worker_url, res.status() ); // if the url does not have http or https prefix, warn users @@ -436,7 +516,7 @@ impl Router { } } Err(e) => { - info!("Worker {} health check failed: {}", worker_url, e); + info!("Worker {} health check is pending with error: {}", worker_url, e); // if the url does not have http or https prefix, warn users if !worker_url.starts_with("http://") && !worker_url.starts_with("https://") { @@ -463,9 +543,11 @@ impl Router { } // if cache aware, remove the worker from the tree - if let Router::CacheAware { tree, .. } = self { + if let Router::CacheAware { tree, running_queue, processed_queue, .. } = self { tree.lock().unwrap().remove_tenant(&worker_url); - info!("Removed worker from tree: {}", worker_url); + running_queue.lock().unwrap().remove(&worker_url); + processed_queue.lock().unwrap().remove(&worker_url); + info!("Removed worker from tree and cleaned up queues: {}", worker_url); } } } diff --git a/rust/src/server.rs b/rust/src/server.rs index d7ec6ebc6e5..8a0eb1547d6 100644 --- a/rust/src/server.rs +++ b/rust/src/server.rs @@ -20,7 +20,10 @@ impl AppState { policy_config: PolicyConfig, ) -> Self { // Create router based on policy - let router = Router::new(worker_urls, policy_config); + let router = match Router::new(worker_urls, policy_config) { + Ok(router) => router, + Err(error) => panic!("Failed to create router: {}", error), + }; Self { router, client } } @@ -141,7 +144,11 @@ async fn add_worker( .body("Worker URL required. Provide 'url' query parameter") } }; - data.router.add_worker(worker_url).await + + match data.router.add_worker(worker_url).await { + Ok(message) => HttpResponse::Ok().body(message), + Err(error) => HttpResponse::BadRequest().body(error), + } } #[post("/remove_worker")] @@ -187,20 +194,20 @@ pub async fn startup(config: ServerConfig) -> std::io::Result<()> { ) .init(); - info!("Starting server on {}:{}", config.host, config.port); - info!("Worker URLs: {:?}", config.worker_urls); - info!("Policy Config: {:?}", config.policy_config); - let client = reqwest::Client::builder() .build() .expect("Failed to create HTTP client"); let app_state = web::Data::new(AppState::new( - config.worker_urls, + config.worker_urls.clone(), client, - config.policy_config, + config.policy_config.clone(), )); + info!("✅ Starting router on {}:{}", config.host, config.port); + info!("✅ Serving Worker URLs: {:?}", config.worker_urls); + info!("✅ Policy Config: {:?}", config.policy_config); + HttpServer::new(move || { App::new() .app_data(app_state.clone())