diff --git a/rust/py_test/test_launch_server.py b/rust/py_test/test_launch_server.py index a7a695aa9f6..dcfe423466d 100644 --- a/rust/py_test/test_launch_server.py +++ b/rust/py_test/test_launch_server.py @@ -1,3 +1,4 @@ +import socket import subprocess import time import unittest @@ -49,7 +50,7 @@ def popen_launch_router( # Use current environment env = None - process = subprocess.Popen(command, stdout=None, stderr=None, env=env) + process = subprocess.Popen(command, stdout=None, stderr=None) start_time = time.time() with requests.Session() as session: @@ -57,6 +58,52 @@ def popen_launch_router( try: response = session.get(f"{base_url}/health") if response.status_code == 200: + print(f"Router {base_url} is healthy") + return process + except requests.RequestException: + pass + time.sleep(10) + + raise TimeoutError("Router failed to start within the timeout period.") + + +def find_available_port(): + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("127.0.0.1", 0)) + return s.getsockname()[1] + + +def popen_launch_server( + model: str, + base_url: str, + timeout: float, +): + _, host, port = base_url.split(":") + host = host[2:] + + command = [ + "python3", + "-m", + "sglang.launch_server", + "--model-path", + model, + "--host", + host, + "--port", + port, + "--base-gpu-id", + "1", + ] + + process = subprocess.Popen(command, stdout=None, stderr=None) + + start_time = time.time() + with requests.Session() as session: + while time.time() - start_time < timeout: + try: + response = session.get(f"{base_url}/health") + if response.status_code == 200: + print(f"Server {base_url} is healthy") return process except requests.RequestException: pass @@ -76,10 +123,13 @@ def setUpClass(cls): dp_size=1, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, ) + cls.other_process = [] @classmethod def tearDownClass(cls): kill_process_tree(cls.process.pid) + for process in cls.other_process: + kill_process_tree(process.pid) def test_mmlu(self): args = SimpleNamespace( @@ -98,6 +148,35 @@ def test_mmlu(self): msg = f"MMLU test {'passed' if passed else 'failed'} with score {score:.3f} (threshold: {THRESHOLD})" self.assertGreaterEqual(score, THRESHOLD, msg) + def test_add_worker(self): + # 1. start a worker, and wait until it is healthy + port = find_available_port() + worker_url = f"http://127.0.0.1:{port}" + worker_process = popen_launch_server( + self.model, worker_url, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH + ) + self.other_process.append(worker_process) + # 2. use /add_worker api to add it the the router + with requests.Session() as session: + response = session.post(f"{self.base_url}/add_worker?url={worker_url}") + print(f"status code: {response.status_code}, response: {response.text}") + self.assertEqual(response.status_code, 200) + # 3. run mmlu + args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="mmlu", + num_examples=64, + num_threads=32, + temperature=0.1, + ) + metrics = run_eval(args) + score = metrics["score"] + THRESHOLD = 0.65 + passed = score >= THRESHOLD + msg = f"MMLU test {'passed' if passed else 'failed'} with score {score:.3f} (threshold: {THRESHOLD})" + self.assertGreaterEqual(score, THRESHOLD, msg) + if __name__ == "__main__": unittest.main() diff --git a/rust/src/router.rs b/rust/src/router.rs index e17cba874c9..74e47209bd7 100644 --- a/rust/src/router.rs +++ b/rust/src/router.rs @@ -7,18 +7,18 @@ use log::{debug, info}; use std::collections::HashMap; use std::fmt::Debug; use std::sync::atomic::AtomicUsize; -use std::sync::{Arc, Mutex}; +use std::sync::{Arc, Mutex, RwLock}; use std::thread; use std::time::Duration; #[derive(Debug)] pub enum Router { RoundRobin { - worker_urls: Vec, + worker_urls: Arc>>, current_index: AtomicUsize, }, Random { - worker_urls: Vec, + worker_urls: Arc>>, }, CacheAware { /* @@ -81,7 +81,7 @@ pub enum Router { Maximum nodes per tree. When exceeded, LRU leaf nodes are evicted during the next eviction cycle. */ - worker_urls: Vec, + worker_urls: Arc>>, tree: Arc>, running_queue: Arc>>, processed_queue: Arc>>, @@ -129,9 +129,11 @@ fn get_text_from_request(body: &Bytes, route: &str) -> String { impl Router { pub fn new(worker_urls: Vec, policy_config: PolicyConfig) -> Self { match policy_config { - PolicyConfig::RandomConfig => Router::Random { worker_urls }, + PolicyConfig::RandomConfig => Router::Random { + worker_urls: Arc::new(RwLock::new(worker_urls)), + }, PolicyConfig::RoundRobinConfig => Router::RoundRobin { - worker_urls, + worker_urls: Arc::new(RwLock::new(worker_urls)), current_index: std::sync::atomic::AtomicUsize::new(0), }, PolicyConfig::CacheAwareConfig { @@ -183,7 +185,7 @@ impl Router { } Router::CacheAware { - worker_urls, + worker_urls: Arc::new(RwLock::new(worker_urls)), tree, running_queue, processed_queue, @@ -201,10 +203,10 @@ impl Router { Router::RoundRobin { worker_urls, .. } | Router::Random { worker_urls } | Router::CacheAware { worker_urls, .. } => { - if worker_urls.is_empty() { + if worker_urls.read().unwrap().is_empty() { None } else { - Some(worker_urls[0].clone()) + Some(worker_urls.read().unwrap()[0].clone()) } } } @@ -228,15 +230,15 @@ impl Router { .fetch_update( std::sync::atomic::Ordering::SeqCst, std::sync::atomic::Ordering::SeqCst, - |x| Some((x + 1) % worker_urls.len()), + |x| Some((x + 1) % worker_urls.read().unwrap().len()), ) .unwrap(); - worker_urls[idx].clone() + worker_urls.read().unwrap()[idx].clone() } - Router::Random { worker_urls } => { - worker_urls[rand::random::() % worker_urls.len()].clone() - } + Router::Random { worker_urls } => worker_urls.read().unwrap() + [rand::random::() % worker_urls.read().unwrap().len()] + .clone(), Router::CacheAware { worker_urls, @@ -277,7 +279,7 @@ impl Router { .iter() .min_by_key(|(_url, &count)| count) .map(|(url, _)| url.clone()) - .unwrap_or_else(|| worker_urls[0].clone()) + .unwrap_or_else(|| worker_urls.read().unwrap()[0].clone()) } else { // Use cache-aware routing when load is balanced let (matched_text, matched_worker) = tree.prefix_match(&text); @@ -333,7 +335,10 @@ impl Router { // For non-streaming requests, get response first let response = match res.bytes().await { Ok(body) => HttpResponse::build(status).body(body.to_vec()), - Err(_) => HttpResponse::InternalServerError().finish(), + Err(e) => { + let error_msg = format!("Failed to get response body: {}", e); + HttpResponse::InternalServerError().body(error_msg) + } }; // Then decrement running queue counter if using CacheAware @@ -379,4 +384,16 @@ impl Router { })) } } + + pub fn add_worker(&self, worker_url: String) { + match self { + Router::RoundRobin { worker_urls, .. } + | Router::Random { worker_urls } + | Router::CacheAware { worker_urls, .. } => { + let mut urls = worker_urls.write().unwrap(); + info!("Added worker: {}", worker_url); + urls.push(worker_url); + } + } + } } diff --git a/rust/src/server.rs b/rust/src/server.rs index 3fbe5c3e895..269214acfef 100644 --- a/rust/src/server.rs +++ b/rust/src/server.rs @@ -1,9 +1,12 @@ use crate::router::PolicyConfig; use crate::router::Router; -use actix_web::{get, post, web, App, HttpRequest, HttpResponse, HttpServer, Responder}; +use actix_web::{ + delete, get, post, put, web, App, HttpRequest, HttpResponse, HttpServer, Responder, +}; use bytes::Bytes; use env_logger::Builder; use log::{info, LevelFilter}; +use std::collections::HashMap; use std::io::Write; #[derive(Debug)] @@ -128,6 +131,22 @@ async fn v1_completions( .await } +#[post("/add_worker")] +async fn add_worker( + query: web::Query>, + data: web::Data, +) -> impl Responder { + let worker_url = match query.get("url") { + Some(url) => url.to_string(), + None => { + return HttpResponse::BadRequest() + .body("Worker URL required. Provide 'url' query parameter") + } + }; + data.router.add_worker(worker_url); + HttpResponse::Ok().finish() +} + pub struct ServerConfig { pub host: String, pub port: u16, @@ -183,6 +202,7 @@ pub async fn startup(config: ServerConfig) -> std::io::Result<()> { .service(health) .service(health_generate) .service(get_server_info) + .service(add_worker) }) .bind((config.host, config.port))? .run()