diff --git a/benchmark/kernels/fused_moe_triton/benchmark_torch_compile_fused_moe.py b/benchmark/kernels/fused_moe_triton/benchmark_torch_compile_fused_moe.py index 1bd6eec1645..10e258f16e5 100644 --- a/benchmark/kernels/fused_moe_triton/benchmark_torch_compile_fused_moe.py +++ b/benchmark/kernels/fused_moe_triton/benchmark_torch_compile_fused_moe.py @@ -5,7 +5,9 @@ from torch.nn import functional as F from transformers import AutoConfig -from sglang.srt.layers.fused_moe_triton.fused_moe import fused_moe as fused_moe_triton +from sglang.srt.layers.moe.fused_moe_triton.fused_moe import ( + fused_moe as fused_moe_triton, +) from sglang.srt.model_executor.cuda_graph_runner import set_torch_compile_config diff --git a/benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py b/benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py index 7bfb2731b98..1d16b256855 100644 --- a/benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py +++ b/benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py @@ -5,7 +5,9 @@ from transformers import AutoConfig from vllm.model_executor.layers.fused_moe.fused_moe import fused_moe as fused_moe_vllm -from sglang.srt.layers.fused_moe_triton.fused_moe import fused_moe as fused_moe_sglang +from sglang.srt.layers.moe.fused_moe_triton.fused_moe import ( + fused_moe as fused_moe_sglang, +) def get_model_config(model_name: str, tp_size: int): diff --git a/benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py b/benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py index 6f6a57be1de..da6af2fb863 100644 --- a/benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py +++ b/benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py @@ -11,7 +11,7 @@ from ray.experimental.tqdm_ray import tqdm from transformers import AutoConfig -from sglang.srt.layers.fused_moe_triton.fused_moe import ( +from sglang.srt.layers.moe.fused_moe_triton.fused_moe import ( fused_moe, get_config_dtype_str, get_config_file_name, @@ -97,7 +97,7 @@ def prepare(i: int): input_gating.copy_(gating_output[i]) def run(): - from sglang.srt.layers.fused_moe_triton import override_config + from sglang.srt.layers.moe.fused_moe_triton import override_config with override_config(config): fused_moe( diff --git a/python/sglang/srt/layers/fused_moe_patch.py b/python/sglang/srt/layers/fused_moe_patch.py deleted file mode 100644 index baca2581150..00000000000 --- a/python/sglang/srt/layers/fused_moe_patch.py +++ /dev/null @@ -1,133 +0,0 @@ -""" -Torch-native implementation for FusedMoE. This is used for torch.compile. -It is based on https://github.com/pytorch-labs/gpt-fast/blob/32971d3129541c5bfb4f715abc33d1c5f408d204/mixtral-moe/model.py#L204 -""" - -from typing import Callable, Optional - -import torch -from torch.nn import functional as F - - -def fused_topk_native( - hidden_states: torch.Tensor, - gating_output: torch.Tensor, - topk: int, - renormalize: bool, -): - assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" - M, _ = hidden_states.shape - topk_weights = torch.empty( - M, topk, dtype=torch.float32, device=hidden_states.device - ) - topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device) - topk_weights = F.softmax(gating_output.float(), dim=-1) - topk_weights, topk_ids = torch.topk(topk_weights, topk, dim=-1) - if renormalize: - topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) - return topk_weights, topk_ids - - -# This is used by the Deepseek-V2 model -def grouped_topk( - hidden_states: torch.Tensor, - gating_output: torch.Tensor, - topk: int, - renormalize: bool, - num_expert_group: int = 0, - topk_group: int = 0, -): - - assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" - - scores = torch.softmax(gating_output, dim=-1) - num_token = scores.shape[0] - group_scores = ( - scores.view(num_token, num_expert_group, -1).max(dim=-1).values - ) # [n, n_group] - group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[ - 1 - ] # [n, top_k_group] - group_mask = torch.zeros_like(group_scores) # [n, n_group] - group_mask.scatter_(1, group_idx, 1) # [n, n_group] - score_mask = ( - group_mask.unsqueeze(-1) - .expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group) - .reshape(num_token, -1) - ) # [n, e] - tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e] - topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False) - - if renormalize: - topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) - return topk_weights, topk_ids - - -def select_experts_native( - hidden_states: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - use_grouped_topk: bool, - renormalize: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, -): - # DeekSeekv2 uses grouped_top_k - if use_grouped_topk: - assert topk_group is not None - assert num_expert_group is not None - topk_weights, topk_ids = grouped_topk( - hidden_states=hidden_states, - gating_output=router_logits, - topk=top_k, - renormalize=renormalize, - num_expert_group=num_expert_group, - topk_group=topk_group, - ) - else: - topk_weights, topk_ids = fused_topk_native( - hidden_states=hidden_states, - gating_output=router_logits, - topk=top_k, - renormalize=renormalize, - ) - return topk_weights, topk_ids - - -def fused_moe_forward_native( - layer: torch.nn.Module, - x: torch.Tensor, - use_grouped_topk: bool, - top_k: int, - router_logits: torch.Tensor, - renormalize: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - custom_routing_function: Optional[Callable] = None, -) -> torch.Tensor: - - if use_grouped_topk: - assert num_expert_group is not None and topk_group is not None - topk_weights, topk_ids = grouped_topk( - x, - router_logits, - top_k, - renormalize, - num_expert_group, - topk_group, - ) - elif custom_routing_function is None: - topk_weights, topk_ids = fused_topk_native(x, router_logits, top_k, renormalize) - else: - topk_weights, topk_ids = custom_routing_function( - x, router_logits, top_k, renormalize - ) - - w13_weights = layer.w13_weight[topk_ids] - w1_weights, w3_weights = torch.chunk(w13_weights, 2, dim=2) - w2_weights = layer.w2_weight[topk_ids] - x1 = torch.einsum("ti,taoi -> tao", x, w1_weights) - x1 = F.silu(x1) - x3 = torch.einsum("ti, taoi -> tao", x, w3_weights) - expert_outs = torch.einsum("tao, taio -> tai", (x1 * x3), w2_weights) - return torch.einsum("tai,ta -> ti", expert_outs, topk_weights.to(expert_outs.dtype)) diff --git a/python/sglang/srt/layers/ep_moe/__init__.py b/python/sglang/srt/layers/moe/ep_moe/__init__.py similarity index 100% rename from python/sglang/srt/layers/ep_moe/__init__.py rename to python/sglang/srt/layers/moe/ep_moe/__init__.py diff --git a/python/sglang/srt/layers/ep_moe/kernels.py b/python/sglang/srt/layers/moe/ep_moe/kernels.py similarity index 100% rename from python/sglang/srt/layers/ep_moe/kernels.py rename to python/sglang/srt/layers/moe/ep_moe/kernels.py diff --git a/python/sglang/srt/layers/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py similarity index 93% rename from python/sglang/srt/layers/ep_moe/layer.py rename to python/sglang/srt/layers/moe/ep_moe/layer.py index 3c477fdc2ef..96e02e31278 100644 --- a/python/sglang/srt/layers/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -12,15 +12,15 @@ from vllm.model_executor.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod from sglang.srt.layers.custom_op_util import register_custom_op -from sglang.srt.layers.ep_moe.kernels import ( +from sglang.srt.layers.moe.ep_moe.kernels import ( grouped_gemm_triton, post_reorder_triton_kernel, pre_reorder_triton_kernel, run_moe_ep_preproess, silu_and_mul_triton_kernel, ) -from sglang.srt.layers.fused_moe_triton.fused_moe import fused_topk, grouped_topk -from sglang.srt.layers.fused_moe_triton.layer import FusedMoEMethodBase +from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoEMethodBase +from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase, @@ -113,6 +113,7 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, tp_size: Optional[int] = None, prefix: str = "", + correction_bias: Optional[torch.Tensor] = None, ): super().__init__() @@ -138,6 +139,7 @@ def __init__( assert num_expert_group is not None and topk_group is not None self.num_expert_group = num_expert_group self.topk_group = topk_group + self.correction_bias = correction_bias if quant_config is None: self.quant_method: Optional[QuantizeMethodBase] = UnquantizedEPMoEMethod() @@ -170,13 +172,15 @@ def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): hidden_states.device, use_flashinfer=False # TODO: use flashinfer ) - topk_weights, topk_ids = self.select_experts( - hidden_states, - router_logits, - self.top_k, - self.renormalize, - self.topk_group, - self.num_expert_group, + topk_weights, topk_ids = select_experts( + hidden_states=hidden_states, + router_logits=router_logits, + top_k=self.top_k, + use_grouped_topk=self.use_grouped_topk, + renormalize=self.renormalize, + topk_group=self.topk_group, + num_expert_group=self.num_expert_group, + correction_bias=self.correction_bias, ) reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess( @@ -297,35 +301,6 @@ def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): ) return output - def select_experts( - self, - hidden_states: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - ): - if self.use_grouped_topk: - assert topk_group is not None - assert num_expert_group is not None - topk_weights, topk_ids = grouped_topk( - hidden_states=hidden_states, - gating_output=router_logits, - topk=top_k, - renormalize=renormalize, - num_expert_group=num_expert_group, - topk_group=topk_group, - ) - else: - topk_weights, topk_ids = fused_topk( - hidden_states=hidden_states, - gating_output=router_logits, - topk=top_k, - renormalize=renormalize, - ) - return topk_weights, topk_ids.to(torch.int32) - @classmethod def make_expert_params_mapping( cls, diff --git a/python/sglang/srt/layers/moe/fused_moe_native.py b/python/sglang/srt/layers/moe/fused_moe_native.py new file mode 100644 index 00000000000..638173b647d --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_native.py @@ -0,0 +1,46 @@ +""" +Torch-native implementation for FusedMoE. This is used for torch.compile. +It is based on https://github.com/pytorch-labs/gpt-fast/blob/32971d3129541c5bfb4f715abc33d1c5f408d204/mixtral-moe/model.py#L204 +""" + +from typing import Callable, Optional + +import torch +from torch.nn import functional as F + +from sglang.srt.layers.moe.topk import select_experts + + +def fused_moe_forward_native( + layer: torch.nn.Module, + x: torch.Tensor, + use_grouped_topk: bool, + top_k: int, + router_logits: torch.Tensor, + renormalize: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + correction_bias: Optional[torch.Tensor] = None, +) -> torch.Tensor: + topk_weights, topk_ids = select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + correction_bias=correction_bias, + torch_native=True, + ) + + w13_weights = layer.w13_weight[topk_ids] + w1_weights, w3_weights = torch.chunk(w13_weights, 2, dim=2) + w2_weights = layer.w2_weight[topk_ids] + x1 = torch.einsum("ti,taoi -> tao", x, w1_weights) + x1 = F.silu(x1) + x3 = torch.einsum("ti, taoi -> tao", x, w3_weights) + expert_outs = torch.einsum("tao, taio -> tai", (x1 * x3), w2_weights) + return torch.einsum("tai,ta -> ti", expert_outs, topk_weights.to(expert_outs.dtype)) diff --git a/python/sglang/srt/layers/fused_moe_triton/__init__.py b/python/sglang/srt/layers/moe/fused_moe_triton/__init__.py similarity index 72% rename from python/sglang/srt/layers/fused_moe_triton/__init__.py rename to python/sglang/srt/layers/moe/fused_moe_triton/__init__.py index b895b9e4836..b68961931d5 100644 --- a/python/sglang/srt/layers/fused_moe_triton/__init__.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/__init__.py @@ -1,14 +1,12 @@ from contextlib import contextmanager from typing import Any, Dict, Optional -import sglang.srt.layers.fused_moe_triton.fused_moe # noqa -from sglang.srt.layers.fused_moe_triton.fused_moe import ( +import sglang.srt.layers.moe.fused_moe_triton.fused_moe # noqa +from sglang.srt.layers.moe.fused_moe_triton.fused_moe import ( fused_experts, - fused_topk, get_config_file_name, - grouped_topk, ) -from sglang.srt.layers.fused_moe_triton.layer import ( +from sglang.srt.layers.moe.fused_moe_triton.layer import ( FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported, @@ -37,8 +35,6 @@ def get_config() -> Optional[Dict[str, Any]]: "override_config", "get_config", "fused_moe", - "fused_topk", "fused_experts", "get_config_file_name", - "grouped_topk", ] diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=1,N=14336,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=14336,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=1,N=14336,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=14336,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=1,N=14336,device_name=NVIDIA_A100-SXM4-80GB.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=14336,device_name=NVIDIA_A100-SXM4-80GB.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=1,N=14336,device_name=NVIDIA_A100-SXM4-80GB.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=14336,device_name=NVIDIA_A100-SXM4-80GB.json diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=1,N=1792,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=1792,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=1,N=1792,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=1792,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=1,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=1,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=1,N=3072,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=3072,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=1,N=3072,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=3072,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=1,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=1,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=1,N=3072,device_name=NVIDIA_H100_80GB_HBM3.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=3072,device_name=NVIDIA_H100_80GB_HBM3.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=1,N=3072,device_name=NVIDIA_H100_80GB_HBM3.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=3072,device_name=NVIDIA_H100_80GB_HBM3.json diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=1,N=3584,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=3584,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=1,N=3584,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=3584,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=1,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=1,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=1,N=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=1,N=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=1,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=1,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-40GB.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-40GB.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-40GB.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-40GB.json diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-80GB.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-80GB.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-80GB.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-80GB.json diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=16,N=1344,device_name=NVIDIA_H100_80GB_HBM3.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1344,device_name=NVIDIA_H100_80GB_HBM3.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=16,N=1344,device_name=NVIDIA_H100_80GB_HBM3.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1344,device_name=NVIDIA_H100_80GB_HBM3.json diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=16,N=14336,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=14336,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=16,N=14336,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=14336,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=16,N=14336,device_name=NVIDIA_A100-SXM4-80GB.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=14336,device_name=NVIDIA_A100-SXM4-80GB.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=16,N=14336,device_name=NVIDIA_A100-SXM4-80GB.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=14336,device_name=NVIDIA_A100-SXM4-80GB.json diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=16,N=1792,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1792,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=16,N=1792,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1792,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=16,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=16,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=16,N=2688,device_name=NVIDIA_A100-SXM4-80GB.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=2688,device_name=NVIDIA_A100-SXM4-80GB.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=16,N=2688,device_name=NVIDIA_A100-SXM4-80GB.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=2688,device_name=NVIDIA_A100-SXM4-80GB.json diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=16,N=2688,device_name=NVIDIA_H100_80GB_HBM3.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=2688,device_name=NVIDIA_H100_80GB_HBM3.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=16,N=2688,device_name=NVIDIA_H100_80GB_HBM3.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=2688,device_name=NVIDIA_H100_80GB_HBM3.json diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=16,N=3072,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=3072,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=16,N=3072,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=3072,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=16,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=16,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=16,N=3200,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=3200,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=16,N=3200,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=3200,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=16,N=3584,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=3584,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=16,N=3584,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=3584,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=16,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=16,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=16,N=6400,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=6400,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=16,N=6400,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=6400,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=16,N=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=16,N=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=16,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=16,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=16,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=16,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=16,N=800,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=800,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=16,N=800,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=800,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_A100-SXM4-80GB.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_A100-SXM4-80GB.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_A100-SXM4-80GB.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_A100-SXM4-80GB.json diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_A800-SXM4-80GB.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_A800-SXM4-80GB.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_A800-SXM4-80GB.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_A800-SXM4-80GB.json diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_H100_80GB_HBM3.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_H100_80GB_HBM3.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_H100_80GB_HBM3.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_H100_80GB_HBM3.json diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_A100-SXM4-80GB.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_A100-SXM4-80GB.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_A100-SXM4-80GB.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_A100-SXM4-80GB.json diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_A800-SXM4-80GB.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_A800-SXM4-80GB.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_A800-SXM4-80GB.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_A800-SXM4-80GB.json diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_H100_80GB_HBM3.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_H100_80GB_HBM3.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_H100_80GB_HBM3.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_H100_80GB_HBM3.json diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=14336,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=14336,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-40GB.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-40GB.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-40GB.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-40GB.json diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_H100_80GB_HBM3.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_H100_80GB_HBM3.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_H100_80GB_HBM3.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_H100_80GB_HBM3.json diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_A100-SXM4-80GB.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_A100-SXM4-80GB.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_A100-SXM4-80GB.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_A100-SXM4-80GB.json diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-40GB.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-40GB.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-40GB.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-40GB.json diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3.json diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_L40S.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_L40S.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_L40S.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_L40S.json diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_A100-SXM4-80GB.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_A100-SXM4-80GB.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_A100-SXM4-80GB.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_A100-SXM4-80GB.json diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3.json diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=8192,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=8192,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/README b/python/sglang/srt/layers/moe/fused_moe_triton/configs/README similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/README rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/README diff --git a/python/sglang/srt/layers/fused_moe_triton/fused_moe.py b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py similarity index 90% rename from python/sglang/srt/layers/fused_moe_triton/fused_moe.py rename to python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py index e6ce9cb4d39..24e0133a121 100644 --- a/python/sglang/srt/layers/fused_moe_triton/fused_moe.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py @@ -13,6 +13,7 @@ import triton.language as tl from vllm import _custom_ops as ops +from sglang.srt.layers.moe.topk import select_experts from sglang.srt.utils import direct_register_custom_op, get_device_name logger = logging.getLogger(__name__) @@ -415,7 +416,7 @@ def try_get_optimal_moe_config( M: int, is_marlin: bool = False, ): - from sglang.srt.layers.fused_moe_triton import get_config + from sglang.srt.layers.moe.fused_moe_triton import get_config override_config = get_config() if override_config: @@ -435,74 +436,6 @@ def try_get_optimal_moe_config( return config -def fused_topk( - hidden_states: torch.Tensor, - gating_output: torch.Tensor, - topk: int, - renormalize: bool, -): - assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" - - M, _ = hidden_states.shape - - topk_weights = torch.empty( - M, topk, dtype=torch.float32, device=hidden_states.device - ) - topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device) - token_expert_indicies = torch.empty( - M, topk, dtype=torch.int32, device=hidden_states.device - ) - - ops.topk_softmax( - topk_weights, - topk_ids, - token_expert_indicies, - gating_output.float(), # TODO(woosuk): Optimize this. - ) - del token_expert_indicies # Not used. Will be used in the future. - - if renormalize: - topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) - - return topk_weights, topk_ids - - -# This is used by the Deepseek-V2 model -def grouped_topk( - hidden_states: torch.Tensor, - gating_output: torch.Tensor, - topk: int, - renormalize: bool, - num_expert_group: int = 0, - topk_group: int = 0, -): - - assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" - - scores = torch.softmax(gating_output, dim=-1) - num_token = scores.shape[0] - group_scores = ( - scores.view(num_token, num_expert_group, -1).max(dim=-1).values - ) # [n, n_group] - group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[ - 1 - ] # [n, top_k_group] - group_mask = torch.zeros_like(group_scores) # [n, n_group] - group_mask.scatter_(1, group_idx, 1) # [n, n_group] - score_mask = ( - group_mask.unsqueeze(-1) - .expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group) - .reshape(num_token, -1) - ) # [n, e] - tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e] - topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False) - - if renormalize: - topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) - - return topk_weights.to(torch.float32), topk_ids.to(torch.int32) - - def get_config_dtype_str( dtype: torch.dtype, use_int8_w8a16: Optional[bool] = False, @@ -869,24 +802,16 @@ def fused_moe( # Check constraints. assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch" - if use_grouped_topk: - assert num_expert_group is not None and topk_group is not None - topk_weights, topk_ids = grouped_topk( - hidden_states, - gating_output, - topk, - renormalize, - num_expert_group, - topk_group, - ) - elif custom_routing_function is None: - topk_weights, topk_ids = fused_topk( - hidden_states, gating_output, topk, renormalize - ) - else: - topk_weights, topk_ids = custom_routing_function( - hidden_states, gating_output, topk, renormalize - ) + topk_weights, topk_ids = select_experts( + hidden_states=hidden_states, + router_logits=gating_output, + use_grouped_topk=use_grouped_topk, + top_k=topk, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + ) return fused_experts( hidden_states, diff --git a/python/sglang/srt/layers/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py similarity index 92% rename from python/sglang/srt/layers/fused_moe_triton/layer.py rename to python/sglang/srt/layers/moe/fused_moe_triton/layer.py index 0785583cd94..2548ca16330 100644 --- a/python/sglang/srt/layers/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -13,6 +13,7 @@ from vllm.model_executor.custom_op import CustomOp from sglang.srt.layers.custom_op_util import register_custom_op +from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase, @@ -20,7 +21,7 @@ from sglang.srt.utils import set_weight_attrs if torch.cuda.is_available(): - from sglang.srt.layers.fused_moe_triton.fused_moe import fused_experts + from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts else: fused_experts = None # type: ignore @@ -106,6 +107,7 @@ def apply( topk_group: Optional[int] = None, num_expert_group: Optional[int] = None, custom_routing_function: Optional[Callable] = None, + correction_bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: return self.forward( x=x, @@ -117,6 +119,7 @@ def apply( topk_group=topk_group, num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, + correction_bias=correction_bias, ) def forward_cuda( @@ -130,8 +133,9 @@ def forward_cuda( topk_group: Optional[int] = None, num_expert_group: Optional[int] = None, custom_routing_function: Optional[Callable] = None, + correction_bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: - topk_weights, topk_ids = FusedMoE.select_experts( + topk_weights, topk_ids = select_experts( hidden_states=x, router_logits=router_logits, use_grouped_topk=use_grouped_topk, @@ -140,6 +144,7 @@ def forward_cuda( topk_group=topk_group, num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, + correction_bias=correction_bias, ) return fused_experts( @@ -197,6 +202,7 @@ def __init__( tp_size: Optional[int] = None, prefix: str = "", custom_routing_function: Optional[Callable] = None, + correction_bias: Optional[torch.Tensor] = None, ): super().__init__() @@ -217,6 +223,7 @@ def __init__( self.num_expert_group = num_expert_group self.topk_group = topk_group self.custom_routing_function = custom_routing_function + self.correction_bias = correction_bias if quant_config is None: self.quant_method: Optional[QuantizeMethodBase] = ( @@ -503,51 +510,6 @@ def weight_loader( ) return - @staticmethod - def select_experts( - hidden_states: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - use_grouped_topk: bool, - renormalize: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - custom_routing_function: Optional[Callable] = None, - ): - from sglang.srt.layers.fused_moe_triton.fused_moe import ( - fused_topk, - grouped_topk, - ) - - # DeekSeekv2 uses grouped_top_k - if use_grouped_topk: - assert topk_group is not None - assert num_expert_group is not None - topk_weights, topk_ids = grouped_topk( - hidden_states=hidden_states, - gating_output=router_logits, - topk=top_k, - renormalize=renormalize, - num_expert_group=num_expert_group, - topk_group=topk_group, - ) - elif custom_routing_function is None: - topk_weights, topk_ids = fused_topk( - hidden_states=hidden_states, - gating_output=router_logits, - topk=top_k, - renormalize=renormalize, - ) - else: - topk_weights, topk_ids = custom_routing_function( - hidden_states=hidden_states, - gating_output=router_logits, - topk=top_k, - renormalize=renormalize, - ) - - return topk_weights, topk_ids - def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): assert self.quant_method is not None @@ -562,6 +524,7 @@ def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): topk_group=self.topk_group, num_expert_group=self.num_expert_group, custom_routing_function=self.custom_routing_function, + correction_bias=self.correction_bias, ) if self.reduce_results and self.tp_size > 1: diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py new file mode 100644 index 00000000000..459ba6fc9af --- /dev/null +++ b/python/sglang/srt/layers/moe/topk.py @@ -0,0 +1,191 @@ +from typing import Callable, Optional + +import torch +import torch.nn.functional as F + + +def fused_topk_native( + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, +): + assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" + M, _ = hidden_states.shape + topk_weights = torch.empty( + M, topk, dtype=torch.float32, device=hidden_states.device + ) + topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device) + topk_weights = F.softmax(gating_output.float(), dim=-1) + topk_weights, topk_ids = torch.topk(topk_weights, topk, dim=-1) + if renormalize: + topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + return topk_weights, topk_ids + + +def fused_topk( + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, +): + from vllm import _custom_ops as ops + + assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" + + M, _ = hidden_states.shape + + topk_weights = torch.empty( + M, topk, dtype=torch.float32, device=hidden_states.device + ) + topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device) + token_expert_indicies = torch.empty( + M, topk, dtype=torch.int32, device=hidden_states.device + ) + + ops.topk_softmax( + topk_weights, + topk_ids, + token_expert_indicies, + gating_output.float(), + ) + del token_expert_indicies + + if renormalize: + topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + + return topk_weights, topk_ids + + +# This is used by the Deepseek-V2 model +def grouped_topk( + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + num_expert_group: int = 0, + topk_group: int = 0, +): + assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" + + scores = torch.softmax(gating_output, dim=-1) + num_token = scores.shape[0] + group_scores = ( + scores.view(num_token, num_expert_group, -1).max(dim=-1).values + ) # [n, n_group] + group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[ + 1 + ] # [n, top_k_group] + group_mask = torch.zeros_like(group_scores) # [n, n_group] + group_mask.scatter_(1, group_idx, 1) # [n, n_group] + score_mask = ( + group_mask.unsqueeze(-1) + .expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group) + .reshape(num_token, -1) + ) # [n, e] + tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e] + topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False) + + if renormalize: + topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + + return topk_weights.to(torch.float32), topk_ids.to(torch.int32) + + +def biased_grouped_topk( + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + correction_bias: torch.Tensor, + topk: int, + renormalize: bool, + num_expert_group: int = 0, + topk_group: int = 0, +): + assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" + + scores = gating_output.sigmoid() + num_token = scores.shape[0] + scores_for_choice = scores.view(num_token, -1) + correction_bias.unsqueeze(0) + group_scores = ( + scores_for_choice.view(num_token, num_expert_group, -1) + .topk(2, dim=-1)[0] + .sum(dim=-1) + ) # [n, n_group] + group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[ + 1 + ] # [n, top_k_group] + group_mask = torch.zeros_like(group_scores) # [n, n_group] + group_mask.scatter_(1, group_idx, 1) # [n, n_group] + score_mask = ( + group_mask.unsqueeze(-1) + .expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group) + .reshape(num_token, -1) + ) # [n, e] + tmp_scores = scores_for_choice.masked_fill(~score_mask.bool(), 0.0) # [n, e] + _, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False) + topk_weights = scores.gather(1, topk_ids) + + if renormalize: + topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + + return topk_weights.to(torch.float32), topk_ids.to(torch.int32) + + +def select_experts( + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + use_grouped_topk: bool, + renormalize: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + correction_bias: Optional[torch.Tensor] = None, + torch_native: bool = False, +): + # DeekSeekv2 uses grouped_top_k + if use_grouped_topk: + assert topk_group is not None + assert num_expert_group is not None + if correction_bias is None: + topk_weights, topk_ids = grouped_topk( + hidden_states=hidden_states, + gating_output=router_logits, + topk=top_k, + renormalize=renormalize, + num_expert_group=num_expert_group, + topk_group=topk_group, + ) + else: + topk_weights, topk_ids = biased_grouped_topk( + hidden_states=hidden_states, + gating_output=router_logits, + correction_bias=correction_bias, + topk=top_k, + renormalize=renormalize, + num_expert_group=num_expert_group, + topk_group=topk_group, + ) + elif torch_native: + topk_weights, topk_ids = fused_topk_native( + hidden_states=hidden_states, + gating_output=router_logits, + topk=top_k, + renormalize=renormalize, + ) + elif custom_routing_function is None: + topk_weights, topk_ids = fused_topk( + hidden_states=hidden_states, + gating_output=router_logits, + topk=top_k, + renormalize=renormalize, + ) + else: + topk_weights, topk_ids = custom_routing_function( + hidden_states=hidden_states, + gating_output=router_logits, + topk=top_k, + renormalize=renormalize, + ) + + return topk_weights, topk_ids diff --git a/python/sglang/srt/layers/quantization/__init__.py b/python/sglang/srt/layers/quantization/__init__.py index 48b733fdb78..ae9319f2824 100644 --- a/python/sglang/srt/layers/quantization/__init__.py +++ b/python/sglang/srt/layers/quantization/__init__.py @@ -60,8 +60,8 @@ def fp8_get_quant_method(self, layer, prefix): is_layer_skipped, ) - from sglang.srt.layers.fused_moe_triton.layer import FusedMoE from sglang.srt.layers.linear import UnquantizedLinearMethod + from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE from sglang.srt.layers.quantization.fp8 import Fp8LinearMethod, Fp8MoEMethod if isinstance(layer, LinearBase): @@ -80,7 +80,7 @@ def gptq_get_quant_method(self, layer, prefix): GPTQMarlinMoEMethod, ) - from sglang.srt.layers.fused_moe_triton.layer import FusedMoE + from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE if isinstance(layer, LinearBase): return GPTQMarlinLinearMethod(self) @@ -96,7 +96,7 @@ def awq_get_quant_method(self, layer, prefix): AWQMoEMethod, ) - from sglang.srt.layers.fused_moe_triton.layer import FusedMoE + from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE if isinstance(layer, LinearBase): return AWQMarlinLinearMethod(self) diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index c5a254b547e..b12815c665f 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -26,8 +26,8 @@ ) from vllm.model_executor.parameter import ModelWeightParameter, PerTensorScaleParameter -from sglang.srt.layers.fused_moe_triton.fused_moe import padding_size from sglang.srt.layers.linear import LinearMethodBase, UnquantizedLinearMethod +from sglang.srt.layers.moe.fused_moe_triton.fused_moe import padding_size from sglang.srt.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase, @@ -98,7 +98,7 @@ def get_quant_method( ) -> Optional["QuantizeMethodBase"]: from vllm.attention.layer import Attention # Avoid circular import - from sglang.srt.layers.fused_moe_triton import FusedMoE + from sglang.srt.layers.moe.fused_moe_triton import FusedMoE if isinstance(layer, LinearBase): if is_layer_skipped(prefix, self.ignored_layers): @@ -320,7 +320,7 @@ class Fp8MoEMethod: """ def __new__(cls, *args, **kwargs): - from sglang.srt.layers.fused_moe_triton import FusedMoEMethodBase + from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase if not hasattr(cls, "_initialized"): original_init = cls.__init__ @@ -349,7 +349,7 @@ def create_weights( params_dtype: torch.dtype, **extra_weight_attrs, ): - from sglang.srt.layers.fused_moe_triton import FusedMoeWeightScaleSupported + from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported if self.quant_config.is_checkpoint_fp8_serialized: params_dtype = torch.float8_e4m3fn @@ -566,12 +566,14 @@ def apply( topk_group: Optional[int] = None, num_expert_group: Optional[int] = None, custom_routing_function: Optional[Callable] = None, + correction_bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: - from sglang.srt.layers.fused_moe_triton import FusedMoE - from sglang.srt.layers.fused_moe_triton.fused_moe import fused_experts + from sglang.srt.layers.moe.fused_moe_triton import FusedMoE + from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts + from sglang.srt.layers.moe.topk import select_experts # Expert selection - topk_weights, topk_ids = FusedMoE.select_experts( + topk_weights, topk_ids = select_experts( hidden_states=x, router_logits=router_logits, use_grouped_topk=use_grouped_topk, @@ -580,6 +582,7 @@ def apply( topk_group=topk_group, num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, + correction_bias=correction_bias, ) # Expert fusion with FP8 quantization diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 1efd65e577d..65418c703ec 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -25,12 +25,12 @@ from vllm.distributed.parallel_state import graph_capture from vllm.model_executor.custom_op import CustomOp -from sglang.srt.layers.fused_moe_patch import fused_moe_forward_native from sglang.srt.layers.logits_processor import ( LogitsMetadata, LogitsProcessor, LogitsProcessorOutput, ) +from sglang.srt.layers.moe.fused_moe_native import fused_moe_forward_native from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.utils import maybe_torch_compile, monkey_patch_vllm_all_gather diff --git a/python/sglang/srt/models/dbrx.py b/python/sglang/srt/models/dbrx.py index 45561d1dbc0..852f58a710d 100644 --- a/python/sglang/srt/models/dbrx.py +++ b/python/sglang/srt/models/dbrx.py @@ -27,13 +27,13 @@ from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.transformers_utils.configs.dbrx import DbrxConfig -from sglang.srt.layers.fused_moe_triton import fused_moe from sglang.srt.layers.linear import ( QKVParallelLinear, ReplicatedLinear, RowParallelLinear, ) from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.moe.fused_moe_triton import fused_moe from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.vocab_parallel_embedding import ( diff --git a/python/sglang/srt/models/deepseek.py b/python/sglang/srt/models/deepseek.py index ce1b152fbc7..d840cb866bd 100644 --- a/python/sglang/srt/models/deepseek.py +++ b/python/sglang/srt/models/deepseek.py @@ -29,7 +29,6 @@ from vllm.model_executor.layers.rotary_embedding import get_rope from sglang.srt.layers.activation import SiluAndMul -from sglang.srt.layers.fused_moe_triton import fused_moe from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.linear import ( MergedColumnParallelLinear, @@ -38,6 +37,7 @@ RowParallelLinear, ) from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.moe.fused_moe_triton import fused_moe from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.vocab_parallel_embedding import ( diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 63cea92c289..92b987a23a3 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -19,6 +19,7 @@ from typing import Any, Dict, Iterable, Optional, Tuple import torch +import torch.nn.functional as F from torch import nn from transformers import PretrainedConfig from vllm import _custom_ops as ops @@ -31,8 +32,6 @@ from vllm.model_executor.layers.rotary_embedding import get_rope from sglang.srt.layers.activation import SiluAndMul -from sglang.srt.layers.ep_moe.layer import EPMoE -from sglang.srt.layers.fused_moe_triton import FusedMoE from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.linear import ( ColumnParallelLinear, @@ -41,6 +40,8 @@ RowParallelLinear, ) from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.moe.ep_moe.layer import EPMoE +from sglang.srt.layers.moe.fused_moe_triton import FusedMoE from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.vocab_parallel_embedding import ( @@ -90,6 +91,24 @@ def forward(self, x): return x +class MoEGate(nn.Module): + def __init__(self, config): + super().__init__() + self.weight = nn.Parameter( + torch.empty((config.n_routed_experts, config.hidden_size)) + ) + if config.topk_method == "noaux_tc": + self.e_score_correction_bias = nn.Parameter( + torch.empty((config.n_routed_experts)) + ) + else: + self.e_score_correction_bias = None + + def forward(self, hidden_states): + logits = F.linear(hidden_states, self.weight, None) + return logits + + class DeepseekV2MoE(nn.Module): def __init__( @@ -114,6 +133,8 @@ def __init__( "Only silu is supported for now." ) + self.gate = MoEGate(config=config) + MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE self.experts = MoEImpl( num_experts=config.n_routed_experts, @@ -125,11 +146,9 @@ def __init__( use_grouped_topk=True, num_expert_group=config.n_group, topk_group=config.topk_group, + correction_bias=self.gate.e_score_correction_bias, ) - self.gate = ReplicatedLinear( - config.hidden_size, config.n_routed_experts, bias=False, quant_config=None - ) if config.n_shared_experts is not None: intermediate_size = config.moe_intermediate_size * config.n_shared_experts self.shared_experts = DeepseekV2MLP( @@ -146,7 +165,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: if self.n_shared_experts is not None: shared_output = self.shared_experts(hidden_states) # router_logits: (num_tokens, n_experts) - router_logits, _ = self.gate(hidden_states) + router_logits = self.gate(hidden_states) final_hidden_states = ( self.experts(hidden_states=hidden_states, router_logits=router_logits) * self.routed_scaling_factor @@ -439,7 +458,10 @@ def __init__( quant_config=quant_config, ) self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps) - rope_scaling["rope_type"] = "deepseek_yarn" + + if rope_scaling: + rope_scaling["rope_type"] = "deepseek_yarn" + self.rotary_emb = get_rope( qk_rope_head_dim, rotary_dim=qk_rope_head_dim, @@ -454,6 +476,8 @@ def __init__( scaling_factor = rope_scaling["factor"] mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim)) self.scaling = self.scaling * mscale * mscale + else: + self.rotary_emb.forward = self.rotary_emb.forward_native self.attn_mqa = RadixAttention( self.num_local_heads, diff --git a/python/sglang/srt/models/grok.py b/python/sglang/srt/models/grok.py index e74a0bc5a54..cb6a72a3f60 100644 --- a/python/sglang/srt/models/grok.py +++ b/python/sglang/srt/models/grok.py @@ -26,7 +26,6 @@ from vllm.model_executor.layers.rotary_embedding import get_rope from sglang.srt.layers.activation import GeluAndMul -from sglang.srt.layers.fused_moe_triton import FusedMoE from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.linear import ( MergedColumnParallelLinear, @@ -35,6 +34,7 @@ RowParallelLinear, ) from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.moe.fused_moe_triton import FusedMoE from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.vocab_parallel_embedding import ( diff --git a/python/sglang/srt/models/mixtral.py b/python/sglang/srt/models/mixtral.py index f3fad226091..9dbdb46ff97 100644 --- a/python/sglang/srt/models/mixtral.py +++ b/python/sglang/srt/models/mixtral.py @@ -27,8 +27,6 @@ ) from vllm.model_executor.layers.rotary_embedding import get_rope -from sglang.srt.layers.ep_moe.layer import EPMoE -from sglang.srt.layers.fused_moe_triton import FusedMoE from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.linear import ( QKVParallelLinear, @@ -36,6 +34,8 @@ RowParallelLinear, ) from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.moe.ep_moe.layer import EPMoE +from sglang.srt.layers.moe.fused_moe_triton import FusedMoE from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.vocab_parallel_embedding import ( diff --git a/python/sglang/srt/models/olmoe.py b/python/sglang/srt/models/olmoe.py index 859f4135c4b..df96be3bc94 100644 --- a/python/sglang/srt/models/olmoe.py +++ b/python/sglang/srt/models/olmoe.py @@ -36,9 +36,9 @@ from vllm.model_executor.layers.rotary_embedding import get_rope from sglang.srt.layers.activation import SiluAndMul -from sglang.srt.layers.fused_moe_triton import FusedMoE from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput +from sglang.srt.layers.moe.fused_moe_triton import FusedMoE from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.vocab_parallel_embedding import ( diff --git a/python/sglang/srt/models/qwen2_moe.py b/python/sglang/srt/models/qwen2_moe.py index 62cd3281d03..9db2d538234 100644 --- a/python/sglang/srt/models/qwen2_moe.py +++ b/python/sglang/srt/models/qwen2_moe.py @@ -29,7 +29,6 @@ from vllm.model_executor.layers.rotary_embedding import get_rope from sglang.srt.layers.activation import SiluAndMul -from sglang.srt.layers.fused_moe_triton import FusedMoE from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.linear import ( MergedColumnParallelLinear, @@ -38,6 +37,7 @@ RowParallelLinear, ) from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.moe.fused_moe_triton import FusedMoE from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.vocab_parallel_embedding import ( diff --git a/python/sglang/srt/models/xverse_moe.py b/python/sglang/srt/models/xverse_moe.py index e1f3288753b..9b4b27f07d2 100644 --- a/python/sglang/srt/models/xverse_moe.py +++ b/python/sglang/srt/models/xverse_moe.py @@ -33,8 +33,8 @@ ) from vllm.model_executor.layers.rotary_embedding import get_rope -from sglang.srt.layers.fused_moe_triton import fused_moe from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.moe.fused_moe_triton import fused_moe from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.vocab_parallel_embedding import ( diff --git a/test/srt/test_fused_moe.py b/test/srt/test_fused_moe.py index 7b50c551a82..80aeab257c3 100644 --- a/test/srt/test_fused_moe.py +++ b/test/srt/test_fused_moe.py @@ -4,7 +4,7 @@ from vllm.model_executor.layers.fused_moe import fused_moe as fused_moe_vllm from sglang.srt.layers.activation import SiluAndMul -from sglang.srt.layers.fused_moe_triton.fused_moe import fused_moe +from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe class TestFusedMOE(unittest.TestCase):