Skip to content

Commit

Permalink
add ep-size args
Browse files Browse the repository at this point in the history
  • Loading branch information
xiaobochen123 committed Dec 2, 2024
1 parent 35db97c commit 3aa1e25
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 2 deletions.
17 changes: 17 additions & 0 deletions python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ class ServerArgs:
# Data parallelism
dp_size: int = 1
load_balance_method: str = "round_robin"
# Expert parallelism
ep_size: int = 1

# Multi-node distributed serving
dist_init_addr: Optional[str] = None
Expand Down Expand Up @@ -199,6 +201,12 @@ def __post_init__(self):
"Data parallel size is adjusted to be the same as tensor parallel size. "
"Overlap schedule is disabled."
)
# Expert parallelism
if self.enable_ep_moe:
self.ep_size = self.tp_size
logger.info(
f"EP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
)

@staticmethod
def add_cli_args(parser: argparse.ArgumentParser):
Expand Down Expand Up @@ -493,6 +501,14 @@ def add_cli_args(parser: argparse.ArgumentParser):
"shortest_queue",
],
)
# Expert parallelism
parser.add_argument(
"--expert-parallel-size",
"--ep-size",
type=int,
default=ServerArgs.ep_size,
help="The expert parallelism size.",
)

# Multi-node distributed serving
parser.add_argument(
Expand Down Expand Up @@ -727,6 +743,7 @@ def add_cli_args(parser: argparse.ArgumentParser):
def from_cli_args(cls, args: argparse.Namespace):
args.tp_size = args.tensor_parallel_size
args.dp_size = args.data_parallel_size
args.ep_size = args.expert_parallel_size
attrs = [attr.name for attr in dataclasses.fields(cls)]
return cls(**{attr: getattr(args, attr) for attr in attrs})

Expand Down
13 changes: 11 additions & 2 deletions test/srt/test_moe_ep.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,14 @@ def setUpClass(cls):
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=["--tp", "2", "--trust-remote-code", "--enable-ep-moe"],
other_args=[
"--trust-remote-code",
"--tp",
"2",
"--ep-size",
"2",
"--enable-ep-moe",
],
)

@classmethod
Expand Down Expand Up @@ -62,9 +69,11 @@ def setUpClass(cls):
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--trust-remote-code",
"--tp",
"2",
"--trust-remote-code",
"--ep-size",
"2",
"--enable-ep-moe",
"--quantization",
"fp8",
Expand Down

0 comments on commit 3aa1e25

Please sign in to comment.