From 1ba776f8da8ff5fd380c0102aa09fd02f1cf8947 Mon Sep 17 00:00:00 2001 From: xiaobo Date: Wed, 20 Nov 2024 08:53:35 +0000 Subject: [PATCH 01/10] first add ep moe impl --- python/sglang/srt/layers/ep_moe/__init__.py | 0 python/sglang/srt/layers/ep_moe/kernels.py | 349 ++++++++ python/sglang/srt/layers/ep_moe/layer.py | 761 ++++++++++++++++++ python/sglang/srt/managers/schedule_batch.py | 1 + .../sglang/srt/model_executor/model_runner.py | 1 + python/sglang/srt/models/deepseek_v2.py | 8 +- python/sglang/srt/server_args.py | 6 + test/srt/test_moe_ep.py | 104 +++ 8 files changed, 1227 insertions(+), 3 deletions(-) create mode 100644 python/sglang/srt/layers/ep_moe/__init__.py create mode 100644 python/sglang/srt/layers/ep_moe/kernels.py create mode 100644 python/sglang/srt/layers/ep_moe/layer.py create mode 100644 test/srt/test_moe_ep.py diff --git a/python/sglang/srt/layers/ep_moe/__init__.py b/python/sglang/srt/layers/ep_moe/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/python/sglang/srt/layers/ep_moe/kernels.py b/python/sglang/srt/layers/ep_moe/kernels.py new file mode 100644 index 00000000000..8ec01bb340b --- /dev/null +++ b/python/sglang/srt/layers/ep_moe/kernels.py @@ -0,0 +1,349 @@ +from typing import Any, Dict, Optional, Tuple + +import torch +import triton +import triton.language as tl +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +@triton.jit +def compute_seg_indptr_triton_kernel(reorder_topk_ids, seg_indptr, num_toks): + expert = tl.program_id(0) + low = 0 + high = num_toks - 1 + target_location = -1 + while low <= high: + mid = (low + high) // 2 + + if tl.load(reorder_topk_ids + mid) > expert: + high = mid - 1 + else: + low = mid + 1 + target_location = mid + tl.store(seg_indptr + expert + 1, target_location + 1) + + +@triton.jit +def compute_src2dst_triton_kernel( + reorder_ids, src2dst, num_toks, BLOCK_SIZE: tl.constexpr +): + pid = tl.program_id(axis=0) + dst_id = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = dst_id < num_toks + src_id = tl.load(reorder_ids + dst_id, mask=mask) + tl.store(src2dst + src_id, dst_id, mask=mask) + + +def run_moe_ep_preproess(topk_ids: torch.Tensor, num_experts: int): + reorder_topk_ids, reorder_ids = torch.sort(topk_ids.view(-1), stable=True) + seg_indptr = torch.zeros(num_experts + 1, device=topk_ids.device, dtype=torch.int64) + src2dst = torch.empty(topk_ids.numel(), device=topk_ids.device, dtype=torch.int32) + + # + compute_seg_indptr_triton_kernel[(num_experts,)]( + reorder_topk_ids, seg_indptr, topk_ids.numel() + ) + # + BLOCK_SIZE = 512 + grid = (triton.cdiv(topk_ids.numel(), BLOCK_SIZE),) + compute_src2dst_triton_kernel[grid]( + reorder_ids, src2dst, topk_ids.numel(), BLOCK_SIZE + ) + return reorder_topk_ids, src2dst, seg_indptr + + +@triton.jit +def pre_reorder_triton_kernel( + input_ptr, + gateup_input_ptr, + src2dst_ptr, + topk_ids_ptr, + a1_scales_ptr, + start_expert_id, + end_expert_id, + topk, + hidden_size, +): + OutDtype = gateup_input_ptr.dtype.element_ty + + BLOCK_SIZE: tl.constexpr = 512 + + src_idx = tl.program_id(0) + src2dst_ptr = src2dst_ptr + src_idx * topk + topk_ids_ptr = topk_ids_ptr + src_idx * topk + + src_ptr = input_ptr + src_idx * hidden_size + for idx in range(topk): + expert_id = tl.load(topk_ids_ptr + idx) + if expert_id >= start_expert_id and expert_id <= end_expert_id: + if a1_scales_ptr is not None: + scale = 1.0 / tl.load(a1_scales_ptr + expert_id - start_expert_id) + else: + scale = 1.0 + + dst_idx = tl.load(src2dst_ptr + idx) + dst_ptr = gateup_input_ptr + dst_idx * hidden_size + for start_offset in tl.range(0, hidden_size, BLOCK_SIZE): + offset = start_offset + tl.arange(0, BLOCK_SIZE) + mask = offset < hidden_size + in_data = tl.load(src_ptr + offset, mask=mask).to(tl.float32) + out_data = (in_data * scale).to(OutDtype) + tl.store(dst_ptr + offset, out_data, mask=mask) + + +@triton.jit +def silu_and_mul_triton_kernel( + gateup_output, + down_input, + hidden_size, + reorder_topk_ids, + scales, + start_expert_id, + end_expert_id, +): + InDtype = gateup_output.dtype.element_ty + OutDtype = down_input.dtype.element_ty + + BLOCK_SIZE: tl.constexpr = 512 + half_hidden_size = hidden_size // 2 + + pid = tl.program_id(0) + expert_id = tl.load(reorder_topk_ids + pid) + if expert_id >= start_expert_id and expert_id <= end_expert_id: + gateup_output_ptr = gateup_output + pid * hidden_size + gate_output_ptr = gateup_output_ptr + up_output_ptr = gateup_output_ptr + half_hidden_size + down_input_ptr = down_input + pid * half_hidden_size + + if scales is not None: + scale = tl.load(scales + expert_id - start_expert_id) + scale = (1 / scale).to(InDtype) + else: + scale = 1 + + for start_offset in tl.range(0, half_hidden_size, BLOCK_SIZE): + offset = start_offset + tl.arange(0, BLOCK_SIZE) + mask = offset < half_hidden_size + + gate_output = tl.load(gate_output_ptr + offset, mask=mask).to(tl.float32) + up_output = tl.load(up_output_ptr + offset, mask=mask) + + # silu & mul & quantize + gate_output = gate_output * tl.sigmoid(gate_output) + gate_output = gate_output.to(InDtype) + + silu_mul_output = gate_output * up_output * scale + silu_mul_output = silu_mul_output.to(OutDtype) + tl.store(down_input_ptr + offset, silu_mul_output, mask=mask) + + +@triton.jit +def post_reorder_triton_kernel( + down_output_ptr, + output_ptr, + src2dst_ptr, + topk_ids_ptr, + topk_weights_ptr, + start_expert_id, + end_expert_id, + topk, + hidden_size, +): + InDtype = down_output_ptr.dtype.element_ty + + BLOCK_SIZE: tl.constexpr = 512 + + src_idx = tl.program_id(0) + src2dst_ptr = src2dst_ptr + src_idx * topk + topk_ids_ptr = topk_ids_ptr + src_idx * topk + topk_weights_ptr = topk_weights_ptr + src_idx * topk + + computed = False + store_ptr = output_ptr + src_idx * hidden_size + for start_offset in tl.range(0, hidden_size, BLOCK_SIZE): + offset = start_offset + tl.arange(0, BLOCK_SIZE) + mask = offset < hidden_size + + sum_vec = tl.zeros([BLOCK_SIZE], dtype=InDtype) + for idx in range(topk): + expert_id = tl.load(topk_ids_ptr + idx) + if expert_id >= start_expert_id and expert_id <= end_expert_id: + computed = True + dst_idx = tl.load(src2dst_ptr + idx) + weigh_scale = tl.load(topk_weights_ptr + idx).to(InDtype) + load_ptr = down_output_ptr + dst_idx * hidden_size + in_data = tl.load(load_ptr + offset, mask=mask) + sum_vec += in_data * weigh_scale + tl.store(store_ptr + offset, sum_vec, mask=mask) + + if computed == False: + for start_offset in tl.range(0, hidden_size, BLOCK_SIZE): + offset = start_offset + tl.arange(0, BLOCK_SIZE) + mask = offset < hidden_size + tl.store( + store_ptr + offset, tl.zeros([BLOCK_SIZE], dtype=InDtype), mask=mask + ) + + +@triton.jit +def compute_m_range( + pid, + batch_size, + seg_indptr, + weight_indices, + m_num_tiles_indptr, + BLOCK_SIZE_M: tl.constexpr, +): + idx = 0 + for bs in range(batch_size): + tiles = tl.load(m_num_tiles_indptr + bs) + if pid >= tiles: + idx = bs + + idx_start = tl.load(m_num_tiles_indptr + idx) + + m_range_start = tl.load(seg_indptr + idx) + (pid - idx_start) * BLOCK_SIZE_M + m_range_end = min(tl.load(seg_indptr + idx + 1), m_range_start + BLOCK_SIZE_M) + expert_id = tl.load(weight_indices + idx) + return m_range_start, m_range_end, expert_id + + +@triton.jit +def grouped_gemm_triton_kernel( + a, + b, + c, + batch_size, + N, + K, + seg_indptr, + weight_indices, + m_num_tiles_indptr, + scale_a, + scale_b, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + a_stride_0: tl.constexpr, + b_stride_0: tl.constexpr, + b_stride_1: tl.constexpr, +): + c_dtype = c.dtype.element_ty + + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + total_m_block = tl.load(m_num_tiles_indptr + batch_size) + if pid_m >= total_m_block: + return + + m_range_start, m_range_end, expert_id = compute_m_range( + pid_m, batch_size, seg_indptr, weight_indices, m_num_tiles_indptr, BLOCK_SIZE_M + ) + if m_range_end - m_range_start == 0: + return + + n_range_start = pid_n * BLOCK_SIZE_N + n_range_end = min(n_range_start + BLOCK_SIZE_N, N) + + # + offs_am = tl.arange(0, BLOCK_SIZE_M) + offs_bn = tl.arange(0, BLOCK_SIZE_N) + + offs_am = tl.where(offs_am < m_range_end - m_range_start, offs_am, 0) + offs_bn = tl.where(offs_bn < n_range_end - n_range_start, offs_bn, 0) + offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_SIZE_M), BLOCK_SIZE_M) + offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N), BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + + a_ptr = a + (m_range_start + offs_am[:, None]) * a_stride_0 + offs_k[None, :] + b_ptr = b + ( + (expert_id * b_stride_0) + + (n_range_start + offs_bn[:, None]) * b_stride_1 + + offs_k[None, :] + ) + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a_tile = tl.load( + a_ptr, mask=offs_k[None, :] < (K - k * BLOCK_SIZE_K), other=0.0 + ) + b_tile = tl.load( + b_ptr, mask=offs_k[None, :] < (K - k * BLOCK_SIZE_K), other=0.0 + ) + accumulator = tl.dot(a_tile, b_tile.T, accumulator) + a_ptr += BLOCK_SIZE_K + b_ptr += BLOCK_SIZE_K + + if scale_a is not None and scale_b is not None: + scale_a_value = tl.load(scale_a + expert_id) + scale_b_value = tl.load(scale_b + expert_id) + accumulator *= scale_a_value * scale_b_value + c_tile = accumulator.to(c_dtype) + + offs_cm = m_range_start + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_range_start + tl.arange(0, BLOCK_SIZE_N) + c_ptr = c + offs_cm[:, None] * N + offs_cn[None, :] + c_mask = (offs_cm[:, None] < m_range_end) & (offs_cn[None, :] < n_range_end) + tl.store(c_ptr, c_tile, mask=c_mask) + + +@triton.jit +def compute_m_num_tiles_indptr( + m_num_tiles_indptr, seg_indptr, batch_size: tl.constexpr, BLOCK_SIZE_M: tl.constexpr +): + for bs in range(batch_size): + m = tl.load(seg_indptr + bs + 1) - tl.load(seg_indptr + bs) + cur_num_tiles = tl.cdiv(m, BLOCK_SIZE_M) + pre_num_tiles = tl.load(m_num_tiles_indptr + bs) + tl.store(m_num_tiles_indptr + bs + 1, pre_num_tiles + cur_num_tiles) + + +def grouped_gemm_triton( + a: torch.Tensor, + b: torch.Tensor, + c: torch.Tensor, + batch_size: int, + weight_column_major: bool, + seg_indptr: Optional[torch.Tensor] = None, + weight_indices: Optional[torch.Tensor] = None, + enable_fp8: bool = False, + scale_a: torch.Tensor = None, + scale_b: torch.Tensor = None, +): + assert weight_column_major == True # TODO: more + if enable_fp8: + assert scale_a is not None and scale_b is not None + + BLOCK_SIZE_M = 128 + BLOCK_SIZE_N = 128 + BLOCK_SIZE_K = 128 + + m_num_tiles_indptr = torch.zeros(batch_size + 1, device=a.device, dtype=torch.int64) + compute_m_num_tiles_indptr[(1,)]( + m_num_tiles_indptr, seg_indptr, batch_size, BLOCK_SIZE_M + ) + + num_m_tiles = (a.size(0) + BLOCK_SIZE_M - 1) // BLOCK_SIZE_M + batch_size + num_n_tiles = (b.size(1) + BLOCK_SIZE_N - 1) // BLOCK_SIZE_N + + grouped_gemm_triton_kernel[(num_m_tiles, num_n_tiles)]( + a, + b, + c, + batch_size, + b.size(1), + b.size(2), + seg_indptr, + weight_indices, + m_num_tiles_indptr, + scale_a, + scale_b, + BLOCK_SIZE_M, + BLOCK_SIZE_N, + BLOCK_SIZE_K, + a.stride(0), + b.stride(0), + b.stride(1), + ) + return c diff --git a/python/sglang/srt/layers/ep_moe/layer.py b/python/sglang/srt/layers/ep_moe/layer.py new file mode 100644 index 00000000000..0ee88cab6dd --- /dev/null +++ b/python/sglang/srt/layers/ep_moe/layer.py @@ -0,0 +1,761 @@ +from typing import Callable, List, Optional, Tuple + +import torch +from torch.nn import Module +from vllm import _custom_ops as ops +from vllm.distributed import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) +from vllm.logger import init_logger +from vllm.model_executor.custom_op import CustomOp +from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, + QuantizeMethodBase, +) +from vllm.model_executor.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + all_close_1d, + normalize_e4m3fn_to_e4m3fnuz, + per_tensor_dequantize, +) +from vllm.model_executor.utils import set_weight_attrs +from vllm.utils import is_hip, print_warning_once + +from sglang.srt.layers.ep_moe.kernels import ( + grouped_gemm_triton, + post_reorder_triton_kernel, + pre_reorder_triton_kernel, + run_moe_ep_preproess, + silu_and_mul_triton_kernel, +) +from sglang.srt.layers.fused_moe_triton.fused_moe import fused_topk, grouped_topk + +logger = init_logger(__name__) + + +class GroupedGemmRunner(torch.nn.Module): + flashinfer_gemm_warpper = None + + def __init__(self, device, use_flashinfer: bool = False): + super().__init__() + self.device = device + self.use_flashinfer = use_flashinfer + if self.use_flashinfer and GroupedGemmRunner.flashinfer_gemm_warpper is None: + GroupedGemmRunner._init_flashinfer_wrapper(device) + + @classmethod + def _init_flashinfer_wrapper(cls, device): + from flashinfer import SegmentGEMMWrapper + + workspace_buffer = torch.empty( + 128 * 1024 * 1024, dtype=torch.int8, device=device + ) + cls.flashinfer_gemm_warpper = SegmentGEMMWrapper(workspace_buffer) + + # c = a * b + def forward( + self, + a: torch.Tensor, + b: torch.Tensor, + c: torch.Tensor, + batch_size: int, + weight_column_major: bool, + seg_indptr: Optional[torch.Tensor] = None, + weight_indices: Optional[torch.Tensor] = None, + enable_fp8: bool = False, + scale_a: torch.Tensor = None, + scale_b: torch.Tensor = None, + ): + if self.use_flashinfer: + # TODO: flashinfer + assert False + assert GroupedGemmRunner.flashinfer_gemm_warpper is not None + c = GroupedGemmRunner.flashinfer_gemm_warpper.run( + x=a, + weights=b, + batch_size=batch_size, + weight_column_major=weight_column_major, + seg_indptr=seg_indptr, + weight_indices=weight_indices, + ) + # else: + # assert weight_column_major == True + # for bs in range(batch_size): + # start_offset = seg_indptr[bs] + # end_offset = seg_indptr[bs + 1] + # if (end_offset - start_offset) <= 0: + # continue + # if enable_fp8: + # c[start_offset:end_offset], _ = torch._scaled_mm( + # a[start_offset:end_offset], + # b[bs].T, + # out_dtype=c.dtype, + # scale_a=scale_a[bs], + # scale_b=scale_b[bs], + # ) + # else: + # c[start_offset:end_offset] = torch.matmul( + # a[start_offset:end_offset], b[bs].T + # ) + else: + assert weight_column_major == True + c = grouped_gemm_triton( + a, + b, + c, + batch_size, + weight_column_major, + seg_indptr, + weight_indices, + enable_fp8, + scale_a, + scale_b, + ) + return c + + +class EPMoE(torch.nn.Module): + """ + MoE Expert Parallel Impl + + + """ + + def __init__( + self, + num_experts: int, + top_k: int, + hidden_size: int, + intermediate_size: int, + params_dtype: Optional[torch.dtype] = None, + renormalize: bool = True, + use_grouped_topk: bool = False, + num_expert_group: Optional[int] = None, + topk_group: Optional[int] = None, + quant_config: Optional[QuantizationConfig] = None, + tp_size: Optional[int] = None, + prefix: str = "", + ): + super().__init__() + + if params_dtype is None: + params_dtype = torch.get_default_dtype() + + self.tp_size = ( + tp_size if tp_size is not None else get_tensor_model_parallel_world_size() + ) + self.tp_rank = get_tensor_model_parallel_rank() + + self.num_experts = num_experts + assert self.num_experts % self.tp_size == 0 + self.num_experts_per_partition = self.num_experts // self.tp_size + self.start_expert_id = self.tp_rank * self.num_experts_per_partition + self.end_expert_id = self.start_expert_id + self.num_experts_per_partition - 1 + + self.top_k = top_k + self.intermediate_size = intermediate_size + self.renormalize = renormalize + self.use_grouped_topk = use_grouped_topk + if self.use_grouped_topk: + assert num_expert_group is not None and topk_group is not None + self.num_expert_group = num_expert_group + self.topk_group = topk_group + + if quant_config is None: + self.quant_method: Optional[QuantizeMethodBase] = UnquantizedEPMoEMethod( + quant_config + ) + self.enable_fp8 = False + self.activation_scheme = None + else: + self.quant_method: Optional[QuantizeMethodBase] = Fp8EPMoEMethod( + quant_config + ) + self.enable_fp8 = True + self.fp8_dtype = torch.float8_e4m3fn + self.activation_scheme = quant_config.activation_scheme + + self.quant_method.create_weights( + layer=self, + num_experts_per_partition=self.num_experts_per_partition, + hidden_size=hidden_size, + intermediate_size=self.intermediate_size, + params_dtype=params_dtype, + weight_loader=self.weight_loader, + ) + + self.grouped_gemm_runner = None + + def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): + assert self.quant_method is not None + + if self.grouped_gemm_runner is None: + self.grouped_gemm_runner = GroupedGemmRunner( + hidden_states.device, use_flashinfer=False # TODO: use flashinfer + ) + + topk_weights, topk_ids = self.select_experts( + hidden_states, + router_logits, + self.top_k, + self.renormalize, + self.topk_group, + self.num_expert_group, + ) + + reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess( + topk_ids, self.num_experts + ) + + gateup_input = torch.empty( + (int(hidden_states.shape[0] * self.top_k), hidden_states.shape[1]), + device=hidden_states.device, + dtype=self.fp8_dtype if self.enable_fp8 else hidden_states.dtype, + ) + if self.activation_scheme == "dynamic": + max_value = ( + torch.max(hidden_states) + .repeat(self.num_experts_per_partition) + .to(torch.float32) + ) + self.w13_input_scale = max_value / torch.finfo(self.fp8_dtype).max + + # PreReorder + pre_reorder_triton_kernel[(hidden_states.shape[0],)]( + hidden_states, + gateup_input, + src2dst, + topk_ids, + self.w13_input_scale, + self.start_expert_id, + self.end_expert_id, + self.top_k, + hidden_states.shape[1], + ) + + seg_indptr_cur_rank = seg_indptr[self.start_expert_id : self.end_expert_id + 2] + weight_indices_cur_rank = torch.arange( + 0, + self.num_experts_per_partition, + device=hidden_states.device, + dtype=torch.int64, + ) + # GroupGemm-0 + gateup_output = torch.empty( + gateup_input.shape[0], + self.w13_weight.shape[1], + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + gateup_output = self.grouped_gemm_runner( + a=gateup_input, + b=self.w13_weight, + c=gateup_output, + batch_size=self.num_experts_per_partition, + weight_column_major=True, + seg_indptr=seg_indptr_cur_rank, + weight_indices=weight_indices_cur_rank, + enable_fp8=self.enable_fp8, + scale_a=self.w13_input_scale, + scale_b=self.w13_weight_scale, + ) + + # Act + down_input = torch.empty( + gateup_output.shape[0], + gateup_output.shape[1] // 2, + device=gateup_output.device, + dtype=self.fp8_dtype if self.enable_fp8 else hidden_states.dtype, + ) + if self.w2_input_scale is None: + self.w2_input_scale = torch.ones( + self.num_experts_per_partition, + dtype=torch.float32, + device=hidden_states.device, + ) + silu_and_mul_triton_kernel[(gateup_output.shape[0],)]( + gateup_output, + down_input, + gateup_output.shape[1], + reorder_topk_ids, + self.w2_input_scale, + self.start_expert_id, + self.end_expert_id, + ) + + # GroupGemm-1 + down_output = torch.empty( + down_input.shape[0], + self.w2_weight.shape[1], + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + down_output = self.grouped_gemm_runner( + a=down_input, + b=self.w2_weight, + c=down_output, + batch_size=self.num_experts_per_partition, + weight_column_major=True, + seg_indptr=seg_indptr_cur_rank, + weight_indices=weight_indices_cur_rank, + enable_fp8=self.enable_fp8, + scale_a=self.w2_input_scale, + scale_b=self.w2_weight_scale, + ) + + # PostReorder + output = torch.empty_like(hidden_states) + post_reorder_triton_kernel[(hidden_states.size(0),)]( + down_output, + output, + src2dst, + topk_ids, + topk_weights, + self.start_expert_id, + self.end_expert_id, + self.top_k, + hidden_states.size(1), + ) + return output + + def select_experts( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + ): + if self.use_grouped_topk: + assert topk_group is not None + assert num_expert_group is not None + topk_weights, topk_ids = grouped_topk( + hidden_states=hidden_states, + gating_output=router_logits, + topk=top_k, + renormalize=renormalize, + num_expert_group=num_expert_group, + topk_group=topk_group, + ) + else: + topk_weights, topk_ids = fused_topk( + hidden_states=hidden_states, + gating_output=router_logits, + topk=top_k, + renormalize=renormalize, + ) + return topk_weights, topk_ids.to(torch.int32) + + @classmethod + def make_expert_params_mapping( + cls, + ckpt_gate_proj_name: str, + ckpt_down_proj_name: str, + ckpt_up_proj_name: str, + num_experts: int, + ) -> List[Tuple[str, str, int, str]]: + + return [ + # (param_name, weight_name, expert_id, shard_id) + ( + ( + "experts.w13_" + if weight_name in [ckpt_gate_proj_name, ckpt_up_proj_name] + else "experts.w2_" + ), + f"experts.{expert_id}.{weight_name}.", + expert_id, + shard_id, + ) + for expert_id in range(num_experts) + for shard_id, weight_name in [ + ("w1", ckpt_gate_proj_name), + ("w2", ckpt_down_proj_name), + ("w3", ckpt_up_proj_name), + ] + ] + + def weight_loader( + self, + param: torch.nn.Parameter, + loaded_weight: torch.Tensor, + weight_name: str, + shard_id: str, + expert_id: int, + ) -> None: + if expert_id < self.start_expert_id or expert_id > self.end_expert_id: + return + expert_id = expert_id - self.start_expert_id + + if shard_id not in ("w1", "w2", "w3"): + raise ValueError( + f"shard_id must be ['w1','w2','w3'] but " f"got {shard_id}." + ) + + # Special case for fp8 scales. + if "scale" in weight_name: + self._load_fp8_scale( + param.data, loaded_weight, weight_name, shard_id, expert_id + ) + return + + expert_data = param.data[expert_id] + if shard_id == "w2": + param.data[expert_id] = loaded_weight + elif shard_id == "w1": + param.data[expert_id][: self.intermediate_size, :] = loaded_weight + elif shard_id == "w3": + param.data[expert_id][self.intermediate_size :, :] = loaded_weight + else: + raise ValueError(f"Expected shard_id w1,w2 or w3 but got {shard_id}") + + def _load_fp8_scale( + self, + param: torch.nn.Parameter, + loaded_weight: torch.Tensor, + weight_name: str, + shard_id: str, + expert_id: int, + ) -> None: + param_data = param.data + + # Input scales can be loaded directly and should be equal. + if "input_scale" in weight_name: + if ( + param_data[expert_id] != 1 + and (param_data[expert_id] - loaded_weight).abs() > 1e-5 + ): + raise ValueError( + "input_scales of w1 and w3 of a layer " + f"must be equal. But got {param_data[expert_id]} " + f"vs. {loaded_weight}" + ) + param_data[expert_id] = loaded_weight + # Weight scales + elif "weight_scale" in weight_name: + # If we are in merged column case (gate_up_proj) + if shard_id in ("w1", "w3"): + # We have to keep the weight scales of w1 and w3 because + # we need to re-quantize w1/w3 weights after weight loading. + idx = 0 if shard_id == "w1" else 1 + param_data[expert_id][idx] = loaded_weight + # If we are in the row parallel case (down_proj) + else: + param_data[expert_id] = loaded_weight + + +class UnquantizedEPMoEMethod(FusedMoEMethodBase, CustomOp): + def create_weights( + self, + layer: torch.nn.Module, + num_experts_per_partition: int, + hidden_size: int, + intermediate_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + # Fused gate_up_proj (column parallel) + w13_weight = torch.nn.Parameter( + torch.empty( + num_experts_per_partition, + 2 * intermediate_size, + hidden_size, + dtype=params_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + # down_proj (row parallel) + w2_weight = torch.nn.Parameter( + torch.empty( + num_experts_per_partition, + hidden_size, + intermediate_size, + dtype=params_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + # scale + ones_tensor = torch.ones(num_experts_per_partition, dtype=torch.float32) + w13_input_scale = torch.nn.Parameter( + ones_tensor, + requires_grad=False, + ) + layer.register_parameter("w13_input_scale", w13_input_scale) + set_weight_attrs(w13_input_scale, extra_weight_attrs) + + w2_input_scale = torch.nn.Parameter( + ones_tensor, + requires_grad=False, + ) + layer.register_parameter("w2_input_scale", w2_input_scale) + set_weight_attrs(w2_input_scale, extra_weight_attrs) + + w13_weight_scale = torch.nn.Parameter( + ones_tensor, + requires_grad=False, + ) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + + w2_weight_scale = torch.nn.Parameter( + ones_tensor, + requires_grad=False, + ) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + ) -> torch.Tensor: + raise NotImplementedError + + +class Fp8EPMoEMethod(Fp8MoEMethod): + """MoE method for FP8. + Supports loading FP8 checkpoints with static weight scale and + dynamic/static activation scale. + + Args: + quant_config: The quantization config. + """ + + def __init__(self, quant_config: Fp8Config): + self.quant_config = quant_config + + def create_weights( + self, + layer: Module, + num_experts_per_partition: int, + hidden_size: int, + intermediate_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + + if self.quant_config.is_checkpoint_fp8_serialized: + params_dtype = torch.float8_e4m3fn + + # WEIGHTS + w13_weight = torch.nn.Parameter( + torch.empty( + num_experts_per_partition, + 2 * intermediate_size, + hidden_size, + dtype=params_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + w2_weight = torch.nn.Parameter( + torch.empty( + num_experts_per_partition, + hidden_size, + intermediate_size, + dtype=params_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + # WEIGHT_SCALES + # Allocate 2 scales for w1 and w3 respectively. + # They will be combined to a single scale after weight loading. + w13_weight_scale = torch.nn.Parameter( + torch.ones(num_experts_per_partition, 2, dtype=torch.float32), + requires_grad=False, + ) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + + w2_weight_scale = torch.nn.Parameter( + torch.ones(num_experts_per_partition, dtype=torch.float32), + requires_grad=False, + ) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + # Add the quantization method used (per tensor/grouped/channel) + # to ensure the weight scales are loaded in properly + extra_weight_attrs.update({"quant_method": "tensor"}) + # If loading fp8 checkpoint, pass the weight loaders. + # If loading an fp16 checkpoint, do not (we will quantize in + # process_weights_after_loading() + if self.quant_config.is_checkpoint_fp8_serialized: + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) + + # INPUT_SCALES + if self.quant_config.activation_scheme == "static": + if not self.quant_config.is_checkpoint_fp8_serialized: + raise ValueError( + "Found static activation scheme for checkpoint that " + "was not serialized fp8." + ) + + w13_input_scale = torch.nn.Parameter( + torch.ones(num_experts_per_partition, dtype=torch.float32), + requires_grad=False, + ) + layer.register_parameter("w13_input_scale", w13_input_scale) + set_weight_attrs(w13_input_scale, extra_weight_attrs) + + w2_input_scale = torch.nn.Parameter( + torch.ones(num_experts_per_partition, dtype=torch.float32), + requires_grad=False, + ) + layer.register_parameter("w2_input_scale", w2_input_scale) + set_weight_attrs(w2_input_scale, extra_weight_attrs) + + else: + layer.w13_input_scale = None + layer.w2_input_scale = None + + def process_weights_after_loading(self, layer: Module) -> None: + + # If checkpoint is fp16, quantize in place. + if not self.quant_config.is_checkpoint_fp8_serialized: + # If rocm, use float8_e4m3fnuz as dtype + fp8_dtype = torch.float8_e4m3fnuz if is_hip() else torch.float8_e4m3fn + w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype) + w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype) + + # Re-initialize w13_scale because we directly quantize + # merged w13 weights and generate a single scaling factor. + layer.w13_weight_scale = torch.nn.Parameter( + torch.ones( + layer.num_experts_per_partition, + dtype=torch.float32, + device=w13_weight.device, + ), + requires_grad=False, + ) + + for expert in range(layer.num_experts_per_partition): + w13_weight[expert, :, :], layer.w13_weight_scale[expert] = ( + ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :]) + ) + w2_weight[expert, :, :], layer.w2_weight_scale[expert] = ( + ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :]) + ) + layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False) + layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False) + return + + # If checkpoint is fp8, we need to handle that the + # MoE kernels require single activation scale and single weight + # scale for w13 per expert. + else: + # Fp8 moe kernels require a single activation scale. + # We take the max of all the scales in case they differ. + if self.quant_config.activation_scheme == "static": + if layer.w13_input_scale is None or layer.w2_input_scale is None: + raise ValueError( + "QuantConfig has static quantization, but found " + "activation scales are None." + ) + if not all_close_1d(layer.w13_input_scale) or not all_close_1d( + layer.w2_input_scale + ): + print_warning_once( + "Found input_scales that are not equal for " + "fp8 MoE layer. Using the maximum across experts " + "for each layer. " + ) + # layer.w13_input_scale = torch.nn.Parameter( + # layer.w13_input_scale.max().repeat(layer.num_experts_per_partition), + # requires_grad=False) + # layer.w2_input_scale = torch.nn.Parameter( + # layer.w2_input_scale.max().repeat(layer.num_experts_per_partition), + # requires_grad=False) + + layer.w13_input_scale = torch.nn.Parameter( + layer.w13_input_scale, requires_grad=False + ) + layer.w2_input_scale = torch.nn.Parameter( + layer.w2_input_scale, requires_grad=False + ) + + # If rocm, normalize the weights and scales to e4m3fnuz + if is_hip(): + # Normalize the weights and scales + w13_weight, w13_weight_scale, w13_input_scale = ( + normalize_e4m3fn_to_e4m3fnuz( + layer.w13_weight, layer.w13_weight_scale, layer.w13_input_scale + ) + ) + w2_weight, w2_weight_scale, w2_input_scale = ( + normalize_e4m3fn_to_e4m3fnuz( + layer.w2_weight, layer.w2_weight_scale, layer.w2_input_scale + ) + ) + # Reset the parameter + layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False) + layer.w13_weight_scale = torch.nn.Parameter( + w13_weight_scale, requires_grad=False + ) + if w13_input_scale is not None: + layer.w13_input_scale = torch.nn.Parameter( + w13_input_scale, requires_grad=False + ) + layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False) + layer.w2_weight_scale = torch.nn.Parameter( + w2_weight_scale, requires_grad=False + ) + if w2_input_scale is not None: + layer.w2_input_scale = torch.nn.Parameter( + w2_input_scale, requires_grad=False + ) + + # Fp8 moe kernel needs single weight scale for w13 per expert. + # We take the max then dequant and requant each expert. + assert layer.w13_weight_scale is not None + shard_size = layer.intermediate_size + max_w13_scales = layer.w13_weight_scale.max(dim=1).values + for expert_id in range(layer.num_experts_per_partition): + start = 0 + for shard_id in range(2): + dq_weight = per_tensor_dequantize( + layer.w13_weight[expert_id][start : start + shard_size, :], + layer.w13_weight_scale[expert_id][shard_id], + ) + layer.w13_weight[expert_id][start : start + shard_size, :], _ = ( + ops.scaled_fp8_quant(dq_weight, max_w13_scales[expert_id]) + ) + start += shard_size + + layer.w13_weight_scale = torch.nn.Parameter( + max_w13_scales, requires_grad=False + ) + return + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + ) -> torch.Tensor: + raise NotImplementedError diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 28677efeac4..5855d4248ff 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -58,6 +58,7 @@ "torchao_config": ServerArgs.torchao_config, "enable_nan_detection": ServerArgs.enable_nan_detection, "enable_dp_attention": ServerArgs.enable_dp_attention, + "enable_ep_moe": ServerArgs.enable_ep_moe, } diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 4eaedbccbff..3f0cbecac15 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -141,6 +141,7 @@ def __init__( "torchao_config": server_args.torchao_config, "enable_nan_detection": server_args.enable_nan_detection, "enable_dp_attention": server_args.enable_dp_attention, + "enable_ep_moe": server_args.enable_ep_moe, } ) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 80db9a35c71..63cea92c289 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -31,6 +31,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope from sglang.srt.layers.activation import SiluAndMul +from sglang.srt.layers.ep_moe.layer import EPMoE from sglang.srt.layers.fused_moe_triton import FusedMoE from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.linear import ( @@ -113,12 +114,12 @@ def __init__( "Only silu is supported for now." ) - self.experts = FusedMoE( + MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE + self.experts = MoEImpl( num_experts=config.n_routed_experts, top_k=config.num_experts_per_tok, hidden_size=config.hidden_size, intermediate_size=config.moe_intermediate_size, - reduce_results=False, renormalize=config.norm_topk_prob, quant_config=quant_config, use_grouped_topk=True, @@ -834,7 +835,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Params for weights, fp8 weight scales, fp8 activation scales # (param_name, weight_name, expert_id, shard_id) - expert_params_mapping = FusedMoE.make_expert_params_mapping( + MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE + expert_params_mapping = MoEImpl.make_expert_params_mapping( ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 7b337500fd7..536e9ba175d 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -130,6 +130,7 @@ class ServerArgs: disable_overlap_schedule: bool = False enable_mixed_chunk: bool = False enable_dp_attention: bool = False + enable_ep_moe: bool = (False,) enable_torch_compile: bool = False torch_compile_max_bs: int = 32 cuda_graph_max_bs: Optional[int] = None @@ -681,6 +682,11 @@ def add_cli_args(parser: argparse.ArgumentParser): action="store_true", help="Enabling data parallelism for attention and tensor parallelism for FFN. The dp size should be equal to the tp size. Currently only DeepSeek-V2 is supported.", ) + parser.add_argument( + "--enable-ep-moe", + action="store_true", + help="Enabling expert parallelism for moe. The ep size is equal to the tp size.", + ) parser.add_argument( "--enable-torch-compile", action="store_true", diff --git a/test/srt/test_moe_ep.py b/test/srt/test_moe_ep.py new file mode 100644 index 00000000000..3aeca8aee52 --- /dev/null +++ b/test/srt/test_moe_ep.py @@ -0,0 +1,104 @@ +import unittest +from types import SimpleNamespace + +from sglang.srt.utils import kill_child_process +from sglang.test.run_eval import run_eval +from sglang.test.test_utils import ( + DEFAULT_MLA_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + popen_launch_server, +) + + +class TestEpMoE(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=["--tp", "2", "--trust-remote-code", "--enable-ep-moe"], + ) + + @classmethod + def tearDownClass(cls): + kill_child_process(cls.process.pid, include_self=True) + + def test_mmlu(self): + args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="mmlu", + num_examples=64, + num_threads=32, + ) + + metrics = run_eval(args) + assert metrics["score"] >= 0.5 + + def test_mgsm_en(self): + args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="mgsm_en", + num_examples=None, + num_threads=1024, + ) + + metrics = run_eval(args) + assert metrics["score"] >= 0.8 + + +class TestEpMoEFP8(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--tp", + "2", + "--trust-remote-code", + "--enable-ep-moe", + "--quantization", + "fp8", + ], + ) + + @classmethod + def tearDownClass(cls): + kill_child_process(cls.process.pid, include_self=True) + + def test_mmlu(self): + args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="mmlu", + num_examples=64, + num_threads=32, + ) + + metrics = run_eval(args) + assert metrics["score"] >= 0.5 + + def test_mgsm_en(self): + args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="mgsm_en", + num_examples=None, + num_threads=1024, + ) + + metrics = run_eval(args) + assert metrics["score"] >= 0.8 + + +if __name__ == "__main__": + unittest.main() From d02b725e1300a93271ddb99e8fdd69673434465a Mon Sep 17 00:00:00 2001 From: xiaobo Date: Tue, 26 Nov 2024 22:55:27 +0800 Subject: [PATCH 02/10] fix bug --- python/sglang/srt/layers/ep_moe/layer.py | 9 +++++---- python/sglang/srt/server_args.py | 2 +- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/layers/ep_moe/layer.py b/python/sglang/srt/layers/ep_moe/layer.py index 0ee88cab6dd..24a32c13b57 100644 --- a/python/sglang/srt/layers/ep_moe/layer.py +++ b/python/sglang/srt/layers/ep_moe/layer.py @@ -21,8 +21,9 @@ per_tensor_dequantize, ) from vllm.model_executor.utils import set_weight_attrs -from vllm.utils import is_hip, print_warning_once +from vllm.utils import print_warning_once +from sglang.srt.layers.custom_op_util import register_custom_op from sglang.srt.layers.ep_moe.kernels import ( grouped_gemm_triton, post_reorder_triton_kernel, @@ -31,6 +32,7 @@ silu_and_mul_triton_kernel, ) from sglang.srt.layers.fused_moe_triton.fused_moe import fused_topk, grouped_topk +from sglang.srt.utils import is_hip logger = init_logger(__name__) @@ -164,9 +166,7 @@ def __init__( self.topk_group = topk_group if quant_config is None: - self.quant_method: Optional[QuantizeMethodBase] = UnquantizedEPMoEMethod( - quant_config - ) + self.quant_method: Optional[QuantizeMethodBase] = UnquantizedEPMoEMethod() self.enable_fp8 = False self.activation_scheme = None else: @@ -447,6 +447,7 @@ def _load_fp8_scale( param_data[expert_id] = loaded_weight +@register_custom_op("sglang_unquantized_ep_moe") class UnquantizedEPMoEMethod(FusedMoEMethodBase, CustomOp): def create_weights( self, diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 536e9ba175d..7a0d91c0709 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -130,7 +130,7 @@ class ServerArgs: disable_overlap_schedule: bool = False enable_mixed_chunk: bool = False enable_dp_attention: bool = False - enable_ep_moe: bool = (False,) + enable_ep_moe: bool = False enable_torch_compile: bool = False torch_compile_max_bs: int = 32 cuda_graph_max_bs: Optional[int] = None From 23716e8adf18e459af8f863ce79e9f87de24f7ae Mon Sep 17 00:00:00 2001 From: xiaobo Date: Mon, 2 Dec 2024 12:58:33 +0800 Subject: [PATCH 03/10] fix some problem --- .github/workflows/pr-test.yml | 6 +++ python/sglang/srt/layers/ep_moe/kernels.py | 8 ++- python/sglang/srt/layers/ep_moe/layer.py | 61 +++------------------- test/srt/test_moe_ep.py | 6 +-- 4 files changed, 20 insertions(+), 61 deletions(-) diff --git a/.github/workflows/pr-test.yml b/.github/workflows/pr-test.yml index 59f0006e128..49c6ec88327 100644 --- a/.github/workflows/pr-test.yml +++ b/.github/workflows/pr-test.yml @@ -105,6 +105,12 @@ jobs: cd test/srt python3 test_update_weights_from_distributed.py + - name: Evaluate MoE EP accuracy (TP=2) + timeout-minutes: 10 + run: | + cd test/srt + python3 test_moe_ep.py + performance-test-1-gpu-part-1: if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' runs-on: 1-gpu-runner diff --git a/python/sglang/srt/layers/ep_moe/kernels.py b/python/sglang/srt/layers/ep_moe/kernels.py index 8ec01bb340b..9dd0851057e 100644 --- a/python/sglang/srt/layers/ep_moe/kernels.py +++ b/python/sglang/srt/layers/ep_moe/kernels.py @@ -1,11 +1,11 @@ +import logging from typing import Any, Dict, Optional, Tuple import torch import triton import triton.language as tl -from vllm.logger import init_logger -logger = init_logger(__name__) +logger = logging.getLogger(__name__) @triton.jit @@ -41,11 +41,10 @@ def run_moe_ep_preproess(topk_ids: torch.Tensor, num_experts: int): seg_indptr = torch.zeros(num_experts + 1, device=topk_ids.device, dtype=torch.int64) src2dst = torch.empty(topk_ids.numel(), device=topk_ids.device, dtype=torch.int32) - # compute_seg_indptr_triton_kernel[(num_experts,)]( reorder_topk_ids, seg_indptr, topk_ids.numel() ) - # + BLOCK_SIZE = 512 grid = (triton.cdiv(topk_ids.numel(), BLOCK_SIZE),) compute_src2dst_triton_kernel[grid]( @@ -247,7 +246,6 @@ def grouped_gemm_triton_kernel( n_range_start = pid_n * BLOCK_SIZE_N n_range_end = min(n_range_start + BLOCK_SIZE_N, N) - # offs_am = tl.arange(0, BLOCK_SIZE_M) offs_bn = tl.arange(0, BLOCK_SIZE_N) diff --git a/python/sglang/srt/layers/ep_moe/layer.py b/python/sglang/srt/layers/ep_moe/layer.py index 24a32c13b57..aee54843eed 100644 --- a/python/sglang/srt/layers/ep_moe/layer.py +++ b/python/sglang/srt/layers/ep_moe/layer.py @@ -1,3 +1,4 @@ +import logging from typing import Callable, List, Optional, Tuple import torch @@ -7,21 +8,12 @@ get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, ) -from vllm.logger import init_logger from vllm.model_executor.custom_op import CustomOp -from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig, - QuantizeMethodBase, -) from vllm.model_executor.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - all_close_1d, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize, ) -from vllm.model_executor.utils import set_weight_attrs -from vllm.utils import print_warning_once from sglang.srt.layers.custom_op_util import register_custom_op from sglang.srt.layers.ep_moe.kernels import ( @@ -32,9 +24,14 @@ silu_and_mul_triton_kernel, ) from sglang.srt.layers.fused_moe_triton.fused_moe import fused_topk, grouped_topk -from sglang.srt.utils import is_hip +from sglang.srt.layers.fused_moe_triton.layer import FusedMoEMethodBase +from sglang.srt.layers.quantization.base_config import ( + QuantizationConfig, + QuantizeMethodBase, +) +from sglang.srt.utils import is_hip, set_weight_attrs -logger = init_logger(__name__) +logger = logging.getLogger(__name__) class GroupedGemmRunner(torch.nn.Module): @@ -82,25 +79,6 @@ def forward( seg_indptr=seg_indptr, weight_indices=weight_indices, ) - # else: - # assert weight_column_major == True - # for bs in range(batch_size): - # start_offset = seg_indptr[bs] - # end_offset = seg_indptr[bs + 1] - # if (end_offset - start_offset) <= 0: - # continue - # if enable_fp8: - # c[start_offset:end_offset], _ = torch._scaled_mm( - # a[start_offset:end_offset], - # b[bs].T, - # out_dtype=c.dtype, - # scale_a=scale_a[bs], - # scale_b=scale_b[bs], - # ) - # else: - # c[start_offset:end_offset] = torch.matmul( - # a[start_offset:end_offset], b[bs].T - # ) else: assert weight_column_major == True c = grouped_gemm_triton( @@ -664,35 +642,12 @@ def process_weights_after_loading(self, layer: Module) -> None: # MoE kernels require single activation scale and single weight # scale for w13 per expert. else: - # Fp8 moe kernels require a single activation scale. - # We take the max of all the scales in case they differ. if self.quant_config.activation_scheme == "static": if layer.w13_input_scale is None or layer.w2_input_scale is None: raise ValueError( "QuantConfig has static quantization, but found " "activation scales are None." ) - if not all_close_1d(layer.w13_input_scale) or not all_close_1d( - layer.w2_input_scale - ): - print_warning_once( - "Found input_scales that are not equal for " - "fp8 MoE layer. Using the maximum across experts " - "for each layer. " - ) - # layer.w13_input_scale = torch.nn.Parameter( - # layer.w13_input_scale.max().repeat(layer.num_experts_per_partition), - # requires_grad=False) - # layer.w2_input_scale = torch.nn.Parameter( - # layer.w2_input_scale.max().repeat(layer.num_experts_per_partition), - # requires_grad=False) - - layer.w13_input_scale = torch.nn.Parameter( - layer.w13_input_scale, requires_grad=False - ) - layer.w2_input_scale = torch.nn.Parameter( - layer.w2_input_scale, requires_grad=False - ) # If rocm, normalize the weights and scales to e4m3fnuz if is_hip(): diff --git a/test/srt/test_moe_ep.py b/test/srt/test_moe_ep.py index 3aeca8aee52..349d36225a4 100644 --- a/test/srt/test_moe_ep.py +++ b/test/srt/test_moe_ep.py @@ -1,7 +1,7 @@ import unittest from types import SimpleNamespace -from sglang.srt.utils import kill_child_process +from sglang.srt.utils import kill_process_tree from sglang.test.run_eval import run_eval from sglang.test.test_utils import ( DEFAULT_MLA_MODEL_NAME_FOR_TEST, @@ -25,7 +25,7 @@ def setUpClass(cls): @classmethod def tearDownClass(cls): - kill_child_process(cls.process.pid, include_self=True) + kill_process_tree(cls.process.pid) def test_mmlu(self): args = SimpleNamespace( @@ -73,7 +73,7 @@ def setUpClass(cls): @classmethod def tearDownClass(cls): - kill_child_process(cls.process.pid, include_self=True) + kill_process_tree(cls.process.pid) def test_mmlu(self): args = SimpleNamespace( From c508ccbb477288d8b8c1bf50939ed243c0e4a368 Mon Sep 17 00:00:00 2001 From: xiaobo Date: Mon, 2 Dec 2024 13:14:37 +0800 Subject: [PATCH 04/10] BLOCK_SIZE as args --- python/sglang/srt/layers/ep_moe/kernels.py | 8 +++----- python/sglang/srt/layers/ep_moe/layer.py | 3 +++ 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/layers/ep_moe/kernels.py b/python/sglang/srt/layers/ep_moe/kernels.py index 9dd0851057e..43b5b602864 100644 --- a/python/sglang/srt/layers/ep_moe/kernels.py +++ b/python/sglang/srt/layers/ep_moe/kernels.py @@ -64,11 +64,10 @@ def pre_reorder_triton_kernel( end_expert_id, topk, hidden_size, + BLOCK_SIZE: tl.constexpr, ): OutDtype = gateup_input_ptr.dtype.element_ty - BLOCK_SIZE: tl.constexpr = 512 - src_idx = tl.program_id(0) src2dst_ptr = src2dst_ptr + src_idx * topk topk_ids_ptr = topk_ids_ptr + src_idx * topk @@ -101,11 +100,11 @@ def silu_and_mul_triton_kernel( scales, start_expert_id, end_expert_id, + BLOCK_SIZE: tl.constexpr, ): InDtype = gateup_output.dtype.element_ty OutDtype = down_input.dtype.element_ty - BLOCK_SIZE: tl.constexpr = 512 half_hidden_size = hidden_size // 2 pid = tl.program_id(0) @@ -149,11 +148,10 @@ def post_reorder_triton_kernel( end_expert_id, topk, hidden_size, + BLOCK_SIZE: tl.constexpr, ): InDtype = down_output_ptr.dtype.element_ty - BLOCK_SIZE: tl.constexpr = 512 - src_idx = tl.program_id(0) src2dst_ptr = src2dst_ptr + src_idx * topk topk_ids_ptr = topk_ids_ptr + src_idx * topk diff --git a/python/sglang/srt/layers/ep_moe/layer.py b/python/sglang/srt/layers/ep_moe/layer.py index aee54843eed..eca34a60749 100644 --- a/python/sglang/srt/layers/ep_moe/layer.py +++ b/python/sglang/srt/layers/ep_moe/layer.py @@ -211,6 +211,7 @@ def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): self.end_expert_id, self.top_k, hidden_states.shape[1], + BLOCK_SIZE=512, ) seg_indptr_cur_rank = seg_indptr[self.start_expert_id : self.end_expert_id + 2] @@ -261,6 +262,7 @@ def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): self.w2_input_scale, self.start_expert_id, self.end_expert_id, + BLOCK_SIZE=512, ) # GroupGemm-1 @@ -295,6 +297,7 @@ def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): self.end_expert_id, self.top_k, hidden_states.size(1), + BLOCK_SIZE=512, ) return output From 5ac339984aa3b6cdcfe7344a4e0c530095fec1c3 Mon Sep 17 00:00:00 2001 From: xiaobo Date: Mon, 2 Dec 2024 13:53:23 +0800 Subject: [PATCH 05/10] clean code --- python/sglang/srt/layers/ep_moe/layer.py | 59 ------------------------ 1 file changed, 59 deletions(-) diff --git a/python/sglang/srt/layers/ep_moe/layer.py b/python/sglang/srt/layers/ep_moe/layer.py index eca34a60749..8be61efd822 100644 --- a/python/sglang/srt/layers/ep_moe/layer.py +++ b/python/sglang/srt/layers/ep_moe/layer.py @@ -10,10 +10,6 @@ ) from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod -from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - normalize_e4m3fn_to_e4m3fnuz, - per_tensor_dequantize, -) from sglang.srt.layers.custom_op_util import register_custom_op from sglang.srt.layers.ep_moe.kernels import ( @@ -562,7 +558,6 @@ def create_weights( # WEIGHT_SCALES # Allocate 2 scales for w1 and w3 respectively. - # They will be combined to a single scale after weight loading. w13_weight_scale = torch.nn.Parameter( torch.ones(num_experts_per_partition, 2, dtype=torch.float32), requires_grad=False, @@ -619,8 +614,6 @@ def process_weights_after_loading(self, layer: Module) -> None: w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype) w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype) - # Re-initialize w13_scale because we directly quantize - # merged w13 weights and generate a single scaling factor. layer.w13_weight_scale = torch.nn.Parameter( torch.ones( layer.num_experts_per_partition, @@ -651,58 +644,6 @@ def process_weights_after_loading(self, layer: Module) -> None: "QuantConfig has static quantization, but found " "activation scales are None." ) - - # If rocm, normalize the weights and scales to e4m3fnuz - if is_hip(): - # Normalize the weights and scales - w13_weight, w13_weight_scale, w13_input_scale = ( - normalize_e4m3fn_to_e4m3fnuz( - layer.w13_weight, layer.w13_weight_scale, layer.w13_input_scale - ) - ) - w2_weight, w2_weight_scale, w2_input_scale = ( - normalize_e4m3fn_to_e4m3fnuz( - layer.w2_weight, layer.w2_weight_scale, layer.w2_input_scale - ) - ) - # Reset the parameter - layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False) - layer.w13_weight_scale = torch.nn.Parameter( - w13_weight_scale, requires_grad=False - ) - if w13_input_scale is not None: - layer.w13_input_scale = torch.nn.Parameter( - w13_input_scale, requires_grad=False - ) - layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False) - layer.w2_weight_scale = torch.nn.Parameter( - w2_weight_scale, requires_grad=False - ) - if w2_input_scale is not None: - layer.w2_input_scale = torch.nn.Parameter( - w2_input_scale, requires_grad=False - ) - - # Fp8 moe kernel needs single weight scale for w13 per expert. - # We take the max then dequant and requant each expert. - assert layer.w13_weight_scale is not None - shard_size = layer.intermediate_size - max_w13_scales = layer.w13_weight_scale.max(dim=1).values - for expert_id in range(layer.num_experts_per_partition): - start = 0 - for shard_id in range(2): - dq_weight = per_tensor_dequantize( - layer.w13_weight[expert_id][start : start + shard_size, :], - layer.w13_weight_scale[expert_id][shard_id], - ) - layer.w13_weight[expert_id][start : start + shard_size, :], _ = ( - ops.scaled_fp8_quant(dq_weight, max_w13_scales[expert_id]) - ) - start += shard_size - - layer.w13_weight_scale = torch.nn.Parameter( - max_w13_scales, requires_grad=False - ) return def apply( From 41eafe5cd097875aa05447af28ae9f35f954a9a2 Mon Sep 17 00:00:00 2001 From: xiaobo Date: Mon, 2 Dec 2024 14:39:50 +0800 Subject: [PATCH 06/10] use use_fp8_w8a8 --- python/sglang/srt/layers/ep_moe/kernels.py | 10 ++++++---- python/sglang/srt/layers/ep_moe/layer.py | 16 ++++++++-------- 2 files changed, 14 insertions(+), 12 deletions(-) diff --git a/python/sglang/srt/layers/ep_moe/kernels.py b/python/sglang/srt/layers/ep_moe/kernels.py index 43b5b602864..50d1f315440 100644 --- a/python/sglang/srt/layers/ep_moe/kernels.py +++ b/python/sglang/srt/layers/ep_moe/kernels.py @@ -1,5 +1,5 @@ import logging -from typing import Any, Dict, Optional, Tuple +from typing import Optional import torch import triton @@ -218,6 +218,7 @@ def grouped_gemm_triton_kernel( seg_indptr, weight_indices, m_num_tiles_indptr, + use_fp8_w8a8, scale_a, scale_b, BLOCK_SIZE_M: tl.constexpr, @@ -271,7 +272,7 @@ def grouped_gemm_triton_kernel( a_ptr += BLOCK_SIZE_K b_ptr += BLOCK_SIZE_K - if scale_a is not None and scale_b is not None: + if use_fp8_w8a8: scale_a_value = tl.load(scale_a + expert_id) scale_b_value = tl.load(scale_b + expert_id) accumulator *= scale_a_value * scale_b_value @@ -303,12 +304,12 @@ def grouped_gemm_triton( weight_column_major: bool, seg_indptr: Optional[torch.Tensor] = None, weight_indices: Optional[torch.Tensor] = None, - enable_fp8: bool = False, + use_fp8_w8a8: bool = False, scale_a: torch.Tensor = None, scale_b: torch.Tensor = None, ): assert weight_column_major == True # TODO: more - if enable_fp8: + if use_fp8_w8a8: assert scale_a is not None and scale_b is not None BLOCK_SIZE_M = 128 @@ -333,6 +334,7 @@ def grouped_gemm_triton( seg_indptr, weight_indices, m_num_tiles_indptr, + use_fp8_w8a8, scale_a, scale_b, BLOCK_SIZE_M, diff --git a/python/sglang/srt/layers/ep_moe/layer.py b/python/sglang/srt/layers/ep_moe/layer.py index 8be61efd822..eca119845a7 100644 --- a/python/sglang/srt/layers/ep_moe/layer.py +++ b/python/sglang/srt/layers/ep_moe/layer.py @@ -59,7 +59,7 @@ def forward( weight_column_major: bool, seg_indptr: Optional[torch.Tensor] = None, weight_indices: Optional[torch.Tensor] = None, - enable_fp8: bool = False, + use_fp8_w8a8: bool = False, scale_a: torch.Tensor = None, scale_b: torch.Tensor = None, ): @@ -85,7 +85,7 @@ def forward( weight_column_major, seg_indptr, weight_indices, - enable_fp8, + use_fp8_w8a8, scale_a, scale_b, ) @@ -141,13 +141,13 @@ def __init__( if quant_config is None: self.quant_method: Optional[QuantizeMethodBase] = UnquantizedEPMoEMethod() - self.enable_fp8 = False + self.use_fp8_w8a8 = False self.activation_scheme = None else: self.quant_method: Optional[QuantizeMethodBase] = Fp8EPMoEMethod( quant_config ) - self.enable_fp8 = True + self.use_fp8_w8a8 = True self.fp8_dtype = torch.float8_e4m3fn self.activation_scheme = quant_config.activation_scheme @@ -186,7 +186,7 @@ def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): gateup_input = torch.empty( (int(hidden_states.shape[0] * self.top_k), hidden_states.shape[1]), device=hidden_states.device, - dtype=self.fp8_dtype if self.enable_fp8 else hidden_states.dtype, + dtype=self.fp8_dtype if self.use_fp8_w8a8 else hidden_states.dtype, ) if self.activation_scheme == "dynamic": max_value = ( @@ -232,7 +232,7 @@ def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): weight_column_major=True, seg_indptr=seg_indptr_cur_rank, weight_indices=weight_indices_cur_rank, - enable_fp8=self.enable_fp8, + use_fp8_w8a8=self.use_fp8_w8a8, scale_a=self.w13_input_scale, scale_b=self.w13_weight_scale, ) @@ -242,7 +242,7 @@ def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): gateup_output.shape[0], gateup_output.shape[1] // 2, device=gateup_output.device, - dtype=self.fp8_dtype if self.enable_fp8 else hidden_states.dtype, + dtype=self.fp8_dtype if self.use_fp8_w8a8 else hidden_states.dtype, ) if self.w2_input_scale is None: self.w2_input_scale = torch.ones( @@ -276,7 +276,7 @@ def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): weight_column_major=True, seg_indptr=seg_indptr_cur_rank, weight_indices=weight_indices_cur_rank, - enable_fp8=self.enable_fp8, + use_fp8_w8a8=self.use_fp8_w8a8, scale_a=self.w2_input_scale, scale_b=self.w2_weight_scale, ) From 6db8cf6296e3dc49f0602ffa9dde4bb40c1040c0 Mon Sep 17 00:00:00 2001 From: xiaobo Date: Mon, 2 Dec 2024 15:08:54 +0800 Subject: [PATCH 07/10] add ep-size args --- python/sglang/srt/server_args.py | 17 +++++++++++++++++ test/srt/test_moe_ep.py | 13 +++++++++++-- 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 7a0d91c0709..8719d919068 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -93,6 +93,8 @@ class ServerArgs: # Data parallelism dp_size: int = 1 load_balance_method: str = "round_robin" + # Expert parallelism + ep_size: int = 1 # Multi-node distributed serving dist_init_addr: Optional[str] = None @@ -217,6 +219,12 @@ def __post_init__(self): "Data parallel size is adjusted to be the same as tensor parallel size. " "Overlap scheduler is disabled." ) + # Expert parallelism + if self.enable_ep_moe: + self.ep_size = self.tp_size + logger.info( + f"EP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]." + ) # GGUF if ( @@ -527,6 +535,14 @@ def add_cli_args(parser: argparse.ArgumentParser): "shortest_queue", ], ) + # Expert parallelism + parser.add_argument( + "--expert-parallel-size", + "--ep-size", + type=int, + default=ServerArgs.ep_size, + help="The expert parallelism size.", + ) # Multi-node distributed serving parser.add_argument( @@ -766,6 +782,7 @@ def add_cli_args(parser: argparse.ArgumentParser): def from_cli_args(cls, args: argparse.Namespace): args.tp_size = args.tensor_parallel_size args.dp_size = args.data_parallel_size + args.ep_size = args.expert_parallel_size attrs = [attr.name for attr in dataclasses.fields(cls)] return cls(**{attr: getattr(args, attr) for attr in attrs}) diff --git a/test/srt/test_moe_ep.py b/test/srt/test_moe_ep.py index 349d36225a4..4d9fd435edb 100644 --- a/test/srt/test_moe_ep.py +++ b/test/srt/test_moe_ep.py @@ -20,7 +20,14 @@ def setUpClass(cls): cls.model, cls.base_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=["--tp", "2", "--trust-remote-code", "--enable-ep-moe"], + other_args=[ + "--trust-remote-code", + "--tp", + "2", + "--ep-size", + "2", + "--enable-ep-moe", + ], ) @classmethod @@ -62,9 +69,11 @@ def setUpClass(cls): cls.base_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, other_args=[ + "--trust-remote-code", "--tp", "2", - "--trust-remote-code", + "--ep-size", + "2", "--enable-ep-moe", "--quantization", "fp8", From f1ccb91a12f7e5731e04f74f6b1de8aec19681eb Mon Sep 17 00:00:00 2001 From: xiaobo Date: Tue, 3 Dec 2024 16:44:11 +0800 Subject: [PATCH 08/10] update --- python/sglang/srt/layers/ep_moe/kernels.py | 28 ++++++++++++---------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/python/sglang/srt/layers/ep_moe/kernels.py b/python/sglang/srt/layers/ep_moe/kernels.py index 50d1f315440..e0486891aa7 100644 --- a/python/sglang/srt/layers/ep_moe/kernels.py +++ b/python/sglang/srt/layers/ep_moe/kernels.py @@ -221,12 +221,12 @@ def grouped_gemm_triton_kernel( use_fp8_w8a8, scale_a, scale_b, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, a_stride_0: tl.constexpr, b_stride_0: tl.constexpr, b_stride_1: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, ): c_dtype = c.dtype.element_ty @@ -312,19 +312,23 @@ def grouped_gemm_triton( if use_fp8_w8a8: assert scale_a is not None and scale_b is not None - BLOCK_SIZE_M = 128 - BLOCK_SIZE_N = 128 - BLOCK_SIZE_K = 128 + config = { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + } m_num_tiles_indptr = torch.zeros(batch_size + 1, device=a.device, dtype=torch.int64) compute_m_num_tiles_indptr[(1,)]( - m_num_tiles_indptr, seg_indptr, batch_size, BLOCK_SIZE_M + m_num_tiles_indptr, seg_indptr, batch_size, config["BLOCK_SIZE_M"] ) - num_m_tiles = (a.size(0) + BLOCK_SIZE_M - 1) // BLOCK_SIZE_M + batch_size - num_n_tiles = (b.size(1) + BLOCK_SIZE_N - 1) // BLOCK_SIZE_N + grid = lambda META: ( + triton.cdiv(a.size(0), META["BLOCK_SIZE_M"]) + batch_size, + triton.cdiv(b.size(1), META["BLOCK_SIZE_N"]), + ) - grouped_gemm_triton_kernel[(num_m_tiles, num_n_tiles)]( + grouped_gemm_triton_kernel[grid]( a, b, c, @@ -337,11 +341,9 @@ def grouped_gemm_triton( use_fp8_w8a8, scale_a, scale_b, - BLOCK_SIZE_M, - BLOCK_SIZE_N, - BLOCK_SIZE_K, a.stride(0), b.stride(0), b.stride(1), + **config, ) return c From 71bda0d68c3900e75a4e1b9f5ef33474a5ab3223 Mon Sep 17 00:00:00 2001 From: xiaobo Date: Thu, 5 Dec 2024 10:18:16 +0800 Subject: [PATCH 09/10] mixtral support moe-ep --- python/sglang/srt/models/mixtral.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/models/mixtral.py b/python/sglang/srt/models/mixtral.py index f1ae1f57a3d..65e4587ecc1 100644 --- a/python/sglang/srt/models/mixtral.py +++ b/python/sglang/srt/models/mixtral.py @@ -21,9 +21,13 @@ import torch from torch import nn from transformers import MixtralConfig -from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.distributed import ( + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce, +) from vllm.model_executor.layers.rotary_embedding import get_rope +from sglang.srt.layers.ep_moe.layer import EPMoE from sglang.srt.layers.fused_moe_triton import FusedMoE from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.linear import ( @@ -63,6 +67,7 @@ def __init__( prefix: str = "", ): super().__init__() + self.tp_size = get_tensor_model_parallel_world_size() self.hidden_size = hidden_size # Gate always runs at half / full precision for now. @@ -74,14 +79,13 @@ def __init__( quant_config=None, prefix=f"{prefix}.gate", ) - - self.experts = FusedMoE( + MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE + self.experts = MoEImpl( num_experts=num_experts, top_k=top_k, hidden_size=hidden_size, intermediate_size=intermediate_size, params_dtype=params_dtype, - reduce_results=True, renormalize=True, quant_config=quant_config, tp_size=tp_size, @@ -95,6 +99,8 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) final_hidden_states = self.experts(hidden_states, router_logits) + if self.tp_size > 1: + final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) return final_hidden_states.view(orig_shape) @@ -319,7 +325,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Params for weights, fp8 weight scales, fp8 activation scales # (param_name, weight_name, expert_id, shard_id) - expert_params_mapping = FusedMoE.make_expert_params_mapping( + MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE + expert_params_mapping = MoEImpl.make_expert_params_mapping( ckpt_gate_proj_name="w1", ckpt_down_proj_name="w2", ckpt_up_proj_name="w3", From 420939b6b0d365c9aa3b9a41959cdcceb8eb76d8 Mon Sep 17 00:00:00 2001 From: xiaobo Date: Thu, 5 Dec 2024 17:47:30 +0800 Subject: [PATCH 10/10] fix bug --- python/sglang/srt/models/mixtral.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/sglang/srt/models/mixtral.py b/python/sglang/srt/models/mixtral.py index 65e4587ecc1..f3fad226091 100644 --- a/python/sglang/srt/models/mixtral.py +++ b/python/sglang/srt/models/mixtral.py @@ -42,6 +42,7 @@ ParallelLMHead, VocabParallelEmbedding, ) +from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader