Skip to content

Commit

Permalink
Merge branch 'main' into ata/ver
Browse files Browse the repository at this point in the history
  • Loading branch information
ByronHsu authored Dec 12, 2024
2 parents b985d3a + 2ac36b9 commit 85b9fd9
Show file tree
Hide file tree
Showing 7 changed files with 82 additions and 3 deletions.
1 change: 1 addition & 0 deletions .github/workflows/pr-test-rust.yml
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ jobs:
pip install --force-reinstall dist/*.whl
- name: Run e2e test
run: |
bash scripts/killall_sglang.sh
cd rust/py_test
python3 run_suite.py
Expand Down
11 changes: 10 additions & 1 deletion rust/py_src/sglang_router/launch_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class RouterArgs:
balance_rel_threshold: float = 1.0001
eviction_interval: int = 60
max_tree_size: int = 2**24
max_payload_size: int = 4 * 1024 * 1024 # 4MB
verbose: bool = False

@staticmethod
Expand Down Expand Up @@ -116,6 +117,12 @@ def add_cli_args(
default=RouterArgs.max_tree_size,
help="Maximum size of the approximation tree for cache-aware routing",
)
parser.add_argument(
f"--{prefix}max-payload-size",
type=int,
default=RouterArgs.max_payload_size,
help="Maximum payload size in bytes",
)
parser.add_argument(
f"--{prefix}verbose",
action="store_true",
Expand Down Expand Up @@ -144,6 +151,7 @@ def from_cli_args(
balance_rel_threshold=getattr(args, f"{prefix}balance_rel_threshold"),
eviction_interval=getattr(args, f"{prefix}eviction_interval"),
max_tree_size=getattr(args, f"{prefix}max_tree_size"),
max_payload_size=getattr(args, f"{prefix}max_payload_size"),
verbose=getattr(args, f"{prefix}verbose", False),
)

Expand Down Expand Up @@ -187,14 +195,15 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]:
balance_rel_threshold=router_args.balance_rel_threshold,
eviction_interval_secs=router_args.eviction_interval,
max_tree_size=router_args.max_tree_size,
max_payload_size=router_args.max_payload_size,
verbose=router_args.verbose,
)

router.start()
return router

except Exception as e:
logger.error(f"Error starting router: {e}", file=sys.stderr)
logger.error(f"Error starting router: {e}")
return None


Expand Down
3 changes: 3 additions & 0 deletions rust/py_src/sglang_router/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class Router:
AND max_load > min_load * rel_threshold. Otherwise, use cache aware. Default: 1.0001
eviction_interval_secs: Interval in seconds between cache eviction operations in cache-aware
routing. Default: 60
max_payload_size: Maximum payload size in bytes. Default: 4MB
max_tree_size: Maximum size of the approximation tree for cache-aware routing. Default: 2^24
verbose: Enable verbose logging. Default: False
"""
Expand All @@ -41,6 +42,7 @@ def __init__(
balance_rel_threshold: float = 1.0001,
eviction_interval_secs: int = 60,
max_tree_size: int = 2**24,
max_payload_size: int = 4 * 1024 * 1024, # 4MB
verbose: bool = False,
):
self._router = _Router(
Expand All @@ -53,6 +55,7 @@ def __init__(
balance_rel_threshold=balance_rel_threshold,
eviction_interval_secs=eviction_interval_secs,
max_tree_size=max_tree_size,
max_payload_size=max_payload_size,
verbose=verbose,
)

Expand Down
1 change: 1 addition & 0 deletions rust/py_test/test_launch_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def test_launch_router_no_exception(self):
balance_rel_threshold=1.0001,
eviction_interval=60,
max_tree_size=2**24,
max_payload_size=4 * 1024 * 1024, # 4MB
verbose=False,
)

Expand Down
57 changes: 55 additions & 2 deletions rust/py_test/test_launch_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def popen_launch_router(
dp_size: int,
timeout: float,
policy: str = "cache_aware",
max_payload_size: int = None,
):
"""
Launch the router server process.
Expand All @@ -31,6 +32,7 @@ def popen_launch_router(
dp_size: Data parallel size
timeout: Server launch timeout
policy: Router policy, one of "cache_aware", "round_robin", "random"
max_payload_size: Maximum payload size in bytes
"""
_, host, port = base_url.split(":")
host = host[2:]
Expand All @@ -46,13 +48,16 @@ def popen_launch_router(
"--port",
port,
"--dp",
str(dp_size), # Convert dp_size to string
str(dp_size),
"--router-eviction-interval",
"5", # frequent eviction for testing
"5",
"--router-policy",
policy,
]

if max_payload_size is not None:
command.extend(["--router-max-payload-size", str(max_payload_size)])

process = subprocess.Popen(command, stdout=None, stderr=None)

start_time = time.time()
Expand Down Expand Up @@ -280,6 +285,54 @@ def kill_worker():
msg = f"MMLU test {'passed' if passed else 'failed'} with score {score:.3f} (threshold: {THRESHOLD})"
self.assertGreaterEqual(score, THRESHOLD, msg)

def test_4_payload_size(self):
print("Running test_4_payload_size...")
# Start router with 3MB limit
self.process = popen_launch_router(
self.model,
self.base_url,
dp_size=1,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
policy="round_robin",
max_payload_size=1 * 1024 * 1024, # 1MB limit
)

# Test case 1: Payload just under 1MB should succeed
payload_0_5_mb = {
"text": "x" * int(0.5 * 1024 * 1024), # 0.5MB of text
"temperature": 0.0,
}

with requests.Session() as session:
response = session.post(
f"{self.base_url}/generate",
json=payload_0_5_mb,
headers={"Content-Type": "application/json"},
)
self.assertEqual(
response.status_code,
200,
f"0.5MB payload should succeed but got status {response.status_code}",
)

# Test case 2: Payload over 1MB should fail
payload_1_plus_mb = {
"text": "x" * int((1.2 * 1024 * 1024)), # 1.2MB of text
"temperature": 0.0,
}

with requests.Session() as session:
response = session.post(
f"{self.base_url}/generate",
json=payload_1_plus_mb,
headers={"Content-Type": "application/json"},
)
self.assertEqual(
response.status_code,
413, # Payload Too Large
f"1.2MB payload should fail with 413 but got status {response.status_code}",
)


if __name__ == "__main__":
unittest.main()
5 changes: 5 additions & 0 deletions rust/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ struct Router {
balance_rel_threshold: f32,
eviction_interval_secs: u64,
max_tree_size: usize,
max_payload_size: usize,
verbose: bool,
}

Expand All @@ -38,6 +39,7 @@ impl Router {
balance_rel_threshold = 1.0001,
eviction_interval_secs = 60,
max_tree_size = 2usize.pow(24),
max_payload_size = 4 * 1024 * 1024,
verbose = false
))]
fn new(
Expand All @@ -50,6 +52,7 @@ impl Router {
balance_rel_threshold: f32,
eviction_interval_secs: u64,
max_tree_size: usize,
max_payload_size: usize,
verbose: bool,
) -> PyResult<Self> {
Ok(Router {
Expand All @@ -62,6 +65,7 @@ impl Router {
balance_rel_threshold,
eviction_interval_secs,
max_tree_size,
max_payload_size,
verbose,
})
}
Expand All @@ -86,6 +90,7 @@ impl Router {
worker_urls: self.worker_urls.clone(),
policy_config,
verbose: self.verbose,
max_payload_size: self.max_payload_size,
})
.await
.unwrap();
Expand Down
7 changes: 7 additions & 0 deletions rust/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ pub struct ServerConfig {
pub worker_urls: Vec<String>,
pub policy_config: PolicyConfig,
pub verbose: bool,
pub max_payload_size: usize,
}

pub async fn startup(config: ServerConfig) -> std::io::Result<()> {
Expand Down Expand Up @@ -164,10 +165,16 @@ pub async fn startup(config: ServerConfig) -> std::io::Result<()> {
info!("✅ Starting router on {}:{}", config.host, config.port);
info!("✅ Serving Worker URLs: {:?}", config.worker_urls);
info!("✅ Policy Config: {:?}", config.policy_config);
info!(
"✅ Max payload size: {} MB",
config.max_payload_size / (1024 * 1024)
);

HttpServer::new(move || {
App::new()
.app_data(app_state.clone())
.app_data(web::JsonConfig::default().limit(config.max_payload_size))
.app_data(web::PayloadConfig::default().limit(config.max_payload_size))
.service(generate)
.service(v1_chat_completions)
.service(v1_completions)
Expand Down

0 comments on commit 85b9fd9

Please sign in to comment.