-
Notifications
You must be signed in to change notification settings - Fork 710
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: support custom task runner (#2407)
- Loading branch information
Showing
3 changed files
with
392 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
name: Experiment Runner | ||
|
||
on: | ||
workflow_dispatch: | ||
|
||
concurrency: | ||
group: experiment-runner-${{ github.ref }} | ||
cancel-in-progress: true | ||
|
||
jobs: | ||
experiment-runner-1-gpu: | ||
if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' | ||
runs-on: 1-gpu-runner | ||
steps: | ||
- name: Checkout code | ||
uses: actions/checkout@v3 | ||
|
||
- name: Install dependencies | ||
run: | | ||
bash scripts/ci_install_dependency.sh | ||
- name: Test experiment runner | ||
timeout-minutes: 10 | ||
run: | | ||
cd test/srt | ||
python3 experiment_runner.py --config configs/sharegpt_config.yaml |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
tasks: | ||
- name: sglang-benchmark | ||
server_cmd: python3 -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disable-radix-cache | ||
client_cmd: python3 -m sglang.bench_serving --backend sglang --dataset-name sharegpt --request-rate 16 | ||
- name: vllm-benchmark | ||
server_cmd: python3 -m vllm.entrypoints.openai.api_server --model meta-llama/Llama-3.1-8B-Instruct --disable-log-requests | ||
client_cmd: python3 -m sglang.bench_serving --backend vllm --dataset-name sharegpt --request-rate 16 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,359 @@ | ||
import argparse | ||
import logging | ||
import os | ||
import queue | ||
import re | ||
import subprocess | ||
import threading | ||
import time | ||
from dataclasses import dataclass | ||
from datetime import datetime | ||
from pathlib import Path | ||
from typing import Dict, List, Optional, Tuple | ||
|
||
import psutil | ||
import requests | ||
import yaml | ||
|
||
|
||
@dataclass | ||
class ServerConfig: | ||
command: str | ||
process_names: List[str] | ||
default_port: int | ||
|
||
|
||
@dataclass | ||
class TaskConfig: | ||
server_cmd: str | ||
client_cmd: str | ||
name: Optional[str] = None | ||
server_type: Optional[str] = None | ||
|
||
|
||
@dataclass | ||
class TaskResult: | ||
name: str | ||
success: bool | ||
output: str | ||
runtime: float | ||
timestamp: str | ||
|
||
|
||
SERVER_DEFAULTS = { | ||
"sglang": ServerConfig( | ||
command="sglang.launch_server", | ||
process_names=["sglang.launch_server"], | ||
default_port=30000, | ||
), | ||
"vllm": ServerConfig( | ||
command="vllm.entrypoints.openai.api_server", | ||
process_names=["vllm.entrypoints.openai.api_server"], | ||
default_port=8000, | ||
), | ||
} | ||
|
||
|
||
def parse_key_info(output: str) -> str: | ||
"""Extract and format key information from the output""" | ||
key_info = [] | ||
|
||
# Extract Args namespace | ||
args_match = re.search(r"Namespace\(.*?\)", output, re.DOTALL) | ||
if args_match: | ||
key_info.append(args_match.group(0)) | ||
|
||
# Extract input/output token counts | ||
token_matches = re.findall(r"#(Input|Output) tokens: \d+", output) | ||
key_info.extend(token_matches) | ||
|
||
# Extract benchmark result section | ||
result_match = re.search( | ||
r"============ Serving Benchmark Result ============.*?={50,}", | ||
output, | ||
re.DOTALL, | ||
) | ||
if result_match: | ||
key_info.append(result_match.group(0)) | ||
|
||
return "\n\n".join(key_info) | ||
|
||
|
||
def extract_port_from_command(cmd: str, server_type: str) -> int: | ||
port_match = re.search(r"--port[= ](\d+)", cmd) | ||
if port_match: | ||
return int(port_match.group(1)) | ||
return SERVER_DEFAULTS.get(server_type, ServerConfig("", [], 8000)).default_port | ||
|
||
|
||
def detect_server_type(cmd: str) -> str: | ||
for server_type, config in SERVER_DEFAULTS.items(): | ||
if config.command in cmd: | ||
return server_type | ||
return "unknown" | ||
|
||
|
||
def stream_output( | ||
process: subprocess.Popen, prefix: str, logger: logging.Logger | ||
) -> queue.Queue: | ||
output_queue = queue.Queue() | ||
|
||
def stream_pipe(pipe, prefix): | ||
for line in iter(pipe.readline, ""): | ||
if prefix == "CLIENT": | ||
output_queue.put(line.rstrip()) | ||
logger.debug(f"{prefix} | {line.rstrip()}") | ||
|
||
stdout_thread = threading.Thread( | ||
target=stream_pipe, args=(process.stdout, prefix), daemon=True | ||
) | ||
stderr_thread = threading.Thread( | ||
target=stream_pipe, args=(process.stderr, prefix), daemon=True | ||
) | ||
|
||
stdout_thread.start() | ||
stderr_thread.start() | ||
return output_queue, (stdout_thread, stderr_thread) | ||
|
||
|
||
class ProcessManager: | ||
def __init__(self): | ||
self.server_process: Optional[subprocess.Popen] = None | ||
self.client_process: Optional[subprocess.Popen] = None | ||
self.logger = logging.getLogger(__name__) | ||
|
||
def start_process( | ||
self, command: str, prefix: str | ||
) -> Tuple[subprocess.Popen, queue.Queue]: | ||
process = subprocess.Popen( | ||
command, | ||
shell=True, | ||
stdout=subprocess.PIPE, | ||
stderr=subprocess.PIPE, | ||
text=True, | ||
bufsize=1, | ||
) | ||
|
||
output_queue, threads = stream_output(process, prefix, self.logger) | ||
return process, output_queue, threads | ||
|
||
def kill_process_tree(self, process: subprocess.Popen): | ||
try: | ||
parent = psutil.Process(process.pid) | ||
children = parent.children(recursive=True) | ||
|
||
for child in children: | ||
try: | ||
child.kill() | ||
except psutil.NoSuchProcess: | ||
pass | ||
|
||
parent.kill() | ||
gone, alive = psutil.wait_procs(children + [parent], timeout=3) | ||
|
||
for p in alive: | ||
try: | ||
p.kill() | ||
except psutil.NoSuchProcess: | ||
pass | ||
|
||
except psutil.NoSuchProcess: | ||
pass | ||
|
||
def cleanup(self, process_names: List[str]): | ||
if self.client_process: | ||
self.kill_process_tree(self.client_process) | ||
self.client_process = None | ||
|
||
if self.server_process: | ||
self.kill_process_tree(self.server_process) | ||
self.server_process = None | ||
|
||
for proc in psutil.process_iter(["pid", "name", "cmdline"]): | ||
try: | ||
cmdline = " ".join(proc.cmdline()) | ||
if any(name in cmdline for name in process_names): | ||
proc.kill() | ||
except (psutil.NoSuchProcess, psutil.AccessDenied): | ||
continue | ||
|
||
|
||
class ExperimentRunner: | ||
def __init__(self): | ||
self.process_manager = ProcessManager() | ||
self.logger = logging.getLogger(__name__) | ||
|
||
def wait_for_server(self, port: int, timeout: int = 300) -> bool: | ||
start_time = time.time() | ||
|
||
while time.time() - start_time < timeout: | ||
try: | ||
response = requests.get(f"http://localhost:{port}/health") | ||
if response.status_code == 200: | ||
self.logger.debug(f"Server ready on port {port}") | ||
return True | ||
except requests.RequestException: | ||
time.sleep(2) | ||
return False | ||
|
||
def run_task(self, config: TaskConfig) -> TaskResult: | ||
start_time = time.time() | ||
client_output = [] | ||
|
||
try: | ||
if not config.server_type: | ||
config.server_type = detect_server_type(config.server_cmd) | ||
|
||
server_config = SERVER_DEFAULTS.get(config.server_type) | ||
if not server_config: | ||
raise ValueError(f"Unknown server type: {config.server_type}") | ||
|
||
port = extract_port_from_command(config.server_cmd, config.server_type) | ||
|
||
self.process_manager.cleanup(server_config.process_names) | ||
|
||
self.logger.debug(f"Starting server: {config.name}") | ||
self.process_manager.server_process, _, server_threads = ( | ||
self.process_manager.start_process(config.server_cmd, "SERVER") | ||
) | ||
|
||
if not self.wait_for_server(port): | ||
raise TimeoutError("Server startup timeout") | ||
|
||
time.sleep(10) | ||
|
||
self.logger.debug("Starting client") | ||
self.process_manager.client_process, output_queue, client_threads = ( | ||
self.process_manager.start_process(config.client_cmd, "CLIENT") | ||
) | ||
|
||
returncode = self.process_manager.client_process.wait() | ||
|
||
while True: | ||
try: | ||
line = output_queue.get_nowait() | ||
client_output.append(line) | ||
except queue.Empty: | ||
break | ||
|
||
if returncode != 0: | ||
raise RuntimeError(f"Client failed with code {returncode}") | ||
|
||
# Parse and format the output | ||
full_output = "\n".join(client_output) | ||
formatted_output = parse_key_info(full_output) | ||
|
||
return TaskResult( | ||
name=config.name, | ||
success=True, | ||
output=formatted_output, | ||
runtime=time.time() - start_time, | ||
timestamp=datetime.now().isoformat(), | ||
) | ||
|
||
except Exception as e: | ||
return TaskResult( | ||
name=config.name, | ||
success=False, | ||
output=str(e), | ||
runtime=time.time() - start_time, | ||
timestamp=datetime.now().isoformat(), | ||
) | ||
|
||
finally: | ||
if config.server_type in SERVER_DEFAULTS: | ||
self.process_manager.cleanup( | ||
SERVER_DEFAULTS[config.server_type].process_names | ||
) | ||
time.sleep(10) | ||
|
||
|
||
def load_config(config_path: str) -> List[TaskConfig]: | ||
with open(config_path, "r") as f: | ||
config_data = yaml.safe_load(f) | ||
|
||
configs = [] | ||
for idx, entry in enumerate(config_data.get("tasks", [])): | ||
if not isinstance(entry, dict): | ||
raise ValueError(f"Invalid entry at index {idx}") | ||
|
||
config = TaskConfig( | ||
server_cmd=entry.get("server_cmd"), | ||
client_cmd=entry.get("client_cmd"), | ||
name=entry.get("name", f"task-{idx+1}"), | ||
server_type=entry.get("server_type"), | ||
) | ||
|
||
if not config.server_cmd or not config.client_cmd: | ||
raise ValueError(f"Missing commands in {config.name}") | ||
|
||
configs.append(config) | ||
|
||
return configs | ||
|
||
|
||
def setup_logging(debug: bool = False): | ||
level = logging.DEBUG if debug else logging.INFO | ||
logging.basicConfig( | ||
level=level, | ||
format="%(asctime)s - %(levelname)s - %(message)s", | ||
handlers=[logging.StreamHandler(), logging.FileHandler("experiment.log")], | ||
) | ||
|
||
|
||
def format_results(results: List[TaskResult]) -> str: | ||
"""Format experiment results in Markdown for GitHub step summary.""" | ||
output = ["# Experiment Results\n"] | ||
|
||
for result in results: | ||
output.append(f"## {result.name}") | ||
output.append(f"**Status**: {'✅ Success' if result.success else '❌ Failed'}") | ||
output.append(f"**Runtime**: {result.runtime:.2f} seconds") | ||
output.append(f"**Timestamp**: {result.timestamp}") | ||
output.append("\n**Output**:\n```") | ||
output.append(result.output) | ||
output.append("```\n") | ||
|
||
return "\n".join(output) | ||
|
||
|
||
def write_in_github_step_summary(results: List[TaskResult]): | ||
"""Write formatted results to GitHub step summary.""" | ||
if not os.environ.get("GITHUB_STEP_SUMMARY"): | ||
logging.warning("GITHUB_STEP_SUMMARY environment variable not set") | ||
return | ||
|
||
formatted_content = format_results(results) | ||
with open(os.environ["GITHUB_STEP_SUMMARY"], "a") as f: | ||
f.write(formatted_content) | ||
|
||
|
||
def main(): | ||
parser = argparse.ArgumentParser(description="Experiment Runner") | ||
parser.add_argument( | ||
"--config", type=str, required=True, help="Path to YAML config file" | ||
) | ||
parser.add_argument("--debug", action="store_true", help="Enable debug output") | ||
args = parser.parse_args() | ||
|
||
setup_logging(args.debug) | ||
logger = logging.getLogger(__name__) | ||
results = [] | ||
|
||
try: | ||
configs = load_config(args.config) | ||
runner = ExperimentRunner() | ||
|
||
for config in configs: | ||
logger.info(f"Running {config.name}") | ||
result = runner.run_task(config) | ||
results.append(result) | ||
|
||
write_in_github_step_summary(results) | ||
except Exception as e: | ||
logger.error(f"Error: {e}") | ||
raise | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |