Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make request payload size configurable #2444

Merged
merged 6 commits into from
Dec 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/pr-test-rust.yml
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ jobs:
pip install --force-reinstall dist/*.whl
- name: Run e2e test
run: |
bash scripts/killall_sglang.sh
cd rust/py_test
python3 run_suite.py

Expand Down
11 changes: 10 additions & 1 deletion rust/py_src/sglang_router/launch_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class RouterArgs:
balance_rel_threshold: float = 1.0001
eviction_interval: int = 60
max_tree_size: int = 2**24
max_payload_size: int = 4 * 1024 * 1024 # 4MB
verbose: bool = False

@staticmethod
Expand Down Expand Up @@ -116,6 +117,12 @@ def add_cli_args(
default=RouterArgs.max_tree_size,
help="Maximum size of the approximation tree for cache-aware routing",
)
parser.add_argument(
f"--{prefix}max-payload-size",
type=int,
default=RouterArgs.max_payload_size,
help="Maximum payload size in bytes",
)
parser.add_argument(
f"--{prefix}verbose",
action="store_true",
Expand Down Expand Up @@ -144,6 +151,7 @@ def from_cli_args(
balance_rel_threshold=getattr(args, f"{prefix}balance_rel_threshold"),
eviction_interval=getattr(args, f"{prefix}eviction_interval"),
max_tree_size=getattr(args, f"{prefix}max_tree_size"),
max_payload_size=getattr(args, f"{prefix}max_payload_size"),
verbose=getattr(args, f"{prefix}verbose", False),
)

Expand Down Expand Up @@ -187,14 +195,15 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]:
balance_rel_threshold=router_args.balance_rel_threshold,
eviction_interval_secs=router_args.eviction_interval,
max_tree_size=router_args.max_tree_size,
max_payload_size=router_args.max_payload_size,
verbose=router_args.verbose,
)

router.start()
return router

except Exception as e:
logger.error(f"Error starting router: {e}", file=sys.stderr)
logger.error(f"Error starting router: {e}")
return None


Expand Down
3 changes: 3 additions & 0 deletions rust/py_src/sglang_router/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class Router:
AND max_load > min_load * rel_threshold. Otherwise, use cache aware. Default: 1.0001
eviction_interval_secs: Interval in seconds between cache eviction operations in cache-aware
routing. Default: 60
max_payload_size: Maximum payload size in bytes. Default: 4MB
max_tree_size: Maximum size of the approximation tree for cache-aware routing. Default: 2^24
verbose: Enable verbose logging. Default: False
"""
Expand All @@ -41,6 +42,7 @@ def __init__(
balance_rel_threshold: float = 1.0001,
eviction_interval_secs: int = 60,
max_tree_size: int = 2**24,
max_payload_size: int = 4 * 1024 * 1024, # 4MB
verbose: bool = False,
):
self._router = _Router(
Expand All @@ -53,6 +55,7 @@ def __init__(
balance_rel_threshold=balance_rel_threshold,
eviction_interval_secs=eviction_interval_secs,
max_tree_size=max_tree_size,
max_payload_size=max_payload_size,
verbose=verbose,
)

Expand Down
1 change: 1 addition & 0 deletions rust/py_test/test_launch_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def test_launch_router_no_exception(self):
balance_rel_threshold=1.0001,
eviction_interval=60,
max_tree_size=2**24,
max_payload_size=4 * 1024 * 1024, # 4MB
verbose=False,
)

Expand Down
57 changes: 55 additions & 2 deletions rust/py_test/test_launch_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def popen_launch_router(
dp_size: int,
timeout: float,
policy: str = "cache_aware",
max_payload_size: int = None,
):
"""
Launch the router server process.
Expand All @@ -31,6 +32,7 @@ def popen_launch_router(
dp_size: Data parallel size
timeout: Server launch timeout
policy: Router policy, one of "cache_aware", "round_robin", "random"
max_payload_size: Maximum payload size in bytes
"""
_, host, port = base_url.split(":")
host = host[2:]
Expand All @@ -46,13 +48,16 @@ def popen_launch_router(
"--port",
port,
"--dp",
str(dp_size), # Convert dp_size to string
str(dp_size),
"--router-eviction-interval",
"5", # frequent eviction for testing
"5",
"--router-policy",
policy,
]

if max_payload_size is not None:
command.extend(["--router-max-payload-size", str(max_payload_size)])

process = subprocess.Popen(command, stdout=None, stderr=None)

start_time = time.time()
Expand Down Expand Up @@ -280,6 +285,54 @@ def kill_worker():
msg = f"MMLU test {'passed' if passed else 'failed'} with score {score:.3f} (threshold: {THRESHOLD})"
self.assertGreaterEqual(score, THRESHOLD, msg)

def test_4_payload_size(self):
print("Running test_4_payload_size...")
# Start router with 3MB limit
self.process = popen_launch_router(
self.model,
self.base_url,
dp_size=1,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
policy="round_robin",
max_payload_size=1 * 1024 * 1024, # 1MB limit
)

# Test case 1: Payload just under 1MB should succeed
payload_0_5_mb = {
"text": "x" * int(0.5 * 1024 * 1024), # 0.5MB of text
"temperature": 0.0,
}

with requests.Session() as session:
response = session.post(
f"{self.base_url}/generate",
json=payload_0_5_mb,
headers={"Content-Type": "application/json"},
)
self.assertEqual(
response.status_code,
200,
f"0.5MB payload should succeed but got status {response.status_code}",
)

# Test case 2: Payload over 1MB should fail
payload_1_plus_mb = {
"text": "x" * int((1.2 * 1024 * 1024)), # 1.2MB of text
"temperature": 0.0,
}

with requests.Session() as session:
response = session.post(
f"{self.base_url}/generate",
json=payload_1_plus_mb,
headers={"Content-Type": "application/json"},
)
self.assertEqual(
response.status_code,
413, # Payload Too Large
f"1.2MB payload should fail with 413 but got status {response.status_code}",
)


if __name__ == "__main__":
unittest.main()
5 changes: 5 additions & 0 deletions rust/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ struct Router {
balance_rel_threshold: f32,
eviction_interval_secs: u64,
max_tree_size: usize,
max_payload_size: usize,
verbose: bool,
}

Expand All @@ -38,6 +39,7 @@ impl Router {
balance_rel_threshold = 1.0001,
eviction_interval_secs = 60,
max_tree_size = 2usize.pow(24),
max_payload_size = 4 * 1024 * 1024,
verbose = false
))]
fn new(
Expand All @@ -50,6 +52,7 @@ impl Router {
balance_rel_threshold: f32,
eviction_interval_secs: u64,
max_tree_size: usize,
max_payload_size: usize,
verbose: bool,
) -> PyResult<Self> {
Ok(Router {
Expand All @@ -62,6 +65,7 @@ impl Router {
balance_rel_threshold,
eviction_interval_secs,
max_tree_size,
max_payload_size,
verbose,
})
}
Expand All @@ -86,6 +90,7 @@ impl Router {
worker_urls: self.worker_urls.clone(),
policy_config,
verbose: self.verbose,
max_payload_size: self.max_payload_size,
})
.await
.unwrap();
Expand Down
7 changes: 7 additions & 0 deletions rust/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ pub struct ServerConfig {
pub worker_urls: Vec<String>,
pub policy_config: PolicyConfig,
pub verbose: bool,
pub max_payload_size: usize,
}

pub async fn startup(config: ServerConfig) -> std::io::Result<()> {
Expand Down Expand Up @@ -164,10 +165,16 @@ pub async fn startup(config: ServerConfig) -> std::io::Result<()> {
info!("✅ Starting router on {}:{}", config.host, config.port);
info!("✅ Serving Worker URLs: {:?}", config.worker_urls);
info!("✅ Policy Config: {:?}", config.policy_config);
info!(
"✅ Max payload size: {} MB",
config.max_payload_size / (1024 * 1024)
);

HttpServer::new(move || {
App::new()
.app_data(app_state.clone())
.app_data(web::JsonConfig::default().limit(config.max_payload_size))
.app_data(web::PayloadConfig::default().limit(config.max_payload_size))
.service(generate)
.service(v1_chat_completions)
.service(v1_completions)
Expand Down
Loading