diff --git a/.github/workflows/experiment-runner.yml b/.github/workflows/experiment-runner.yml new file mode 100644 index 00000000000..9cac407df91 --- /dev/null +++ b/.github/workflows/experiment-runner.yml @@ -0,0 +1,26 @@ +name: Experiment Runner + +on: + workflow_dispatch: + +concurrency: + group: experiment-runner-${{ github.ref }} + cancel-in-progress: true + +jobs: + experiment-runner-1-gpu: + if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' + runs-on: 1-gpu-runner + steps: + - name: Checkout code + uses: actions/checkout@v3 + + - name: Install dependencies + run: | + bash scripts/ci_install_dependency.sh + + - name: Test experiment runner + timeout-minutes: 10 + run: | + cd test/srt + python3 experiment_runner.py --config configs/sharegpt_config.yaml diff --git a/test/srt/configs/sharegpt_config.yaml b/test/srt/configs/sharegpt_config.yaml new file mode 100644 index 00000000000..a80b96c8eae --- /dev/null +++ b/test/srt/configs/sharegpt_config.yaml @@ -0,0 +1,7 @@ +tasks: + - name: sglang-benchmark + server_cmd: python3 -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disable-radix-cache + client_cmd: python3 -m sglang.bench_serving --backend sglang --dataset-name sharegpt --request-rate 16 + - name: vllm-benchmark + server_cmd: python3 -m vllm.entrypoints.openai.api_server --model meta-llama/Llama-3.1-8B-Instruct --disable-log-requests + client_cmd: python3 -m sglang.bench_serving --backend vllm --dataset-name sharegpt --request-rate 16 diff --git a/test/srt/experiment_runner.py b/test/srt/experiment_runner.py new file mode 100644 index 00000000000..c4966dc77ba --- /dev/null +++ b/test/srt/experiment_runner.py @@ -0,0 +1,359 @@ +import argparse +import logging +import os +import queue +import re +import subprocess +import threading +import time +from dataclasses import dataclass +from datetime import datetime +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import psutil +import requests +import yaml + + +@dataclass +class ServerConfig: + command: str + process_names: List[str] + default_port: int + + +@dataclass +class TaskConfig: + server_cmd: str + client_cmd: str + name: Optional[str] = None + server_type: Optional[str] = None + + +@dataclass +class TaskResult: + name: str + success: bool + output: str + runtime: float + timestamp: str + + +SERVER_DEFAULTS = { + "sglang": ServerConfig( + command="sglang.launch_server", + process_names=["sglang.launch_server"], + default_port=30000, + ), + "vllm": ServerConfig( + command="vllm.entrypoints.openai.api_server", + process_names=["vllm.entrypoints.openai.api_server"], + default_port=8000, + ), +} + + +def parse_key_info(output: str) -> str: + """Extract and format key information from the output""" + key_info = [] + + # Extract Args namespace + args_match = re.search(r"Namespace\(.*?\)", output, re.DOTALL) + if args_match: + key_info.append(args_match.group(0)) + + # Extract input/output token counts + token_matches = re.findall(r"#(Input|Output) tokens: \d+", output) + key_info.extend(token_matches) + + # Extract benchmark result section + result_match = re.search( + r"============ Serving Benchmark Result ============.*?={50,}", + output, + re.DOTALL, + ) + if result_match: + key_info.append(result_match.group(0)) + + return "\n\n".join(key_info) + + +def extract_port_from_command(cmd: str, server_type: str) -> int: + port_match = re.search(r"--port[= ](\d+)", cmd) + if port_match: + return int(port_match.group(1)) + return SERVER_DEFAULTS.get(server_type, ServerConfig("", [], 8000)).default_port + + +def detect_server_type(cmd: str) -> str: + for server_type, config in SERVER_DEFAULTS.items(): + if config.command in cmd: + return server_type + return "unknown" + + +def stream_output( + process: subprocess.Popen, prefix: str, logger: logging.Logger +) -> queue.Queue: + output_queue = queue.Queue() + + def stream_pipe(pipe, prefix): + for line in iter(pipe.readline, ""): + if prefix == "CLIENT": + output_queue.put(line.rstrip()) + logger.debug(f"{prefix} | {line.rstrip()}") + + stdout_thread = threading.Thread( + target=stream_pipe, args=(process.stdout, prefix), daemon=True + ) + stderr_thread = threading.Thread( + target=stream_pipe, args=(process.stderr, prefix), daemon=True + ) + + stdout_thread.start() + stderr_thread.start() + return output_queue, (stdout_thread, stderr_thread) + + +class ProcessManager: + def __init__(self): + self.server_process: Optional[subprocess.Popen] = None + self.client_process: Optional[subprocess.Popen] = None + self.logger = logging.getLogger(__name__) + + def start_process( + self, command: str, prefix: str + ) -> Tuple[subprocess.Popen, queue.Queue]: + process = subprocess.Popen( + command, + shell=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + bufsize=1, + ) + + output_queue, threads = stream_output(process, prefix, self.logger) + return process, output_queue, threads + + def kill_process_tree(self, process: subprocess.Popen): + try: + parent = psutil.Process(process.pid) + children = parent.children(recursive=True) + + for child in children: + try: + child.kill() + except psutil.NoSuchProcess: + pass + + parent.kill() + gone, alive = psutil.wait_procs(children + [parent], timeout=3) + + for p in alive: + try: + p.kill() + except psutil.NoSuchProcess: + pass + + except psutil.NoSuchProcess: + pass + + def cleanup(self, process_names: List[str]): + if self.client_process: + self.kill_process_tree(self.client_process) + self.client_process = None + + if self.server_process: + self.kill_process_tree(self.server_process) + self.server_process = None + + for proc in psutil.process_iter(["pid", "name", "cmdline"]): + try: + cmdline = " ".join(proc.cmdline()) + if any(name in cmdline for name in process_names): + proc.kill() + except (psutil.NoSuchProcess, psutil.AccessDenied): + continue + + +class ExperimentRunner: + def __init__(self): + self.process_manager = ProcessManager() + self.logger = logging.getLogger(__name__) + + def wait_for_server(self, port: int, timeout: int = 300) -> bool: + start_time = time.time() + + while time.time() - start_time < timeout: + try: + response = requests.get(f"http://localhost:{port}/health") + if response.status_code == 200: + self.logger.debug(f"Server ready on port {port}") + return True + except requests.RequestException: + time.sleep(2) + return False + + def run_task(self, config: TaskConfig) -> TaskResult: + start_time = time.time() + client_output = [] + + try: + if not config.server_type: + config.server_type = detect_server_type(config.server_cmd) + + server_config = SERVER_DEFAULTS.get(config.server_type) + if not server_config: + raise ValueError(f"Unknown server type: {config.server_type}") + + port = extract_port_from_command(config.server_cmd, config.server_type) + + self.process_manager.cleanup(server_config.process_names) + + self.logger.debug(f"Starting server: {config.name}") + self.process_manager.server_process, _, server_threads = ( + self.process_manager.start_process(config.server_cmd, "SERVER") + ) + + if not self.wait_for_server(port): + raise TimeoutError("Server startup timeout") + + time.sleep(10) + + self.logger.debug("Starting client") + self.process_manager.client_process, output_queue, client_threads = ( + self.process_manager.start_process(config.client_cmd, "CLIENT") + ) + + returncode = self.process_manager.client_process.wait() + + while True: + try: + line = output_queue.get_nowait() + client_output.append(line) + except queue.Empty: + break + + if returncode != 0: + raise RuntimeError(f"Client failed with code {returncode}") + + # Parse and format the output + full_output = "\n".join(client_output) + formatted_output = parse_key_info(full_output) + + return TaskResult( + name=config.name, + success=True, + output=formatted_output, + runtime=time.time() - start_time, + timestamp=datetime.now().isoformat(), + ) + + except Exception as e: + return TaskResult( + name=config.name, + success=False, + output=str(e), + runtime=time.time() - start_time, + timestamp=datetime.now().isoformat(), + ) + + finally: + if config.server_type in SERVER_DEFAULTS: + self.process_manager.cleanup( + SERVER_DEFAULTS[config.server_type].process_names + ) + time.sleep(10) + + +def load_config(config_path: str) -> List[TaskConfig]: + with open(config_path, "r") as f: + config_data = yaml.safe_load(f) + + configs = [] + for idx, entry in enumerate(config_data.get("tasks", [])): + if not isinstance(entry, dict): + raise ValueError(f"Invalid entry at index {idx}") + + config = TaskConfig( + server_cmd=entry.get("server_cmd"), + client_cmd=entry.get("client_cmd"), + name=entry.get("name", f"task-{idx+1}"), + server_type=entry.get("server_type"), + ) + + if not config.server_cmd or not config.client_cmd: + raise ValueError(f"Missing commands in {config.name}") + + configs.append(config) + + return configs + + +def setup_logging(debug: bool = False): + level = logging.DEBUG if debug else logging.INFO + logging.basicConfig( + level=level, + format="%(asctime)s - %(levelname)s - %(message)s", + handlers=[logging.StreamHandler(), logging.FileHandler("experiment.log")], + ) + + +def format_results(results: List[TaskResult]) -> str: + """Format experiment results in Markdown for GitHub step summary.""" + output = ["# Experiment Results\n"] + + for result in results: + output.append(f"## {result.name}") + output.append(f"**Status**: {'✅ Success' if result.success else '❌ Failed'}") + output.append(f"**Runtime**: {result.runtime:.2f} seconds") + output.append(f"**Timestamp**: {result.timestamp}") + output.append("\n**Output**:\n```") + output.append(result.output) + output.append("```\n") + + return "\n".join(output) + + +def write_in_github_step_summary(results: List[TaskResult]): + """Write formatted results to GitHub step summary.""" + if not os.environ.get("GITHUB_STEP_SUMMARY"): + logging.warning("GITHUB_STEP_SUMMARY environment variable not set") + return + + formatted_content = format_results(results) + with open(os.environ["GITHUB_STEP_SUMMARY"], "a") as f: + f.write(formatted_content) + + +def main(): + parser = argparse.ArgumentParser(description="Experiment Runner") + parser.add_argument( + "--config", type=str, required=True, help="Path to YAML config file" + ) + parser.add_argument("--debug", action="store_true", help="Enable debug output") + args = parser.parse_args() + + setup_logging(args.debug) + logger = logging.getLogger(__name__) + results = [] + + try: + configs = load_config(args.config) + runner = ExperimentRunner() + + for config in configs: + logger.info(f"Running {config.name}") + result = runner.run_task(config) + results.append(result) + + write_in_github_step_summary(results) + except Exception as e: + logger.error(f"Error: {e}") + raise + + +if __name__ == "__main__": + main()