From 5e52843b2f7a11a6fc6000b97cd911c01e630f6f Mon Sep 17 00:00:00 2001 From: HaiShaw Date: Sat, 7 Dec 2024 01:50:28 -0800 Subject: [PATCH 1/3] fused_moe_triton optimizations on AMD --- .../srt/layers/fused_moe_triton/fused_moe.py | 86 ++++++++++++++----- python/sglang/srt/layers/quantization/fp8.py | 35 +++++++- 2 files changed, 96 insertions(+), 25 deletions(-) diff --git a/python/sglang/srt/layers/fused_moe_triton/fused_moe.py b/python/sglang/srt/layers/fused_moe_triton/fused_moe.py index 4f92512b2d5..f767b9944b0 100644 --- a/python/sglang/srt/layers/fused_moe_triton/fused_moe.py +++ b/python/sglang/srt/layers/fused_moe_triton/fused_moe.py @@ -16,7 +16,7 @@ from sglang.srt.utils import direct_register_custom_op, get_device_name logger = logging.getLogger(__name__) - +padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0 @triton.jit def fused_moe_kernel( @@ -58,6 +58,7 @@ def fused_moe_kernel( compute_type: tl.constexpr, use_fp8_w8a8: tl.constexpr, use_int8_w8a16: tl.constexpr, + even_Ks: tl.constexpr, ): """ Implements the fused computation for a Mixture of Experts (MOE) using @@ -143,12 +144,21 @@ def fused_moe_kernel( for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): # Load the next block of A and B, generate a mask by checking the # K dimension. - a = tl.load( - a_ptrs, - mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K), - other=0.0, - ) - b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + if even_Ks: + a = tl.load( + a_ptrs, + mask=token_mask[:, None], + other=0.0, + ) + b = tl.load(b_ptrs) + else: + a = tl.load( + a_ptrs, + mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K), + other=0.0, + ) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + # We accumulate along the K dimension. if use_int8_w8a16: accumulator = tl.dot(a, b.to(compute_type), acc=accumulator) @@ -254,7 +264,9 @@ def invoke_fused_moe_kernel( assert topk_weights.stride(1) == 1 assert sorted_token_ids.stride(0) == 1 + padded_size = 0 if use_fp8_w8a8: + padded_size = padding_size A, A_scale = ops.scaled_fp8_quant(A, A_scale) assert B_scale is not None elif use_int8_w8a16: @@ -268,6 +280,12 @@ def invoke_fused_moe_kernel( * triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]), ) + K = B.shape[2] - padded_size + if K % config["BLOCK_SIZE_K"] == 0: + even_Ks = True + else: + even_Ks = False + fused_moe_kernel[grid]( A, B, @@ -279,7 +297,7 @@ def invoke_fused_moe_kernel( expert_ids, num_tokens_post_padded, B.shape[1], - B.shape[2], + B.shape[2] - padded_size, sorted_token_ids.shape[0], topk_ids.numel(), A.stride(0), @@ -296,6 +314,7 @@ def invoke_fused_moe_kernel( compute_type=compute_type, use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a16=use_int8_w8a16, + even_Ks=even_Ks, **config, ) @@ -351,20 +370,39 @@ def get_default_config( dtype: Optional[str], is_marlin: bool, ) -> Dict[str, int]: - config = { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - } - # A heuristic: fused marlin works faster with this config for small M - if M <= E or (is_marlin and M <= 32): + if dtype == "fp8_w8a8": config = { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4, } + if M <= E: + config = { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4, + } + else: + config = { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + } + # A heuristic: fused marlin works faster with this config for small M + if M <= E or (is_marlin and M <= 32): + config = { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + } return config @@ -645,8 +683,12 @@ def fused_experts_impl( a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, ): + padded_size = padding_size + if not use_fp8_w8a8: + padded_size = 0 + # Check constraints. - assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" + assert hidden_states.shape[1] == w1.shape[2] - padded_size, "Hidden size mismatch" assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" assert w1.is_contiguous(), "Expert weights1 must be contiguous" @@ -668,7 +710,7 @@ def fused_experts_impl( get_config_func = functools.partial( try_get_optimal_moe_config, w1.shape, - w2.shape, + (w2.shape[0], w2.shape[1], w2.shape[2] - padded_size), topk_ids.shape[1], config_dtype, ) diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index acdce0b8cbd..f657a25643a 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -4,6 +4,7 @@ from typing import Any, Callable, Dict, List, Optional import torch +import torch.nn.functional as F from torch.nn import Module from torch.nn.parameter import Parameter from vllm import _custom_ops as ops @@ -29,6 +30,7 @@ FusedMoEMethodBase, FusedMoeWeightScaleSupported, ) +from sglang.srt.layers.fused_moe_triton.fused_moe import fused_experts, padding_size from sglang.srt.layers.linear import LinearMethodBase, UnquantizedLinearMethod from sglang.srt.layers.quantization.base_config import ( QuantizationConfig, @@ -404,7 +406,7 @@ def create_weights( def process_weights_after_loading(self, layer: Module) -> None: - # If checkpoint is fp16, quantize in place. + # If checkpoint is fp16 or bfloat16, quantize in place. if not self.quant_config.is_checkpoint_fp8_serialized: # If ROCm, use float8_e4m3fnuz instead (MI300x HW) fp8_dtype = torch.float8_e4m3fnuz if is_hip() else torch.float8_e4m3fn @@ -428,6 +430,19 @@ def process_weights_after_loading(self, layer: Module) -> None: ) layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False) layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False) + + # If ROCm, apply weight padding (min. Mem channel contention) only if set + if is_hip() and bool(int(os.getenv("MOE_PADDING", "0"))): + layer.w13_weight = torch.nn.Parameter( + F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0), + requires_grad=False, + ) + torch.cuda.empty_cache() + layer.w2_weight = torch.nn.Parameter( + F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0), + requires_grad=False, + ) + torch.cuda.empty_cache() return # If checkpoint is fp8, we need to handle that the @@ -456,6 +471,7 @@ def process_weights_after_loading(self, layer: Module) -> None: layer.w2_input_scale = torch.nn.Parameter( layer.w2_input_scale.max(), requires_grad=False ) + # If ROCm, normalize the weights and scales to e4m3fnuz if is_hip(): # Normalize the weights and scales @@ -507,6 +523,19 @@ def process_weights_after_loading(self, layer: Module) -> None: layer.w13_weight_scale = torch.nn.Parameter( max_w13_scales, requires_grad=False ) + + # If ROCm, apply weight padding (min. Mem channel contention) only if set + if is_hip() and bool(int(os.getenv("MOE_PADDING", "0"))): + layer.w13_weight = torch.nn.Parameter( + F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0), + requires_grad=False, + ) + torch.cuda.empty_cache() + layer.w2_weight = torch.nn.Parameter( + F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0), + requires_grad=False, + ) + torch.cuda.empty_cache() return def apply( @@ -522,8 +551,7 @@ def apply( custom_routing_function: Optional[Callable] = None, ) -> torch.Tensor: - from vllm.model_executor.layers.fused_moe import fused_experts - + # Expert selection topk_weights, topk_ids = FusedMoE.select_experts( hidden_states=x, router_logits=router_logits, @@ -535,6 +563,7 @@ def apply( custom_routing_function=custom_routing_function, ) + # Expert fusion with FP8 quantization return fused_experts( x, layer.w13_weight, From cdc023f7924fb540e90cfaafbcafeaed1463e8e0 Mon Sep 17 00:00:00 2001 From: HaiShaw Date: Sat, 7 Dec 2024 03:01:41 -0800 Subject: [PATCH 2/3] Lint fix and simplification --- .../srt/layers/fused_moe_triton/fused_moe.py | 1 + .../srt/layers/quantization/__init__.py | 45 ------------------- python/sglang/srt/layers/quantization/fp8.py | 1 + 3 files changed, 2 insertions(+), 45 deletions(-) diff --git a/python/sglang/srt/layers/fused_moe_triton/fused_moe.py b/python/sglang/srt/layers/fused_moe_triton/fused_moe.py index f767b9944b0..e6ce9cb4d39 100644 --- a/python/sglang/srt/layers/fused_moe_triton/fused_moe.py +++ b/python/sglang/srt/layers/fused_moe_triton/fused_moe.py @@ -18,6 +18,7 @@ logger = logging.getLogger(__name__) padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0 + @triton.jit def fused_moe_kernel( # Pointers to matrices diff --git a/python/sglang/srt/layers/quantization/__init__.py b/python/sglang/srt/layers/quantization/__init__.py index 3e2078c4a4d..c7be4f927a0 100644 --- a/python/sglang/srt/layers/quantization/__init__.py +++ b/python/sglang/srt/layers/quantization/__init__.py @@ -53,50 +53,6 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]: return QUANTIZATION_METHODS[quantization] -def fp8_moe_apply( - self, - layer: torch.nn.Module, - x: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool, - use_grouped_topk: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - custom_routing_function: Optional[Callable] = None, -) -> torch.Tensor: - """Enhanced apply method for FP8 MoE.""" - from sglang.srt.layers.fused_moe_triton import FusedMoE - from sglang.srt.layers.fused_moe_triton.fused_moe import fused_experts - - # Expert selection - topk_weights, topk_ids = FusedMoE.select_experts( - hidden_states=x, - router_logits=router_logits, - use_grouped_topk=use_grouped_topk, - top_k=top_k, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - custom_routing_function=custom_routing_function, - ) - - # Expert fusion with FP8 quantization - return fused_experts( - x, - layer.w13_weight, - layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - inplace=True, - use_fp8_w8a8=True, - w1_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, - a1_scale=layer.w13_input_scale, - a2_scale=layer.w2_input_scale, - ) - - def fp8_get_quant_method(self, layer, prefix): """Enhanced get_quant_method for FP8 config.""" from vllm.model_executor.layers.linear import LinearBase @@ -151,7 +107,6 @@ def awq_get_quant_method(self, layer, prefix): def apply_monkey_patches(): """Apply all monkey patches in one place.""" - setattr(Fp8MoEMethod, "apply", fp8_moe_apply) setattr(Fp8Config, "get_quant_method", fp8_get_quant_method) setattr(GPTQMarlinConfig, "get_quant_method", gptq_get_quant_method) setattr(AWQMarlinConfig, "get_quant_method", awq_get_quant_method) diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index f657a25643a..6fb6b0347c8 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -1,6 +1,7 @@ # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/quantization/fp8.py import logging +import os from typing import Any, Callable, Dict, List, Optional import torch From 5f0c259b93601d2a60188c12aeba12e5190e8c89 Mon Sep 17 00:00:00 2001 From: HaiShaw Date: Sat, 7 Dec 2024 04:32:04 -0800 Subject: [PATCH 3/3] Lint --- python/sglang/srt/layers/quantization/fp8.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index ca84126b285..c5a254b547e 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -569,7 +569,7 @@ def apply( ) -> torch.Tensor: from sglang.srt.layers.fused_moe_triton import FusedMoE from sglang.srt.layers.fused_moe_triton.fused_moe import fused_experts - + # Expert selection topk_weights, topk_ids = FusedMoE.select_experts( hidden_states=x,