From ff33cf80b17ea26215dc9f7670584609be9248df Mon Sep 17 00:00:00 2001 From: ByronHsu Date: Fri, 6 Dec 2024 01:32:43 +0000 Subject: [PATCH 1/7] wip --- rust/src/router.rs | 40 +++++++++++++++++++++++++++------------- rust/src/server.rs | 18 +++++++++++++++++- 2 files changed, 44 insertions(+), 14 deletions(-) diff --git a/rust/src/router.rs b/rust/src/router.rs index e17cba874c9..62b149bb838 100644 --- a/rust/src/router.rs +++ b/rust/src/router.rs @@ -7,18 +7,18 @@ use log::{debug, info}; use std::collections::HashMap; use std::fmt::Debug; use std::sync::atomic::AtomicUsize; -use std::sync::{Arc, Mutex}; +use std::sync::{Arc, Mutex, RwLock}; use std::thread; use std::time::Duration; #[derive(Debug)] pub enum Router { RoundRobin { - worker_urls: Vec, + worker_urls: Arc>>, current_index: AtomicUsize, }, Random { - worker_urls: Vec, + worker_urls: Arc>>, }, CacheAware { /* @@ -81,7 +81,7 @@ pub enum Router { Maximum nodes per tree. When exceeded, LRU leaf nodes are evicted during the next eviction cycle. */ - worker_urls: Vec, + worker_urls: Arc>>, tree: Arc>, running_queue: Arc>>, processed_queue: Arc>>, @@ -129,9 +129,11 @@ fn get_text_from_request(body: &Bytes, route: &str) -> String { impl Router { pub fn new(worker_urls: Vec, policy_config: PolicyConfig) -> Self { match policy_config { - PolicyConfig::RandomConfig => Router::Random { worker_urls }, + PolicyConfig::RandomConfig => Router::Random { + worker_urls: Arc::new(RwLock::new(worker_urls)), + }, PolicyConfig::RoundRobinConfig => Router::RoundRobin { - worker_urls, + worker_urls: Arc::new(RwLock::new(worker_urls)), current_index: std::sync::atomic::AtomicUsize::new(0), }, PolicyConfig::CacheAwareConfig { @@ -183,7 +185,7 @@ impl Router { } Router::CacheAware { - worker_urls, + worker_urls: Arc::new(RwLock::new(worker_urls)), tree, running_queue, processed_queue, @@ -201,10 +203,10 @@ impl Router { Router::RoundRobin { worker_urls, .. } | Router::Random { worker_urls } | Router::CacheAware { worker_urls, .. } => { - if worker_urls.is_empty() { + if worker_urls.read().unwrap().is_empty() { None } else { - Some(worker_urls[0].clone()) + Some(worker_urls.read().unwrap()[0].clone()) } } } @@ -228,14 +230,14 @@ impl Router { .fetch_update( std::sync::atomic::Ordering::SeqCst, std::sync::atomic::Ordering::SeqCst, - |x| Some((x + 1) % worker_urls.len()), + |x| Some((x + 1) % worker_urls.read().unwrap().len()), ) .unwrap(); - worker_urls[idx].clone() + worker_urls.read().unwrap()[idx].clone() } Router::Random { worker_urls } => { - worker_urls[rand::random::() % worker_urls.len()].clone() + worker_urls.read().unwrap()[rand::random::() % worker_urls.read().unwrap().len()].clone() } Router::CacheAware { @@ -277,7 +279,7 @@ impl Router { .iter() .min_by_key(|(_url, &count)| count) .map(|(url, _)| url.clone()) - .unwrap_or_else(|| worker_urls[0].clone()) + .unwrap_or_else(|| worker_urls.read().unwrap()[0].clone()) } else { // Use cache-aware routing when load is balanced let (matched_text, matched_worker) = tree.prefix_match(&text); @@ -379,4 +381,16 @@ impl Router { })) } } + + pub fn add_worker(&self, worker_url: String) { + match self { + Router::RoundRobin { worker_urls, .. } + | Router::Random { worker_urls } + | Router::CacheAware { worker_urls, .. } => { + let mut urls = worker_urls.write().unwrap(); + info!("Added worker: {}", worker_url); + urls.push(worker_url); + } + } + } } diff --git a/rust/src/server.rs b/rust/src/server.rs index 3fbe5c3e895..1a6f85d8515 100644 --- a/rust/src/server.rs +++ b/rust/src/server.rs @@ -1,9 +1,10 @@ use crate::router::PolicyConfig; use crate::router::Router; -use actix_web::{get, post, web, App, HttpRequest, HttpResponse, HttpServer, Responder}; +use actix_web::{get, post, web, App, HttpRequest, HttpResponse, HttpServer, Responder, put, delete}; use bytes::Bytes; use env_logger::Builder; use log::{info, LevelFilter}; +use std::collections::HashMap; use std::io::Write; #[derive(Debug)] @@ -128,6 +129,20 @@ async fn v1_completions( .await } +#[post("/add_worker")] +async fn add_worker( + query: web::Query>, + data: web::Data, +) -> impl Responder { + let worker_url = match query.get("url") { + Some(url) => url.to_string(), + None => return HttpResponse::BadRequest() + .body("Worker URL required. Provide 'url' query parameter"), + }; + data.router.add_worker(worker_url); + HttpResponse::Ok().finish() +} + pub struct ServerConfig { pub host: String, pub port: u16, @@ -183,6 +198,7 @@ pub async fn startup(config: ServerConfig) -> std::io::Result<()> { .service(health) .service(health_generate) .service(get_server_info) + .service(add_worker) }) .bind((config.host, config.port))? .run() From 63fd7ed2faf0f79b097a70a87896c0f52b0823d5 Mon Sep 17 00:00:00 2001 From: ByronHsu Date: Fri, 6 Dec 2024 01:38:03 +0000 Subject: [PATCH 2/7] format --- rust/src/router.rs | 6 +++--- rust/src/server.rs | 10 +++++++--- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/rust/src/router.rs b/rust/src/router.rs index 62b149bb838..12e2b717ad8 100644 --- a/rust/src/router.rs +++ b/rust/src/router.rs @@ -236,9 +236,9 @@ impl Router { worker_urls.read().unwrap()[idx].clone() } - Router::Random { worker_urls } => { - worker_urls.read().unwrap()[rand::random::() % worker_urls.read().unwrap().len()].clone() - } + Router::Random { worker_urls } => worker_urls.read().unwrap() + [rand::random::() % worker_urls.read().unwrap().len()] + .clone(), Router::CacheAware { worker_urls, diff --git a/rust/src/server.rs b/rust/src/server.rs index 1a6f85d8515..269214acfef 100644 --- a/rust/src/server.rs +++ b/rust/src/server.rs @@ -1,6 +1,8 @@ use crate::router::PolicyConfig; use crate::router::Router; -use actix_web::{get, post, web, App, HttpRequest, HttpResponse, HttpServer, Responder, put, delete}; +use actix_web::{ + delete, get, post, put, web, App, HttpRequest, HttpResponse, HttpServer, Responder, +}; use bytes::Bytes; use env_logger::Builder; use log::{info, LevelFilter}; @@ -136,8 +138,10 @@ async fn add_worker( ) -> impl Responder { let worker_url = match query.get("url") { Some(url) => url.to_string(), - None => return HttpResponse::BadRequest() - .body("Worker URL required. Provide 'url' query parameter"), + None => { + return HttpResponse::BadRequest() + .body("Worker URL required. Provide 'url' query parameter") + } }; data.router.add_worker(worker_url); HttpResponse::Ok().finish() From 8b58519e126134b0ca39898502dfd844bb687e0b Mon Sep 17 00:00:00 2001 From: ByronHsu Date: Fri, 6 Dec 2024 02:01:43 +0000 Subject: [PATCH 3/7] run on ci --- rust/py_test/test_launch_server.py | 84 ++++++++++++++++++++++++++++-- 1 file changed, 80 insertions(+), 4 deletions(-) diff --git a/rust/py_test/test_launch_server.py b/rust/py_test/test_launch_server.py index a7a695aa9f6..13fcc8f4e4c 100644 --- a/rust/py_test/test_launch_server.py +++ b/rust/py_test/test_launch_server.py @@ -1,3 +1,4 @@ +import socket import subprocess import time import unittest @@ -49,7 +50,49 @@ def popen_launch_router( # Use current environment env = None - process = subprocess.Popen(command, stdout=None, stderr=None, env=env) + process = subprocess.Popen(command, stdout=None, stderr=None) + + start_time = time.time() + with requests.Session() as session: + while time.time() - start_time < timeout: + try: + response = session.get(f"{base_url}/health") + if response.status_code == 200: + return process + except requests.RequestException: + pass + time.sleep(10) + + raise TimeoutError("Router failed to start within the timeout period.") + + +def find_available_port(): + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + return s.getsockname()[1] + + +def popen_launch_server( + model: str, + base_url: str, + timeout: float, +): + _, host, port = base_url.split(":") + host = host[2:] + + command = [ + "python3", + "-m", + "sglang.launch_server", + "--model-path", + model, + "--host", + host, + "--port", + port, + ] + + process = subprocess.Popen(command, stdout=None, stderr=None) start_time = time.time() with requests.Session() as session: @@ -76,12 +119,46 @@ def setUpClass(cls): dp_size=1, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, ) + cls.other_process = [] @classmethod def tearDownClass(cls): kill_process_tree(cls.process.pid) - - def test_mmlu(self): + for process in cls.other_process: + kill_process_tree(process.pid) + + # def test_mmlu(self): + # args = SimpleNamespace( + # base_url=self.base_url, + # model=self.model, + # eval_name="mmlu", + # num_examples=64, + # num_threads=32, + # temperature=0.1, + # ) + + # metrics = run_eval(args) + # score = metrics["score"] + # THRESHOLD = 0.65 + # passed = score >= THRESHOLD + # msg = f"MMLU test {'passed' if passed else 'failed'} with score {score:.3f} (threshold: {THRESHOLD})" + # self.assertGreaterEqual(score, THRESHOLD, msg) + + def test_add_worker(self): + # 1. start a worker, and wait until it is healthy + port = find_available_port() + worker_url = f"http://127.0.0.1:{port}" + worker_process = popen_launch_server( + self.model, worker_url, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH + ) + self.other_process.append(worker_process) + # 2. use /add_worker api to add it the the router + with requests.Session() as session: + response = session.post( + f"{self.base_url}/add_worker", json={"url": worker_url} + ) + self.assertEqual(response.status_code, 200) + # 3. run mmlu args = SimpleNamespace( base_url=self.base_url, model=self.model, @@ -90,7 +167,6 @@ def test_mmlu(self): num_threads=32, temperature=0.1, ) - metrics = run_eval(args) score = metrics["score"] THRESHOLD = 0.65 From b60772a3c2b7222fcec5025ca44244d2ebfe0f03 Mon Sep 17 00:00:00 2001 From: ByronHsu Date: Fri, 6 Dec 2024 02:07:59 +0000 Subject: [PATCH 4/7] base gpu id --- rust/py_test/test_launch_server.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/rust/py_test/test_launch_server.py b/rust/py_test/test_launch_server.py index 13fcc8f4e4c..abcf4cb59e1 100644 --- a/rust/py_test/test_launch_server.py +++ b/rust/py_test/test_launch_server.py @@ -90,6 +90,8 @@ def popen_launch_server( host, "--port", port, + "--base-gpu-id", + "1", ] process = subprocess.Popen(command, stdout=None, stderr=None) From 6a48f56fac7af880c8f56aaf74feffe899d71950 Mon Sep 17 00:00:00 2001 From: ByronHsu Date: Fri, 6 Dec 2024 09:00:32 +0000 Subject: [PATCH 5/7] fix failed test --- rust/py_test/test_launch_server.py | 37 ++++++++++++++++-------------- rust/src/router.rs | 5 +++- 2 files changed, 24 insertions(+), 18 deletions(-) diff --git a/rust/py_test/test_launch_server.py b/rust/py_test/test_launch_server.py index abcf4cb59e1..ef24ef3302d 100644 --- a/rust/py_test/test_launch_server.py +++ b/rust/py_test/test_launch_server.py @@ -58,6 +58,7 @@ def popen_launch_router( try: response = session.get(f"{base_url}/health") if response.status_code == 200: + print(f"Router {base_url} is healthy") return process except requests.RequestException: pass @@ -102,6 +103,7 @@ def popen_launch_server( try: response = session.get(f"{base_url}/health") if response.status_code == 200: + print(f"Server {base_url} is healthy") return process except requests.RequestException: pass @@ -129,22 +131,22 @@ def tearDownClass(cls): for process in cls.other_process: kill_process_tree(process.pid) - # def test_mmlu(self): - # args = SimpleNamespace( - # base_url=self.base_url, - # model=self.model, - # eval_name="mmlu", - # num_examples=64, - # num_threads=32, - # temperature=0.1, - # ) - - # metrics = run_eval(args) - # score = metrics["score"] - # THRESHOLD = 0.65 - # passed = score >= THRESHOLD - # msg = f"MMLU test {'passed' if passed else 'failed'} with score {score:.3f} (threshold: {THRESHOLD})" - # self.assertGreaterEqual(score, THRESHOLD, msg) + def test_mmlu(self): + args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="mmlu", + num_examples=64, + num_threads=32, + temperature=0.1, + ) + + metrics = run_eval(args) + score = metrics["score"] + THRESHOLD = 0.65 + passed = score >= THRESHOLD + msg = f"MMLU test {'passed' if passed else 'failed'} with score {score:.3f} (threshold: {THRESHOLD})" + self.assertGreaterEqual(score, THRESHOLD, msg) def test_add_worker(self): # 1. start a worker, and wait until it is healthy @@ -157,8 +159,9 @@ def test_add_worker(self): # 2. use /add_worker api to add it the the router with requests.Session() as session: response = session.post( - f"{self.base_url}/add_worker", json={"url": worker_url} + f"{self.base_url}/add_worker?url={worker_url}" ) + print(f"status code: {response.status_code}, response: {response.text}") self.assertEqual(response.status_code, 200) # 3. run mmlu args = SimpleNamespace( diff --git a/rust/src/router.rs b/rust/src/router.rs index 12e2b717ad8..74e47209bd7 100644 --- a/rust/src/router.rs +++ b/rust/src/router.rs @@ -335,7 +335,10 @@ impl Router { // For non-streaming requests, get response first let response = match res.bytes().await { Ok(body) => HttpResponse::build(status).body(body.to_vec()), - Err(_) => HttpResponse::InternalServerError().finish(), + Err(e) => { + let error_msg = format!("Failed to get response body: {}", e); + HttpResponse::InternalServerError().body(error_msg) + } }; // Then decrement running queue counter if using CacheAware From cbc003fd4deca5c60be50b95063365f3128ecd29 Mon Sep 17 00:00:00 2001 From: ByronHsu Date: Fri, 6 Dec 2024 09:02:05 +0000 Subject: [PATCH 6/7] fix security --- rust/py_test/test_launch_server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rust/py_test/test_launch_server.py b/rust/py_test/test_launch_server.py index ef24ef3302d..299837c4da5 100644 --- a/rust/py_test/test_launch_server.py +++ b/rust/py_test/test_launch_server.py @@ -69,7 +69,7 @@ def popen_launch_router( def find_available_port(): with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.bind(("", 0)) + s.bind(("127.0.0.1", 0)) return s.getsockname()[1] From afa715f7d28747ec515c31def55b20e4ee0b5559 Mon Sep 17 00:00:00 2001 From: ByronHsu Date: Fri, 6 Dec 2024 09:04:13 +0000 Subject: [PATCH 7/7] fmt --- rust/py_test/test_launch_server.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/rust/py_test/test_launch_server.py b/rust/py_test/test_launch_server.py index 299837c4da5..dcfe423466d 100644 --- a/rust/py_test/test_launch_server.py +++ b/rust/py_test/test_launch_server.py @@ -158,9 +158,7 @@ def test_add_worker(self): self.other_process.append(worker_process) # 2. use /add_worker api to add it the the router with requests.Session() as session: - response = session.post( - f"{self.base_url}/add_worker?url={worker_url}" - ) + response = session.post(f"{self.base_url}/add_worker?url={worker_url}") print(f"status code: {response.status_code}, response: {response.text}") self.assertEqual(response.status_code, 200) # 3. run mmlu