Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[router] support /add_worker api #2369

Merged
merged 7 commits into from
Dec 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 80 additions & 1 deletion rust/py_test/test_launch_server.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import socket
import subprocess
import time
import unittest
Expand Down Expand Up @@ -49,14 +50,60 @@ def popen_launch_router(
# Use current environment
env = None

process = subprocess.Popen(command, stdout=None, stderr=None, env=env)
process = subprocess.Popen(command, stdout=None, stderr=None)

start_time = time.time()
with requests.Session() as session:
while time.time() - start_time < timeout:
try:
response = session.get(f"{base_url}/health")
if response.status_code == 200:
print(f"Router {base_url} is healthy")
return process
except requests.RequestException:
pass
time.sleep(10)

raise TimeoutError("Router failed to start within the timeout period.")


def find_available_port():
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("127.0.0.1", 0))
return s.getsockname()[1]


def popen_launch_server(
model: str,
base_url: str,
timeout: float,
):
_, host, port = base_url.split(":")
host = host[2:]

command = [
"python3",
"-m",
"sglang.launch_server",
"--model-path",
model,
"--host",
host,
"--port",
port,
"--base-gpu-id",
"1",
]

process = subprocess.Popen(command, stdout=None, stderr=None)

start_time = time.time()
with requests.Session() as session:
while time.time() - start_time < timeout:
try:
response = session.get(f"{base_url}/health")
if response.status_code == 200:
print(f"Server {base_url} is healthy")
return process
except requests.RequestException:
pass
Expand All @@ -76,10 +123,13 @@ def setUpClass(cls):
dp_size=1,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
)
cls.other_process = []

@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
for process in cls.other_process:
kill_process_tree(process.pid)

