Skip to content

Commit

Permalink
Fp8 MoE optimizations on AMD (#2388)
Browse files Browse the repository at this point in the history
  • Loading branch information
HaiShaw authored Dec 7, 2024
1 parent aaac33f commit 95f93f4
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 22 deletions.
85 changes: 64 additions & 21 deletions python/sglang/srt/layers/fused_moe_triton/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from sglang.srt.utils import direct_register_custom_op, get_device_name

logger = logging.getLogger(__name__)
padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0


@triton.jit
Expand Down Expand Up @@ -58,6 +59,7 @@ def fused_moe_kernel(
compute_type: tl.constexpr,
use_fp8_w8a8: tl.constexpr,
use_int8_w8a16: tl.constexpr,
even_Ks: tl.constexpr,
):
"""
Implements the fused computation for a Mixture of Experts (MOE) using
Expand Down Expand Up @@ -143,12 +145,21 @@ def fused_moe_kernel(
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
# Load the next block of A and B, generate a mask by checking the
# K dimension.
a = tl.load(
a_ptrs,
mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
other=0.0,
)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
if even_Ks:
a = tl.load(
a_ptrs,
mask=token_mask[:, None],
other=0.0,
)
b = tl.load(b_ptrs)
else:
a = tl.load(
a_ptrs,
mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
other=0.0,
)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)

# We accumulate along the K dimension.
if use_int8_w8a16:
accumulator = tl.dot(a, b.to(compute_type), acc=accumulator)
Expand Down Expand Up @@ -254,7 +265,9 @@ def invoke_fused_moe_kernel(
assert topk_weights.stride(1) == 1
assert sorted_token_ids.stride(0) == 1

padded_size = 0
if use_fp8_w8a8:
padded_size = padding_size
A, A_scale = ops.scaled_fp8_quant(A, A_scale)
assert B_scale is not None
elif use_int8_w8a16:
Expand All @@ -268,6 +281,12 @@ def invoke_fused_moe_kernel(
* triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]),
)

K = B.shape[2] - padded_size
if K % config["BLOCK_SIZE_K"] == 0:
even_Ks = True
else:
even_Ks = False

fused_moe_kernel[grid](
A,
B,
Expand All @@ -279,7 +298,7 @@ def invoke_fused_moe_kernel(
expert_ids,
num_tokens_post_padded,
B.shape[1],
B.shape[2],
B.shape[2] - padded_size,
sorted_token_ids.shape[0],
topk_ids.numel(),
A.stride(0),
Expand All @@ -296,6 +315,7 @@ def invoke_fused_moe_kernel(
compute_type=compute_type,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a16=use_int8_w8a16,
even_Ks=even_Ks,
**config,
)

Expand Down Expand Up @@ -351,20 +371,39 @@ def get_default_config(
dtype: Optional[str],
is_marlin: bool,
) -> Dict[str, int]:
config = {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
}
# A heuristic: fused marlin works faster with this config for small M
if M <= E or (is_marlin and M <= 32):
if dtype == "fp8_w8a8":
config = {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 4,
}
if M <= E:
config = {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4,
}
else:
config = {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
}
# A heuristic: fused marlin works faster with this config for small M
if M <= E or (is_marlin and M <= 32):
config = {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
}
return config


Expand Down Expand Up @@ -645,8 +684,12 @@ def fused_experts_impl(
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
):
padded_size = padding_size
if not use_fp8_w8a8:
padded_size = 0

# Check constraints.
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
assert hidden_states.shape[1] == w1.shape[2] - padded_size, "Hidden size mismatch"
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
Expand All @@ -668,7 +711,7 @@ def fused_experts_impl(
get_config_func = functools.partial(
try_get_optimal_moe_config,
w1.shape,
w2.shape,
(w2.shape[0], w2.shape[1], w2.shape[2] - padded_size),
topk_ids.shape[1],
config_dtype,
)
Expand Down
34 changes: 33 additions & 1 deletion python/sglang/srt/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/quantization/fp8.py

import logging
import os
from typing import Any, Callable, Dict, List, Optional

import torch
import torch.nn.functional as F
from torch.nn import Module
from torch.nn.parameter import Parameter
from vllm import _custom_ops as ops
Expand All @@ -24,6 +26,7 @@
)
from vllm.model_executor.parameter import ModelWeightParameter, PerTensorScaleParameter

from sglang.srt.layers.fused_moe_triton.fused_moe import padding_size
from sglang.srt.layers.linear import LinearMethodBase, UnquantizedLinearMethod
from sglang.srt.layers.quantization.base_config import (
QuantizationConfig,
Expand Down Expand Up @@ -420,7 +423,7 @@ def create_weights(

def process_weights_after_loading(self, layer: Module) -> None:

# If checkpoint is fp16, quantize in place.
# If checkpoint is fp16 or bfloat16, quantize in place.
if not self.quant_config.is_checkpoint_fp8_serialized:
# If ROCm, use float8_e4m3fnuz instead (MI300x HW)
fp8_dtype = torch.float8_e4m3fnuz if is_hip() else torch.float8_e4m3fn
Expand All @@ -444,6 +447,19 @@ def process_weights_after_loading(self, layer: Module) -> None:
)
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)

# If ROCm, apply weight padding (min. Mem channel contention) only if set
if is_hip() and bool(int(os.getenv("MOE_PADDING", "0"))):
layer.w13_weight = torch.nn.Parameter(
F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0),
requires_grad=False,
)
torch.cuda.empty_cache()
layer.w2_weight = torch.nn.Parameter(
F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0),
requires_grad=False,
)
torch.cuda.empty_cache()
return

# If checkpoint is fp8, we need to handle that the
Expand Down Expand Up @@ -472,6 +488,7 @@ def process_weights_after_loading(self, layer: Module) -> None:
layer.w2_input_scale = torch.nn.Parameter(
layer.w2_input_scale.max(), requires_grad=False
)

# If ROCm, normalize the weights and scales to e4m3fnuz
if is_hip():
# Normalize the weights and scales
Expand Down Expand Up @@ -523,6 +540,19 @@ def process_weights_after_loading(self, layer: Module) -> None:
layer.w13_weight_scale = torch.nn.Parameter(
max_w13_scales, requires_grad=False
)

# If ROCm, apply weight padding (min. Mem channel contention) only if set
if is_hip() and bool(int(os.getenv("MOE_PADDING", "0"))):
layer.w13_weight = torch.nn.Parameter(
F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0),
requires_grad=False,
)
torch.cuda.empty_cache()
layer.w2_weight = torch.nn.Parameter(
F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0),
requires_grad=False,
)
torch.cuda.empty_cache()
return

def apply(
Expand All @@ -540,6 +570,7 @@ def apply(
from sglang.srt.layers.fused_moe_triton import FusedMoE
from sglang.srt.layers.fused_moe_triton.fused_moe import fused_experts

# Expert selection
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
Expand All @@ -551,6 +582,7 @@ def apply(
custom_routing_function=custom_routing_function,
)

# Expert fusion with FP8 quantization
return fused_experts(
x,
layer.w13_weight,
Expand Down

0 comments on commit 95f93f4

Please sign in to comment.