Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[router] Improve cleanup logic #2411

Merged
merged 4 commits into from
Dec 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 37 additions & 80 deletions rust/py_src/sglang_router/launch_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@
from typing import List

import requests
from setproctitle import setproctitle
from sglang_router.launch_router import RouterArgs, launch_router

from sglang.srt.server import launch_server
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import is_port_available
from sglang.utils import get_exception_traceback


def setup_logger():
Expand All @@ -34,10 +34,12 @@ def setup_logger():
return logger


logger = setup_logger()


# Create new process group
def run_server(server_args, dp_rank):
os.setpgrp() # Create new process group

setproctitle(f"sglang::server")
# Set SGLANG_DP_RANK environment variable
os.environ["SGLANG_DP_RANK"] = str(dp_rank)

Expand All @@ -58,36 +60,6 @@ def launch_server_process(
return proc


def cleanup_processes(processes: List[mp.Process]):
logger = logging.getLogger("router")
logger.info("Cleaning up processes...")
for proc in processes:
if proc.is_alive():
try:
os.killpg(os.getpgid(proc.pid), signal.SIGTERM)
proc.join(timeout=3)
if proc.is_alive():
logger.warning(
f"Process {proc.pid} did not terminate gracefully, force killing..."
)
os.killpg(os.getpgid(proc.pid), signal.SIGKILL)
except ProcessLookupError:
pass


def setup_signal_handlers(cleanup_func):
"""Setup handlers for various termination signals."""

def signal_handler(signum, frame):
cleanup_func()
sys.exit(1)

signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)
if hasattr(signal, "SIGQUIT"):
signal.signal(signal.SIGQUIT, signal_handler)


def wait_for_server_health(host: str, port: int, timeout: int = 300) -> bool:
"""Wait for server to be healthy by checking /health endpoint."""
start_time = time.time()
Expand Down Expand Up @@ -117,8 +89,12 @@ def find_available_ports(base_port: int, count: int) -> List[int]:
return available_ports


def cleanup_processes(processes: List[mp.Process]):
for process in processes:
process.terminate()


def main():
logger = setup_logger()

# CUDA runtime isn't fork-safe, which can lead to subtle bugs or crashes
mp.set_start_method("spawn")
Expand Down Expand Up @@ -148,52 +124,33 @@ def main():
# Start server processes
server_processes = []

try:
for i, worker_port in enumerate(worker_ports):
logger.info(f"Launching DP server process {i} on port {worker_port}")
proc = launch_server_process(server_args, worker_port, i)
server_processes.append(proc)

# Setup cleanup handler
setup_signal_handlers(lambda: cleanup_processes(server_processes))

# Wait for all servers to be healthy
all_healthy = True

for port in worker_ports:
if not wait_for_server_health(server_args.host, port):
logger.error(f"Server on port {port} failed to become healthy")
all_healthy = False
break

if not all_healthy:
logger.error("Not all servers are healthy. Shutting down...")
cleanup_processes(server_processes)
sys.exit(1)

logger.info("All servers are healthy. Starting router...")

# Update router args with worker URLs
router_args.worker_urls = [
f"http://{server_args.host}:{port}" for port in worker_ports
]

# Start the router
router = launch_router(router_args)

if router is None:
logger.error("Failed to start router. Shutting down...")
cleanup_processes(server_processes)
sys.exit(1)

except KeyboardInterrupt:
logger.info("Received shutdown signal...")
except Exception as e:
logger.error(f"Error occurred: {e}")
logger.error(get_exception_traceback())
finally:
logger.info("Cleaning up processes...")
cleanup_processes(server_processes)
for i, worker_port in enumerate(worker_ports):
logger.info(f"Launching DP server process {i} on port {worker_port}")
proc = launch_server_process(server_args, worker_port, i)
server_processes.append(proc)

signal.signal(signal.SIGINT, lambda sig, frame: cleanup_processes(server_processes))
signal.signal(
signal.SIGTERM, lambda sig, frame: cleanup_processes(server_processes)
)
signal.signal(
signal.SIGQUIT, lambda sig, frame: cleanup_processes(server_processes)
)

for port in worker_ports:
if not wait_for_server_health(server_args.host, port):
logger.error(f"Server on port {port} failed to become healthy")
break

logger.info("All servers are healthy. Starting router...")

# Update router args with worker URLs
router_args.worker_urls = [
f"http://{server_args.host}:{port}" for port in worker_ports
]

# Start the router
router = launch_router(router_args)


if __name__ == "__main__":
Expand Down
67 changes: 48 additions & 19 deletions rust/py_test/test_launch_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

import requests

from sglang.srt.utils import kill_process_tree
from sglang.test.run_eval import run_eval
from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST,
Expand Down Expand Up @@ -104,23 +103,52 @@ def popen_launch_server(
return process


def terminate_and_wait(process, timeout=300):
"""Terminate a process and wait until it is terminated.

Args:
process: subprocess.Popen object
timeout: maximum time to wait in seconds

Raises:
TimeoutError: if process does not terminate within timeout
"""
if process is None:
return

process.terminate()
start_time = time.time()

while process.poll() is None:
print(f"Terminating process {process.pid}")
if time.time() - start_time > timeout:
raise TimeoutError(
f"Process {process.pid} failed to terminate within {timeout}s"
)
time.sleep(1)

print(f"Process {process.pid} is successfully terminated")


class TestLaunchServer(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = None
cls.other_process = []

@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
for process in cls.other_process:
kill_process_tree(process.pid)

def test_mmlu(self):
def setUp(self):
self.model = DEFAULT_MODEL_NAME_FOR_TEST
self.base_url = DEFAULT_URL_FOR_TEST
self.process = None
self.other_process = []

def tearDown(self):
print("Running tearDown...")
if self.process:
terminate_and_wait(self.process)
for process in self.other_process:
terminate_and_wait(process)
print("tearDown done")

def test_1_mmlu(self):
print("Running test_1_mmlu...")
# DP size = 2
TestLaunchServer.process = popen_launch_router(
self.process = popen_launch_router(
self.model,
self.base_url,
dp_size=2,
Expand All @@ -144,9 +172,10 @@ def test_mmlu(self):
msg = f"MMLU test {'passed' if passed else 'failed'} with score {score:.3f} (threshold: {THRESHOLD})"
self.assertGreaterEqual(score, THRESHOLD, msg)

def test_add_and_remove_worker(self):
def test_2_add_and_remove_worker(self):
print("Running test_2_add_and_remove_worker...")
# DP size = 1
TestLaunchServer.process = popen_launch_router(
self.process = popen_launch_router(
self.model,
self.base_url,
dp_size=1,
Expand All @@ -159,7 +188,7 @@ def test_add_and_remove_worker(self):
worker_process = popen_launch_server(
self.model, worker_url, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
)
TestLaunchServer.other_process.append(worker_process)
self.other_process.append(worker_process)

# 2. use /add_worker api to add it the the router. It will be used by router after it is healthy
with requests.Session() as session:
Expand Down
Loading