diff --git a/rust/py_test/test_launch_server.py b/rust/py_test/test_launch_server.py index a7a695aa9f6..13fcc8f4e4c 100644 --- a/rust/py_test/test_launch_server.py +++ b/rust/py_test/test_launch_server.py @@ -1,3 +1,4 @@ +import socket import subprocess import time import unittest @@ -49,7 +50,49 @@ def popen_launch_router( # Use current environment env = None - process = subprocess.Popen(command, stdout=None, stderr=None, env=env) + process = subprocess.Popen(command, stdout=None, stderr=None) + + start_time = time.time() + with requests.Session() as session: + while time.time() - start_time < timeout: + try: + response = session.get(f"{base_url}/health") + if response.status_code == 200: + return process + except requests.RequestException: + pass + time.sleep(10) + + raise TimeoutError("Router failed to start within the timeout period.") + + +def find_available_port(): + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + return s.getsockname()[1] + + +def popen_launch_server( + model: str, + base_url: str, + timeout: float, +): + _, host, port = base_url.split(":") + host = host[2:] + + command = [ + "python3", + "-m", + "sglang.launch_server", + "--model-path", + model, + "--host", + host, + "--port", + port, + ] + + process = subprocess.Popen(command, stdout=None, stderr=None) start_time = time.time() with requests.Session() as session: @@ -76,12 +119,46 @@ def setUpClass(cls): dp_size=1, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, ) + cls.other_process = [] @classmethod def tearDownClass(cls): kill_process_tree(cls.process.pid) - - def test_mmlu(self): + for process in cls.other_process: + kill_process_tree(process.pid) + + # def test_mmlu(self): + # args = SimpleNamespace( + # base_url=self.base_url, + # model=self.model, + # eval_name="mmlu", + # num_examples=64, + # num_threads=32, + # temperature=0.1, + # ) + + # metrics = run_eval(args) + # score = metrics["score"] + # THRESHOLD = 0.65 + # passed = score >= THRESHOLD + # msg = f"MMLU test {'passed' if passed else 'failed'} with score {score:.3f} (threshold: {THRESHOLD})" + # self.assertGreaterEqual(score, THRESHOLD, msg) + + def test_add_worker(self): + # 1. start a worker, and wait until it is healthy + port = find_available_port() + worker_url = f"http://127.0.0.1:{port}" + worker_process = popen_launch_server( + self.model, worker_url, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH + ) + self.other_process.append(worker_process) + # 2. use /add_worker api to add it the the router + with requests.Session() as session: + response = session.post( + f"{self.base_url}/add_worker", json={"url": worker_url} + ) + self.assertEqual(response.status_code, 200) + # 3. run mmlu args = SimpleNamespace( base_url=self.base_url, model=self.model, @@ -90,7 +167,6 @@ def test_mmlu(self): num_threads=32, temperature=0.1, ) - metrics = run_eval(args) score = metrics["score"] THRESHOLD = 0.65