From 53aed988cbaa7433c59c070d72d5aad3815cb286 Mon Sep 17 00:00:00 2001 From: HandH1998 <1335248067@qq.com> Date: Thu, 26 Dec 2024 00:02:14 +0800 Subject: [PATCH] Refactor MoE (#2575) Co-authored-by: zhyncs --- python/sglang/srt/configs/model_config.py | 5 +- python/sglang/srt/layers/linear.py | 22 +- .../layers/moe/fused_moe_triton/fused_moe.py | 86 ++++- .../srt/layers/moe/fused_moe_triton/layer.py | 7 +- python/sglang/srt/layers/quantization/fp8.py | 184 ++++++++-- .../srt/layers/quantization/fp8_kernel.py | 278 ++++++++++++++ .../srt/layers/quantization/fp8_utils.py | 91 ++++- python/sglang/srt/models/deepseek_v2.py | 47 ++- python/sglang/test/test_block_fp8.py | 341 ++++++++++++++++++ 9 files changed, 1012 insertions(+), 49 deletions(-) create mode 100644 python/sglang/srt/layers/quantization/fp8_kernel.py create mode 100644 python/sglang/test/test_block_fp8.py diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index f43706d43f7..69a61737120 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -94,7 +94,10 @@ def __init__( ) # FIXME: temporary special judge for MLA architecture - if "DeepseekV2ForCausalLM" in self.hf_config.architectures: + if ( + "DeepseekV2ForCausalLM" in self.hf_config.architectures + or "DeepseekV3ForCausalLM" in self.hf_config.architectures + ): self.head_dim = 256 self.attention_arch = AttentionArch.MLA self.kv_lora_rank = self.hf_config.kv_lora_rank diff --git a/python/sglang/srt/layers/linear.py b/python/sglang/srt/layers/linear.py index f69058ff319..d5dfb871847 100644 --- a/python/sglang/srt/layers/linear.py +++ b/python/sglang/srt/layers/linear.py @@ -30,6 +30,7 @@ QuantizationConfig, QuantizeMethodBase, ) +from sglang.srt.layers.quantization.fp8_utils import BlockQuantScaleParameter from sglang.srt.utils import set_weight_attrs logger = logging.getLogger(__name__) @@ -628,8 +629,19 @@ def weight_loader_v2( assert loaded_shard_id < len(self.output_sizes) tp_size = get_tensor_model_parallel_world_size() - shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size - shard_size = self.output_sizes[loaded_shard_id] // tp_size + + if isinstance(param, BlockQuantScaleParameter): + weight_block_size = self.quant_method.quant_config.weight_block_size + block_n, _ = weight_block_size[0], weight_block_size[1] + shard_offset = ( + (sum(self.output_sizes[:loaded_shard_id]) + block_n - 1) // block_n + ) // tp_size + shard_size = ( + (self.output_sizes[loaded_shard_id] + block_n - 1) // block_n // tp_size + ) + else: + shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size + shard_size = self.output_sizes[loaded_shard_id] // tp_size param.load_merged_column_weight( loaded_weight=loaded_weight, @@ -795,6 +807,12 @@ def weight_loader_v2( shard_offset = self._get_shard_offset_mapping(loaded_shard_id) shard_size = self._get_shard_size_mapping(loaded_shard_id) + if isinstance(param, BlockQuantScaleParameter): + weight_block_size = self.quant_method.quant_config.weight_block_size + block_n, _ = weight_block_size[0], weight_block_size[1] + shard_offset = (shard_offset + block_n - 1) // block_n + shard_size = (shard_size + block_n - 1) // block_n + param.load_qkv_weight( loaded_weight=loaded_weight, num_heads=self.num_kv_head_replicas, diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py index 24e0133a121..a645e5f7d8f 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py @@ -6,7 +6,7 @@ import json import logging import os -from typing import Any, Callable, Dict, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple import torch import triton @@ -14,6 +14,7 @@ from vllm import _custom_ops as ops from sglang.srt.layers.moe.topk import select_experts +from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8 from sglang.srt.utils import direct_register_custom_op, get_device_name logger = logging.getLogger(__name__) @@ -48,8 +49,14 @@ def fused_moe_kernel( stride_bn, stride_cm, stride_cn, + stride_asm, + stride_ask, stride_bse, + stride_bsk, stride_bsn, + # Block size for block-wise quantization + group_n: tl.constexpr, + group_k: tl.constexpr, # Meta-parameters BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, @@ -133,8 +140,15 @@ def fused_moe_kernel( b_scale = tl.load(b_scale_ptrs) if use_fp8_w8a8: - a_scale = tl.load(a_scale_ptr) - b_scale = tl.load(b_scale_ptr + off_experts) + if group_k > 0 and group_n > 0: + a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm + offs_bsn = offs_bn // group_n + b_scale_ptrs = ( + b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn + ) + else: + a_scale = tl.load(a_scale_ptr) + b_scale = tl.load(b_scale_ptr + off_experts) # ----------------------------------------------------------- # Iterate to compute a block of the C matrix. @@ -165,7 +179,17 @@ def fused_moe_kernel( if use_int8_w8a16: accumulator = tl.dot(a, b.to(compute_type), acc=accumulator) elif use_fp8_w8a8: - accumulator = tl.dot(a, b, acc=accumulator) + if group_k > 0 and group_n > 0: + k_start = k * BLOCK_SIZE_K + offs_ks = k_start // group_k + a_scale = tl.load( + a_scale_ptrs + offs_ks * stride_ask, mask=token_mask, other=0.0 + ) + b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk) + + accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :] + else: + accumulator = tl.dot(a, b, acc=accumulator) else: accumulator += tl.dot(a, b) # Advance the ptrs to the next K block. @@ -178,7 +202,10 @@ def fused_moe_kernel( if use_int8_w8a16: accumulator = (accumulator * b_scale).to(compute_type) elif use_fp8_w8a8: - accumulator = (accumulator * a_scale * b_scale).to(compute_type) + if group_k > 0 and group_n > 0: + accumulator = accumulator.to(compute_type) + else: + accumulator = (accumulator * a_scale * b_scale).to(compute_type) else: accumulator = accumulator.to(compute_type) # ----------------------------------------------------------- @@ -262,6 +289,7 @@ def invoke_fused_moe_kernel( compute_type: tl.dtype, use_fp8_w8a8: bool, use_int8_w8a16: bool, + block_shape: Optional[List[int]] = None, ) -> None: assert topk_weights.stride(1) == 1 assert sorted_token_ids.stride(0) == 1 @@ -269,8 +297,16 @@ def invoke_fused_moe_kernel( padded_size = 0 if use_fp8_w8a8: padded_size = padding_size - A, A_scale = ops.scaled_fp8_quant(A, A_scale) assert B_scale is not None + if block_shape is None: + A, A_scale = ops.scaled_fp8_quant(A, A_scale) + else: + assert len(block_shape) == 2 + block_n, block_k = block_shape[0], block_shape[1] + A, A_scale = per_token_group_quant_fp8(A, block_k) + assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1] + assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2] + assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1] elif use_int8_w8a16: assert B_scale is not None else: @@ -309,8 +345,13 @@ def invoke_fused_moe_kernel( B.stride(1), C.stride(1), C.stride(2), - B_scale.stride(0) if B_scale is not None and use_int8_w8a16 else 0, - B_scale.stride(1) if B_scale is not None and use_int8_w8a16 else 0, + A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0, + A_scale.stride(1) if A_scale is not None and A_scale.ndim == 2 else 0, + B_scale.stride(0) if B_scale is not None and B_scale.ndim >= 2 else 0, + B_scale.stride(2) if B_scale is not None and B_scale.ndim == 3 else 0, + B_scale.stride(1) if B_scale is not None and B_scale.ndim >= 2 else 0, + 0 if block_shape is None else block_shape[0], + 0 if block_shape is None else block_shape[1], MUL_ROUTED_WEIGHT=mul_routed_weight, top_k=top_k, compute_type=compute_type, @@ -415,6 +456,7 @@ def try_get_optimal_moe_config( dtype: Optional[str], M: int, is_marlin: bool = False, + block_shape: Optional[List[int]] = None, ): from sglang.srt.layers.moe.fused_moe_triton import get_config @@ -433,6 +475,13 @@ def try_get_optimal_moe_config( else: # Else use the default config config = get_default_config(M, E, N, w1_shape[2], top_k, dtype, is_marlin) + # TODO(HandH1998): Optimize the configs of block-wise quant. + # NOTE(HandH1998): For block-wise quant, + # BLOCK_K must be divisable by block_shape[1] + # BLOCK_N and BLOCK_M has no requirements + if block_shape is not None: + config["BLOCK_SIZE_N"] = block_shape[0] + config["BLOCK_SIZE_K"] = block_shape[1] return config @@ -464,6 +513,7 @@ def inplace_fused_experts( w2_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None, ) -> None: fused_experts_impl( hidden_states, @@ -478,6 +528,7 @@ def inplace_fused_experts( w2_scale, a1_scale, a2_scale, + block_shape, ) @@ -493,6 +544,7 @@ def inplace_fused_experts_fake( w2_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None, ) -> None: pass @@ -517,6 +569,7 @@ def outplace_fused_experts( w2_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None, ) -> torch.Tensor: return fused_experts_impl( hidden_states, @@ -531,6 +584,7 @@ def outplace_fused_experts( w2_scale, a1_scale, a2_scale, + block_shape, ) @@ -546,6 +600,7 @@ def outplace_fused_experts_fake( w2_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None, ) -> torch.Tensor: return torch.empty_like(hidden_states) @@ -571,6 +626,7 @@ def fused_experts( w2_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None, ): if inplace: torch.ops.sglang.inplace_fused_experts( @@ -585,6 +641,7 @@ def fused_experts( w2_scale, a1_scale, a2_scale, + block_shape, ) return hidden_states else: @@ -600,6 +657,7 @@ def fused_experts( w2_scale, a1_scale, a2_scale, + block_shape, ) @@ -616,6 +674,7 @@ def fused_experts_impl( w2_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None, ): padded_size = padding_size if not use_fp8_w8a8: @@ -647,6 +706,7 @@ def fused_experts_impl( (w2.shape[0], w2.shape[1], w2.shape[2] - padded_size), topk_ids.shape[1], config_dtype, + block_shape=block_shape, ) config = get_config_func(M) @@ -719,6 +779,7 @@ def fused_experts_impl( compute_type=compute_type, use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a16=use_int8_w8a16, + block_shape=block_shape, ) ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) @@ -740,6 +801,7 @@ def fused_experts_impl( compute_type=compute_type, use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a16=use_int8_w8a16, + block_shape=block_shape, ) torch.sum( @@ -768,6 +830,7 @@ def fused_moe( w2_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None, ) -> torch.Tensor: """ This function computes a Mixture of Experts (MoE) layer using two sets of @@ -795,6 +858,12 @@ def fused_moe( w1. - w2_scale (Optional[torch.Tensor]): Optional scale to be used for w2. + - a1_scale (Optional[torch.Tensor]): Optional scale to be used for + a1. + - a2_scale (Optional[torch.Tensor]): Optional scale to be used for + a2. + - block_shape: (Optional[List[int]]): Optional block size for block-wise + quantization. Returns: - torch.Tensor: The output tensor after applying the MoE layer. @@ -826,4 +895,5 @@ def fused_moe( w2_scale=w2_scale, a1_scale=a1_scale, a2_scale=a2_scale, + block_shape=block_shape, ) diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index 2548ca16330..b7f87a9a98f 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -34,6 +34,7 @@ class FusedMoeWeightScaleSupported(Enum): TENSOR = "tensor" CHANNEL = "channel" GROUP = "group" + BLOCK = "block" class FusedMoEMethodBase(QuantizeMethodBase): @@ -214,6 +215,7 @@ def __init__( ) self.top_k = top_k self.num_experts = num_experts + assert intermediate_size % self.tp_size == 0 self.intermediate_size_per_partition = intermediate_size // self.tp_size self.reduce_results = reduce_results self.renormalize = renormalize @@ -470,7 +472,10 @@ def weight_loader( expert_data=expert_data, tp_rank=tp_rank, ) - elif quant_method == FusedMoeWeightScaleSupported.GROUP.value: + elif quant_method in [ + FusedMoeWeightScaleSupported.GROUP.value, + FusedMoeWeightScaleSupported.BLOCK.value, + ]: self._load_model_weight_or_group_weight_scale( shard_id=shard_id, shard_dim=shard_dim, diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index b12815c665f..989e5cd5cc4 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -9,6 +9,7 @@ from torch.nn import Module from torch.nn.parameter import Parameter from vllm import _custom_ops as ops +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.linear import LinearBase from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( @@ -32,7 +33,11 @@ QuantizationConfig, QuantizeMethodBase, ) -from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz +from sglang.srt.layers.quantization.fp8_utils import ( + BlockQuantScaleParameter, + apply_w8a8_block_fp8_linear, + normalize_e4m3fn_to_e4m3fnuz, +) from sglang.srt.utils import ( get_bool_env_var, is_hip, @@ -53,6 +58,7 @@ def __init__( is_checkpoint_fp8_serialized: bool = False, activation_scheme: str = "dynamic", ignored_layers: Optional[List[str]] = None, + weight_block_size: List[int] = None, ) -> None: self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized if is_checkpoint_fp8_serialized: @@ -64,6 +70,20 @@ def __init__( raise ValueError(f"Unsupported activation scheme {activation_scheme}") self.activation_scheme = activation_scheme self.ignored_layers = ignored_layers or [] + if weight_block_size is not None: + if not is_checkpoint_fp8_serialized: + raise ValueError( + f"The block-wise quantization only supports fp8-serialized checkpoint for now." + ) + if len(weight_block_size) != 2: + raise ValueError( + f"The quantization block size of weight must have 2 dimensions, but got {len(weight_block_size)} dimensions." + ) + if activation_scheme != "dynamic": + raise ValueError( + f"The block-wise quantization only supports dynamic activation scheme for now, but got {activation_scheme} activation scheme." + ) + self.weight_block_size = weight_block_size @classmethod def get_name(cls) -> str: @@ -87,10 +107,12 @@ def from_config(cls, config: Dict[str, Any]) -> "Fp8Config": is_checkpoint_fp8_serialized = "fp8" in quant_method activation_scheme = cls.get_from_keys(config, ["activation_scheme"]) ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None) + weight_block_size = cls.get_from_keys_or(config, ["weight_block_size"], None) return cls( is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized, activation_scheme=activation_scheme, ignored_layers=ignored_layers, + weight_block_size=weight_block_size, ) def get_quant_method( @@ -143,6 +165,11 @@ def __init__(self, quant_config: Fp8Config): if is_hip(): self.use_marlin = False + self.block_quant = self.quant_config.weight_block_size is not None + if self.block_quant: + # Marlin doesn't support block-wise fp8 + self.use_marlin = False + def create_weights( self, layer: torch.nn.Module, @@ -153,10 +180,35 @@ def create_weights( params_dtype: torch.dtype, **extra_weight_attrs, ): - del input_size, output_size output_size_per_partition = sum(output_partition_sizes) weight_loader = extra_weight_attrs.get("weight_loader") + tp_size = get_tensor_model_parallel_world_size() + if self.block_quant: + block_n, block_k = ( + self.quant_config.weight_block_size[0], + self.quant_config.weight_block_size[1], + ) + # Required by row parallel + if tp_size > 1 and input_size // input_size_per_partition == tp_size: + if input_size_per_partition % block_k != 0: + raise ValueError( + f"Weight input_size_per_partition = " + f"{input_size_per_partition} is not divisible by " + f"weight quantization block_k = {block_k}." + ) + # Required by collum parallel or enabling merged weights + if ( + tp_size > 1 and output_size // output_size_per_partition == tp_size + ) or len(output_partition_sizes) > 1: + for output_partition_size in output_partition_sizes: + if output_partition_size % block_n != 0: + raise ValueError( + f"Weight output_partition_size = " + f"{output_partition_size} is not divisible by " + f"weight quantization block_n = {block_n}." + ) + layer.logical_widths = output_partition_sizes layer.input_size_per_partition = input_size_per_partition @@ -184,13 +236,27 @@ def create_weights( # Otherwise, wait until process_weights_after_loading. if self.quant_config.is_checkpoint_fp8_serialized: # WEIGHT SCALE - scale = PerTensorScaleParameter( - data=torch.empty(len(output_partition_sizes), dtype=torch.float32), - weight_loader=weight_loader, - ) - - scale[:] = torch.finfo(torch.float32).min - layer.register_parameter("weight_scale", scale) + if self.block_quant: + assert self.quant_config.activation_scheme == "dynamic" + scale = BlockQuantScaleParameter( + data=torch.empty( + (output_size_per_partition + block_n - 1) // block_n, + (input_size_per_partition + block_k - 1) // block_k, + dtype=torch.float32, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + scale[:] = torch.finfo(torch.float32).min + layer.register_parameter("weight_scale_inv", scale) + else: + scale = PerTensorScaleParameter( + data=torch.empty(len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader, + ) + scale[:] = torch.finfo(torch.float32).min + layer.register_parameter("weight_scale", scale) # INPUT ACTIVATION SCALE if self.quant_config.activation_scheme == "static": @@ -205,6 +271,9 @@ def create_weights( layer.register_parameter("input_scale", None) def process_weights_after_loading(self, layer: Module) -> None: + # Block quant doesn't need to process weights after loading + if self.block_quant: + return layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False) # If checkpoint not serialized fp8, quantize the weights. if not self.quant_config.is_checkpoint_fp8_serialized: @@ -295,6 +364,16 @@ def apply( bias=bias, ) + if self.block_quant: + return apply_w8a8_block_fp8_linear( + input=x, + weight=layer.weight, + block_size=self.quant_config.weight_block_size, + weight_scale=layer.weight_scale_inv, + input_scale=layer.input_scale, + bias=bias, + ) + return apply_fp8_linear( input=x, weight=layer.weight, @@ -339,6 +418,7 @@ def __new__(cls, *args, **kwargs): def __init__(self, quant_config): self.quant_config = quant_config + self.block_quant = self.quant_config.weight_block_size is not None def create_weights( self, @@ -353,6 +433,28 @@ def create_weights( if self.quant_config.is_checkpoint_fp8_serialized: params_dtype = torch.float8_e4m3fn + tp_size = get_tensor_model_parallel_world_size() + if self.block_quant: + block_n, block_k = ( + self.quant_config.weight_block_size[0], + self.quant_config.weight_block_size[1], + ) + # NOTE(HandH1998): To ensure proper alignment of the block-wise quantization scales, the output_size of the weights for both the gate and up layers must be divisible by block_n. + # Required by collum parallel or enabling merged weights + if intermediate_size % block_n != 0: + raise ValueError( + f"The output_size of gate's and up's weight = " + f"{intermediate_size} is not divisible by " + f"weight quantization block_n = {block_n}." + ) + if tp_size > 1: + # Required by row parallel + if intermediate_size % block_k != 0: + raise ValueError( + f"The input_size of down's weight = " + f"{intermediate_size} is not divisible by " + f"weight quantization block_k = {block_k}." + ) # WEIGHTS w13_weight = torch.nn.Parameter( @@ -374,21 +476,45 @@ def create_weights( set_weight_attrs(w2_weight, extra_weight_attrs) # WEIGHT_SCALES - # Allocate 2 scales for w1 and w3 respectively. - # They will be combined to a single scale after weight loading. - w13_weight_scale = torch.nn.Parameter( - torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False - ) - layer.register_parameter("w13_weight_scale", w13_weight_scale) - - w2_weight_scale = torch.nn.Parameter( - torch.ones(num_experts, dtype=torch.float32), requires_grad=False - ) - layer.register_parameter("w2_weight_scale", w2_weight_scale) + if self.block_quant: + w13_weight_scale = torch.nn.Parameter( + torch.ones( + num_experts, + 2 * ((intermediate_size + block_n - 1) // block_n), + (hidden_size + block_k - 1) // block_k, + dtype=torch.float32, + ), + requires_grad=False, + ) + w2_weight_scale = torch.nn.Parameter( + torch.ones( + num_experts, + (hidden_size + block_n - 1) // block_n, + (intermediate_size + block_k - 1) // block_k, + dtype=torch.float32, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight_scale_inv", w13_weight_scale) + layer.register_parameter("w2_weight_scale_inv", w2_weight_scale) + assert self.quant_config.activation_scheme == "dynamic" + else: + # Allocate 2 scales for w1 and w3 respectively. + # They will be combined to a single scale after weight loading. + w13_weight_scale = torch.nn.Parameter( + torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False + ) + w2_weight_scale = torch.nn.Parameter( + torch.ones(num_experts, dtype=torch.float32), requires_grad=False + ) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + layer.register_parameter("w2_weight_scale", w2_weight_scale) # Add the quantization method used (per tensor/grouped/channel) # to ensure the weight scales are loaded in properly extra_weight_attrs.update( - {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value} + {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value} + if self.block_quant + else {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value} ) # If loading fp8 checkpoint, pass the weight loaders. # If loading an fp16 checkpoint, do not (we will quantize in @@ -422,7 +548,9 @@ def create_weights( layer.w2_input_scale = None def process_weights_after_loading(self, layer: Module) -> None: - + # Block quant doesn't need to process weights after loading + if self.block_quant: + return # If checkpoint is fp16 or bfloat16, quantize in place. if not self.quant_config.is_checkpoint_fp8_serialized: # If ROCm, use float8_e4m3fnuz instead (MI300x HW) @@ -519,7 +647,6 @@ def process_weights_after_loading(self, layer: Module) -> None: layer.w2_input_scale = torch.nn.Parameter( w2_input_scale, requires_grad=False ) - # Fp8 moe kernel needs single weight scale for w13 per expert. # We take the max then dequant and requant each expert. assert layer.w13_weight_scale is not None @@ -594,10 +721,17 @@ def apply( topk_ids=topk_ids, inplace=True, use_fp8_w8a8=True, - w1_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, + w1_scale=( + layer.w13_weight_scale_inv + if self.block_quant + else layer.w13_weight_scale + ), + w2_scale=( + layer.w2_weight_scale_inv if self.block_quant else layer.w2_weight_scale + ), a1_scale=layer.w13_input_scale, a2_scale=layer.w2_input_scale, + block_shape=self.quant_config.weight_block_size, ) diff --git a/python/sglang/srt/layers/quantization/fp8_kernel.py b/python/sglang/srt/layers/quantization/fp8_kernel.py new file mode 100644 index 00000000000..4080560f19c --- /dev/null +++ b/python/sglang/srt/layers/quantization/fp8_kernel.py @@ -0,0 +1,278 @@ +from typing import List, Tuple + +import torch +import triton +import triton.language as tl + + +@triton.jit +def _per_token_group_quant_fp8( + # Pointers to inputs and output + y_ptr, + y_q_ptr, + y_s_ptr, + # Stride of input + y_stride, + # Collums of input + N, + # Avoid to divide zero + eps, + # Information for float8 + fp8_min, + fp8_max, + # Meta-parameters + BLOCK: tl.constexpr, +): + """A Triton-accelerated function to perform per-token-group quantization on a + tensor. + + This function converts the tensor values into float8 values. + """ + # Map the program id to the row of X and Y it should compute. + g_id = tl.program_id(0) + y_ptr += g_id * y_stride + y_q_ptr += g_id * y_stride + y_s_ptr += g_id + + cols = tl.arange(0, BLOCK) # N <= BLOCK + mask = cols < N + + y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32) + # Quant + _absmax = tl.maximum(tl.max(tl.abs(y)), eps) + y_s = _absmax / fp8_max + y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) + + tl.store(y_q_ptr + cols, y_q, mask=mask) + tl.store(y_s_ptr, y_s) + + +def per_token_group_quant_fp8( + x: torch.Tensor, + group_size: int, + eps: float = 1e-10, + dtype: torch.dtype = torch.float8_e4m3fn, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Function to perform per-token-group quantization on an input tensor `x`. + + It converts the tensor values into signed float8 values and returns the + quantized tensor along with the scaling factor used for quantization. + + Args: + x: The input tenosr with ndim >= 2. + group_size: The group size used for quantization. + eps: The minimum to avoid dividing zero. + dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn` is supported for now. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization. + """ + assert ( + x.shape[-1] % group_size == 0 + ), "the last dimension of `x` cannot be divisible by `group_size`" + assert x.is_contiguous(), "`x` is not contiguous" + + finfo = torch.finfo(dtype) + fp8_min = finfo.min + fp8_max = finfo.max + + x_q = torch.empty_like(x, device=x.device, dtype=dtype) + M = x.numel() // group_size + N = group_size + x_s = torch.empty( + x.shape[:-1] + (x.shape[-1] // group_size,), + device=x.device, + dtype=torch.float32, + ) + + BLOCK = triton.next_power_of_2(N) + # heuristics for number of warps + num_warps = min(max(BLOCK // 256, 1), 8) + num_stages = 1 + _per_token_group_quant_fp8[(M,)]( + x, + x_q, + x_s, + group_size, + N, + eps, + fp8_min=fp8_min, + fp8_max=fp8_max, + BLOCK=BLOCK, + num_warps=num_warps, + num_stages=num_stages, + ) + + return x_q, x_s + + +@triton.jit +def _w8a8_block_fp8_matmul( + # Pointers to inputs and output + A, + B, + C, + As, + Bs, + # Shape for matmul + M, + N, + K, + # Block size for block-wise quantization + group_n, + group_k, + # Stride for inputs and output + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_As_m, + stride_As_k, + stride_Bs_k, + stride_Bs_n, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + """Triton-accelerated function used to perform linear operations (dot + product) on input tensors `A` and `B` with block-wise quantization, and store the result in output + tensor `C`. + """ + + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + As_ptrs = As + offs_am * stride_As_m + offs_bsn = offs_bn // group_n + Bs_ptrs = Bs + offs_bsn * stride_Bs_n + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + + k_start = k * BLOCK_SIZE_K + offs_ks = k_start // group_k + a_s = tl.load(As_ptrs + offs_ks * stride_As_k) + b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k) + + accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :] + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + if C.dtype.element_ty == tl.bfloat16: + c = accumulator.to(tl.bfloat16) + elif C.dtype.element_ty == tl.float16: + c = accumulator.to(tl.float16) + else: + c = accumulator.to(tl.float32) + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) + + +def w8a8_block_fp8_matmul( + A: torch.Tensor, + B: torch.Tensor, + As: torch.Tensor, + Bs: torch.Tensor, + block_size: List[int], + output_dtype: torch.dtype = torch.float16, +) -> torch.Tensor: + """This function performs matrix multiplication with block-wise quantization. + + It takes two input tensors `A` and `B` with scales `As` and `Bs`. + The output is returned in the specified `output_dtype`. + + Args: + A: The input tensor, e.g., activation. + B: The input tensor, e.g., weight. + As: The per-token-group quantization scale for `A`. + Bs: The per-block quantization scale for `B`. + block_size: The block size for per-block quantization. It should be 2-dim, e.g., [128, 128]. + output_dytpe: The dtype of the returned tensor. + + Returns: + torch.Tensor: The result of matmul. + """ + assert len(block_size) == 2 + block_n, block_k = block_size[0], block_size[1] + + assert A.shape[-1] == B.shape[-1] + assert A.shape[:-1] == As.shape[:-1] and A.is_contiguous() + assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1] + M = A.numel() // A.shape[-1] + + assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2 + N, K = B.shape + assert triton.cdiv(N, block_n) == Bs.shape[0] + assert triton.cdiv(K, block_k) == Bs.shape[1] + + C_shape = A.shape[:-1] + (N,) + C = A.new_empty(C_shape, dtype=output_dtype) + + # TODO(HandH1998): + # BLOCK_SIZE_M, BLOCK_SIZE_K, BLOCK_SIZE_N can be optimized. + # BLOCK_SIZE_K must be divisable by block_k + # BLOCK_SIZE_N and BLOCK_SIZE_M has no requirements + BLOCK_SIZE_M = 128 + if M < BLOCK_SIZE_M: + BLOCK_SIZE_M = triton.next_power_of_2(M) + BLOCK_SIZE_M = max(BLOCK_SIZE_M, 16) + BLOCK_SIZE_K = block_k + assert block_k % BLOCK_SIZE_K == 0 + BLOCK_SIZE_N = block_n + + def grid(META): + return ( + triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), + ) + + _w8a8_block_fp8_matmul[grid]( + A, + B, + C, + As, + Bs, + M, + N, + K, + block_n, + block_k, + A.stride(-2), + A.stride(-1), + B.stride(1), + B.stride(0), + C.stride(-2), + C.stride(-1), + As.stride(-2), + As.stride(-1), + Bs.stride(1), + Bs.stride(0), + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_K=BLOCK_SIZE_K, + GROUP_SIZE_M=8, + ) + + return C diff --git a/python/sglang/srt/layers/quantization/fp8_utils.py b/python/sglang/srt/layers/quantization/fp8_utils.py index 3ba381a373f..deb3c91e854 100644 --- a/python/sglang/srt/layers/quantization/fp8_utils.py +++ b/python/sglang/srt/layers/quantization/fp8_utils.py @@ -1,6 +1,12 @@ -from typing import Optional, Tuple +from typing import List, Optional, Tuple import torch +from vllm.model_executor.parameter import RowvLLMParameter, _ColumnvLLMParameter + +from sglang.srt.layers.quantization.fp8_kernel import ( + per_token_group_quant_fp8, + w8a8_block_fp8_matmul, +) def normalize_e4m3fn_to_e4m3fnuz( @@ -25,3 +31,86 @@ def normalize_e4m3fn_to_e4m3fnuz( if input_scale is not None: input_scale = input_scale * 2.0 return weight, weight_scale, input_scale + + +def apply_w8a8_block_fp8_linear( + input: torch.Tensor, + weight: torch.Tensor, + block_size: List[int], + weight_scale: torch.Tensor, + input_scale: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, +) -> torch.Tensor: + assert input_scale is None + # View input as 2D matrix for fp8 methods + input_2d = input.view(-1, input.shape[-1]) + output_shape = [*input.shape[:-1], weight.shape[0]] + + q_input, x_scale = per_token_group_quant_fp8(input_2d, block_size[1]) + output = w8a8_block_fp8_matmul( + q_input, weight, x_scale, weight_scale, block_size, output_dtype=input.dtype + ) + + if bias is not None: + output = output + bias + return output.to(dtype=input.dtype).view(*output_shape) + + +def input_to_float8( + x: torch.Tensor, dtype: torch.dtype = torch.float8_e4m3fn +) -> Tuple[torch.Tensor, torch.Tensor]: + """This function quantizes input values to float8 values with tensor-wise quantization.""" + finfo = torch.finfo(dtype) + min_val, max_val = x.aminmax() + amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12) + scale = finfo.max / amax + x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max) + return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal() + + +def block_quant_to_tensor_quant( + x_q_block: torch.Tensor, + x_s: torch.Tensor, + block_size: List[int], +) -> Tuple[torch.Tensor, torch.Tensor]: + """This function converts block-wise quantization to tensor-wise quantization. + The inputs are block-wise quantization tensor `x_q_block`, block-wise quantization scale + and the block size. + The outputs are tensor-wise quantization tensor and tensor-wise quantization scale. + Note only float8 is supported for now. + """ + block_n, block_k = block_size[0], block_size[1] + n, k = x_q_block.shape + n_tiles = (n + block_n - 1) // block_n + k_tiles = (k + block_k - 1) // block_k + assert n_tiles == x_s.shape[0] + assert k_tiles == x_s.shape[1] + + x_dq_block = x_q_block.to(torch.float32) + + x_dq_block_tiles = [ + [ + x_dq_block[ + j * block_n : min((j + 1) * block_n, n), + i * block_k : min((i + 1) * block_k, k), + ] + for i in range(k_tiles) + ] + for j in range(n_tiles) + ] + + for i in range(k_tiles): + for j in range(n_tiles): + x_dq_block_tiles[j][i][:, :] = x_dq_block_tiles[j][i] * x_s[j][i] + + x_q_tensor, scale = input_to_float8(x_dq_block, dtype=x_q_block.dtype) + return x_q_tensor, scale + + +class BlockQuantScaleParameter(_ColumnvLLMParameter, RowvLLMParameter): + """ + Parameter class for weight scales loaded for weights with + block-wise quantization. Uses both column and row parallelism. + """ + + pass diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 92b987a23a3..c56430ce0a0 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -43,6 +43,10 @@ from sglang.srt.layers.moe.ep_moe.layer import EPMoE from sglang.srt.layers.moe.fused_moe_triton import FusedMoE from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.layers.quantization.fp8_utils import ( + block_quant_to_tensor_quant, + input_to_float8, +) from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, @@ -186,15 +190,6 @@ def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float: return 0.1 * mscale * math.log(scale) + 1.0 -def input_to_float8(x, dtype=torch.float8_e4m3fn): - finfo = torch.finfo(dtype) - min_val, max_val = x.aminmax() - amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12) - scale = finfo.max / amax - x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max) - return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal() - - class DeepseekV2Attention(nn.Module): def __init__( @@ -869,6 +864,16 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: + # TODO(HandH1998): Modify it when nextn is supported. + if hasattr(self.config, "num_nextn_predict_layers"): + num_nextn_layers = self.config.num_nextn_predict_layers + if num_nextn_layers > 0 and name.startswith("model.layers"): + name_list = name.split(".") + if ( + len(name_list) >= 3 + and int(name_list[2]) >= self.config.num_hidden_layers + ): + continue if "rotary_emb.inv_freq" in name: continue for param_name, weight_name, shard_id in stacked_params_mapping: @@ -933,13 +938,33 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ).T else: w = self_attn.kv_b_proj.weight + # NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`. + # This may affect the accuracy of fp8 model. + if ( + hasattr(self.quant_config, "weight_block_size") + and w.dtype == torch.float8_e4m3fn + ): + weight_block_size = self.quant_config.weight_block_size + if weight_block_size is not None: + assert hasattr(self_attn.kv_b_proj, "weight_scale_inv") + w, scale = block_quant_to_tensor_quant( + w, self_attn.kv_b_proj.weight_scale_inv, weight_block_size + ) + self_attn.w_scale = scale w_kc, w_vc = w.unflatten( 0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim) ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1) self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2) self_attn.w_vc = w_vc.contiguous().transpose(1, 2) - if hasattr(self_attn.kv_b_proj, "weight_scale"): + if ( + hasattr(self_attn.kv_b_proj, "weight_scale") + and self_attn.w_scale is None + ): self_attn.w_scale = self_attn.kv_b_proj.weight_scale -EntryClass = DeepseekV2ForCausalLM +class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM): + pass + + +EntryClass = [DeepseekV2ForCausalLM, DeepseekV3ForCausalLM] diff --git a/python/sglang/test/test_block_fp8.py b/python/sglang/test/test_block_fp8.py new file mode 100644 index 00000000000..3a02531e695 --- /dev/null +++ b/python/sglang/test/test_block_fp8.py @@ -0,0 +1,341 @@ +import itertools +import unittest + +import torch + +from sglang.srt.layers.activation import SiluAndMul +from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe +from sglang.srt.layers.quantization.fp8_kernel import ( + per_token_group_quant_fp8, + w8a8_block_fp8_matmul, +) + + +# For test +def native_per_token_group_quant_fp8( + x, group_size, eps=1e-10, dtype=torch.float8_e4m3fn +): + """Function to perform per-token-group quantization on an input tensor `x` using native torch. + + It converts the tensor values into float8 values and returns the + quantized tensor along with the scaling factor used for quantization. + Note that only `torch.float8_e4m3fn` is supported for now. + """ + assert ( + x.shape[-1] % group_size == 0 + ), "the last dimension of `x` cannot be divisible by `group_size`" + assert x.is_contiguous(), "`x` is not contiguous" + + finfo = torch.finfo(dtype) + fp8_min = finfo.min + fp8_max = finfo.max + + x_ = x.reshape(x.numel() // group_size, group_size) + amax = x_.abs().max(dim=-1, keepdim=True)[0].clamp(min=eps).to(torch.float32) + x_s = amax / fp8_max + x_q = (x_ / x_s).clamp(min=fp8_min, max=fp8_max).to(dtype) + x_q = x_q.reshape(x.shape) + x_s = x_s.reshape(x.shape[:-1] + (x.shape[-1] // group_size,)) + + return x_q, x_s + + +class TestPerTokenGroupQuantFP8(unittest.TestCase): + DTYPES = [torch.half, torch.bfloat16, torch.float32] + NUM_TOKENS = [7, 83, 2048] + D = [512, 4096, 5120, 13824] + GROUP_SIZE = [64, 128, 256, 512] + SEEDS = [0] + + @classmethod + def setUpClass(cls): + if not torch.cuda.is_available(): + raise unittest.SkipTest("CUDA is not available") + torch.set_default_device("cuda") + + def _per_token_group_quant_fp8(self, num_tokens, d, dtype, group_size, seed): + torch.manual_seed(seed) + + x = torch.rand(num_tokens, d, dtype=dtype) + + with torch.inference_mode(): + ref_out, ref_scale = native_per_token_group_quant_fp8(x, group_size) + out, scale = per_token_group_quant_fp8(x, group_size) + + self.assertTrue( + torch.allclose(out.to(torch.float32), ref_out.to(torch.float32), rtol=0.15) + ) + self.assertTrue(torch.allclose(scale, ref_scale)) + + def test_per_token_group_quant_fp8(self): + for params in itertools.product( + self.NUM_TOKENS, + self.D, + self.DTYPES, + self.GROUP_SIZE, + self.SEEDS, + ): + with self.subTest( + num_tokens=params[0], + d=params[1], + dtype=params[2], + group_size=params[3], + seed=params[4], + ): + self._per_token_group_quant_fp8(*params) + + +# For test +def native_w8a8_block_fp8_matmul(A, B, As, Bs, block_size, output_dtype=torch.float16): + """This function performs matrix multiplication with block-wise quantization using native torch. + + It takes two input tensors `A` and `B` with scales `As` and `Bs`. + The output is returned in the specified `output_dtype`. + """ + + A = A.to(torch.float32) + B = B.to(torch.float32) + assert A.shape[-1] == B.shape[-1] + assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2 + assert len(block_size) == 2 + block_n, block_k = block_size[0], block_size[1] + assert (A.shape[-1] + block_k - 1) // block_k == As.shape[-1] + assert A.shape[:-1] == As.shape[:-1] + + M = A.numel() // A.shape[-1] + N, K = B.shape + origin_C_shape = A.shape[:-1] + (N,) + A = A.reshape(M, A.shape[-1]) + As = As.reshape(M, As.shape[-1]) + n_tiles = (N + block_n - 1) // block_n + k_tiles = (K + block_k - 1) // block_k + assert n_tiles == Bs.shape[0] + assert k_tiles == Bs.shape[1] + + C_shape = (M, N) + C = torch.zeros(C_shape, dtype=torch.float32, device=A.device) + + A_tiles = [A[:, i * block_k : min((i + 1) * block_k, K)] for i in range(k_tiles)] + B_tiles = [ + [ + B[ + j * block_n : min((j + 1) * block_n, N), + i * block_k : min((i + 1) * block_k, K), + ] + for i in range(k_tiles) + ] + for j in range(n_tiles) + ] + C_tiles = [C[:, j * block_n : min((j + 1) * block_n, N)] for j in range(n_tiles)] + As_tiles = [As[:, i : i + 1] for i in range(k_tiles)] + + for i in range(k_tiles): + for j in range(n_tiles): + a = A_tiles[i] + b = B_tiles[j][i] + c = C_tiles[j] + s = As_tiles[i] * Bs[j][i] + c[:, :] += torch.matmul(a, b.t()) * s + + C = C.reshape(origin_C_shape).to(output_dtype) + return C + + +class TestW8A8BlockFP8Matmul(unittest.TestCase): + OUT_DTYPES = [torch.float32, torch.half, torch.bfloat16] + M = [1, 7, 83, 512, 2048] + N = [128, 512, 1024, 4096, 7748, 13824] + K = [256, 4096, 5120, 3884, 13824] + # BLOCK_SIZE = [[64, 64], [64, 128], [128, 64], [128, 128]] + BLOCK_SIZE = [[128, 128]] + SEEDS = [0] + + @classmethod + def setUpClass(cls): + if not torch.cuda.is_available(): + raise unittest.SkipTest("CUDA is not available") + torch.set_default_device("cuda") + + def _w8a8_block_fp8_matmul(self, M, N, K, block_size, out_dtype, seed): + torch.manual_seed(seed) + # NOTE(HandH1998): to avoid overflow when out_dtype = torch.half + factor_for_scale = 1e-2 + fp8_info = torch.finfo(torch.float8_e4m3fn) + fp8_max, fp8_min = fp8_info.max, fp8_info.min + + A_fp32 = (torch.rand(M, K, dtype=torch.float32) - 0.5) * 2 * fp8_max + A_fp8 = A_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) + + B_fp32 = (torch.rand(N, K, dtype=torch.float32) - 0.5) * 2 * fp8_max + B_fp8 = B_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) + + block_n, block_k = block_size[0], block_size[1] + n_tiles = (N + block_n - 1) // block_n + k_tiles = (K + block_k - 1) // block_k + + As = torch.rand(M, k_tiles, dtype=torch.float32) * factor_for_scale + Bs = torch.rand(n_tiles, k_tiles, dtype=torch.float32) * factor_for_scale + + with torch.inference_mode(): + ref_out = native_w8a8_block_fp8_matmul( + A_fp8, B_fp8, As, Bs, block_size, out_dtype + ) + out = w8a8_block_fp8_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype) + + self.assertTrue( + torch.mean(torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) + / torch.mean(torch.abs(ref_out.to(torch.float32))) + < 0.001 + ) + + def test_w8a8_block_fp8_matmul(self): + for params in itertools.product( + self.M, + self.N, + self.K, + self.BLOCK_SIZE, + self.OUT_DTYPES, + self.SEEDS, + ): + with self.subTest( + M=params[0], + N=params[1], + K=params[2], + block_size=params[3], + out_dtype=params[4], + seed=params[5], + ): + self._w8a8_block_fp8_matmul(*params) + + +# For test +def torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_shape): + """This function performs fused moe with block-wise quantization using native torch.""" + + B, D = a.shape + a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) + out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device) + score = torch.softmax(score, dim=-1, dtype=torch.float32) + topk_weight, topk_ids = torch.topk(score, topk) + topk_weight = topk_weight.view(-1) + topk_ids = topk_ids.view(-1) + + _, block_k = block_shape[0], block_shape[1] + a_q, a_s = native_per_token_group_quant_fp8(a, block_k) + # NOTE(HandH1998): Since "index_cuda" not implemented for 'Float8_e4m3fn', we need to cast `float8`` to `float32``. + a_q = a_q.to(torch.float32) + for i in range(w1.shape[0]): + mask = topk_ids == i + if mask.sum(): + inter_out = native_w8a8_block_fp8_matmul( + a_q[mask], w1[i], a_s[mask], w1_s[i], block_shape, output_dtype=a.dtype + ) + act_out = SiluAndMul().forward_native(inter_out) + act_out_q, act_out_s = native_per_token_group_quant_fp8(act_out, block_k) + act_out = act_out.to(torch.float32) + out[mask] = native_w8a8_block_fp8_matmul( + act_out_q, w2[i], act_out_s, w2_s[i], block_shape, output_dtype=a.dtype + ) + return ( + out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype) + ).sum(dim=1) + + +class TestW8A8BlockFP8FusedMoE(unittest.TestCase): + DTYPES = [torch.float32, torch.half, torch.bfloat16] + M = [1, 33, 64, 222, 1024 * 128] + N = [128, 1024, 2048] + K = [256, 4096, 5120] + E = [8, 24] + TOP_KS = [2, 6] + BLOCK_SIZE = [[64, 64], [64, 128], [128, 64], [128, 128]] + # BLOCK_SIZE = [[128, 128]] + SEEDS = [0] + + @classmethod + def setUpClass(cls): + if not torch.cuda.is_available(): + raise unittest.SkipTest("CUDA is not available") + torch.set_default_device("cuda") + + def _w8a8_block_fp8_fused_moe(self, M, N, K, E, topk, block_size, dtype, seed): + torch.manual_seed(seed) + # NOTE(HandH1998): to avoid overflow when out_dtype = torch.half + factor_for_scale = 1e-2 + fp8_info = torch.finfo(torch.float8_e4m3fn) + fp8_max, fp8_min = fp8_info.max, fp8_info.min + + a = torch.randn((M, K), dtype=dtype) / 10 + + w1_fp32 = (torch.rand((E, 2 * N, K), dtype=torch.float32) - 0.5) * 2 * fp8_max + w1 = w1_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) + + w2_fp32 = (torch.rand((E, K, N), dtype=torch.float32) - 0.5) * 2 * fp8_max + w2 = w2_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) + + block_n, block_k = block_size[0], block_size[1] + n_tiles_w1 = (2 * N + block_n - 1) // block_n + n_tiles_w2 = (K + block_n - 1) // block_n + k_tiles_w1 = (K + block_k - 1) // block_k + k_tiles_w2 = (N + block_k - 1) // block_k + + w1_s = ( + torch.rand((E, n_tiles_w1, k_tiles_w1), dtype=torch.float32) + * factor_for_scale + ) + w2_s = ( + torch.rand((E, n_tiles_w2, k_tiles_w2), dtype=torch.float32) + * factor_for_scale + ) + + score = torch.randn((M, E), dtype=dtype) + + with torch.inference_mode(): + out = fused_moe( + a, + w1, + w2, + score, + topk, + renormalize=False, + use_fp8_w8a8=True, + w1_scale=w1_s, + w2_scale=w2_s, + block_shape=block_size, + ) + ref_out = torch_w8a8_block_fp8_moe( + a, w1, w2, w1_s, w2_s, score, topk, block_size + ) + + self.assertTrue( + torch.mean(torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) + / torch.mean(torch.abs(ref_out.to(torch.float32))) + < 0.02 + ) + + def test_w8a8_block_fp8_fused_moe(self): + for params in itertools.product( + self.M, + self.N, + self.K, + self.E, + self.TOP_KS, + self.BLOCK_SIZE, + self.DTYPES, + self.SEEDS, + ): + with self.subTest( + M=params[0], + N=params[1], + K=params[2], + E=params[3], + topk=params[4], + block_size=params[5], + dtype=params[6], + seed=params[7], + ): + self._w8a8_block_fp8_fused_moe(*params) + + +if __name__ == "__main__": + unittest.main(verbosity=2)