Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fp8 MoE optimizations on AMD #2388

Merged
merged 5 commits into from
Dec 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 64 additions & 21 deletions python/sglang/srt/layers/fused_moe_triton/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from sglang.srt.utils import direct_register_custom_op, get_device_name

logger = logging.getLogger(__name__)
padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0


@triton.jit
Expand Down Expand Up @@ -58,6 +59,7 @@ def fused_moe_kernel(
compute_type: tl.constexpr,
use_fp8_w8a8: tl.constexpr,
use_int8_w8a16: tl.constexpr,
even_Ks: tl.constexpr,
):
"""
Implements the fused computation for a Mixture of Experts (MOE) using
Expand Down Expand Up @@ -143,12 +145,21 @@ def fused_moe_kernel(
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
# Load the next block of A and B, generate a mask by checking the
# K dimension.
a = tl.load(
a_ptrs,
mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
other=0.0,
)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
if even_Ks:
a = tl.load(
a_ptrs,
mask=token_mask[:, None],
other=0.0,
)
b = tl.load(b_ptrs)
else:
a = tl.load(
a_ptrs,
mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
other=0.0,
)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)

# We accumulate along the K dimension.
if use_int8_w8a16:
accumulator = tl.dot(a, b.to(compute_type), acc=accumulator)
Expand Down Expand Up @@ -254,7 +265,9 @@ def invoke_fused_moe_kernel(
assert topk_weights.stride(1) == 1
assert sorted_token_ids.stride(0) == 1

padded_size = 0
if use_fp8_w8a8:
padded_size = padding_size
A, A_scale = ops.scaled_fp8_quant(A, A_scale)
assert B_scale is not None
elif use_int8_w8a16:
Expand All @@ -268,6 +281,12 @@ def invoke_fused_moe_kernel(
* triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]),
)

K = B.shape[2] - padded_size
if K % config["BLOCK_SIZE_K"] == 0:
even_Ks = True
else:
even_Ks = False

fused_moe_kernel[grid](
A,
B,
Expand All @@ -279,7 +298,7 @@ def invoke_fused_moe_kernel(
expert_ids,
num_tokens_post_padded,
B.shape[1],
B.shape[2],
B.shape[2] - padded_size,
sorted_token_ids.shape[0],
topk_ids.numel(),
A.stride(0),
Expand All @@ -296,6 +315,7 @@ def invoke_fused_moe_kernel(
compute_type=compute_type,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a16=use_int8_w8a16,
even_Ks=even_Ks,
**config,
)

Expand Down Expand Up @@ -351,20 +371,39 @@ def get_default_config(
dtype: Optional[str],
is_marlin: bool,
) -> Dict[str, int]:
config = {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
}
# A heuristic: fused marlin works faster with this config for small M
if M <= E or (is_marlin and M <= 32):
if dtype == "fp8_w8a8":
config = {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 4,
}
if M <= E:
config = {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4,
}
else:
config = {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
}
# A heuristic: fused marlin works faster with this config for small M
if M <= E or (is_marlin and M <= 32):
config = {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
}
return config


Expand Down Expand Up @@ -645,8 +684,12 @@ def fused_experts_impl(
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
):
padded_size = padding_size
if not use_fp8_w8a8:
padded_size = 0

# Check constraints.
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
assert hidden_states.shape[1] == w1.shape[2] - padded_size, "Hidden size mismatch"
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
Expand All @@ -668,7 +711,7 @@ def fused_experts_impl(
get_config_func = functools.partial(
try_get_optimal_moe_config,
w1.shape,
w2.shape,
(w2.shape[0], w2.shape[1], w2.shape[2] - padded_size),
topk_ids.shape[1],
config_dtype,
)
Expand Down
34 changes: 33 additions & 1 deletion python/sglang/srt/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/quantization/fp8.py

import logging
import os
from typing import Any, Callable, Dict, List, Optional

import torch
import torch.nn.functional as F
from torch.nn import Module
from torch.nn.parameter import Parameter
from vllm import _custom_ops as ops
Expand All @@ -24,6 +26,7 @@
)
from vllm.model_executor.parameter import ModelWeightParameter, PerTensorScaleParameter

from sglang.srt.layers.fused_moe_triton.fused_moe import padding_size
from sglang.srt.layers.linear import LinearMethodBase, UnquantizedLinearMethod
from sglang.srt.layers.quantization.base_config import (
QuantizationConfig,
Expand Down Expand Up @@ -420,7 +423,7 @@ def create_weights(

def process_weights_after_loading(self, layer: Module) -> None:

# If checkpoint is fp16, quantize in place.
# If checkpoint is fp16 or bfloat16, quantize in place.
if not self.quant_config.is_checkpoint_fp8_serialized:
# If ROCm, use float8_e4m3fnuz instead (MI300x HW)
fp8_dtype = torch.float8_e4m3fnuz if is_hip() else torch.float8_e4m3fn
Expand All @@ -444,6 +447,19 @@ def process_weights_after_loading(self, layer: Module) -> None:
)
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)

# If ROCm, apply weight padding (min. Mem channel contention) only if set
if is_hip() and bool(int(os.getenv("MOE_PADDING", "0"))):
layer.w13_weight = torch.nn.Parameter(
F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0),
requires_grad=False,
)
torch.cuda.empty_cache()
layer.w2_weight = torch.nn.Parameter(
F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0),
requires_grad=False,
)
torch.cuda.empty_cache()
return

# If checkpoint is fp8, we need to handle that the
Expand Down Expand Up @@ -472,6 +488,7 @@ def process_weights_after_loading(self, layer: Module) -> None:
layer.w2_input_scale = torch.nn.Parameter(
layer.w2_input_scale.max(), requires_grad=False
)

# If ROCm, normalize the weights and scales to e4m3fnuz
if is_hip():
# Normalize the weights and scales
Expand Down Expand Up @@ -523,6 +540,19 @@ def process_weights_after_loading(self, layer: Module) -> None:
layer.w13_weight_scale = torch.nn.Parameter(
max_w13_scales, requires_grad=False
)

# If ROCm, apply weight padding (min. Mem channel contention) only if set
if is_hip() and bool(int(os.getenv("MOE_PADDING", "0"))):
layer.w13_weight = torch.nn.Parameter(
F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0),
requires_grad=False,
)
torch.cuda.empty_cache()
layer.w2_weight = torch.nn.Parameter(
F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0),
requires_grad=False,
)
torch.cuda.empty_cache()
return

def apply(
Expand All @@ -540,6 +570,7 @@ def apply(
from sglang.srt.layers.fused_moe_triton import FusedMoE
from sglang.srt.layers.fused_moe_triton.fused_moe import fused_experts

# Expert selection
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
Expand All @@ -551,6 +582,7 @@ def apply(
custom_routing_function=custom_routing_function,
)

# Expert fusion with FP8 quantization
return fused_experts(
x,
layer.w13_weight,
Expand Down
Loading