From eea00490d24621fc0a84ee6fd16040f5790ac863 Mon Sep 17 00:00:00 2001 From: mgoin Date: Thu, 26 Dec 2024 14:12:55 +0000 Subject: [PATCH 1/9] Support deepseek_v3 w8a8 fp8 block-wise quantization Signed-off-by: mgoin --- .../layers/fused_moe/fused_moe.py | 129 +++++-- vllm/model_executor/layers/fused_moe/layer.py | 7 +- vllm/model_executor/layers/linear.py | 21 +- .../model_executor/layers/quantization/fp8.py | 199 ++++++++-- .../layers/quantization/utils/fp8_utils.py | 346 ++++++++++++++++++ vllm/model_executor/models/registry.py | 1 + vllm/model_executor/parameter.py | 9 + 7 files changed, 650 insertions(+), 62 deletions(-) create mode 100644 vllm/model_executor/layers/quantization/utils/fp8_utils.py diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index e6f9f01ef0f74..db308b6d4e597 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -2,7 +2,7 @@ import functools import json import os -from typing import Any, Callable, Dict, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple import torch import triton @@ -11,6 +11,8 @@ import vllm.envs as envs from vllm import _custom_ops as ops from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + per_token_group_quant_fp8) from vllm.platforms import current_platform from vllm.utils import direct_register_custom_op @@ -45,8 +47,14 @@ def fused_moe_kernel( stride_bn, stride_cm, stride_cn, + stride_asm, + stride_ask, stride_bse, + stride_bsk, stride_bsn, + # Block size for block-wise quantization + group_n: tl.constexpr, + group_k: tl.constexpr, # Meta-parameters BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, @@ -125,8 +133,14 @@ def fused_moe_kernel( b_scale = tl.load(b_scale_ptrs) if use_fp8_w8a8: - a_scale = tl.load(a_scale_ptr) - b_scale = tl.load(b_scale_ptr + off_experts) + if group_k > 0 and group_n > 0: + a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm + offs_bsn = offs_bn // group_n + b_scale_ptrs = (b_scale_ptr + off_experts * stride_bse + + offs_bsn * stride_bsn) + else: + a_scale = tl.load(a_scale_ptr) + b_scale = tl.load(b_scale_ptr + off_experts) # ----------------------------------------------------------- # Iterate to compute a block of the C matrix. @@ -149,7 +163,18 @@ def fused_moe_kernel( if use_int8_w8a16: accumulator = tl.dot(a, b.to(compute_type), acc=accumulator) elif use_fp8_w8a8: - accumulator = tl.dot(a, b, acc=accumulator) + if group_k > 0 and group_n > 0: + k_start = k * BLOCK_SIZE_K + offs_ks = k_start // group_k + a_scale = tl.load(a_scale_ptrs + offs_ks * stride_ask, + mask=token_mask, + other=0.0) + b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk) + + accumulator += tl.dot(a, b) * a_scale[:, + None] * b_scale[None, :] + else: + accumulator = tl.dot(a, b, acc=accumulator) else: accumulator += tl.dot(a, b) # Advance the ptrs to the next K block. @@ -164,7 +189,10 @@ def fused_moe_kernel( if use_int8_w8a16: accumulator = (accumulator * b_scale).to(compute_type) elif use_fp8_w8a8: - accumulator = (accumulator * a_scale * b_scale).to(compute_type) + if group_k > 0 and group_n > 0: + accumulator = accumulator.to(compute_type) + else: + accumulator = (accumulator * a_scale * b_scale).to(compute_type) else: accumulator = accumulator.to(compute_type) # ----------------------------------------------------------- @@ -233,22 +261,37 @@ def moe_align_block_size( return sorted_ids, expert_ids, num_tokens_post_pad -def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, +def invoke_fused_moe_kernel(A: torch.Tensor, + B: torch.Tensor, + C: torch.Tensor, A_scale: Optional[torch.Tensor], B_scale: Optional[torch.Tensor], - topk_weights: torch.Tensor, topk_ids: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, sorted_token_ids: torch.Tensor, expert_ids: torch.Tensor, num_tokens_post_padded: torch.Tensor, - mul_routed_weight: bool, top_k: int, - config: Dict[str, Any], compute_type: tl.dtype, - use_fp8_w8a8: bool, use_int8_w8a16: bool) -> None: + mul_routed_weight: bool, + top_k: int, + config: Dict[str, Any], + compute_type: tl.dtype, + use_fp8_w8a8: bool, + use_int8_w8a16: bool, + block_shape: Optional[List[int]] = None) -> None: assert topk_weights.stride(1) == 1 assert sorted_token_ids.stride(0) == 1 if use_fp8_w8a8: - A, A_scale = ops.scaled_fp8_quant(A, A_scale) assert B_scale is not None + if block_shape is None: + A, A_scale = ops.scaled_fp8_quant(A, A_scale) + else: + assert len(block_shape) == 2 + block_n, block_k = block_shape[0], block_shape[1] + A, A_scale = per_token_group_quant_fp8(A, block_k) + assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1] + assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2] + assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1] elif use_int8_w8a16: assert B_scale is not None else: @@ -279,8 +322,13 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, B.stride(1), C.stride(1), C.stride(2), - B_scale.stride(0) if B_scale is not None and use_int8_w8a16 else 0, - B_scale.stride(1) if B_scale is not None and use_int8_w8a16 else 0, + A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0, + A_scale.stride(1) if A_scale is not None and A_scale.ndim == 2 else 0, + B_scale.stride(0) if B_scale is not None and B_scale.ndim >= 2 else 0, + B_scale.stride(2) if B_scale is not None and B_scale.ndim == 3 else 0, + B_scale.stride(1) if B_scale is not None and B_scale.ndim >= 2 else 0, + 0 if block_shape is None else block_shape[0], + 0 if block_shape is None else block_shape[1], MUL_ROUTED_WEIGHT=mul_routed_weight, top_k=top_k, compute_type=compute_type, @@ -362,6 +410,7 @@ def try_get_optimal_moe_config( dtype: Optional[str], M: int, is_marlin: bool = False, + block_shape: Optional[List[int]] = None, ): from vllm.model_executor.layers.fused_moe import get_config override_config = get_config() @@ -380,6 +429,12 @@ def try_get_optimal_moe_config( # Else use the default config config = get_default_config(M, E, N, w1_shape[2], top_k, dtype, is_marlin) + # NOTE: For block-wise quant, + # BLOCK_K must be divisible by block_shape[1] + # BLOCK_N and BLOCK_M has no requirements + if block_shape is not None: + config["BLOCK_SIZE_N"] = block_shape[0] + config["BLOCK_SIZE_K"] = block_shape[1] return config @@ -479,10 +534,11 @@ def inplace_fused_experts(hidden_states: torch.Tensor, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None) -> None: + a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None) -> None: fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True, use_fp8_w8a8, use_int8_w8a16, w1_scale, w2_scale, - a1_scale, a2_scale) + a1_scale, a2_scale, block_shape) def inplace_fused_experts_fake( @@ -496,7 +552,8 @@ def inplace_fused_experts_fake( w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None) -> None: + a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None) -> None: pass @@ -519,7 +576,8 @@ def outplace_fused_experts( w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None) -> torch.Tensor: + a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None) -> torch.Tensor: return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, False, use_fp8_w8a8, use_int8_w8a16, w1_scale, w2_scale, a1_scale, a2_scale) @@ -536,7 +594,8 @@ def outplace_fused_experts_fake( w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None) -> torch.Tensor: + a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None) -> torch.Tensor: return torch.empty_like(hidden_states) @@ -559,18 +618,22 @@ def fused_experts(hidden_states: torch.Tensor, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None): + a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None): if inplace: torch.ops.vllm.inplace_fused_experts(hidden_states, w1, w2, topk_weights, topk_ids, use_fp8_w8a8, use_int8_w8a16, w1_scale, w2_scale, a1_scale, - a2_scale) + a2_scale, block_shape) return hidden_states else: - return torch.ops.vllm.outplace_fused_experts( - hidden_states, w1, w2, topk_weights, topk_ids, use_fp8_w8a8, - use_int8_w8a16, w1_scale, w2_scale, a1_scale, a2_scale) + return torch.ops.vllm.outplace_fused_experts(hidden_states, w1, w2, + topk_weights, topk_ids, + use_fp8_w8a8, + use_int8_w8a16, w1_scale, + w2_scale, a1_scale, + a2_scale, block_shape) def fused_experts_impl(hidden_states: torch.Tensor, @@ -584,7 +647,8 @@ def fused_experts_impl(hidden_states: torch.Tensor, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None): + a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None): # Check constraints. assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" @@ -611,6 +675,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, w2.shape, topk_ids.shape[1], config_dtype, + block_shape=block_shape, ) config = get_config_func(M) @@ -674,7 +739,8 @@ def fused_experts_impl(hidden_states: torch.Tensor, config, compute_type=compute_type, use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a16=use_int8_w8a16) + use_int8_w8a16=use_int8_w8a16, + block_shape=block_shape) ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) @@ -693,7 +759,8 @@ def fused_experts_impl(hidden_states: torch.Tensor, config, compute_type=compute_type, use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a16=use_int8_w8a16) + use_int8_w8a16=use_int8_w8a16, + block_shape=block_shape) ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.shape), out_hidden_states[begin_chunk_idx:end_chunk_idx]) @@ -718,6 +785,7 @@ def fused_moe( w2_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None, ) -> torch.Tensor: """ This function computes a Mixture of Experts (MoE) layer using two sets of @@ -745,6 +813,12 @@ def fused_moe( w1. - w2_scale (Optional[torch.Tensor]): Optional scale to be used for w2. + - a1_scale (Optional[torch.Tensor]): Optional scale to be used for + a1. + - a2_scale (Optional[torch.Tensor]): Optional scale to be used for + a2. + - block_shape: (Optional[List[int]]): Optional block size for block-wise + quantization. Returns: - torch.Tensor: The output tensor after applying the MoE layer. @@ -775,4 +849,5 @@ def fused_moe( w1_scale=w1_scale, w2_scale=w2_scale, a1_scale=a1_scale, - a2_scale=a2_scale) + a2_scale=a2_scale, + block_shape=block_shape) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 8c6f7c6e06515..55c0a202920ff 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -29,6 +29,7 @@ class FusedMoeWeightScaleSupported(Enum): TENSOR = "tensor" CHANNEL = "channel" GROUP = "group" + BLOCK = "block" class FusedMoEMethodBase(QuantizeMethodBase): @@ -199,6 +200,7 @@ def __init__( get_tensor_model_parallel_world_size()) self.top_k = top_k self.num_experts = num_experts + assert intermediate_size % self.tp_size == 0 self.intermediate_size_per_partition = intermediate_size // self.tp_size self.reduce_results = reduce_results self.renormalize = renormalize @@ -398,7 +400,10 @@ def weight_loader(self, param: torch.nn.Parameter, loaded_weight=loaded_weight, expert_data=expert_data, tp_rank=tp_rank) - elif quant_method == FusedMoeWeightScaleSupported.GROUP.value: + elif quant_method in [ + FusedMoeWeightScaleSupported.GROUP.value, + FusedMoeWeightScaleSupported.BLOCK.value, + ]: self._load_model_weight_or_group_weight_scale( shard_id=shard_id, shard_dim=shard_dim, diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 46ef11e7d02c6..9686554ba6d0f 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -15,6 +15,7 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.parameter import (BasevLLMParameter, + BlockQuantScaleParameter, PackedColumnParameter, PackedvLLMParameter, PerTensorScaleParameter, @@ -623,8 +624,24 @@ def weight_loader_v2(self, assert loaded_shard_id < len(self.output_sizes) tp_size = get_tensor_model_parallel_world_size() - shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size - shard_size = self.output_sizes[loaded_shard_id] // tp_size + + if isinstance(param, BlockQuantScaleParameter): + from vllm.model_executor.layers.quantization.fp8 import ( + Fp8LinearMethod, Fp8MoEMethod) + assert self.quant_method is not None + assert isinstance(self.quant_method, + (Fp8LinearMethod, Fp8MoEMethod)) + weight_block_size = self.quant_method.quant_config.weight_block_size + assert weight_block_size is not None + block_n, _ = weight_block_size[0], weight_block_size[1] + shard_offset = ( + (sum(self.output_sizes[:loaded_shard_id]) + block_n - 1) // + block_n) // tp_size + shard_size = ((self.output_sizes[loaded_shard_id] + block_n - 1) // + block_n // tp_size) + else: + shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size + shard_size = self.output_sizes[loaded_shard_id] // tp_size param.load_merged_column_weight(loaded_weight=loaded_weight, shard_id=loaded_shard_id, diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 978e727bc7cb3..5dfd86727a02a 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -6,6 +6,7 @@ import vllm.envs as envs from vllm import _custom_ops as ops +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) @@ -14,6 +15,8 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + apply_w8a8_block_fp8_linear) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin) from vllm.model_executor.layers.quantization.utils.quant_utils import ( @@ -22,7 +25,8 @@ all_close_1d, apply_fp8_linear, convert_to_channelwise, cutlass_fp8_supported, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize, requantize_with_max_scale) -from vllm.model_executor.parameter import (ModelWeightParameter, +from vllm.model_executor.parameter import (BlockQuantScaleParameter, + ModelWeightParameter, PerTensorScaleParameter) from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform @@ -41,6 +45,7 @@ def __init__( is_checkpoint_fp8_serialized: bool = False, activation_scheme: str = "dynamic", ignored_layers: Optional[List[str]] = None, + weight_block_size: Optional[List[int]] = None, ) -> None: self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized if is_checkpoint_fp8_serialized: @@ -51,6 +56,20 @@ def __init__( f"Unsupported activation scheme {activation_scheme}") self.activation_scheme = activation_scheme self.ignored_layers = ignored_layers or [] + if weight_block_size is not None: + if not is_checkpoint_fp8_serialized: + raise ValueError( + "The block-wise quantization only supports fp8-serialized " + "checkpoint for now.") + if len(weight_block_size) != 2: + raise ValueError( + "The quantization block size of weight must have 2 " + f"dimensions, but got {len(weight_block_size)} dimensions") + if activation_scheme != "dynamic": + raise ValueError("The block-wise quantization only supports " + "dynamic activation scheme for now, but got " + f"{activation_scheme} activation scheme.") + self.weight_block_size = weight_block_size @classmethod def get_name(cls) -> str: @@ -74,9 +93,12 @@ def from_config(cls, config: Dict[str, Any]) -> "Fp8Config": is_checkpoint_fp8_serialized = ("fp8" in quant_method) activation_scheme = cls.get_from_keys(config, ["activation_scheme"]) ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None) + weight_block_size = cls.get_from_keys_or(config, ["weight_block_size"], + None) return cls(is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized, activation_scheme=activation_scheme, - ignored_layers=ignored_layers) + ignored_layers=ignored_layers, + weight_block_size=weight_block_size) def get_quant_method(self, layer: torch.nn.Module, prefix: str) -> Optional["QuantizeMethodBase"]: @@ -123,6 +145,11 @@ def __init__(self, quant_config: Fp8Config): if current_platform.is_rocm(): self.use_marlin = False + self.block_quant = self.quant_config.weight_block_size is not None + if self.block_quant: + # Marlin doesn't support block-wise fp8 + self.use_marlin = False + def create_weights( self, layer: torch.nn.Module, @@ -133,10 +160,34 @@ def create_weights( params_dtype: torch.dtype, **extra_weight_attrs, ): - del input_size, output_size output_size_per_partition = sum(output_partition_sizes) weight_loader = extra_weight_attrs.get("weight_loader") + if self.block_quant: + tp_size = get_tensor_model_parallel_world_size() + assert self.quant_config.weight_block_size is not None + block_n, block_k = ( + self.quant_config.weight_block_size[0], + self.quant_config.weight_block_size[1], + ) + # Required by row parallel + if (tp_size > 1 + and input_size // input_size_per_partition == tp_size + and input_size_per_partition % block_k != 0): + raise ValueError( + f"Weight input_size_per_partition = " + f"{input_size_per_partition} is not divisible by " + f"weight quantization block_k = {block_k}.") + # Required by column parallel or enabling merged weights + if (tp_size > 1 and output_size // output_size_per_partition + == tp_size) or len(output_partition_sizes) > 1: + for output_partition_size in output_partition_sizes: + if output_partition_size % block_n != 0: + raise ValueError( + f"Weight output_partition_size = " + f"{output_partition_size} is not divisible by " + f"weight quantization block_n = {block_n}.") + layer.logical_widths = output_partition_sizes layer.input_size_per_partition = input_size_per_partition @@ -161,12 +212,29 @@ def create_weights( # Otherwise, wait until process_weights_after_loading. if self.quant_config.is_checkpoint_fp8_serialized: # WEIGHT SCALE - scale = PerTensorScaleParameter(data=torch.empty( - len(output_partition_sizes), dtype=torch.float32), - weight_loader=weight_loader) - - scale[:] = torch.finfo(torch.float32).min - layer.register_parameter("weight_scale", scale) + if not self.block_quant: + scale = PerTensorScaleParameter( + data=torch.empty(len(output_partition_sizes), + dtype=torch.float32), + weight_loader=weight_loader, + ) + scale[:] = torch.finfo(torch.float32).min + layer.register_parameter("weight_scale", scale) + else: + assert self.quant_config.activation_scheme == "dynamic" + scale = BlockQuantScaleParameter( + data=torch.empty( + (output_size_per_partition + block_n - 1) // block_n, + (input_size_per_partition + block_k - 1) // block_k, + dtype=torch.float32, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + scale[:] = torch.finfo(torch.float32).min + # The weight_scale_inv name is intentional for deepseekv3 + layer.register_parameter("weight_scale_inv", scale) # INPUT ACTIVATION SCALE if self.quant_config.activation_scheme == "static": @@ -180,6 +248,9 @@ def create_weights( layer.register_parameter("input_scale", None) def process_weights_after_loading(self, layer: Module) -> None: + # Block quant doesn't need to process weights after loading + if self.block_quant: + return layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False) # If checkpoint not serialized fp8, quantize the weights. @@ -266,6 +337,17 @@ def apply(self, size_k=layer.input_size_per_partition, bias=bias) + if self.block_quant: + assert self.quant_config.weight_block_size is not None + return apply_w8a8_block_fp8_linear( + input=x, + weight=layer.weight, + block_size=self.quant_config.weight_block_size, + weight_scale=layer.weight_scale_inv, + input_scale=layer.input_scale, + bias=bias, + ) + return apply_fp8_linear( input=x, weight=layer.weight, @@ -291,6 +373,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): def __init__(self, quant_config: Fp8Config): self.quant_config = quant_config + self.block_quant = self.quant_config.weight_block_size is not None def create_weights(self, layer: Module, num_experts: int, hidden_size: int, intermediate_size: int, params_dtype: torch.dtype, @@ -298,6 +381,27 @@ def create_weights(self, layer: Module, num_experts: int, hidden_size: int, if self.quant_config.is_checkpoint_fp8_serialized: params_dtype = torch.float8_e4m3fn + if self.block_quant: + assert self.quant_config.weight_block_size is not None + tp_size = get_tensor_model_parallel_world_size() + block_n, block_k = ( + self.quant_config.weight_block_size[0], + self.quant_config.weight_block_size[1], + ) + # NOTE: To ensure proper alignment of the block-wise quantization + # scales, the output_size of the weights for both the gate and up + # layers must be divisible by block_n. + # Required by column parallel or enabling merged weights + if intermediate_size % block_n != 0: + raise ValueError( + f"The output_size of gate's and up's weight = " + f"{intermediate_size} is not divisible by " + f"weight quantization block_n = {block_n}.") + if (tp_size > 1 and intermediate_size % block_k != 0): + # Required by row parallel + raise ValueError(f"The input_size of down's weight = " + f"{intermediate_size} is not divisible by " + f"weight quantization block_k = {block_k}.") # WEIGHTS w13_weight = torch.nn.Parameter(torch.empty(num_experts, @@ -317,21 +421,45 @@ def create_weights(self, layer: Module, num_experts: int, hidden_size: int, set_weight_attrs(w2_weight, extra_weight_attrs) # WEIGHT_SCALES - # Allocate 2 scales for w1 and w3 respectively. - # They will be combined to a single scale after weight loading. - w13_weight_scale = torch.nn.Parameter(torch.ones(num_experts, - 2, - dtype=torch.float32), - requires_grad=False) - layer.register_parameter("w13_weight_scale", w13_weight_scale) + if not self.block_quant: + # Allocate 2 scales for w1 and w3 respectively. + # They will be combined to a single scale after weight loading. + w13_weight_scale = torch.nn.Parameter(torch.ones( + num_experts, 2, dtype=torch.float32), + requires_grad=False) + w2_weight_scale = torch.nn.Parameter(torch.ones( + num_experts, dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + else: + w13_weight_scale = torch.nn.Parameter( + torch.ones( + num_experts, + 2 * ((intermediate_size + block_n - 1) // block_n), + (hidden_size + block_k - 1) // block_k, + dtype=torch.float32, + ), + requires_grad=False, + ) + w2_weight_scale = torch.nn.Parameter( + torch.ones( + num_experts, + (hidden_size + block_n - 1) // block_n, + (intermediate_size + block_k - 1) // block_k, + dtype=torch.float32, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight_scale_inv", w13_weight_scale) + layer.register_parameter("w2_weight_scale_inv", w2_weight_scale) + assert self.quant_config.activation_scheme == "dynamic" - w2_weight_scale = torch.nn.Parameter(torch.ones(num_experts, - dtype=torch.float32), - requires_grad=False) - layer.register_parameter("w2_weight_scale", w2_weight_scale) # Add the quantization method used (per tensor/grouped/channel) # to ensure the weight scales are loaded in properly extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.BLOCK. + value} if self.block_quant else {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}) # If loading fp8 checkpoint, pass the weight loaders. # If loading an fp16 checkpoint, do not (we will quantize in @@ -364,7 +492,9 @@ def create_weights(self, layer: Module, num_experts: int, hidden_size: int, layer.w2_input_scale = None def process_weights_after_loading(self, layer: Module) -> None: - + # Block quant doesn't need to process weights after loading + if self.block_quant: + return # If checkpoint is fp16, quantize in place. if not self.quant_config.is_checkpoint_fp8_serialized: # If rocm, use float8_e4m3fnuz as dtype @@ -489,17 +619,22 @@ def apply( num_expert_group=num_expert_group, custom_routing_function=custom_routing_function) - return fused_experts(x, - layer.w13_weight, - layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - inplace=True, - use_fp8_w8a8=True, - w1_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, - a1_scale=layer.w13_input_scale, - a2_scale=layer.w2_input_scale) + return fused_experts( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=True, + use_fp8_w8a8=True, + w1_scale=(layer.w13_weight_scale_inv + if self.block_quant else layer.w13_weight_scale), + w2_scale=(layer.w2_weight_scale_inv + if self.block_quant else layer.w2_weight_scale), + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale, + block_shape=self.quant_config.weight_block_size, + ) class Fp8KVCacheMethod(BaseKVCacheMethod): diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py new file mode 100644 index 0000000000000..e6c005bfcb94e --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -0,0 +1,346 @@ +from typing import List, Tuple, Optional + +import torch +import triton +import triton.language as tl + + +def apply_w8a8_block_fp8_linear( + input: torch.Tensor, + weight: torch.Tensor, + block_size: List[int], + weight_scale: torch.Tensor, + input_scale: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, +) -> torch.Tensor: + assert input_scale is None + # View input as 2D matrix for fp8 methods + input_2d = input.view(-1, input.shape[-1]) + output_shape = [*input.shape[:-1], weight.shape[0]] + + q_input, x_scale = per_token_group_quant_fp8(input_2d, block_size[1]) + output = w8a8_block_fp8_matmul( + q_input, weight, x_scale, weight_scale, block_size, output_dtype=input.dtype + ) + + if bias is not None: + output = output + bias + return output.to(dtype=input.dtype).view(*output_shape) + + +def input_to_float8( + x: torch.Tensor, dtype: torch.dtype = torch.float8_e4m3fn +) -> Tuple[torch.Tensor, torch.Tensor]: + """This function quantizes input values to float8 values with tensor-wise quantization.""" + finfo = torch.finfo(dtype) + min_val, max_val = x.aminmax() + amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12) + scale = finfo.max / amax + x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max) + return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal() + + +def block_quant_to_tensor_quant( + x_q_block: torch.Tensor, + x_s: torch.Tensor, + block_size: List[int], +) -> Tuple[torch.Tensor, torch.Tensor]: + """This function converts block-wise quantization to tensor-wise quantization. + The inputs are block-wise quantization tensor `x_q_block`, block-wise quantization scale + and the block size. + The outputs are tensor-wise quantization tensor and tensor-wise quantization scale. + Note only float8 is supported for now. + """ + block_n, block_k = block_size[0], block_size[1] + n, k = x_q_block.shape + n_tiles = (n + block_n - 1) // block_n + k_tiles = (k + block_k - 1) // block_k + assert n_tiles == x_s.shape[0] + assert k_tiles == x_s.shape[1] + + x_dq_block = x_q_block.to(torch.float32) + + x_dq_block_tiles = [ + [ + x_dq_block[ + j * block_n : min((j + 1) * block_n, n), + i * block_k : min((i + 1) * block_k, k), + ] + for i in range(k_tiles) + ] + for j in range(n_tiles) + ] + + for i in range(k_tiles): + for j in range(n_tiles): + x_dq_block_tiles[j][i][:, :] = x_dq_block_tiles[j][i] * x_s[j][i] + + x_q_tensor, scale = input_to_float8(x_dq_block, dtype=x_q_block.dtype) + return x_q_tensor, scale + + + +@triton.jit +def _per_token_group_quant_fp8( + # Pointers to inputs and output + y_ptr, + y_q_ptr, + y_s_ptr, + # Stride of input + y_stride, + # Collums of input + N, + # Avoid to divide zero + eps, + # Information for float8 + fp8_min, + fp8_max, + # Meta-parameters + BLOCK: tl.constexpr, +): + """A Triton-accelerated function to perform per-token-group quantization on a + tensor. + This function converts the tensor values into float8 values. + """ + # Map the program id to the row of X and Y it should compute. + g_id = tl.program_id(0) + y_ptr += g_id * y_stride + y_q_ptr += g_id * y_stride + y_s_ptr += g_id + + cols = tl.arange(0, BLOCK) # N <= BLOCK + mask = cols < N + + y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32) + # Quant + _absmax = tl.maximum(tl.max(tl.abs(y)), eps) + y_s = _absmax / fp8_max + y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) + + tl.store(y_q_ptr + cols, y_q, mask=mask) + tl.store(y_s_ptr, y_s) + + +def per_token_group_quant_fp8( + x: torch.Tensor, + group_size: int, + eps: float = 1e-10, + dtype: torch.dtype = torch.float8_e4m3fn, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Function to perform per-token-group quantization on an input tensor `x`. + It converts the tensor values into signed float8 values and returns the + quantized tensor along with the scaling factor used for quantization. + Args: + x: The input tenosr with ndim >= 2. + group_size: The group size used for quantization. + eps: The minimum to avoid dividing zero. + dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn` is supported for now. + Returns: + Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization. + """ + assert ( + x.shape[-1] % group_size == 0 + ), "the last dimension of `x` cannot be divisible by `group_size`" + assert x.is_contiguous(), "`x` is not contiguous" + + finfo = torch.finfo(dtype) + fp8_min = finfo.min + fp8_max = finfo.max + + x_q = torch.empty_like(x, device=x.device, dtype=dtype) + M = x.numel() // group_size + N = group_size + x_s = torch.empty( + x.shape[:-1] + (x.shape[-1] // group_size,), + device=x.device, + dtype=torch.float32, + ) + + BLOCK = triton.next_power_of_2(N) + # heuristics for number of warps + num_warps = min(max(BLOCK // 256, 1), 8) + num_stages = 1 + _per_token_group_quant_fp8[(M,)]( + x, + x_q, + x_s, + group_size, + N, + eps, + fp8_min=fp8_min, + fp8_max=fp8_max, + BLOCK=BLOCK, + num_warps=num_warps, + num_stages=num_stages, + ) + + return x_q, x_s + + +@triton.jit +def _w8a8_block_fp8_matmul( + # Pointers to inputs and output + A, + B, + C, + As, + Bs, + # Shape for matmul + M, + N, + K, + # Block size for block-wise quantization + group_n, + group_k, + # Stride for inputs and output + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_As_m, + stride_As_k, + stride_Bs_k, + stride_Bs_n, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + """Triton-accelerated function used to perform linear operations (dot + product) on input tensors `A` and `B` with block-wise quantization, and store the result in output + tensor `C`. + """ + + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + As_ptrs = As + offs_am * stride_As_m + offs_bsn = offs_bn // group_n + Bs_ptrs = Bs + offs_bsn * stride_Bs_n + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + + k_start = k * BLOCK_SIZE_K + offs_ks = k_start // group_k + a_s = tl.load(As_ptrs + offs_ks * stride_As_k) + b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k) + + accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :] + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + if C.dtype.element_ty == tl.bfloat16: + c = accumulator.to(tl.bfloat16) + elif C.dtype.element_ty == tl.float16: + c = accumulator.to(tl.float16) + else: + c = accumulator.to(tl.float32) + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) + + +def w8a8_block_fp8_matmul( + A: torch.Tensor, + B: torch.Tensor, + As: torch.Tensor, + Bs: torch.Tensor, + block_size: List[int], + output_dtype: torch.dtype = torch.float16, +) -> torch.Tensor: + """This function performs matrix multiplication with block-wise quantization. + It takes two input tensors `A` and `B` with scales `As` and `Bs`. + The output is returned in the specified `output_dtype`. + Args: + A: The input tensor, e.g., activation. + B: The input tensor, e.g., weight. + As: The per-token-group quantization scale for `A`. + Bs: The per-block quantization scale for `B`. + block_size: The block size for per-block quantization. It should be 2-dim, e.g., [128, 128]. + output_dytpe: The dtype of the returned tensor. + Returns: + torch.Tensor: The result of matmul. + """ + assert len(block_size) == 2 + block_n, block_k = block_size[0], block_size[1] + + assert A.shape[-1] == B.shape[-1] + assert A.shape[:-1] == As.shape[:-1] and A.is_contiguous() + assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1] + M = A.numel() // A.shape[-1] + + assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2 + N, K = B.shape + assert triton.cdiv(N, block_n) == Bs.shape[0] + assert triton.cdiv(K, block_k) == Bs.shape[1] + + C_shape = A.shape[:-1] + (N,) + C = A.new_empty(C_shape, dtype=output_dtype) + + # TODO(HandH1998): + # BLOCK_SIZE_M, BLOCK_SIZE_K, BLOCK_SIZE_N can be optimized. + # BLOCK_SIZE_K must be divisable by block_k + # BLOCK_SIZE_N and BLOCK_SIZE_M has no requirements + BLOCK_SIZE_M = 128 + if M < BLOCK_SIZE_M: + BLOCK_SIZE_M = triton.next_power_of_2(M) + BLOCK_SIZE_M = max(BLOCK_SIZE_M, 16) + BLOCK_SIZE_K = block_k + assert block_k % BLOCK_SIZE_K == 0 + BLOCK_SIZE_N = block_n + + def grid(META): + return ( + triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), + ) + + _w8a8_block_fp8_matmul[grid]( + A, + B, + C, + As, + Bs, + M, + N, + K, + block_n, + block_k, + A.stride(-2), + A.stride(-1), + B.stride(1), + B.stride(0), + C.stride(-2), + C.stride(-1), + As.stride(-2), + As.stride(-1), + Bs.stride(1), + Bs.stride(0), + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_K=BLOCK_SIZE_K, + GROUP_SIZE_M=8, + ) + + return C \ No newline at end of file diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index b32a3421d5841..9aa9eb0871fd4 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -45,6 +45,7 @@ "DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"), "DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"), "DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"), + "DeepseekV3ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"), "ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"), "FalconForCausalLM": ("falcon", "FalconForCausalLM"), "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"), diff --git a/vllm/model_executor/parameter.py b/vllm/model_executor/parameter.py index 7a6d7c90f34d5..02d22a5ca62c0 100644 --- a/vllm/model_executor/parameter.py +++ b/vllm/model_executor/parameter.py @@ -328,6 +328,15 @@ def adjust_shard_indexes_for_packing(self, shard_size, shard_offset): marlin_tile_size=self.marlin_tile_size) +class BlockQuantScaleParameter(_ColumnvLLMParameter, RowvLLMParameter): + """ + Parameter class for weight scales loaded for weights with + block-wise quantization. Uses both column and row parallelism. + """ + + pass + + def permute_param_layout_(param: BasevLLMParameter, input_dim: int, output_dim: int, **kwargs) -> BasevLLMParameter: """ From 083d904c77b2a8947b7b428b6dfb8f5fee491265 Mon Sep 17 00:00:00 2001 From: mgoin Date: Thu, 26 Dec 2024 14:16:46 +0000 Subject: [PATCH 2/9] Format Signed-off-by: mgoin --- vllm/model_executor/layers/linear.py | 9 +-- .../layers/quantization/utils/fp8_utils.py | 60 +++++++++---------- 2 files changed, 33 insertions(+), 36 deletions(-) diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 9686554ba6d0f..8d75df15a906e 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -14,12 +14,9 @@ from vllm.logger import init_logger from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) -from vllm.model_executor.parameter import (BasevLLMParameter, - BlockQuantScaleParameter, - PackedColumnParameter, - PackedvLLMParameter, - PerTensorScaleParameter, - RowvLLMParameter) +from vllm.model_executor.parameter import ( + BasevLLMParameter, BlockQuantScaleParameter, PackedColumnParameter, + PackedvLLMParameter, PerTensorScaleParameter, RowvLLMParameter) from vllm.model_executor.utils import set_weight_attrs logger = init_logger(__name__) diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index e6c005bfcb94e..7d08895866288 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -19,9 +19,12 @@ def apply_w8a8_block_fp8_linear( output_shape = [*input.shape[:-1], weight.shape[0]] q_input, x_scale = per_token_group_quant_fp8(input_2d, block_size[1]) - output = w8a8_block_fp8_matmul( - q_input, weight, x_scale, weight_scale, block_size, output_dtype=input.dtype - ) + output = w8a8_block_fp8_matmul(q_input, + weight, + x_scale, + weight_scale, + block_size, + output_dtype=input.dtype) if bias is not None: output = output + bias @@ -29,7 +32,8 @@ def apply_w8a8_block_fp8_linear( def input_to_float8( - x: torch.Tensor, dtype: torch.dtype = torch.float8_e4m3fn + x: torch.Tensor, + dtype: torch.dtype = torch.float8_e4m3fn ) -> Tuple[torch.Tensor, torch.Tensor]: """This function quantizes input values to float8 values with tensor-wise quantization.""" finfo = torch.finfo(dtype) @@ -60,16 +64,11 @@ def block_quant_to_tensor_quant( x_dq_block = x_q_block.to(torch.float32) - x_dq_block_tiles = [ - [ - x_dq_block[ - j * block_n : min((j + 1) * block_n, n), - i * block_k : min((i + 1) * block_k, k), - ] - for i in range(k_tiles) - ] - for j in range(n_tiles) - ] + x_dq_block_tiles = [[ + x_dq_block[j * block_n:min((j + 1) * block_n, n), + i * block_k:min((i + 1) * block_k, k), ] + for i in range(k_tiles) + ] for j in range(n_tiles)] for i in range(k_tiles): for j in range(n_tiles): @@ -79,7 +78,6 @@ def block_quant_to_tensor_quant( return x_q_tensor, scale - @triton.jit def _per_token_group_quant_fp8( # Pointers to inputs and output @@ -88,7 +86,7 @@ def _per_token_group_quant_fp8( y_s_ptr, # Stride of input y_stride, - # Collums of input + # Columns of input N, # Avoid to divide zero eps, @@ -138,9 +136,8 @@ def per_token_group_quant_fp8( Returns: Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization. """ - assert ( - x.shape[-1] % group_size == 0 - ), "the last dimension of `x` cannot be divisible by `group_size`" + assert (x.shape[-1] % group_size == 0 + ), "the last dimension of `x` cannot be divisible by `group_size`" assert x.is_contiguous(), "`x` is not contiguous" finfo = torch.finfo(dtype) @@ -151,7 +148,7 @@ def per_token_group_quant_fp8( M = x.numel() // group_size N = group_size x_s = torch.empty( - x.shape[:-1] + (x.shape[-1] // group_size,), + x.shape[:-1] + (x.shape[-1] // group_size, ), device=x.device, dtype=torch.float32, ) @@ -160,7 +157,7 @@ def per_token_group_quant_fp8( # heuristics for number of warps num_warps = min(max(BLOCK // 256, 1), 8) num_stages = 1 - _per_token_group_quant_fp8[(M,)]( + _per_token_group_quant_fp8[(M, )]( x, x_q, x_s, @@ -236,8 +233,12 @@ def _w8a8_block_fp8_matmul( accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): - a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) - b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + a = tl.load(a_ptrs, + mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, + other=0.0) + b = tl.load(b_ptrs, + mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, + other=0.0) k_start = k * BLOCK_SIZE_K offs_ks = k_start // group_k @@ -296,12 +297,12 @@ def w8a8_block_fp8_matmul( assert triton.cdiv(N, block_n) == Bs.shape[0] assert triton.cdiv(K, block_k) == Bs.shape[1] - C_shape = A.shape[:-1] + (N,) + C_shape = A.shape[:-1] + (N, ) C = A.new_empty(C_shape, dtype=output_dtype) - # TODO(HandH1998): + # TODO: # BLOCK_SIZE_M, BLOCK_SIZE_K, BLOCK_SIZE_N can be optimized. - # BLOCK_SIZE_K must be divisable by block_k + # BLOCK_SIZE_K must be divisible by block_k # BLOCK_SIZE_N and BLOCK_SIZE_M has no requirements BLOCK_SIZE_M = 128 if M < BLOCK_SIZE_M: @@ -312,9 +313,8 @@ def w8a8_block_fp8_matmul( BLOCK_SIZE_N = block_n def grid(META): - return ( - triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), - ) + return (triton.cdiv(M, META["BLOCK_SIZE_M"]) * + triton.cdiv(N, META["BLOCK_SIZE_N"]), ) _w8a8_block_fp8_matmul[grid]( A, @@ -343,4 +343,4 @@ def grid(META): GROUP_SIZE_M=8, ) - return C \ No newline at end of file + return C From e71e6aa330e7a6cd4af3dc7005ab8fe9fd18451f Mon Sep 17 00:00:00 2001 From: mgoin Date: Thu, 26 Dec 2024 14:22:21 +0000 Subject: [PATCH 3/9] Format Signed-off-by: mgoin --- tests/kernels/test_block_fp8.py | 249 ++++++++++++++++++ vllm/model_executor/layers/linear.py | 9 +- .../layers/quantization/utils/fp8_utils.py | 35 +-- 3 files changed, 275 insertions(+), 18 deletions(-) create mode 100644 tests/kernels/test_block_fp8.py diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py new file mode 100644 index 0000000000000..11a881812fb6f --- /dev/null +++ b/tests/kernels/test_block_fp8.py @@ -0,0 +1,249 @@ +import itertools + +import pytest +import torch + +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.fused_moe import fused_moe +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + per_token_group_quant_fp8, w8a8_block_fp8_matmul) + +# Test configurations +DTYPES = [torch.half, torch.bfloat16, torch.float32] +NUM_TOKENS = [7, 83, 2048] +D = [512, 4096, 5120, 13824] +GROUP_SIZE = [64, 128, 256, 512] +M = [1, 7, 83, 512, 2048] +N = [128, 512, 1024, 4096, 7748, 13824] +K = [256, 4096, 5120, 3884, 13824] +BLOCK_SIZE = [[128, 128]] +E = [8, 24] +TOP_KS = [2, 6] +OUT_DTYPES = [torch.float32, torch.half, torch.bfloat16] +SEEDS = [0] + + +def native_per_token_group_quant_fp8(x, + group_size, + eps=1e-10, + dtype=torch.float8_e4m3fn): + """Function to perform per-token-group quantization on an input tensor + `x` using native torch.""" + assert x.shape[-1] % group_size == 0, ("the last dimension of `x` cannot " + "be divisible by `group_size`") + assert x.is_contiguous(), "`x` is not contiguous" + + finfo = torch.finfo(dtype) + fp8_min = finfo.min + fp8_max = finfo.max + + x_ = x.reshape(x.numel() // group_size, group_size) + amax = x_.abs().max(dim=-1, + keepdim=True)[0].clamp(min=eps).to(torch.float32) + x_s = amax / fp8_max + x_q = (x_ / x_s).clamp(min=fp8_min, max=fp8_max).to(dtype) + x_q = x_q.reshape(x.shape) + x_s = x_s.reshape(x.shape[:-1] + (x.shape[-1] // group_size, )) + + return x_q, x_s + + +def native_w8a8_block_fp8_matmul(A, + B, + As, + Bs, + block_size, + output_dtype=torch.float16): + """Matrix multiplication with block-wise quantization using native torch.""" + A = A.to(torch.float32) + B = B.to(torch.float32) + assert A.shape[-1] == B.shape[-1] + assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2 + assert len(block_size) == 2 + block_n, block_k = block_size[0], block_size[1] + assert (A.shape[-1] + block_k - 1) // block_k == As.shape[-1] + assert A.shape[:-1] == As.shape[:-1] + + M = A.numel() // A.shape[-1] + N, K = B.shape + origin_C_shape = A.shape[:-1] + (N, ) + A = A.reshape(M, A.shape[-1]) + As = As.reshape(M, As.shape[-1]) + n_tiles = (N + block_n - 1) // block_n + k_tiles = (K + block_k - 1) // block_k + assert n_tiles == Bs.shape[0] + assert k_tiles == Bs.shape[1] + + C_shape = (M, N) + C = torch.zeros(C_shape, dtype=torch.float32, device=A.device) + + A_tiles = [ + A[:, i * block_k:min((i + 1) * block_k, K)] for i in range(k_tiles) + ] + B_tiles = [[ + B[j * block_n:min((j + 1) * block_n, N), + i * block_k:min((i + 1) * block_k, K), ] for i in range(k_tiles) + ] for j in range(n_tiles)] + C_tiles = [ + C[:, j * block_n:min((j + 1) * block_n, N)] for j in range(n_tiles) + ] + As_tiles = [As[:, i:i + 1] for i in range(k_tiles)] + + for i in range(k_tiles): + for j in range(n_tiles): + a = A_tiles[i] + b = B_tiles[j][i] + c = C_tiles[j] + s = As_tiles[i] * Bs[j][i] + c[:, :] += torch.matmul(a, b.t()) * s + + C = C.reshape(origin_C_shape).to(output_dtype) + return C + + +def torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_shape): + """Fused moe with block-wise quantization using native torch.""" + B, D = a.shape + a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) + out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device) + score = torch.softmax(score, dim=-1, dtype=torch.float32) + topk_weight, topk_ids = torch.topk(score, topk) + topk_weight = topk_weight.view(-1) + topk_ids = topk_ids.view(-1) + + _, block_k = block_shape[0], block_shape[1] + a_q, a_s = native_per_token_group_quant_fp8(a, block_k) + a_q = a_q.to(torch.float32) + for i in range(w1.shape[0]): + mask = topk_ids == i + if mask.sum(): + inter_out = native_w8a8_block_fp8_matmul(a_q[mask], + w1[i], + a_s[mask], + w1_s[i], + block_shape, + output_dtype=a.dtype) + act_out = SiluAndMul().forward_native(inter_out) + act_out_q, act_out_s = native_per_token_group_quant_fp8( + act_out, block_k) + act_out = act_out.to(torch.float32) + out[mask] = native_w8a8_block_fp8_matmul(act_out_q, + w2[i], + act_out_s, + w2_s[i], + block_shape, + output_dtype=a.dtype) + return (out.view(B, -1, w2.shape[1]) * + topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1) + + +# Skip all tests if CUDA is not available +pytest.importorskip("torch.cuda") + + +@pytest.fixture(autouse=True) +def setup_cuda(): + torch.set_default_device("cuda") + + +@pytest.mark.parametrize("num_tokens,d,dtype,group_size,seed", + itertools.product(NUM_TOKENS, D, DTYPES, GROUP_SIZE, + SEEDS)) +@torch.inference_mode() +def test_per_token_group_quant_fp8(num_tokens, d, dtype, group_size, seed): + torch.manual_seed(seed) + x = torch.rand(num_tokens, d, dtype=dtype) + + ref_out, ref_scale = native_per_token_group_quant_fp8(x, group_size) + out, scale = per_token_group_quant_fp8(x, group_size) + + assert torch.allclose(out.to(torch.float32), + ref_out.to(torch.float32), + rtol=0.15) + assert torch.allclose(scale, ref_scale) + + +@pytest.mark.parametrize("M,N,K,block_size,out_dtype,seed", + itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, + SEEDS)) +@torch.inference_mode() +def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed): + torch.manual_seed(seed) + factor_for_scale = 1e-2 + fp8_info = torch.finfo(torch.float8_e4m3fn) + fp8_max, fp8_min = fp8_info.max, fp8_info.min + + A_fp32 = (torch.rand(M, K, dtype=torch.float32) - 0.5) * 2 * fp8_max + A_fp8 = A_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) + + B_fp32 = (torch.rand(N, K, dtype=torch.float32) - 0.5) * 2 * fp8_max + B_fp8 = B_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) + + block_n, block_k = block_size[0], block_size[1] + n_tiles = (N + block_n - 1) // block_n + k_tiles = (K + block_k - 1) // block_k + + As = torch.rand(M, k_tiles, dtype=torch.float32) * factor_for_scale + Bs = torch.rand(n_tiles, k_tiles, dtype=torch.float32) * factor_for_scale + + ref_out = native_w8a8_block_fp8_matmul(A_fp8, B_fp8, As, Bs, block_size, + out_dtype) + out = w8a8_block_fp8_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype) + + rel_diff = (torch.mean( + torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) / + torch.mean(torch.abs(ref_out.to(torch.float32)))) + assert rel_diff < 0.001 + + +@pytest.mark.parametrize("M,N,K,E,topk,block_size,dtype,seed", + itertools.product(M, N, K, E, TOP_KS, BLOCK_SIZE, + DTYPES, SEEDS)) +@torch.inference_mode() +def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed): + torch.manual_seed(seed) + factor_for_scale = 1e-2 + fp8_info = torch.finfo(torch.float8_e4m3fn) + fp8_max, fp8_min = fp8_info.max, fp8_info.min + + a = torch.randn((M, K), dtype=dtype) / 10 + + w1_fp32 = (torch.rand( + (E, 2 * N, K), dtype=torch.float32) - 0.5) * 2 * fp8_max + w1 = w1_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) + + w2_fp32 = (torch.rand((E, K, N), dtype=torch.float32) - 0.5) * 2 * fp8_max + w2 = w2_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) + + block_n, block_k = block_size[0], block_size[1] + n_tiles_w1 = (2 * N + block_n - 1) // block_n + n_tiles_w2 = (K + block_n - 1) // block_n + k_tiles_w1 = (K + block_k - 1) // block_k + k_tiles_w2 = (N + block_k - 1) // block_k + + w1_s = torch.rand( + (E, n_tiles_w1, k_tiles_w1), dtype=torch.float32) * factor_for_scale + w2_s = torch.rand( + (E, n_tiles_w2, k_tiles_w2), dtype=torch.float32) * factor_for_scale + + score = torch.randn((M, E), dtype=dtype) + + out = fused_moe( + a, + w1, + w2, + score, + topk, + renormalize=False, + use_fp8_w8a8=True, + w1_scale=w1_s, + w2_scale=w2_s, + block_shape=block_size, + ) + ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, + block_size) + + rel_diff = (torch.mean( + torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) / + torch.mean(torch.abs(ref_out.to(torch.float32)))) + assert rel_diff < 0.02 diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 8d75df15a906e..9686554ba6d0f 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -14,9 +14,12 @@ from vllm.logger import init_logger from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) -from vllm.model_executor.parameter import ( - BasevLLMParameter, BlockQuantScaleParameter, PackedColumnParameter, - PackedvLLMParameter, PerTensorScaleParameter, RowvLLMParameter) +from vllm.model_executor.parameter import (BasevLLMParameter, + BlockQuantScaleParameter, + PackedColumnParameter, + PackedvLLMParameter, + PerTensorScaleParameter, + RowvLLMParameter) from vllm.model_executor.utils import set_weight_attrs logger = init_logger(__name__) diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 7d08895866288..0d42a8e300e5a 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -1,4 +1,4 @@ -from typing import List, Tuple, Optional +from typing import List, Optional, Tuple import torch import triton @@ -35,7 +35,8 @@ def input_to_float8( x: torch.Tensor, dtype: torch.dtype = torch.float8_e4m3fn ) -> Tuple[torch.Tensor, torch.Tensor]: - """This function quantizes input values to float8 values with tensor-wise quantization.""" + """This function quantizes input values to float8 values " + "with tensor-wise quantization.""" finfo = torch.finfo(dtype) min_val, max_val = x.aminmax() amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12) @@ -49,11 +50,11 @@ def block_quant_to_tensor_quant( x_s: torch.Tensor, block_size: List[int], ) -> Tuple[torch.Tensor, torch.Tensor]: - """This function converts block-wise quantization to tensor-wise quantization. - The inputs are block-wise quantization tensor `x_q_block`, block-wise quantization scale - and the block size. - The outputs are tensor-wise quantization tensor and tensor-wise quantization scale. - Note only float8 is supported for now. + """This function converts block-wise quantization to tensor-wise + quantization. The inputs are block-wise quantization tensor `x_q_block`, + block-wise quantization scale and the block size. + The outputs are tensor-wise quantization tensor and tensor-wise + quantization scale. Note only float8 is supported for now. """ block_n, block_k = block_size[0], block_size[1] n, k = x_q_block.shape @@ -96,8 +97,8 @@ def _per_token_group_quant_fp8( # Meta-parameters BLOCK: tl.constexpr, ): - """A Triton-accelerated function to perform per-token-group quantization on a - tensor. + """A Triton-accelerated function to perform per-token-group + quantization on a tensor. This function converts the tensor values into float8 values. """ # Map the program id to the row of X and Y it should compute. @@ -132,9 +133,11 @@ def per_token_group_quant_fp8( x: The input tenosr with ndim >= 2. group_size: The group size used for quantization. eps: The minimum to avoid dividing zero. - dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn` is supported for now. + dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn` + is supported for now. Returns: - Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization. + Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the + scaling factor for quantization. """ assert (x.shape[-1] % group_size == 0 ), "the last dimension of `x` cannot be divisible by `group_size`" @@ -207,8 +210,8 @@ def _w8a8_block_fp8_matmul( GROUP_SIZE_M: tl.constexpr, ): """Triton-accelerated function used to perform linear operations (dot - product) on input tensors `A` and `B` with block-wise quantization, and store the result in output - tensor `C`. + product) on input tensors `A` and `B` with block-wise quantization, and + store the result in output tensor `C`. """ pid = tl.program_id(axis=0) @@ -271,7 +274,8 @@ def w8a8_block_fp8_matmul( block_size: List[int], output_dtype: torch.dtype = torch.float16, ) -> torch.Tensor: - """This function performs matrix multiplication with block-wise quantization. + """This function performs matrix multiplication with block-wise + quantization. It takes two input tensors `A` and `B` with scales `As` and `Bs`. The output is returned in the specified `output_dtype`. Args: @@ -279,7 +283,8 @@ def w8a8_block_fp8_matmul( B: The input tensor, e.g., weight. As: The per-token-group quantization scale for `A`. Bs: The per-block quantization scale for `B`. - block_size: The block size for per-block quantization. It should be 2-dim, e.g., [128, 128]. + block_size: The block size for per-block quantization. It should + be 2-dim, e.g., [128, 128]. output_dytpe: The dtype of the returned tensor. Returns: torch.Tensor: The result of matmul. From 080725284106054f8ecb945fdadcf9b567ebd16d Mon Sep 17 00:00:00 2001 From: mgoin Date: Thu, 26 Dec 2024 14:34:47 +0000 Subject: [PATCH 4/9] Fix yapf Signed-off-by: mgoin --- vllm/model_executor/layers/linear.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 9686554ba6d0f..33b221b994b2b 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -14,12 +14,14 @@ from vllm.logger import init_logger from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) +# yapf: disable from vllm.model_executor.parameter import (BasevLLMParameter, BlockQuantScaleParameter, PackedColumnParameter, PackedvLLMParameter, PerTensorScaleParameter, RowvLLMParameter) +# yapf: enable from vllm.model_executor.utils import set_weight_attrs logger = init_logger(__name__) From f147947cef44c3f35028b9d1bfea015e7b0e102f Mon Sep 17 00:00:00 2001 From: simon-mo Date: Thu, 26 Dec 2024 12:16:12 -0800 Subject: [PATCH 5/9] remove v3 model Signed-off-by: simon-mo --- vllm/model_executor/models/registry.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 9aa9eb0871fd4..b32a3421d5841 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -45,7 +45,6 @@ "DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"), "DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"), "DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"), - "DeepseekV3ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"), "ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"), "FalconForCausalLM": ("falcon", "FalconForCausalLM"), "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"), From 1cde5cb2dc1f133613f42fcd062e9aabb51d8730 Mon Sep 17 00:00:00 2001 From: simon-mo Date: Thu, 26 Dec 2024 13:14:00 -0800 Subject: [PATCH 6/9] add ack Signed-off-by: simon-mo --- tests/kernels/test_block_fp8.py | 3 ++- vllm/model_executor/layers/quantization/utils/fp8_utils.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index 11a881812fb6f..608a03e21a590 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -1,3 +1,4 @@ +# Adapted from https://github.com/sgl-project/sglang/pull/2575 import itertools import pytest @@ -27,7 +28,7 @@ def native_per_token_group_quant_fp8(x, group_size, eps=1e-10, dtype=torch.float8_e4m3fn): - """Function to perform per-token-group quantization on an input tensor + """Function to perform per-token-group quantization on an input tensor `x` using native torch.""" assert x.shape[-1] % group_size == 0, ("the last dimension of `x` cannot " "be divisible by `group_size`") diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 0d42a8e300e5a..c91a97652ed0b 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -1,3 +1,4 @@ +# Adapted from https://github.com/sgl-project/sglang/pull/2575 from typing import List, Optional, Tuple import torch @@ -283,7 +284,7 @@ def w8a8_block_fp8_matmul( B: The input tensor, e.g., weight. As: The per-token-group quantization scale for `A`. Bs: The per-block quantization scale for `B`. - block_size: The block size for per-block quantization. It should + block_size: The block size for per-block quantization. It should be 2-dim, e.g., [128, 128]. output_dytpe: The dtype of the returned tensor. Returns: From e854c93efc1b365034562b1637dfb2f2d7fe14ae Mon Sep 17 00:00:00 2001 From: simon-mo Date: Thu, 26 Dec 2024 14:25:28 -0800 Subject: [PATCH 7/9] fix outplace Signed-off-by: simon-mo --- vllm/model_executor/layers/fused_moe/fused_moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index db308b6d4e597..92e9ba3c9cebd 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -580,7 +580,7 @@ def outplace_fused_experts( block_shape: Optional[List[int]] = None) -> torch.Tensor: return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, False, use_fp8_w8a8, use_int8_w8a16, w1_scale, - w2_scale, a1_scale, a2_scale) + w2_scale, a1_scale, a2_scale, block_shape) def outplace_fused_experts_fake( From 3e73c3c7070f26efbb8d2a735b906ba293da3d56 Mon Sep 17 00:00:00 2001 From: simon-mo Date: Thu, 26 Dec 2024 15:20:40 -0800 Subject: [PATCH 8/9] fix test Signed-off-by: simon-mo --- tests/kernels/test_block_fp8.py | 60 ++++++++++++------- vllm/config.py | 21 +++---- .../layers/quantization/utils/fp8_utils.py | 14 +++-- 3 files changed, 58 insertions(+), 37 deletions(-) diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index 608a03e21a590..f28fdf3feedbc 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -8,19 +8,29 @@ from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.quantization.utils.fp8_utils import ( per_token_group_quant_fp8, w8a8_block_fp8_matmul) +from vllm.platforms import current_platform + +if current_platform.get_device_capability() < (9, 0): + pytest.skip("FP8 Triton requires CUDA 9.0 or higher", + allow_module_level=True) # Test configurations -DTYPES = [torch.half, torch.bfloat16, torch.float32] +DTYPES = [torch.bfloat16] # [torch.half, torch.bfloat16, torch.float32] NUM_TOKENS = [7, 83, 2048] D = [512, 4096, 5120, 13824] GROUP_SIZE = [64, 128, 256, 512] M = [1, 7, 83, 512, 2048] N = [128, 512, 1024, 4096, 7748, 13824] K = [256, 4096, 5120, 3884, 13824] +# Deepseek-V3's intermediate size 18432, so N is 18432*2/8=4608 at TP8 +# and its hidden size is 7168. +M_moe = [1, 7, 83, 512, 2048] +N_moe = [4608] # [128, 4608, 13824] +K_moe = [7168] # [256, 7168, 13824] BLOCK_SIZE = [[128, 128]] -E = [8, 24] -TOP_KS = [2, 6] -OUT_DTYPES = [torch.float32, torch.half, torch.bfloat16] +E = [256] # [8, 24, 128, 256] +TOP_KS = [1] # [1, 2, 6] +OUT_DTYPES = [torch.bfloat16] # [torch.float32, torch.half, torch.bfloat16] SEEDS = [0] @@ -82,8 +92,10 @@ def native_w8a8_block_fp8_matmul(A, A[:, i * block_k:min((i + 1) * block_k, K)] for i in range(k_tiles) ] B_tiles = [[ - B[j * block_n:min((j + 1) * block_n, N), - i * block_k:min((i + 1) * block_k, K), ] for i in range(k_tiles) + B[ + j * block_n:min((j + 1) * block_n, N), + i * block_k:min((i + 1) * block_k, K), + ] for i in range(k_tiles) ] for j in range(n_tiles)] C_tiles = [ C[:, j * block_n:min((j + 1) * block_n, N)] for j in range(n_tiles) @@ -147,9 +159,9 @@ def setup_cuda(): torch.set_default_device("cuda") -@pytest.mark.parametrize("num_tokens,d,dtype,group_size,seed", - itertools.product(NUM_TOKENS, D, DTYPES, GROUP_SIZE, - SEEDS)) +@pytest.mark.parametrize( + "num_tokens,d,dtype,group_size,seed", + itertools.product(NUM_TOKENS, D, DTYPES, GROUP_SIZE, SEEDS)) @torch.inference_mode() def test_per_token_group_quant_fp8(num_tokens, d, dtype, group_size, seed): torch.manual_seed(seed) @@ -164,9 +176,9 @@ def test_per_token_group_quant_fp8(num_tokens, d, dtype, group_size, seed): assert torch.allclose(scale, ref_scale) -@pytest.mark.parametrize("M,N,K,block_size,out_dtype,seed", - itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, - SEEDS)) +@pytest.mark.parametrize( + "M,N,K,block_size,out_dtype,seed", + itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS)) @torch.inference_mode() def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed): torch.manual_seed(seed) @@ -197,9 +209,10 @@ def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed): assert rel_diff < 0.001 -@pytest.mark.parametrize("M,N,K,E,topk,block_size,dtype,seed", - itertools.product(M, N, K, E, TOP_KS, BLOCK_SIZE, - DTYPES, SEEDS)) +@pytest.mark.parametrize( + "M,N,K,E,topk,block_size,dtype,seed", + itertools.product(M_moe, N_moe, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES, + SEEDS)) @torch.inference_mode() def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed): torch.manual_seed(seed) @@ -209,12 +222,14 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed): a = torch.randn((M, K), dtype=dtype) / 10 - w1_fp32 = (torch.rand( - (E, 2 * N, K), dtype=torch.float32) - 0.5) * 2 * fp8_max - w1 = w1_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) + w1_bf16 = (torch.rand( + (E, 2 * N, K), dtype=torch.bfloat16) - 0.5) * 2 * fp8_max + w1 = w1_bf16.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) + del w1_bf16 - w2_fp32 = (torch.rand((E, K, N), dtype=torch.float32) - 0.5) * 2 * fp8_max - w2 = w2_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) + w2_bf16 = (torch.rand((E, K, N), dtype=torch.bfloat16) - 0.5) * 2 * fp8_max + w2 = w2_bf16.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) + del w2_bf16 block_n, block_k = block_size[0], block_size[1] n_tiles_w1 = (2 * N + block_n - 1) // block_n @@ -244,7 +259,10 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed): ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_size) + print(f"{out.sum()=}") + print(f"{ref_out.sum()=}") + rel_diff = (torch.mean( torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) / torch.mean(torch.abs(ref_out.to(torch.float32)))) - assert rel_diff < 0.02 + assert rel_diff < 0.03 diff --git a/vllm/config.py b/vllm/config.py index 17602bda15c69..df7f6167e6bf5 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -67,7 +67,8 @@ _TASK_RUNNER: Dict[_ResolvedTask, RunnerType] = { task: runner - for runner, tasks in _RUNNER_TASKS.items() for task in tasks + for runner, tasks in _RUNNER_TASKS.items() + for task in tasks } HfOverrides = Union[Dict[str, Any], Callable[[PretrainedConfig], @@ -160,7 +161,7 @@ class ModelConfig: override default pooling config for the pooling model. logits_processor_pattern: Optional regex pattern specifying valid logits processor qualified names that can be passed with the - `logits_processors` extra completion argument. Defaults to None, + `logits_processors` extra completion argument. Defaults to None, which allows no processors. generation_config: Configuration parameter file for generation. """ @@ -363,7 +364,7 @@ def __init__(self, def maybe_pull_model_tokenizer_for_s3(self, model: str, tokenizer: str) -> None: """ - Pull the model config or tokenizer to a temporary + Pull the model config or tokenizer to a temporary directory in case of S3. Args: @@ -874,14 +875,14 @@ def try_get_generation_config(self) -> Dict[str, Any]: def get_diff_sampling_param(self) -> Dict[str, Any]: """ - This method returns a dictionary containing the parameters - that differ from the default sampling parameters, but only - if `generation_config` is set. If `generation_config` is not + This method returns a dictionary containing the parameters + that differ from the default sampling parameters, but only + if `generation_config` is set. If `generation_config` is not set, an empty dictionary is returned. Returns: - Dict[str, Any]: A dictionary with the differing sampling - parameters if `generation_config` is set, otherwise an + Dict[str, Any]: A dictionary with the differing sampling + parameters if `generation_config` is set, otherwise an empty dictionary. """ if self.generation_config is None: @@ -1955,8 +1956,8 @@ def _verify_args(self) -> None: "typical_acceptance_sampler.") if (self.draft_token_acceptance_method != 'rejection_sampler' - and self.draft_token_acceptance_method != - 'typical_acceptance_sampler'): + and self.draft_token_acceptance_method + != 'typical_acceptance_sampler'): raise ValueError( "Expected draft_token_acceptance_method to be either " "rejection_sampler or typical_acceptance_sampler. Instead it " diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index c91a97652ed0b..89b5233b5b076 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -67,9 +67,10 @@ def block_quant_to_tensor_quant( x_dq_block = x_q_block.to(torch.float32) x_dq_block_tiles = [[ - x_dq_block[j * block_n:min((j + 1) * block_n, n), - i * block_k:min((i + 1) * block_k, k), ] - for i in range(k_tiles) + x_dq_block[ + j * block_n:min((j + 1) * block_n, n), + i * block_k:min((i + 1) * block_k, k), + ] for i in range(k_tiles) ] for j in range(n_tiles)] for i in range(k_tiles): @@ -140,9 +141,10 @@ def per_token_group_quant_fp8( Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization. """ - assert (x.shape[-1] % group_size == 0 - ), "the last dimension of `x` cannot be divisible by `group_size`" - assert x.is_contiguous(), "`x` is not contiguous" + assert (x.shape[-1] % group_size == 0), ( + f"the last dimension of `x` {x.shape[-1]} must be divisible " + f"by `group_size` {group_size}") + assert x.is_contiguous(), "`x` must be contiguous" finfo = torch.finfo(dtype) fp8_min = finfo.min From 25a0b3099973afb513626d4c9359d92ebb9f5ef2 Mon Sep 17 00:00:00 2001 From: simon-mo Date: Thu, 26 Dec 2024 15:28:08 -0800 Subject: [PATCH 9/9] format Signed-off-by: simon-mo --- tests/kernels/test_block_fp8.py | 25 ++++++++----------- vllm/config.py | 7 +++--- .../layers/quantization/utils/fp8_utils.py | 7 +++--- 3 files changed, 17 insertions(+), 22 deletions(-) diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index f28fdf3feedbc..a16cc4582a180 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -92,10 +92,8 @@ def native_w8a8_block_fp8_matmul(A, A[:, i * block_k:min((i + 1) * block_k, K)] for i in range(k_tiles) ] B_tiles = [[ - B[ - j * block_n:min((j + 1) * block_n, N), - i * block_k:min((i + 1) * block_k, K), - ] for i in range(k_tiles) + B[j * block_n:min((j + 1) * block_n, N), + i * block_k:min((i + 1) * block_k, K), ] for i in range(k_tiles) ] for j in range(n_tiles)] C_tiles = [ C[:, j * block_n:min((j + 1) * block_n, N)] for j in range(n_tiles) @@ -159,9 +157,9 @@ def setup_cuda(): torch.set_default_device("cuda") -@pytest.mark.parametrize( - "num_tokens,d,dtype,group_size,seed", - itertools.product(NUM_TOKENS, D, DTYPES, GROUP_SIZE, SEEDS)) +@pytest.mark.parametrize("num_tokens,d,dtype,group_size,seed", + itertools.product(NUM_TOKENS, D, DTYPES, GROUP_SIZE, + SEEDS)) @torch.inference_mode() def test_per_token_group_quant_fp8(num_tokens, d, dtype, group_size, seed): torch.manual_seed(seed) @@ -176,9 +174,9 @@ def test_per_token_group_quant_fp8(num_tokens, d, dtype, group_size, seed): assert torch.allclose(scale, ref_scale) -@pytest.mark.parametrize( - "M,N,K,block_size,out_dtype,seed", - itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS)) +@pytest.mark.parametrize("M,N,K,block_size,out_dtype,seed", + itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, + SEEDS)) @torch.inference_mode() def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed): torch.manual_seed(seed) @@ -209,10 +207,9 @@ def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed): assert rel_diff < 0.001 -@pytest.mark.parametrize( - "M,N,K,E,topk,block_size,dtype,seed", - itertools.product(M_moe, N_moe, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES, - SEEDS)) +@pytest.mark.parametrize("M,N,K,E,topk,block_size,dtype,seed", + itertools.product(M_moe, N_moe, K_moe, E, TOP_KS, + BLOCK_SIZE, DTYPES, SEEDS)) @torch.inference_mode() def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed): torch.manual_seed(seed) diff --git a/vllm/config.py b/vllm/config.py index df7f6167e6bf5..bd862b12c89de 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -67,8 +67,7 @@ _TASK_RUNNER: Dict[_ResolvedTask, RunnerType] = { task: runner - for runner, tasks in _RUNNER_TASKS.items() - for task in tasks + for runner, tasks in _RUNNER_TASKS.items() for task in tasks } HfOverrides = Union[Dict[str, Any], Callable[[PretrainedConfig], @@ -1956,8 +1955,8 @@ def _verify_args(self) -> None: "typical_acceptance_sampler.") if (self.draft_token_acceptance_method != 'rejection_sampler' - and self.draft_token_acceptance_method - != 'typical_acceptance_sampler'): + and self.draft_token_acceptance_method != + 'typical_acceptance_sampler'): raise ValueError( "Expected draft_token_acceptance_method to be either " "rejection_sampler or typical_acceptance_sampler. Instead it " diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 89b5233b5b076..f3c3e130e4161 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -67,10 +67,9 @@ def block_quant_to_tensor_quant( x_dq_block = x_q_block.to(torch.float32) x_dq_block_tiles = [[ - x_dq_block[ - j * block_n:min((j + 1) * block_n, n), - i * block_k:min((i + 1) * block_k, k), - ] for i in range(k_tiles) + x_dq_block[j * block_n:min((j + 1) * block_n, n), + i * block_k:min((i + 1) * block_k, k), ] + for i in range(k_tiles) ] for j in range(n_tiles)] for i in range(k_tiles):