Skip to content

Commit

Permalink
Refactor MoE (#2575)
Browse files Browse the repository at this point in the history
Co-authored-by: zhyncs <[email protected]>
  • Loading branch information
HandH1998 and zhyncs authored Dec 25, 2024
1 parent 8a56b43 commit 53aed98
Show file tree
Hide file tree
Showing 9 changed files with 1,012 additions and 49 deletions.
5 changes: 4 additions & 1 deletion python/sglang/srt/configs/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,10 @@ def __init__(
)

# FIXME: temporary special judge for MLA architecture
if "DeepseekV2ForCausalLM" in self.hf_config.architectures:
if (
"DeepseekV2ForCausalLM" in self.hf_config.architectures
or "DeepseekV3ForCausalLM" in self.hf_config.architectures
):
self.head_dim = 256
self.attention_arch = AttentionArch.MLA
self.kv_lora_rank = self.hf_config.kv_lora_rank
Expand Down
22 changes: 20 additions & 2 deletions python/sglang/srt/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
QuantizationConfig,
QuantizeMethodBase,
)
from sglang.srt.layers.quantization.fp8_utils import BlockQuantScaleParameter
from sglang.srt.utils import set_weight_attrs

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -628,8 +629,19 @@ def weight_loader_v2(
assert loaded_shard_id < len(self.output_sizes)

tp_size = get_tensor_model_parallel_world_size()
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
shard_size = self.output_sizes[loaded_shard_id] // tp_size

if isinstance(param, BlockQuantScaleParameter):
weight_block_size = self.quant_method.quant_config.weight_block_size
block_n, _ = weight_block_size[0], weight_block_size[1]
shard_offset = (
(sum(self.output_sizes[:loaded_shard_id]) + block_n - 1) // block_n
) // tp_size
shard_size = (
(self.output_sizes[loaded_shard_id] + block_n - 1) // block_n // tp_size
)
else:
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
shard_size = self.output_sizes[loaded_shard_id] // tp_size

param.load_merged_column_weight(
loaded_weight=loaded_weight,
Expand Down Expand Up @@ -795,6 +807,12 @@ def weight_loader_v2(
shard_offset = self._get_shard_offset_mapping(loaded_shard_id)
shard_size = self._get_shard_size_mapping(loaded_shard_id)

if isinstance(param, BlockQuantScaleParameter):
weight_block_size = self.quant_method.quant_config.weight_block_size
block_n, _ = weight_block_size[0], weight_block_size[1]
shard_offset = (shard_offset + block_n - 1) // block_n
shard_size = (shard_size + block_n - 1) // block_n

param.load_qkv_weight(
loaded_weight=loaded_weight,
num_heads=self.num_kv_head_replicas,
Expand Down
86 changes: 78 additions & 8 deletions python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@
import json
import logging
import os
from typing import Any, Callable, Dict, Optional, Tuple
from typing import Any, Callable, Dict, List, Optional, Tuple

import torch
import triton
import triton.language as tl
from vllm import _custom_ops as ops

from sglang.srt.layers.moe.topk import select_experts
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
from sglang.srt.utils import direct_register_custom_op, get_device_name

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -48,8 +49,14 @@ def fused_moe_kernel(
stride_bn,
stride_cm,
stride_cn,
stride_asm,
stride_ask,
stride_bse,
stride_bsk,
stride_bsn,
# Block size for block-wise quantization
group_n: tl.constexpr,
group_k: tl.constexpr,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
Expand Down Expand Up @@ -133,8 +140,15 @@ def fused_moe_kernel(
b_scale = tl.load(b_scale_ptrs)

if use_fp8_w8a8:
a_scale = tl.load(a_scale_ptr)
b_scale = tl.load(b_scale_ptr + off_experts)
if group_k > 0 and group_n > 0:
a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
offs_bsn = offs_bn // group_n
b_scale_ptrs = (
b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn
)
else:
a_scale = tl.load(a_scale_ptr)
b_scale = tl.load(b_scale_ptr + off_experts)

# -----------------------------------------------------------
# Iterate to compute a block of the C matrix.
Expand Down Expand Up @@ -165,7 +179,17 @@ def fused_moe_kernel(
if use_int8_w8a16:
accumulator = tl.dot(a, b.to(compute_type), acc=accumulator)
elif use_fp8_w8a8:
accumulator = tl.dot(a, b, acc=accumulator)
if group_k > 0 and group_n > 0:
k_start = k * BLOCK_SIZE_K
offs_ks = k_start // group_k
a_scale = tl.load(
a_scale_ptrs + offs_ks * stride_ask, mask=token_mask, other=0.0
)
b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk)

accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :]
else:
accumulator = tl.dot(a, b, acc=accumulator)
else:
accumulator += tl.dot(a, b)
# Advance the ptrs to the next K block.
Expand All @@ -178,7 +202,10 @@ def fused_moe_kernel(
if use_int8_w8a16:
accumulator = (accumulator * b_scale).to(compute_type)
elif use_fp8_w8a8:
accumulator = (accumulator * a_scale * b_scale).to(compute_type)
if group_k > 0 and group_n > 0:
accumulator = accumulator.to(compute_type)
else:
accumulator = (accumulator * a_scale * b_scale).to(compute_type)
else:
accumulator = accumulator.to(compute_type)
# -----------------------------------------------------------
Expand Down Expand Up @@ -262,15 +289,24 @@ def invoke_fused_moe_kernel(
compute_type: tl.dtype,
use_fp8_w8a8: bool,
use_int8_w8a16: bool,
block_shape: Optional[List[int]] = None,
) -> None:
assert topk_weights.stride(1) == 1
assert sorted_token_ids.stride(0) == 1

padded_size = 0
if use_fp8_w8a8:
padded_size = padding_size
A, A_scale = ops.scaled_fp8_quant(A, A_scale)
assert B_scale is not None
if block_shape is None:
A, A_scale = ops.scaled_fp8_quant(A, A_scale)
else:
assert len(block_shape) == 2
block_n, block_k = block_shape[0], block_shape[1]
A, A_scale = per_token_group_quant_fp8(A, block_k)
assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1]
assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2]
assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1]
elif use_int8_w8a16:
assert B_scale is not None
else:
Expand Down Expand Up @@ -309,8 +345,13 @@ def invoke_fused_moe_kernel(
B.stride(1),
C.stride(1),
C.stride(2),
B_scale.stride(0) if B_scale is not None and use_int8_w8a16 else 0,
B_scale.stride(1) if B_scale is not None and use_int8_w8a16 else 0,
A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0,
A_scale.stride(1) if A_scale is not None and A_scale.ndim == 2 else 0,
B_scale.stride(0) if B_scale is not None and B_scale.ndim >= 2 else 0,
B_scale.stride(2) if B_scale is not None and B_scale.ndim == 3 else 0,
B_scale.stride(1) if B_scale is not None and B_scale.ndim >= 2 else 0,
0 if block_shape is None else block_shape[0],
0 if block_shape is None else block_shape[1],
MUL_ROUTED_WEIGHT=mul_routed_weight,
top_k=top_k,
compute_type=compute_type,
Expand Down Expand Up @@ -415,6 +456,7 @@ def try_get_optimal_moe_config(
dtype: Optional[str],
M: int,
is_marlin: bool = False,
block_shape: Optional[List[int]] = None,
):
from sglang.srt.layers.moe.fused_moe_triton import get_config

Expand All @@ -433,6 +475,13 @@ def try_get_optimal_moe_config(
else:
# Else use the default config
config = get_default_config(M, E, N, w1_shape[2], top_k, dtype, is_marlin)
# TODO(HandH1998): Optimize the configs of block-wise quant.
# NOTE(HandH1998): For block-wise quant,
# BLOCK_K must be divisable by block_shape[1]
# BLOCK_N and BLOCK_M has no requirements
if block_shape is not None:
config["BLOCK_SIZE_N"] = block_shape[0]
config["BLOCK_SIZE_K"] = block_shape[1]
return config


Expand Down Expand Up @@ -464,6 +513,7 @@ def inplace_fused_experts(
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None,
) -> None:
fused_experts_impl(
hidden_states,
Expand All @@ -478,6 +528,7 @@ def inplace_fused_experts(
w2_scale,
a1_scale,
a2_scale,
block_shape,
)


Expand All @@ -493,6 +544,7 @@ def inplace_fused_experts_fake(
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None,
) -> None:
pass

Expand All @@ -517,6 +569,7 @@ def outplace_fused_experts(
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None,
) -> torch.Tensor:
return fused_experts_impl(
hidden_states,
Expand All @@ -531,6 +584,7 @@ def outplace_fused_experts(
w2_scale,
a1_scale,
a2_scale,
block_shape,
)


Expand All @@ -546,6 +600,7 @@ def outplace_fused_experts_fake(
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None,
) -> torch.Tensor:
return torch.empty_like(hidden_states)

Expand All @@ -571,6 +626,7 @@ def fused_experts(
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None,
):
if inplace:
torch.ops.sglang.inplace_fused_experts(
Expand All @@ -585,6 +641,7 @@ def fused_experts(
w2_scale,
a1_scale,
a2_scale,
block_shape,
)
return hidden_states
else:
Expand All @@ -600,6 +657,7 @@ def fused_experts(
w2_scale,
a1_scale,
a2_scale,
block_shape,
)


Expand All @@ -616,6 +674,7 @@ def fused_experts_impl(
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None,
):
padded_size = padding_size
if not use_fp8_w8a8:
Expand Down Expand Up @@ -647,6 +706,7 @@ def fused_experts_impl(
(w2.shape[0], w2.shape[1], w2.shape[2] - padded_size),
topk_ids.shape[1],
config_dtype,
block_shape=block_shape,
)

config = get_config_func(M)
Expand Down Expand Up @@ -719,6 +779,7 @@ def fused_experts_impl(
compute_type=compute_type,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a16=use_int8_w8a16,
block_shape=block_shape,
)

ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
Expand All @@ -740,6 +801,7 @@ def fused_experts_impl(
compute_type=compute_type,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a16=use_int8_w8a16,
block_shape=block_shape,
)

torch.sum(
Expand Down Expand Up @@ -768,6 +830,7 @@ def fused_moe(
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None,
) -> torch.Tensor:
"""
This function computes a Mixture of Experts (MoE) layer using two sets of
Expand Down Expand Up @@ -795,6 +858,12 @@ def fused_moe(
w1.
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
w2.
- a1_scale (Optional[torch.Tensor]): Optional scale to be used for
a1.
- a2_scale (Optional[torch.Tensor]): Optional scale to be used for
a2.
- block_shape: (Optional[List[int]]): Optional block size for block-wise
quantization.
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
Expand Down Expand Up @@ -826,4 +895,5 @@ def fused_moe(
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=block_shape,
)
7 changes: 6 additions & 1 deletion python/sglang/srt/layers/moe/fused_moe_triton/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class FusedMoeWeightScaleSupported(Enum):
TENSOR = "tensor"
CHANNEL = "channel"
GROUP = "group"
BLOCK = "block"


class FusedMoEMethodBase(QuantizeMethodBase):
Expand Down Expand Up @@ -214,6 +215,7 @@ def __init__(
)
self.top_k = top_k
self.num_experts = num_experts
assert intermediate_size % self.tp_size == 0
self.intermediate_size_per_partition = intermediate_size // self.tp_size
self.reduce_results = reduce_results
self.renormalize = renormalize
Expand Down Expand Up @@ -470,7 +472,10 @@ def weight_loader(
expert_data=expert_data,
tp_rank=tp_rank,
)
elif quant_method == FusedMoeWeightScaleSupported.GROUP.value:
elif quant_method in [
FusedMoeWeightScaleSupported.GROUP.value,
FusedMoeWeightScaleSupported.BLOCK.value,
]:
self._load_model_weight_or_group_weight_scale(
shard_id=shard_id,
shard_dim=shard_dim,
Expand Down
Loading

0 comments on commit 53aed98

Please sign in to comment.