From e3c8171f0892d98efd07f3378e42ea2f12d64688 Mon Sep 17 00:00:00 2001 From: Ata Fatahi Date: Wed, 11 Dec 2024 14:57:20 -0800 Subject: [PATCH] add tests Signed-off-by: Ata Fatahi --- rust/py_test/test_launch_server.py | 82 +++++++++++++++++++++++++++++- rust/src/lib.rs | 5 ++ rust/src/server.rs | 7 +++ 3 files changed, 92 insertions(+), 2 deletions(-) diff --git a/rust/py_test/test_launch_server.py b/rust/py_test/test_launch_server.py index daa0b821ed6..8eeaf651396 100644 --- a/rust/py_test/test_launch_server.py +++ b/rust/py_test/test_launch_server.py @@ -21,6 +21,7 @@ def popen_launch_router( dp_size: int, timeout: float, policy: str = "cache_aware", + max_payload_size: int = None, ): """ Launch the router server process. @@ -31,6 +32,7 @@ def popen_launch_router( dp_size: Data parallel size timeout: Server launch timeout policy: Router policy, one of "cache_aware", "round_robin", "random" + max_payload_size: Maximum payload size in bytes """ _, host, port = base_url.split(":") host = host[2:] @@ -46,13 +48,16 @@ def popen_launch_router( "--port", port, "--dp", - str(dp_size), # Convert dp_size to string + str(dp_size), "--router-eviction-interval", - "5", # frequent eviction for testing + "5", "--router-policy", policy, ] + if max_payload_size is not None: + command.extend(["--router-max-payload-size", str(max_payload_size)]) + process = subprocess.Popen(command, stdout=None, stderr=None) start_time = time.time() @@ -280,6 +285,79 @@ def kill_worker(): msg = f"MMLU test {'passed' if passed else 'failed'} with score {score:.3f} (threshold: {THRESHOLD})" self.assertGreaterEqual(score, THRESHOLD, msg) + def test_4_payload_size(self): + print("Running test_4_payload_size...") + # Start router with default 4MB limit + self.process = popen_launch_router( + self.model, + self.base_url, + dp_size=1, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + policy="round_robin", + ) + + # Test case 1: Payload just under 4MB should succeed + payload_3mb = { + "text": "x" * (3 * 1024 * 1024), # 3MB of text + "temperature": 0.0, + } + + with requests.Session() as session: + response = session.post( + f"{self.base_url}/generate", + json=payload_3mb, + headers={"Content-Type": "application/json"}, + ) + self.assertEqual( + response.status_code, + 200, + f"3MB payload should succeed but got status {response.status_code}", + ) + + # Test case 2: Payload over 4MB should fail + payload_5mb = { + "text": "x" * (5 * 1024 * 1024), # 5MB of text + "temperature": 0.0, + } + + with requests.Session() as session: + response = session.post( + f"{self.base_url}/generate", + json=payload_5mb, + headers={"Content-Type": "application/json"}, + ) + self.assertEqual( + response.status_code, + 413, # Payload Too Large + f"5MB payload should fail with 413 but got status {response.status_code}", + ) + + # Test case 3: Start router with custom 8MB limit + if self.process: + terminate_and_wait(self.process) + + self.process = popen_launch_router( + self.model, + self.base_url, + dp_size=1, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + policy="round_robin", + max_payload_size=8 * 1024 * 1024, # 8MB limit + ) + + # Now 5MB payload should succeed with 8MB limit + with requests.Session() as session: + response = session.post( + f"{self.base_url}/generate", + json=payload_5mb, + headers={"Content-Type": "application/json"}, + ) + self.assertEqual( + response.status_code, + 200, + f"5MB payload should succeed with 8MB limit but got status {response.status_code}", + ) + if __name__ == "__main__": unittest.main() diff --git a/rust/src/lib.rs b/rust/src/lib.rs index 63d5bfe324a..2d8cf4c0c8d 100644 --- a/rust/src/lib.rs +++ b/rust/src/lib.rs @@ -22,6 +22,7 @@ struct Router { balance_rel_threshold: f32, eviction_interval_secs: u64, max_tree_size: usize, + max_payload_size: usize, verbose: bool, } @@ -38,6 +39,7 @@ impl Router { balance_rel_threshold = 1.0001, eviction_interval_secs = 60, max_tree_size = 2usize.pow(24), + max_payload_size = 4 * 1024 * 1024, verbose = false ))] fn new( @@ -50,6 +52,7 @@ impl Router { balance_rel_threshold: f32, eviction_interval_secs: u64, max_tree_size: usize, + max_payload_size: usize, verbose: bool, ) -> PyResult { Ok(Router { @@ -62,6 +65,7 @@ impl Router { balance_rel_threshold, eviction_interval_secs, max_tree_size, + max_payload_size, verbose, }) } @@ -86,6 +90,7 @@ impl Router { worker_urls: self.worker_urls.clone(), policy_config, verbose: self.verbose, + max_payload_size: self.max_payload_size, }) .await .unwrap(); diff --git a/rust/src/server.rs b/rust/src/server.rs index 7d0d23ccde1..09878f07f8e 100644 --- a/rust/src/server.rs +++ b/rust/src/server.rs @@ -127,6 +127,7 @@ pub struct ServerConfig { pub worker_urls: Vec, pub policy_config: PolicyConfig, pub verbose: bool, + pub max_payload_size: usize, } pub async fn startup(config: ServerConfig) -> std::io::Result<()> { @@ -164,10 +165,16 @@ pub async fn startup(config: ServerConfig) -> std::io::Result<()> { info!("✅ Starting router on {}:{}", config.host, config.port); info!("✅ Serving Worker URLs: {:?}", config.worker_urls); info!("✅ Policy Config: {:?}", config.policy_config); + info!( + "✅ Max payload size: {} MB", + config.max_payload_size / (1024 * 1024) + ); HttpServer::new(move || { App::new() .app_data(app_state.clone()) + .app_data(web::JsonConfig::default().limit(config.max_payload_size)) + .app_data(web::PayloadConfig::default().limit(config.max_payload_size)) .service(generate) .service(v1_chat_completions) .service(v1_completions)