From c36736c841f735aa3a03bfa0db52c9d603c5fb49 Mon Sep 17 00:00:00 2001 From: Byron Hsu Date: Fri, 6 Dec 2024 17:16:03 -0800 Subject: [PATCH] [router] Add remove worker api (#2380) --- .github/workflows/pr-test-rust.yml | 2 +- rust/py_test/test_launch_server.py | 42 +++++++++++++++++++++++------- rust/src/router.rs | 19 ++++++++++++++ rust/src/server.rs | 14 ++++++++++ 4 files changed, 67 insertions(+), 10 deletions(-) diff --git a/.github/workflows/pr-test-rust.yml b/.github/workflows/pr-test-rust.yml index 0df81b487b5..b9e8c5bcb6b 100644 --- a/.github/workflows/pr-test-rust.yml +++ b/.github/workflows/pr-test-rust.yml @@ -57,7 +57,7 @@ jobs: cd rust pip install setuptools-rust wheel build python3 -m build - pip install dist/*.whl + pip install --force-reinstall dist/*.whl - name: Run e2e test run: | cd rust/py_test diff --git a/rust/py_test/test_launch_server.py b/rust/py_test/test_launch_server.py index 0dacc2c9f7d..b3f82988354 100644 --- a/rust/py_test/test_launch_server.py +++ b/rust/py_test/test_launch_server.py @@ -114,17 +114,12 @@ def popen_launch_server( raise TimeoutError("Server failed to start within the timeout period.") -class TestEvalAccuracyMini(unittest.TestCase): +class TestLaunchServer(unittest.TestCase): @classmethod def setUpClass(cls): cls.model = DEFAULT_MODEL_NAME_FOR_TEST cls.base_url = DEFAULT_URL_FOR_TEST - cls.process = popen_launch_router( - cls.model, - cls.base_url, - dp_size=1, - timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - ) + cls.process = None cls.other_process = [] @classmethod @@ -134,6 +129,14 @@ def tearDownClass(cls): kill_process_tree(process.pid) def test_mmlu(self): + # DP size = 2 + TestLaunchServer.process = popen_launch_router( + self.model, + self.base_url, + dp_size=2, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + ) + args = SimpleNamespace( base_url=self.base_url, model=self.model, @@ -150,14 +153,21 @@ def test_mmlu(self): msg = f"MMLU test {'passed' if passed else 'failed'} with score {score:.3f} (threshold: {THRESHOLD})" self.assertGreaterEqual(score, THRESHOLD, msg) - def test_add_worker(self): + def test_add_and_remove_worker(self): + # DP size = 1 + TestLaunchServer.process = popen_launch_router( + self.model, + self.base_url, + dp_size=1, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + ) # 1. start a worker, and wait until it is healthy port = find_available_port() worker_url = f"http://127.0.0.1:{port}" worker_process = popen_launch_server( self.model, worker_url, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH ) - self.other_process.append(worker_process) + TestLaunchServer.other_process.append(worker_process) # 2. use /add_worker api to add it the the router with requests.Session() as session: response = session.post(f"{self.base_url}/add_worker?url={worker_url}") @@ -179,6 +189,20 @@ def test_add_worker(self): msg = f"MMLU test {'passed' if passed else 'failed'} with score {score:.3f} (threshold: {THRESHOLD})" self.assertGreaterEqual(score, THRESHOLD, msg) + # 4. use /remove_worker api to remove it from the router + with requests.Session() as session: + response = session.post(f"{self.base_url}/remove_worker?url={worker_url}") + print(f"status code: {response.status_code}, response: {response.text}") + self.assertEqual(response.status_code, 200) + + # 5. run mmlu again + metrics = run_eval(args) + score = metrics["score"] + THRESHOLD = 0.65 + passed = score >= THRESHOLD + msg = f"MMLU test {'passed' if passed else 'failed'} with score {score:.3f} (threshold: {THRESHOLD})" + self.assertGreaterEqual(score, THRESHOLD, msg) + if __name__ == "__main__": unittest.main() diff --git a/rust/src/router.rs b/rust/src/router.rs index 2b6b8d52cff..5641fccbc74 100644 --- a/rust/src/router.rs +++ b/rust/src/router.rs @@ -396,4 +396,23 @@ impl Router { } } } + + pub fn remove_worker(&self, worker_url: String) { + match self { + Router::RoundRobin { worker_urls, .. } + | Router::Random { worker_urls } + | Router::CacheAware { worker_urls, .. } => { + let mut urls = worker_urls.write().unwrap(); + let index = urls.iter().position(|url| url == &worker_url).unwrap(); + urls.remove(index); + info!("Removed worker: {}", worker_url); + } + } + + // if cache aware, remove the worker from the tree + if let Router::CacheAware { tree, .. } = self { + tree.lock().unwrap().remove_tenant(&worker_url); + info!("Removed worker from tree: {}", worker_url); + } + } } diff --git a/rust/src/server.rs b/rust/src/server.rs index 7197b9a2709..d8d2e38e945 100644 --- a/rust/src/server.rs +++ b/rust/src/server.rs @@ -145,6 +145,19 @@ async fn add_worker( HttpResponse::Ok().finish() } +#[post("/remove_worker")] +async fn remove_worker( + query: web::Query>, + data: web::Data, +) -> impl Responder { + let worker_url = match query.get("url") { + Some(url) => url.to_string(), + None => return HttpResponse::BadRequest().finish(), + }; + data.router.remove_worker(worker_url); + HttpResponse::Ok().finish() +} + pub struct ServerConfig { pub host: String, pub port: u16, @@ -201,6 +214,7 @@ pub async fn startup(config: ServerConfig) -> std::io::Result<()> { .service(health_generate) .service(get_server_info) .service(add_worker) + .service(remove_worker) }) .bind((config.host, config.port))? .run()