Skip to content

Commit

Permalink
[router] support /add_worker api (#2369)
Browse files Browse the repository at this point in the history
  • Loading branch information
ByronHsu authored Dec 6, 2024
1 parent 37ee906 commit 67b6579
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 18 deletions.
81 changes: 80 additions & 1 deletion rust/py_test/test_launch_server.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import socket
import subprocess
import time
import unittest
Expand Down Expand Up @@ -49,14 +50,60 @@ def popen_launch_router(
# Use current environment
env = None

process = subprocess.Popen(command, stdout=None, stderr=None, env=env)
process = subprocess.Popen(command, stdout=None, stderr=None)

start_time = time.time()
with requests.Session() as session:
while time.time() - start_time < timeout:
try:
response = session.get(f"{base_url}/health")
if response.status_code == 200:
print(f"Router {base_url} is healthy")
return process
except requests.RequestException:
pass
time.sleep(10)

raise TimeoutError("Router failed to start within the timeout period.")


def find_available_port():
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("127.0.0.1", 0))
return s.getsockname()[1]


def popen_launch_server(
model: str,
base_url: str,
timeout: float,
):
_, host, port = base_url.split(":")
host = host[2:]

command = [
"python3",
"-m",
"sglang.launch_server",
"--model-path",
model,
"--host",
host,
"--port",
port,
"--base-gpu-id",
"1",
]

process = subprocess.Popen(command, stdout=None, stderr=None)

start_time = time.time()
with requests.Session() as session:
while time.time() - start_time < timeout:
try:
response = session.get(f"{base_url}/health")
if response.status_code == 200:
print(f"Server {base_url} is healthy")
return process
except requests.RequestException:
pass
Expand All @@ -76,10 +123,13 @@ def setUpClass(cls):
dp_size=1,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
)
cls.other_process = []

@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
for process in cls.other_process:
kill_process_tree(process.pid)

