diff --git a/python/sglang/srt/layers/fused_moe_triton/fused_moe.py b/python/sglang/srt/layers/fused_moe_triton/fused_moe.py index 4f92512b2d5..e6ce9cb4d39 100644 --- a/python/sglang/srt/layers/fused_moe_triton/fused_moe.py +++ b/python/sglang/srt/layers/fused_moe_triton/fused_moe.py @@ -16,6 +16,7 @@ from sglang.srt.utils import direct_register_custom_op, get_device_name logger = logging.getLogger(__name__) +padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0 @triton.jit @@ -58,6 +59,7 @@ def fused_moe_kernel( compute_type: tl.constexpr, use_fp8_w8a8: tl.constexpr, use_int8_w8a16: tl.constexpr, + even_Ks: tl.constexpr, ): """ Implements the fused computation for a Mixture of Experts (MOE) using @@ -143,12 +145,21 @@ def fused_moe_kernel( for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): # Load the next block of A and B, generate a mask by checking the # K dimension. - a = tl.load( - a_ptrs, - mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K), - other=0.0, - ) - b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + if even_Ks: + a = tl.load( + a_ptrs, + mask=token_mask[:, None], + other=0.0, + ) + b = tl.load(b_ptrs) + else: + a = tl.load( + a_ptrs, + mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K), + other=0.0, + ) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + # We accumulate along the K dimension. if use_int8_w8a16: accumulator = tl.dot(a, b.to(compute_type), acc=accumulator) @@ -254,7 +265,9 @@ def invoke_fused_moe_kernel( assert topk_weights.stride(1) == 1 assert sorted_token_ids.stride(0) == 1 + padded_size = 0 if use_fp8_w8a8: + padded_size = padding_size A, A_scale = ops.scaled_fp8_quant(A, A_scale) assert B_scale is not None elif use_int8_w8a16: @@ -268,6 +281,12 @@ def invoke_fused_moe_kernel( * triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]), ) + K = B.shape[2] - padded_size + if K % config["BLOCK_SIZE_K"] == 0: + even_Ks = True + else: + even_Ks = False + fused_moe_kernel[grid]( A, B, @@ -279,7 +298,7 @@ def invoke_fused_moe_kernel( expert_ids, num_tokens_post_padded, B.shape[1], - B.shape[2], + B.shape[2] - padded_size, sorted_token_ids.shape[0], topk_ids.numel(), A.stride(0), @@ -296,6 +315,7 @@ def invoke_fused_moe_kernel( compute_type=compute_type, use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a16=use_int8_w8a16, + even_Ks=even_Ks, **config, ) @@ -351,20 +371,39 @@ def get_default_config( dtype: Optional[str], is_marlin: bool, ) -> Dict[str, int]: - config = { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - } - # A heuristic: fused marlin works faster with this config for small M - if M <= E or (is_marlin and M <= 32): + if dtype == "fp8_w8a8": config = { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4, } + if M <= E: + config = { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4, + } + else: + config = { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + } + # A heuristic: fused marlin works faster with this config for small M + if M <= E or (is_marlin and M <= 32): + config = { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + } return config @@ -645,8 +684,12 @@ def fused_experts_impl( a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, ): + padded_size = padding_size + if not use_fp8_w8a8: + padded_size = 0 + # Check constraints. - assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" + assert hidden_states.shape[1] == w1.shape[2] - padded_size, "Hidden size mismatch" assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" assert w1.is_contiguous(), "Expert weights1 must be contiguous" @@ -668,7 +711,7 @@ def fused_experts_impl( get_config_func = functools.partial( try_get_optimal_moe_config, w1.shape, - w2.shape, + (w2.shape[0], w2.shape[1], w2.shape[2] - padded_size), topk_ids.shape[1], config_dtype, ) diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index 0e3c7abd924..c5a254b547e 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -1,9 +1,11 @@ # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/quantization/fp8.py import logging +import os from typing import Any, Callable, Dict, List, Optional import torch +import torch.nn.functional as F from torch.nn import Module from torch.nn.parameter import Parameter from vllm import _custom_ops as ops @@ -24,6 +26,7 @@ ) from vllm.model_executor.parameter import ModelWeightParameter, PerTensorScaleParameter +from sglang.srt.layers.fused_moe_triton.fused_moe import padding_size from sglang.srt.layers.linear import LinearMethodBase, UnquantizedLinearMethod from sglang.srt.layers.quantization.base_config import ( QuantizationConfig, @@ -420,7 +423,7 @@ def create_weights( def process_weights_after_loading(self, layer: Module) -> None: - # If checkpoint is fp16, quantize in place. + # If checkpoint is fp16 or bfloat16, quantize in place. if not self.quant_config.is_checkpoint_fp8_serialized: # If ROCm, use float8_e4m3fnuz instead (MI300x HW) fp8_dtype = torch.float8_e4m3fnuz if is_hip() else torch.float8_e4m3fn @@ -444,6 +447,19 @@ def process_weights_after_loading(self, layer: Module) -> None: ) layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False) layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False) + + # If ROCm, apply weight padding (min. Mem channel contention) only if set + if is_hip() and bool(int(os.getenv("MOE_PADDING", "0"))): + layer.w13_weight = torch.nn.Parameter( + F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0), + requires_grad=False, + ) + torch.cuda.empty_cache() + layer.w2_weight = torch.nn.Parameter( + F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0), + requires_grad=False, + ) + torch.cuda.empty_cache() return # If checkpoint is fp8, we need to handle that the @@ -472,6 +488,7 @@ def process_weights_after_loading(self, layer: Module) -> None: layer.w2_input_scale = torch.nn.Parameter( layer.w2_input_scale.max(), requires_grad=False ) + # If ROCm, normalize the weights and scales to e4m3fnuz if is_hip(): # Normalize the weights and scales @@ -523,6 +540,19 @@ def process_weights_after_loading(self, layer: Module) -> None: layer.w13_weight_scale = torch.nn.Parameter( max_w13_scales, requires_grad=False ) + + # If ROCm, apply weight padding (min. Mem channel contention) only if set + if is_hip() and bool(int(os.getenv("MOE_PADDING", "0"))): + layer.w13_weight = torch.nn.Parameter( + F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0), + requires_grad=False, + ) + torch.cuda.empty_cache() + layer.w2_weight = torch.nn.Parameter( + F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0), + requires_grad=False, + ) + torch.cuda.empty_cache() return def apply( @@ -540,6 +570,7 @@ def apply( from sglang.srt.layers.fused_moe_triton import FusedMoE from sglang.srt.layers.fused_moe_triton.fused_moe import fused_experts + # Expert selection topk_weights, topk_ids = FusedMoE.select_experts( hidden_states=x, router_logits=router_logits, @@ -551,6 +582,7 @@ def apply( custom_routing_function=custom_routing_function, ) + # Expert fusion with FP8 quantization return fused_experts( x, layer.w13_weight,