Skip to content

Commit

Permalink
Improve torch compile for fused moe (#2327)
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy authored Dec 3, 2024
1 parent 83b340e commit 07ec07a
Show file tree
Hide file tree
Showing 6 changed files with 45 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from transformers import AutoConfig

from sglang.srt.layers.fused_moe_triton.fused_moe import fused_moe as fused_moe_triton
from sglang.srt.model_executor.cuda_graph_runner import set_torch_compile_config


def get_model_config(model_name: str, tp_size: int):
Expand Down Expand Up @@ -64,7 +65,7 @@ def fused_topk_native(
return topk_weights, topk_ids


@torch.compile
@torch.compile(dynamic=False)
def fused_moe_torch(
x,
w1,
Expand All @@ -88,7 +89,8 @@ def fused_moe_torch(
w13_weights = w1[topk_ids]
w1_weights, w3_weights = torch.chunk(w13_weights, 2, dim=2)
w2_weights = w2[topk_ids]
x1 = F.gelu(torch.einsum("ti,taoi -> tao", x, w1_weights))
x1 = torch.einsum("ti,taoi -> tao", x, w1_weights)
x1 = F.silu(x1)
x3 = torch.einsum("ti, taoi -> tao", x, w3_weights)
expert_outs = torch.einsum("tao, taio -> tai", (x1 * x3), w2_weights)
return torch.einsum("tai,ta -> ti", expert_outs, topk_weights.to(expert_outs.dtype))
Expand Down Expand Up @@ -174,6 +176,7 @@ def benchmark(batch_size, provider, model_config, use_fp8=False):
print(f"benchmark {provider} with batch_size={batch_size}")
torch.set_default_device("cuda")
torch.cuda.manual_seed_all(0)
set_torch_compile_config()

num_tokens = batch_size
num_experts = model_config["num_experts"]
Expand Down
31 changes: 20 additions & 11 deletions python/sglang/srt/layers/fused_moe_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,20 +105,29 @@ def fused_moe_forward_native(
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
) -> torch.Tensor:
assert custom_routing_function is None
topk_weights, topk_ids = select_experts_native(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
)

if use_grouped_topk:
assert num_expert_group is not None and topk_group is not None
topk_weights, topk_ids = grouped_topk(
x,
router_logits,
top_k,
renormalize,
num_expert_group,
topk_group,
)
elif custom_routing_function is None:
topk_weights, topk_ids = fused_topk_native(x, router_logits, top_k, renormalize)
else:
topk_weights, topk_ids = custom_routing_function(
x, router_logits, top_k, renormalize
)

w13_weights = layer.w13_weight[topk_ids]
w1_weights, w3_weights = torch.chunk(w13_weights, 2, dim=2)
w2_weights = layer.w2_weight[topk_ids]
x1 = F.silu(torch.einsum("ti,taoi -> tao", x, w1_weights))
x1 = torch.einsum("ti,taoi -> tao", x, w1_weights)
x1 = F.silu(x1)
x3 = torch.einsum("ti, taoi -> tao", x, w3_weights)
expert_outs = torch.einsum("tao, taio -> tai", (x1 * x3), w2_weights)
return torch.einsum("tai,ta -> ti", expert_outs, topk_weights.to(expert_outs.dtype))
23 changes: 16 additions & 7 deletions python/sglang/srt/model_executor/cuda_graph_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from sglang.srt.model_executor.model_runner import ModelRunner


def _to_torch(model: torch.nn.Module, reverse: bool = False):
def _to_torch(model: torch.nn.Module, reverse: bool, batch_size: int):
for sub in model._modules.values():
if isinstance(sub, CustomOp):
if reverse:
Expand All @@ -45,38 +45,46 @@ def _to_torch(model: torch.nn.Module, reverse: bool = False):
else:
# NOTE: Temporarily workaround MoE
if "FusedMoE" in sub.__class__.__name__:
sub._forward_method = fused_moe_forward_native
if batch_size == 1:
# The performance of torch.compile on this layer is not always good when bs > 1,
# so we decide to skip it for now.
sub._forward_method = fused_moe_forward_native
else:
sub._forward_method = sub.forward_native
setattr(sub, "is_torch_compile", True)
if isinstance(sub, torch.nn.Module):
_to_torch(sub, reverse)
_to_torch(sub, reverse, batch_size)


@contextmanager
def patch_model(
model: torch.nn.Module, enable_compile: bool, tp_group: "GroupCoordinator"
model: torch.nn.Module,
enable_compile: bool,
batch_size: int,
tp_group: "GroupCoordinator",
):
"""Patch the model to make it compatible with with torch.compile"""
backup_ca_comm = None

try:
if enable_compile:
_to_torch(model)
_to_torch(model, reverse=False, batch_size=batch_size)
monkey_patch_vllm_all_gather()
backup_ca_comm = tp_group.ca_comm
# Use custom-allreduce here.
# We found the custom allreduce is much faster than the built-in allreduce in torch,
# even with ENABLE_INTRA_NODE_COMM=1.
# tp_group.ca_comm = None
yield torch.compile(
torch.no_grad()(model.forward), mode="max-autotune-no-cudagraphs"
torch.no_grad()(model.forward),
mode="max-autotune-no-cudagraphs",
dynamic=False,
)
else:
yield model.forward
finally:
if enable_compile:
_to_torch(model, reverse=True)
_to_torch(model, reverse=True, batch_size=batch_size)
monkey_patch_vllm_all_gather(reverse=True)
tp_group.ca_comm = backup_ca_comm

Expand Down Expand Up @@ -237,6 +245,7 @@ def capture(self):
with patch_model(
self.model_runner.model,
bs in self.compile_bs,
bs,
self.model_runner.tp_group,
) as forward:
(
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,7 +622,7 @@ def init_cuda_graphs(self):
tic = time.time()
logger.info("Capture cuda graph begin. This can take up to several minutes.")
self.cuda_graph_runner = CudaGraphRunner(self)
logger.info(f"Capture cuda graph end. Time elapsed: {time.time() - tic:.2f}s")
logger.info(f"Capture cuda graph end. Time elapsed: {time.time() - tic:.2f} s")

def apply_torch_tp(self):
logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.")
Expand Down
2 changes: 1 addition & 1 deletion test/srt/test_srt_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def test_8_engine_offline_throughput(self):
)
bench_args = BenchArgs(num_prompts=10)
result = throughput_test(server_args=server_args, bench_args=bench_args)
self.assertGreater(result["total_throughput"], 3500)
self.assertGreater(result["total_throughput"], 3000)


if __name__ == "__main__":
Expand Down
4 changes: 2 additions & 2 deletions test/srt/test_torch_compile_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
)


class TestTorchCompile(unittest.TestCase):
class TestTorchCompileMoe(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST
Expand All @@ -23,7 +23,7 @@ def setUpClass(cls):
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=["--enable-torch-compile", "--torch-compile-max-bs", "1"],
other_args=["--enable-torch-compile", "--torch-compile-max-bs", "8"],
)

@classmethod
Expand Down

0 comments on commit 07ec07a

Please sign in to comment.