def test_mmlu(self):
args = SimpleNamespace(
Expand All @@ -98,6 +148,35 @@ def test_mmlu(self):
msg = f"MMLU test {'passed' if passed else 'failed'} with score {score:.3f} (threshold: {THRESHOLD})"
self.assertGreaterEqual(score, THRESHOLD, msg)

def test_add_worker(self):
# 1. start a worker, and wait until it is healthy
port = find_available_port()
worker_url = f"http://127.0.0.1:{port}"
worker_process = popen_launch_server(
self.model, worker_url, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
)
self.other_process.append(worker_process)
# 2. use /add_worker api to add it the the router
with requests.Session() as session:
response = session.post(f"{self.base_url}/add_worker?url={worker_url}")
print(f"status code: {response.status_code}, response: {response.text}")
self.assertEqual(response.status_code, 200)
# 3. run mmlu
args = SimpleNamespace(
base_url=self.base_url,
model=self.model,
eval_name="mmlu",
num_examples=64,
num_threads=32,
temperature=0.1,
)
metrics = run_eval(args)
score = metrics["score"]
THRESHOLD = 0.65
passed = score >= THRESHOLD
msg = f"MMLU test {'passed' if passed else 'failed'} with score {score:.3f} (threshold: {THRESHOLD})"
self.assertGreaterEqual(score, THRESHOLD, msg)


if __name__ == "__main__":
unittest.main()
49 changes: 33 additions & 16 deletions rust/src/router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,18 @@ use log::{debug, info};
use std::collections::HashMap;
use std::fmt::Debug;
use std::sync::atomic::AtomicUsize;
use std::sync::{Arc, Mutex};
use std::sync::{Arc, Mutex, RwLock};
use std::thread;
use std::time::Duration;

#[derive(Debug)]
pub enum Router {
RoundRobin {
worker_urls: Vec<String>,
worker_urls: Arc<RwLock<Vec<String>>>,
current_index: AtomicUsize,
},
Random {
worker_urls: Vec<String>,
worker_urls: Arc<RwLock<Vec<String>>>,
},
CacheAware {
/*
Expand Down Expand Up @@ -81,7 +81,7 @@ pub enum Router {
Maximum nodes per tree. When exceeded, LRU leaf nodes are evicted
during the next eviction cycle.
*/
worker_urls: Vec<String>,
worker_urls: Arc<RwLock<Vec<String>>>,
tree: Arc<Mutex<Tree>>,
running_queue: Arc<Mutex<HashMap<String, usize>>>,
processed_queue: Arc<Mutex<HashMap<String, usize>>>,
Expand Down Expand Up @@ -129,9 +129,11 @@ fn get_text_from_request(body: &Bytes, route: &str) -> String {
impl Router {
pub fn new(worker_urls: Vec<String>, policy_config: PolicyConfig) -> Self {
match policy_config {
PolicyConfig::RandomConfig => Router::Random { worker_urls },
PolicyConfig::RandomConfig => Router::Random {
worker_urls: Arc::new(RwLock::new(worker_urls)),
},
PolicyConfig::RoundRobinConfig => Router::RoundRobin {
worker_urls,
worker_urls: Arc::new(RwLock::new(worker_urls)),
current_index: std::sync::atomic::AtomicUsize::new(0),
},
PolicyConfig::CacheAwareConfig {
Expand Down Expand Up @@ -183,7 +185,7 @@ impl Router {
}

Router::CacheAware {
worker_urls,
worker_urls: Arc::new(RwLock::new(worker_urls)),
tree,
running_queue,
processed_queue,
Expand All @@ -201,10 +203,10 @@ impl Router {
Router::RoundRobin { worker_urls, .. }
| Router::Random { worker_urls }
| Router::CacheAware { worker_urls, .. } => {
if worker_urls.is_empty() {
if worker_urls.read().unwrap().is_empty() {
None
} else {
Some(worker_urls[0].clone())
Some(worker_urls.read().unwrap()[0].clone())
}
}
}
Expand All @@ -228,15 +230,15 @@ impl Router {
.fetch_update(
std::sync::atomic::Ordering::SeqCst,
std::sync::atomic::Ordering::SeqCst,
|x| Some((x + 1) % worker_urls.len()),
|x| Some((x + 1) % worker_urls.read().unwrap().len()),
)
.unwrap();
worker_urls[idx].clone()
worker_urls.read().unwrap()[idx].clone()
}

Router::Random { worker_urls } => {
worker_urls[rand::random::<usize>() % worker_urls.len()].clone()
}
Router::Random { worker_urls } => worker_urls.read().unwrap()
[rand::random::<usize>() % worker_urls.read().unwrap().len()]
.clone(),

Router::CacheAware {
worker_urls,
Expand Down Expand Up @@ -277,7 +279,7 @@ impl Router {
.iter()
.min_by_key(|(_url, &count)| count)
.map(|(url, _)| url.clone())
.unwrap_or_else(|| worker_urls[0].clone())
.unwrap_or_else(|| worker_urls.read().unwrap()[0].clone())
} else {
// Use cache-aware routing when load is balanced
let (matched_text, matched_worker) = tree.prefix_match(&text);
Expand Down Expand Up @@ -333,7 +335,10 @@ impl Router {
// For non-streaming requests, get response first
let response = match res.bytes().await {
Ok(body) => HttpResponse::build(status).body(body.to_vec()),
Err(_) => HttpResponse::InternalServerError().finish(),
Err(e) => {
let error_msg = format!("Failed to get response body: {}", e);
HttpResponse::InternalServerError().body(error_msg)
}
};

// Then decrement running queue counter if using CacheAware
Expand Down Expand Up @@ -379,4 +384,16 @@ impl Router {
}))
}
}

pub fn add_worker(&self, worker_url: String) {
match self {
Router::RoundRobin { worker_urls, .. }
| Router::Random { worker_urls }
| Router::CacheAware { worker_urls, .. } => {
let mut urls = worker_urls.write().unwrap();
info!("Added worker: {}", worker_url);
urls.push(worker_url);
}
}
}
}
22 changes: 21 additions & 1 deletion rust/src/server.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
use crate::router::PolicyConfig;
use crate::router::Router;
use actix_web::{get, post, web, App, HttpRequest, HttpResponse, HttpServer, Responder};
use actix_web::{
delete, get, post, put, web, App, HttpRequest, HttpResponse, HttpServer, Responder,
};
use bytes::Bytes;
use env_logger::Builder;
use log::{info, LevelFilter};
use std::collections::HashMap;
use std::io::Write;

#[derive(Debug)]
Expand Down Expand Up @@ -128,6 +131,22 @@ async fn v1_completions(
.await
}

#[post("/add_worker")]
async fn add_worker(
query: web::Query<HashMap<String, String>>,
data: web::Data<AppState>,
) -> impl Responder {
let worker_url = match query.get("url") {
Some(url) => url.to_string(),
None => {
return HttpResponse::BadRequest()
.body("Worker URL required. Provide 'url' query parameter")
}
};
data.router.add_worker(worker_url);
HttpResponse::Ok().finish()
}

pub struct ServerConfig {
pub host: String,
pub port: u16,
Expand Down Expand Up @@ -183,6 +202,7 @@ pub async fn startup(config: ServerConfig) -> std::io::Result<()> {
.service(health)
.service(health_generate)
.service(get_server_info)
.service(add_worker)
})
.bind((config.host, config.port))?
.run()
Expand Down

0 comments on commit 67b6579

Please sign in to comment.