From ef995dae1e9e7cdc7cfa7d78a195a3943d7e3e6b Mon Sep 17 00:00:00 2001 From: Byron Hsu Date: Sat, 7 Dec 2024 15:39:54 -0800 Subject: [PATCH] [router] Health check on worker before adding to the router (#2392) --- rust/Cargo.lock | 7 +-- rust/Cargo.toml | 1 + rust/py_test/test_launch_server.py | 28 +++++------- rust/src/router.rs | 71 ++++++++++++++++++++++++++---- rust/src/server.rs | 3 +- 5 files changed, 79 insertions(+), 31 deletions(-) diff --git a/rust/Cargo.lock b/rust/Cargo.lock index 37c2733fdc0..8e7f306589f 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -1,6 +1,6 @@ # This file is automatically @generated by Cargo. # It is not intended for manual editing. -version = 3 +version = 4 [[package]] name = "actix-codec" @@ -2219,6 +2219,7 @@ dependencies = [ "serde", "serde_json", "tokenizers", + "tokio", ] [[package]] @@ -2475,9 +2476,9 @@ dependencies = [ [[package]] name = "tokio" -version = "1.41.0" +version = "1.42.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "145f3413504347a2be84393cc8a7d2fb4d863b375909ea59f2158261aa258bbb" +checksum = "5cec9b21b0450273377fc97bd4c33a8acffc8c996c987a7c5b319a0083707551" dependencies = [ "backtrace", "bytes", diff --git a/rust/Cargo.toml b/rust/Cargo.toml index 5ac77665bcc..d49af81cf56 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -29,6 +29,7 @@ http = "1.1.0" env_logger = "0.11.5" log = "0.4.22" chrono = "0.4.38" +tokio = "1.42.0" [profile.release] lto = "thin" diff --git a/rust/py_test/test_launch_server.py b/rust/py_test/test_launch_server.py index b3f82988354..68945d8fb52 100644 --- a/rust/py_test/test_launch_server.py +++ b/rust/py_test/test_launch_server.py @@ -20,6 +20,7 @@ def popen_launch_router( base_url: str, dp_size: int, timeout: float, + policy: str = "cache_aware", ): """ Launch the router server process. @@ -29,6 +30,7 @@ def popen_launch_router( base_url: Server base URL dp_size: Data parallel size timeout: Server launch timeout + policy: Router policy, one of "cache_aware", "round_robin", "random" """ _, host, port = base_url.split(":") host = host[2:] @@ -47,11 +49,10 @@ def popen_launch_router( str(dp_size), # Convert dp_size to string "--router-eviction-interval", "5", # frequent eviction for testing + "--router-policy", + policy, ] - # Use current environment - env = None - process = subprocess.Popen(command, stdout=None, stderr=None) start_time = time.time() @@ -99,19 +100,8 @@ def popen_launch_server( process = subprocess.Popen(command, stdout=None, stderr=None) - start_time = time.time() - with requests.Session() as session: - while time.time() - start_time < timeout: - try: - response = session.get(f"{base_url}/health") - if response.status_code == 200: - print(f"Server {base_url} is healthy") - return process - except requests.RequestException: - pass - time.sleep(10) - - raise TimeoutError("Server failed to start within the timeout period.") + # intentionally don't wait and defer the job to the router health check + return process class TestLaunchServer(unittest.TestCase): @@ -135,6 +125,7 @@ def test_mmlu(self): self.base_url, dp_size=2, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + policy="cache_aware", ) args = SimpleNamespace( @@ -160,6 +151,7 @@ def test_add_and_remove_worker(self): self.base_url, dp_size=1, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + policy="round_robin", # use round robin to make sure every worker processes requests ) # 1. start a worker, and wait until it is healthy port = find_available_port() @@ -168,11 +160,13 @@ def test_add_and_remove_worker(self): self.model, worker_url, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH ) TestLaunchServer.other_process.append(worker_process) - # 2. use /add_worker api to add it the the router + + # 2. use /add_worker api to add it the the router. It will be used by router after it is healthy with requests.Session() as session: response = session.post(f"{self.base_url}/add_worker?url={worker_url}") print(f"status code: {response.status_code}, response: {response.text}") self.assertEqual(response.status_code, 200) + # 3. run mmlu args = SimpleNamespace( base_url=self.base_url, diff --git a/rust/src/router.rs b/rust/src/router.rs index 5641fccbc74..acba974972c 100644 --- a/rust/src/router.rs +++ b/rust/src/router.rs @@ -3,13 +3,14 @@ use actix_web::http::header::{HeaderValue, CONTENT_TYPE}; use actix_web::{HttpRequest, HttpResponse}; use bytes::Bytes; use futures_util::{StreamExt, TryStreamExt}; -use log::{debug, info}; +use log::{debug, info, warn}; use std::collections::HashMap; use std::fmt::Debug; use std::sync::atomic::AtomicUsize; use std::sync::{Arc, Mutex, RwLock}; use std::thread; use std::time::Duration; +use tokio; #[derive(Debug)] pub enum Router { @@ -385,14 +386,66 @@ impl Router { } } - pub fn add_worker(&self, worker_url: String) { - match self { - Router::RoundRobin { worker_urls, .. } - | Router::Random { worker_urls } - | Router::CacheAware { worker_urls, .. } => { - let mut urls = worker_urls.write().unwrap(); - info!("Added worker: {}", worker_url); - urls.push(worker_url); + pub async fn add_worker(&self, worker_url: String) -> HttpResponse { + let interval_secs = 10; // check every 10 seconds + let timeout_secs = 300; // 5 minutes + + let start_time = std::time::Instant::now(); + let client = reqwest::Client::new(); + + loop { + if start_time.elapsed() > Duration::from_secs(timeout_secs) { + return HttpResponse::InternalServerError().body(format!( + "Timeout {}s waiting for worker {} to become healthy", + timeout_secs, worker_url + )); + } + + match client.get(&format!("{}/health", worker_url)).send().await { + Ok(res) => { + if res.status().is_success() { + match self { + Router::RoundRobin { worker_urls, .. } + | Router::Random { worker_urls } + | Router::CacheAware { worker_urls, .. } => { + info!("Worker {} health check passed", worker_url); + let mut urls = worker_urls.write().unwrap(); + if urls.contains(&worker_url) { + return HttpResponse::BadRequest() + .body(format!("Worker {} already exists", worker_url)); + } + info!("Added worker: {}", worker_url); + urls.push(worker_url.clone()); + } + } + return HttpResponse::Ok() + .body(format!("Successfully added worker: {}", worker_url)); + } else { + info!( + "Worker {} health check failed with status: {}. The worker might still be starting up.", + worker_url, res.status() + ); + // if the url does not have http or https prefix, warn users + if !worker_url.starts_with("http://") && !worker_url.starts_with("https://") + { + warn!("The worker url {} does not have http or https prefix. Please add the prefix to the url.", worker_url); + } + + tokio::time::sleep(Duration::from_secs(interval_secs)).await; + continue; + } + } + Err(e) => { + info!("Worker {} health check failed: {}", worker_url, e); + + // if the url does not have http or https prefix, warn users + if !worker_url.starts_with("http://") && !worker_url.starts_with("https://") { + warn!("The worker url {} does not have http or https prefix. Please add the prefix to the url.", worker_url); + } + + tokio::time::sleep(Duration::from_secs(interval_secs)).await; + continue; + } } } } diff --git a/rust/src/server.rs b/rust/src/server.rs index d8d2e38e945..d7ec6ebc6e5 100644 --- a/rust/src/server.rs +++ b/rust/src/server.rs @@ -141,8 +141,7 @@ async fn add_worker( .body("Worker URL required. Provide 'url' query parameter") } }; - data.router.add_worker(worker_url); - HttpResponse::Ok().finish() + data.router.add_worker(worker_url).await } #[post("/remove_worker")]