Skip to content

Commit

Permalink
mixtral support moe-ep
Browse files Browse the repository at this point in the history
  • Loading branch information
xiaobochen123 committed Dec 5, 2024
1 parent be071f3 commit d624adf
Showing 1 changed file with 12 additions and 5 deletions.
17 changes: 12 additions & 5 deletions python/sglang/srt/models/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,13 @@
import torch
from torch import nn
from transformers import MixtralConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed import (
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce,
)
from vllm.model_executor.layers.rotary_embedding import get_rope

from sglang.srt.layers.ep_moe.layer import EPMoE
from sglang.srt.layers.fused_moe_triton import FusedMoE
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import (
Expand Down Expand Up @@ -65,6 +69,7 @@ def __init__(
prefix: str = "",
):
super().__init__()
self.tp_size = get_tensor_model_parallel_world_size()
self.hidden_size = hidden_size

# Gate always runs at half / full precision for now.
Expand All @@ -76,14 +81,13 @@ def __init__(
quant_config=None,
prefix=f"{prefix}.gate",
)

self.experts = FusedMoE(
MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
self.experts = MoEImpl(
num_experts=num_experts,
top_k=top_k,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
params_dtype=params_dtype,
reduce_results=True,
renormalize=True,
quant_config=quant_config,
tp_size=tp_size,
Expand All @@ -97,6 +101,8 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
final_hidden_states = self.experts(hidden_states, router_logits)
if self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
return final_hidden_states.view(orig_shape)


Expand Down Expand Up @@ -322,7 +328,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):

# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = FusedMoE.make_expert_params_mapping(
MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
expert_params_mapping = MoEImpl.make_expert_params_mapping(
ckpt_gate_proj_name="w1",
ckpt_down_proj_name="w2",
ckpt_up_proj_name="w3",
Expand Down

0 comments on commit d624adf

Please sign in to comment.