Skip to content

Commit

Permalink
Reorg moe code (#2563)
Browse files Browse the repository at this point in the history
  • Loading branch information
ispobock authored Dec 23, 2024
1 parent 23e5e50 commit e835a50
Show file tree
Hide file tree
Showing 88 changed files with 338 additions and 344 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
from torch.nn import functional as F
from transformers import AutoConfig

from sglang.srt.layers.fused_moe_triton.fused_moe import fused_moe as fused_moe_triton
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
fused_moe as fused_moe_triton,
)
from sglang.srt.model_executor.cuda_graph_runner import set_torch_compile_config


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
from transformers import AutoConfig
from vllm.model_executor.layers.fused_moe.fused_moe import fused_moe as fused_moe_vllm

from sglang.srt.layers.fused_moe_triton.fused_moe import fused_moe as fused_moe_sglang
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
fused_moe as fused_moe_sglang,
)


def get_model_config(model_name: str, tp_size: int):
Expand Down
4 changes: 2 additions & 2 deletions benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from ray.experimental.tqdm_ray import tqdm
from transformers import AutoConfig

from sglang.srt.layers.fused_moe_triton.fused_moe import (
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
fused_moe,
get_config_dtype_str,
get_config_file_name,
Expand Down Expand Up @@ -97,7 +97,7 @@ def prepare(i: int):
input_gating.copy_(gating_output[i])

def run():
from sglang.srt.layers.fused_moe_triton import override_config
from sglang.srt.layers.moe.fused_moe_triton import override_config

with override_config(config):
fused_moe(
Expand Down
133 changes: 0 additions & 133 deletions python/sglang/srt/layers/fused_moe_patch.py

This file was deleted.

File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@
from vllm.model_executor.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod

from sglang.srt.layers.custom_op_util import register_custom_op
from sglang.srt.layers.ep_moe.kernels import (
from sglang.srt.layers.moe.ep_moe.kernels import (
grouped_gemm_triton,
post_reorder_triton_kernel,
pre_reorder_triton_kernel,
run_moe_ep_preproess,
silu_and_mul_triton_kernel,
)
from sglang.srt.layers.fused_moe_triton.fused_moe import fused_topk, grouped_topk
from sglang.srt.layers.fused_moe_triton.layer import FusedMoEMethodBase
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoEMethodBase
from sglang.srt.layers.moe.topk import select_experts
from sglang.srt.layers.quantization.base_config import (
QuantizationConfig,
QuantizeMethodBase,
Expand Down Expand Up @@ -113,6 +113,7 @@ def __init__(
quant_config: Optional[QuantizationConfig] = None,
tp_size: Optional[int] = None,
prefix: str = "",
correction_bias: Optional[torch.Tensor] = None,
):
super().__init__()

Expand All @@ -138,6 +139,7 @@ def __init__(
assert num_expert_group is not None and topk_group is not None
self.num_expert_group = num_expert_group
self.topk_group = topk_group
self.correction_bias = correction_bias

if quant_config is None:
self.quant_method: Optional[QuantizeMethodBase] = UnquantizedEPMoEMethod()
Expand Down Expand Up @@ -170,13 +172,15 @@ def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
hidden_states.device, use_flashinfer=False # TODO: use flashinfer
)

topk_weights, topk_ids = self.select_experts(
hidden_states,
router_logits,
self.top_k,
self.renormalize,
self.topk_group,
self.num_expert_group,
topk_weights, topk_ids = select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
top_k=self.top_k,
use_grouped_topk=self.use_grouped_topk,
renormalize=self.renormalize,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
correction_bias=self.correction_bias,
)

reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess(
Expand Down Expand Up @@ -297,35 +301,6 @@ def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
)
return output

def select_experts(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
):
if self.use_grouped_topk:
assert topk_group is not None
assert num_expert_group is not None
topk_weights, topk_ids = grouped_topk(
hidden_states=hidden_states,
gating_output=router_logits,
topk=top_k,
renormalize=renormalize,
num_expert_group=num_expert_group,
topk_group=topk_group,
)
else:
topk_weights, topk_ids = fused_topk(
hidden_states=hidden_states,
gating_output=router_logits,
topk=top_k,
renormalize=renormalize,
)
return topk_weights, topk_ids.to(torch.int32)

@classmethod
def make_expert_params_mapping(
cls,
Expand Down
46 changes: 46 additions & 0 deletions python/sglang/srt/layers/moe/fused_moe_native.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
"""
Torch-native implementation for FusedMoE. This is used for torch.compile.
It is based on https://github.com/pytorch-labs/gpt-fast/blob/32971d3129541c5bfb4f715abc33d1c5f408d204/mixtral-moe/model.py#L204
"""

from typing import Callable, Optional

import torch
from torch.nn import functional as F

from sglang.srt.layers.moe.topk import select_experts


def fused_moe_forward_native(
layer: torch.nn.Module,
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
router_logits: torch.Tensor,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
correction_bias=correction_bias,
torch_native=True,
)

w13_weights = layer.w13_weight[topk_ids]
w1_weights, w3_weights = torch.chunk(w13_weights, 2, dim=2)
w2_weights = layer.w2_weight[topk_ids]
x1 = torch.einsum("ti,taoi -> tao", x, w1_weights)
x1 = F.silu(x1)
x3 = torch.einsum("ti, taoi -> tao", x, w3_weights)
expert_outs = torch.einsum("tao, taio -> tai", (x1 * x3), w2_weights)
return torch.einsum("tai,ta -> ti", expert_outs, topk_weights.to(expert_outs.dtype))
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
from contextlib import contextmanager
from typing import Any, Dict, Optional

import sglang.srt.layers.fused_moe_triton.fused_moe # noqa
from sglang.srt.layers.fused_moe_triton.fused_moe import (
import sglang.srt.layers.moe.fused_moe_triton.fused_moe # noqa
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
fused_experts,
fused_topk,
get_config_file_name,
grouped_topk,
)
from sglang.srt.layers.fused_moe_triton.layer import (
from sglang.srt.layers.moe.fused_moe_triton.layer import (
FusedMoE,
FusedMoEMethodBase,
FusedMoeWeightScaleSupported,
Expand Down Expand Up @@ -37,8 +35,6 @@ def get_config() -> Optional[Dict[str, Any]]:
"override_config",
"get_config",
"fused_moe",
"fused_topk",
"fused_experts",
"get_config_file_name",
"grouped_topk",
]
Loading

0 comments on commit e835a50

Please sign in to comment.