def test_mmlu(self):
args = SimpleNamespace(
Expand All @@ -98,6 +148,35 @@ def test_mmlu(self):
msg = f"MMLU test {'passed' if passed else 'failed'} with score {score:.3f} (threshold: {THRESHOLD})"
self.assertGreaterEqual(score, THRESHOLD, msg)

def test_add_worker(self):
# 1. start a worker, and wait until it is healthy
port = find_available_port()
worker_url = f"http://127.0.0.1:{port}"
worker_process = popen_launch_server(
self.model, worker_url, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
)
self.other_process.append(worker_process)
# 2. use /add_worker api to add it the the router
with requests.Session() as session:
response = session.post(f"{self.base_url}/add_worker?url={worker_url}")
print(f"status code: {response.status_code}, response: {response.text}")
self.assertEqual(response.status_code, 200)
# 3. run mmlu
args = SimpleNamespace(
base_url=self.base_url,
model=self.model,
eval_name="mmlu",
num_examples=64,
num_threads=32,
temperature=0.1,
)
metrics = run_eval(args)
score = metrics["score"]
THRESHOLD = 0.65
passed = score >= THRESHOLD
msg = f"MMLU test {'passed' if passed else 'failed'} with score {score:.3f} (threshold: {THRESHOLD})"
self.assertGreaterEqual(score, THRESHOLD, msg)


if __name__ == "__main__":
unittest.main()
49 changes: 33 additions & 16 deletions rust/src/router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,18 @@ use log::{debug, info};
use std::collections::HashMap;
use std::fmt::Debug;
use std::sync::atomic::AtomicUsize;
use std::sync::{Arc, Mutex};
use std::sync::{Arc, Mutex, RwLock};
use std::thread;
use std::time::Duration;

#[derive(Debug)]
pub enum Router {
RoundRobin {
worker_urls: Vec<String>,
worker_urls: Arc<RwLock<Vec<String>>>,
current_index: AtomicUsize,
},
Random {
worker_urls: Vec<String>,
worker_urls: Arc<RwLock<Vec<String>>>,
},
CacheAware {
/*
Expand Down Expand Up @@ -81,7 +81,7 @@ pub enum Router {
Maximum nodes per tree. When exceeded, LRU leaf nodes are evicted
during the next eviction cycle.
*/
worker_urls: Vec<String>,
worker_urls: Arc<RwLock<Vec<String>>>,
tree: Arc<Mutex<Tree>>,
running_queue: Arc<Mutex<HashMap<String, usize>>>,
processed_queue: Arc<Mutex<HashMap<String, usize>>>,
Expand Down Expand Up @@ -129,9 +129,11 @@ fn get_text_from_request(body: &Bytes, route: &str) -> String {
impl Router {
pub fn new(worker_urls: Vec<String>, policy_config: PolicyConfig) -> Self {
match policy_config {
PolicyConfig::RandomConfig => Router::Random { worker_urls },
PolicyConfig::RandomConfig => Router::Random {
worker_urls: Arc::new(RwLock::new(worker_urls)),
},
PolicyConfig::RoundRobinConfig => Router::RoundRobin {
worker_urls,
worker_urls: Arc::new(RwLock::new(worker_urls)),
current_index: std::sync::atomic::AtomicUsize::new(0),
},
PolicyConfig::CacheAwareConfig {
Expand Down Expand Up @@ -183,7 +185,7 @@ impl Router {
}

Router::CacheAware {
worker_urls,
worker_urls: Arc::new(RwLock::new(worker_urls)),
tree,
running_queue,
processed_queue,
Expand All @@ -201,10 +203,10 @@ impl Router {
Router::RoundRobin { worker_urls, .. }
| Router::Random { worker_urls }
| Router::CacheAware { worker_urls, .. } => {
if worker_urls.is_empty() {
if worker_urls.read().unwrap().is_empty() {
None
} else {
Some(worker_urls[0].clone())
Some(worker_urls.read().unwrap()[0].clone())
}
}
}
Expand All @@ -228,15 +230,15 @@ impl Router {
.fetch_update(
std::sync::atomic::Ordering::SeqCst,
std::sync::atomic::Ordering::SeqCst,
|x| Some((x + 1) % worker_urls.len()),
|x| Some((x + 1) % worker_urls.read().unwrap().len()),
)
.unwrap();
worker_urls[idx].clone()
worker_urls.read().unwrap()[idx].clone()
}

Router::Random { worker_urls } => {
worker_urls[rand::random::<usize>() % worker_urls.len()].clone()
}
Router::Random { worker_urls } => worker_urls.read().unwrap()
[rand::random::<usize>() % worker_urls.read().unwrap().len()]
.clone(),

Router::CacheAware {
worker_urls,
Expand Down Expand Up @@ -277,7 +279,7 @@ impl Router {
.iter()
.min_by_key(|(_url, &count)| count)
.map(|(url, _)| url.clone())
.unwrap_or_else(|| worker_urls[0].clone())
.unwrap_or_else(|| worker_urls.read().unwrap()[0].clone())
} else {
// Use cache-aware routing when load is balanced
let (matched_text, matched_worker) = tree.prefix_match(&text);
Expand Down Expand Up @@ -333,7 +335,10 @@ impl Router {
// For non-streaming requests, get response first
let response = match res.bytes().await {
Ok(body) => HttpResponse::build(status).body(body.to_vec()),
Err(_) => HttpResponse::InternalServerError().finish(),
Err(e) => {
let error_msg = format!("Failed to get response body: {}", e);
HttpResponse::InternalServerError().body(error_msg)
}
};

// Then decrement running queue counter if using CacheAware
Expand Down Expand Up @@ -379,4 +384,16 @@ impl Router {
}))
}
}

pub fn add_worker(&self, worker_url: String) {
match self {
Router::RoundRobin { worker_urls, .. }
| Router::Random { worker_urls }
| Router::CacheAware { worker_urls, .. } => {
let mut urls = worker_urls.write().unwrap();
info!("Added worker: {}", worker_url);
urls.push(worker_url);
}
}
}
}
22 changes: 21 additions & 1 deletion rust/src/server.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
use crate::router::PolicyConfig;
use crate::router::Router;
use actix_web::{get, post, web, App, HttpRequest, HttpResponse, HttpServer, Responder};
use actix_web::{
delete, get, post, put, web, App, HttpRequest, HttpResponse, HttpServer, Responder,
};
use bytes::Bytes;
use env_logger::Builder;
use log::{info, LevelFilter};
use std::collections::HashMap;
use std::io::Write;

#[derive(Debug)]
Expand Down Expand Up @@ -128,6 +131,22 @@ async fn v1_completions(
.await
}

#[post("/add_worker")]
async fn add_worker(
query: web::Query<HashMap<String, String>>,
data: web::Data<AppState>,
) -> impl Responder {
let worker_url = match query.get("url") {
Some(url) => url.to_string(),
None => {
return HttpResponse::BadRequest()
.body("Worker URL required. Provide 'url' query parameter")
}
};
data.router.add_worker(worker_url);
HttpResponse::Ok().finish()
}

pub struct ServerConfig {
pub host: String,
pub port: u16,
Expand Down Expand Up @@ -183,6 +202,7 @@ pub async fn startup(config: ServerConfig) -> std::io::Result<()> {
.service(health)
.service(health_generate)
.service(get_server_info)
.service(add_worker)
})
.bind((config.host, config.port))?
.run()
Expand Down
Loading