Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reorg moe code #2563

Merged
merged 4 commits into from
Dec 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
from torch.nn import functional as F
from transformers import AutoConfig

from sglang.srt.layers.fused_moe_triton.fused_moe import fused_moe as fused_moe_triton
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
fused_moe as fused_moe_triton,
)
from sglang.srt.model_executor.cuda_graph_runner import set_torch_compile_config


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
from transformers import AutoConfig
from vllm.model_executor.layers.fused_moe.fused_moe import fused_moe as fused_moe_vllm

from sglang.srt.layers.fused_moe_triton.fused_moe import fused_moe as fused_moe_sglang
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
fused_moe as fused_moe_sglang,
)


def get_model_config(model_name: str, tp_size: int):
Expand Down
4 changes: 2 additions & 2 deletions benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from ray.experimental.tqdm_ray import tqdm
from transformers import AutoConfig

from sglang.srt.layers.fused_moe_triton.fused_moe import (
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
fused_moe,
get_config_dtype_str,
get_config_file_name,
Expand Down Expand Up @@ -97,7 +97,7 @@ def prepare(i: int):
input_gating.copy_(gating_output[i])

def run():
from sglang.srt.layers.fused_moe_triton import override_config
from sglang.srt.layers.moe.fused_moe_triton import override_config

with override_config(config):
fused_moe(
Expand Down
133 changes: 0 additions & 133 deletions python/sglang/srt/layers/fused_moe_patch.py

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@
from vllm.model_executor.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod

from sglang.srt.layers.custom_op_util import register_custom_op
from sglang.srt.layers.ep_moe.kernels import (
from sglang.srt.layers.moe.ep_moe.kernels import (
grouped_gemm_triton,
post_reorder_triton_kernel,
pre_reorder_triton_kernel,
run_moe_ep_preproess,
silu_and_mul_triton_kernel,
)
from sglang.srt.layers.fused_moe_triton.fused_moe import fused_topk, grouped_topk
from sglang.srt.layers.fused_moe_triton.layer import FusedMoEMethodBase
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoEMethodBase
from sglang.srt.layers.moe.topk import select_experts
from sglang.srt.layers.quantization.base_config import (
QuantizationConfig,
QuantizeMethodBase,
Expand Down Expand Up @@ -113,6 +113,7 @@ def __init__(
quant_config: Optional[QuantizationConfig] = None,
tp_size: Optional[int] = None,
prefix: str = "",
correction_bias: Optional[torch.Tensor] = None,
):
super().__init__()

Expand All @@ -138,6 +139,7 @@ def __init__(
assert num_expert_group is not None and topk_group is not None
self.num_expert_group = num_expert_group
self.topk_group = topk_group
self.correction_bias = correction_bias

if quant_config is None:
self.quant_method: Optional[QuantizeMethodBase] = UnquantizedEPMoEMethod()
Expand Down Expand Up @@ -170,13 +172,15 @@ def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
hidden_states.device, use_flashinfer=False # TODO: use flashinfer
)

topk_weights, topk_ids = self.select_experts(
hidden_states,
router_logits,
self.top_k,
self.renormalize,
self.topk_group,
self.num_expert_group,
topk_weights, topk_ids = select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
top_k=self.top_k,
use_grouped_topk=self.use_grouped_topk,
renormalize=self.renormalize,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
correction_bias=self.correction_bias,
)

reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess(
Expand Down Expand Up @@ -297,35 +301,6 @@ def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
)
return output

def select_experts(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
):
if self.use_grouped_topk:
assert topk_group is not None
assert num_expert_group is not None
topk_weights, topk_ids = grouped_topk(
hidden_states=hidden_states,
gating_output=router_logits,
topk=top_k,
renormalize=renormalize,
num_expert_group=num_expert_group,
topk_group=topk_group,
)
else:
topk_weights, topk_ids = fused_topk(
hidden_states=hidden_states,
gating_output=router_logits,
topk=top_k,
renormalize=renormalize,
)
return topk_weights, topk_ids.to(torch.int32)

@classmethod
def make_expert_params_mapping(
cls,
Expand Down
46 changes: 46 additions & 0 deletions python/sglang/srt/layers/moe/fused_moe_native.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
"""
Torch-native implementation for FusedMoE. This is used for torch.compile.
It is based on https://github.com/pytorch-labs/gpt-fast/blob/32971d3129541c5bfb4f715abc33d1c5f408d204/mixtral-moe/model.py#L204
"""

from typing import Callable, Optional

import torch
from torch.nn import functional as F

from sglang.srt.layers.moe.topk import select_experts


def fused_moe_forward_native(
layer: torch.nn.Module,
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
router_logits: torch.Tensor,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
correction_bias=correction_bias,
torch_native=True,
)

w13_weights = layer.w13_weight[topk_ids]
w1_weights, w3_weights = torch.chunk(w13_weights, 2, dim=2)
w2_weights = layer.w2_weight[topk_ids]
x1 = torch.einsum("ti,taoi -> tao", x, w1_weights)
x1 = F.silu(x1)
x3 = torch.einsum("ti, taoi -> tao", x, w3_weights)
expert_outs = torch.einsum("tao, taio -> tai", (x1 * x3), w2_weights)
return torch.einsum("tai,ta -> ti", expert_outs, topk_weights.to(expert_outs.dtype))
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
from contextlib import contextmanager
from typing import Any, Dict, Optional

import sglang.srt.layers.fused_moe_triton.fused_moe # noqa
from sglang.srt.layers.fused_moe_triton.fused_moe import (
import sglang.srt.layers.moe.fused_moe_triton.fused_moe # noqa
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
fused_experts,
fused_topk,
get_config_file_name,
grouped_topk,
)
from sglang.srt.layers.fused_moe_triton.layer import (
from sglang.srt.layers.moe.fused_moe_triton.layer import (
FusedMoE,
FusedMoEMethodBase,
FusedMoeWeightScaleSupported,
Expand Down Expand Up @@ -37,8 +35,6 @@ def get_config() -> Optional[Dict[str, Any]]:
"override_config",
"get_config",
"fused_moe",
"fused_topk",
"fused_experts",
"get_config_file_name",
"grouped_topk",
]
Loading
Loading