diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py new file mode 100644 index 0000000000000..a16cc4582a180 --- /dev/null +++ b/tests/kernels/test_block_fp8.py @@ -0,0 +1,265 @@ +# Adapted from https://github.com/sgl-project/sglang/pull/2575 +import itertools + +import pytest +import torch + +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.fused_moe import fused_moe +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + per_token_group_quant_fp8, w8a8_block_fp8_matmul) +from vllm.platforms import current_platform + +if current_platform.get_device_capability() < (9, 0): + pytest.skip("FP8 Triton requires CUDA 9.0 or higher", + allow_module_level=True) + +# Test configurations +DTYPES = [torch.bfloat16] # [torch.half, torch.bfloat16, torch.float32] +NUM_TOKENS = [7, 83, 2048] +D = [512, 4096, 5120, 13824] +GROUP_SIZE = [64, 128, 256, 512] +M = [1, 7, 83, 512, 2048] +N = [128, 512, 1024, 4096, 7748, 13824] +K = [256, 4096, 5120, 3884, 13824] +# Deepseek-V3's intermediate size 18432, so N is 18432*2/8=4608 at TP8 +# and its hidden size is 7168. +M_moe = [1, 7, 83, 512, 2048] +N_moe = [4608] # [128, 4608, 13824] +K_moe = [7168] # [256, 7168, 13824] +BLOCK_SIZE = [[128, 128]] +E = [256] # [8, 24, 128, 256] +TOP_KS = [1] # [1, 2, 6] +OUT_DTYPES = [torch.bfloat16] # [torch.float32, torch.half, torch.bfloat16] +SEEDS = [0] + + +def native_per_token_group_quant_fp8(x, + group_size, + eps=1e-10, + dtype=torch.float8_e4m3fn): + """Function to perform per-token-group quantization on an input tensor + `x` using native torch.""" + assert x.shape[-1] % group_size == 0, ("the last dimension of `x` cannot " + "be divisible by `group_size`") + assert x.is_contiguous(), "`x` is not contiguous" + + finfo = torch.finfo(dtype) + fp8_min = finfo.min + fp8_max = finfo.max + + x_ = x.reshape(x.numel() // group_size, group_size) + amax = x_.abs().max(dim=-1, + keepdim=True)[0].clamp(min=eps).to(torch.float32) + x_s = amax / fp8_max + x_q = (x_ / x_s).clamp(min=fp8_min, max=fp8_max).to(dtype) + x_q = x_q.reshape(x.shape) + x_s = x_s.reshape(x.shape[:-1] + (x.shape[-1] // group_size, )) + + return x_q, x_s + + +def native_w8a8_block_fp8_matmul(A, + B, + As, + Bs, + block_size, + output_dtype=torch.float16): + """Matrix multiplication with block-wise quantization using native torch.""" + A = A.to(torch.float32) + B = B.to(torch.float32) + assert A.shape[-1] == B.shape[-1] + assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2 + assert len(block_size) == 2 + block_n, block_k = block_size[0], block_size[1] + assert (A.shape[-1] + block_k - 1) // block_k == As.shape[-1] + assert A.shape[:-1] == As.shape[:-1] + + M = A.numel() // A.shape[-1] + N, K = B.shape + origin_C_shape = A.shape[:-1] + (N, ) + A = A.reshape(M, A.shape[-1]) + As = As.reshape(M, As.shape[-1]) + n_tiles = (N + block_n - 1) // block_n + k_tiles = (K + block_k - 1) // block_k + assert n_tiles == Bs.shape[0] + assert k_tiles == Bs.shape[1] + + C_shape = (M, N) + C = torch.zeros(C_shape, dtype=torch.float32, device=A.device) + + A_tiles = [ + A[:, i * block_k:min((i + 1) * block_k, K)] for i in range(k_tiles) + ] + B_tiles = [[ + B[j * block_n:min((j + 1) * block_n, N), + i * block_k:min((i + 1) * block_k, K), ] for i in range(k_tiles) + ] for j in range(n_tiles)] + C_tiles = [ + C[:, j * block_n:min((j + 1) * block_n, N)] for j in range(n_tiles) + ] + As_tiles = [As[:, i:i + 1] for i in range(k_tiles)] + + for i in range(k_tiles): + for j in range(n_tiles): + a = A_tiles[i] + b = B_tiles[j][i] + c = C_tiles[j] + s = As_tiles[i] * Bs[j][i] + c[:, :] += torch.matmul(a, b.t()) * s + + C = C.reshape(origin_C_shape).to(output_dtype) + return C + + +def torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_shape): + """Fused moe with block-wise quantization using native torch.""" + B, D = a.shape + a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) + out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device) + score = torch.softmax(score, dim=-1, dtype=torch.float32) + topk_weight, topk_ids = torch.topk(score, topk) + topk_weight = topk_weight.view(-1) + topk_ids = topk_ids.view(-1) + + _, block_k = block_shape[0], block_shape[1] + a_q, a_s = native_per_token_group_quant_fp8(a, block_k) + a_q = a_q.to(torch.float32) + for i in range(w1.shape[0]): + mask = topk_ids == i + if mask.sum(): + inter_out = native_w8a8_block_fp8_matmul(a_q[mask], + w1[i], + a_s[mask], + w1_s[i], + block_shape, + output_dtype=a.dtype) + act_out = SiluAndMul().forward_native(inter_out) + act_out_q, act_out_s = native_per_token_group_quant_fp8( + act_out, block_k) + act_out = act_out.to(torch.float32) + out[mask] = native_w8a8_block_fp8_matmul(act_out_q, + w2[i], + act_out_s, + w2_s[i], + block_shape, + output_dtype=a.dtype) + return (out.view(B, -1, w2.shape[1]) * + topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1) + + +# Skip all tests if CUDA is not available +pytest.importorskip("torch.cuda") + + +@pytest.fixture(autouse=True) +def setup_cuda(): + torch.set_default_device("cuda") + + +@pytest.mark.parametrize("num_tokens,d,dtype,group_size,seed", + itertools.product(NUM_TOKENS, D, DTYPES, GROUP_SIZE, + SEEDS)) +@torch.inference_mode() +def test_per_token_group_quant_fp8(num_tokens, d, dtype, group_size, seed): + torch.manual_seed(seed) + x = torch.rand(num_tokens, d, dtype=dtype) + + ref_out, ref_scale = native_per_token_group_quant_fp8(x, group_size) + out, scale = per_token_group_quant_fp8(x, group_size) + + assert torch.allclose(out.to(torch.float32), + ref_out.to(torch.float32), + rtol=0.15) + assert torch.allclose(scale, ref_scale) + + +@pytest.mark.parametrize("M,N,K,block_size,out_dtype,seed", + itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, + SEEDS)) +@torch.inference_mode() +def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed): + torch.manual_seed(seed) + factor_for_scale = 1e-2 + fp8_info = torch.finfo(torch.float8_e4m3fn) + fp8_max, fp8_min = fp8_info.max, fp8_info.min + + A_fp32 = (torch.rand(M, K, dtype=torch.float32) - 0.5) * 2 * fp8_max + A_fp8 = A_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) + + B_fp32 = (torch.rand(N, K, dtype=torch.float32) - 0.5) * 2 * fp8_max + B_fp8 = B_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) + + block_n, block_k = block_size[0], block_size[1] + n_tiles = (N + block_n - 1) // block_n + k_tiles = (K + block_k - 1) // block_k + + As = torch.rand(M, k_tiles, dtype=torch.float32) * factor_for_scale + Bs = torch.rand(n_tiles, k_tiles, dtype=torch.float32) * factor_for_scale + + ref_out = native_w8a8_block_fp8_matmul(A_fp8, B_fp8, As, Bs, block_size, + out_dtype) + out = w8a8_block_fp8_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype) + + rel_diff = (torch.mean( + torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) / + torch.mean(torch.abs(ref_out.to(torch.float32)))) + assert rel_diff < 0.001 + + +@pytest.mark.parametrize("M,N,K,E,topk,block_size,dtype,seed", + itertools.product(M_moe, N_moe, K_moe, E, TOP_KS, + BLOCK_SIZE, DTYPES, SEEDS)) +@torch.inference_mode() +def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed): + torch.manual_seed(seed) + factor_for_scale = 1e-2 + fp8_info = torch.finfo(torch.float8_e4m3fn) + fp8_max, fp8_min = fp8_info.max, fp8_info.min + + a = torch.randn((M, K), dtype=dtype) / 10 + + w1_bf16 = (torch.rand( + (E, 2 * N, K), dtype=torch.bfloat16) - 0.5) * 2 * fp8_max + w1 = w1_bf16.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) + del w1_bf16 + + w2_bf16 = (torch.rand((E, K, N), dtype=torch.bfloat16) - 0.5) * 2 * fp8_max + w2 = w2_bf16.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) + del w2_bf16 + + block_n, block_k = block_size[0], block_size[1] + n_tiles_w1 = (2 * N + block_n - 1) // block_n + n_tiles_w2 = (K + block_n - 1) // block_n + k_tiles_w1 = (K + block_k - 1) // block_k + k_tiles_w2 = (N + block_k - 1) // block_k + + w1_s = torch.rand( + (E, n_tiles_w1, k_tiles_w1), dtype=torch.float32) * factor_for_scale + w2_s = torch.rand( + (E, n_tiles_w2, k_tiles_w2), dtype=torch.float32) * factor_for_scale + + score = torch.randn((M, E), dtype=dtype) + + out = fused_moe( + a, + w1, + w2, + score, + topk, + renormalize=False, + use_fp8_w8a8=True, + w1_scale=w1_s, + w2_scale=w2_s, + block_shape=block_size, + ) + ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, + block_size) + + print(f"{out.sum()=}") + print(f"{ref_out.sum()=}") + + rel_diff = (torch.mean( + torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) / + torch.mean(torch.abs(ref_out.to(torch.float32)))) + assert rel_diff < 0.03 diff --git a/vllm/config.py b/vllm/config.py index de8ba029ddc23..58649236b4225 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -161,7 +161,7 @@ class ModelConfig: override default pooling config for the pooling model. logits_processor_pattern: Optional regex pattern specifying valid logits processor qualified names that can be passed with the - `logits_processors` extra completion argument. Defaults to None, + `logits_processors` extra completion argument. Defaults to None, which allows no processors. generation_config: Configuration parameter file for generation. """ @@ -364,7 +364,7 @@ def __init__(self, def maybe_pull_model_tokenizer_for_s3(self, model: str, tokenizer: str) -> None: """ - Pull the model config or tokenizer to a temporary + Pull the model config or tokenizer to a temporary directory in case of S3. Args: @@ -866,14 +866,14 @@ def try_get_generation_config(self) -> Dict[str, Any]: def get_diff_sampling_param(self) -> Dict[str, Any]: """ - This method returns a dictionary containing the parameters - that differ from the default sampling parameters, but only - if `generation_config` is set. If `generation_config` is not + This method returns a dictionary containing the parameters + that differ from the default sampling parameters, but only + if `generation_config` is set. If `generation_config` is not set, an empty dictionary is returned. Returns: - Dict[str, Any]: A dictionary with the differing sampling - parameters if `generation_config` is set, otherwise an + Dict[str, Any]: A dictionary with the differing sampling + parameters if `generation_config` is set, otherwise an empty dictionary. """ if self.generation_config is None: diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index e6f9f01ef0f74..92e9ba3c9cebd 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -2,7 +2,7 @@ import functools import json import os -from typing import Any, Callable, Dict, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple import torch import triton @@ -11,6 +11,8 @@ import vllm.envs as envs from vllm import _custom_ops as ops from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + per_token_group_quant_fp8) from vllm.platforms import current_platform from vllm.utils import direct_register_custom_op @@ -45,8 +47,14 @@ def fused_moe_kernel( stride_bn, stride_cm, stride_cn, + stride_asm, + stride_ask, stride_bse, + stride_bsk, stride_bsn, + # Block size for block-wise quantization + group_n: tl.constexpr, + group_k: tl.constexpr, # Meta-parameters BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, @@ -125,8 +133,14 @@ def fused_moe_kernel( b_scale = tl.load(b_scale_ptrs) if use_fp8_w8a8: - a_scale = tl.load(a_scale_ptr) - b_scale = tl.load(b_scale_ptr + off_experts) + if group_k > 0 and group_n > 0: + a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm + offs_bsn = offs_bn // group_n + b_scale_ptrs = (b_scale_ptr + off_experts * stride_bse + + offs_bsn * stride_bsn) + else: + a_scale = tl.load(a_scale_ptr) + b_scale = tl.load(b_scale_ptr + off_experts) # ----------------------------------------------------------- # Iterate to compute a block of the C matrix. @@ -149,7 +163,18 @@ def fused_moe_kernel( if use_int8_w8a16: accumulator = tl.dot(a, b.to(compute_type), acc=accumulator) elif use_fp8_w8a8: - accumulator = tl.dot(a, b, acc=accumulator) + if group_k > 0 and group_n > 0: + k_start = k * BLOCK_SIZE_K + offs_ks = k_start // group_k + a_scale = tl.load(a_scale_ptrs + offs_ks * stride_ask, + mask=token_mask, + other=0.0) + b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk) + + accumulator += tl.dot(a, b) * a_scale[:, + None] * b_scale[None, :] + else: + accumulator = tl.dot(a, b, acc=accumulator) else: accumulator += tl.dot(a, b) # Advance the ptrs to the next K block. @@ -164,7 +189,10 @@ def fused_moe_kernel( if use_int8_w8a16: accumulator = (accumulator * b_scale).to(compute_type) elif use_fp8_w8a8: - accumulator = (accumulator * a_scale * b_scale).to(compute_type) + if group_k > 0 and group_n > 0: + accumulator = accumulator.to(compute_type) + else: + accumulator = (accumulator * a_scale * b_scale).to(compute_type) else: accumulator = accumulator.to(compute_type) # ----------------------------------------------------------- @@ -233,22 +261,37 @@ def moe_align_block_size( return sorted_ids, expert_ids, num_tokens_post_pad -def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, +def invoke_fused_moe_kernel(A: torch.Tensor, + B: torch.Tensor, + C: torch.Tensor, A_scale: Optional[torch.Tensor], B_scale: Optional[torch.Tensor], - topk_weights: torch.Tensor, topk_ids: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, sorted_token_ids: torch.Tensor, expert_ids: torch.Tensor, num_tokens_post_padded: torch.Tensor, - mul_routed_weight: bool, top_k: int, - config: Dict[str, Any], compute_type: tl.dtype, - use_fp8_w8a8: bool, use_int8_w8a16: bool) -> None: + mul_routed_weight: bool, + top_k: int, + config: Dict[str, Any], + compute_type: tl.dtype, + use_fp8_w8a8: bool, + use_int8_w8a16: bool, + block_shape: Optional[List[int]] = None) -> None: assert topk_weights.stride(1) == 1 assert sorted_token_ids.stride(0) == 1 if use_fp8_w8a8: - A, A_scale = ops.scaled_fp8_quant(A, A_scale) assert B_scale is not None + if block_shape is None: + A, A_scale = ops.scaled_fp8_quant(A, A_scale) + else: + assert len(block_shape) == 2 + block_n, block_k = block_shape[0], block_shape[1] + A, A_scale = per_token_group_quant_fp8(A, block_k) + assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1] + assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2] + assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1] elif use_int8_w8a16: assert B_scale is not None else: @@ -279,8 +322,13 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, B.stride(1), C.stride(1), C.stride(2), - B_scale.stride(0) if B_scale is not None and use_int8_w8a16 else 0, - B_scale.stride(1) if B_scale is not None and use_int8_w8a16 else 0, + A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0, + A_scale.stride(1) if A_scale is not None and A_scale.ndim == 2 else 0, + B_scale.stride(0) if B_scale is not None and B_scale.ndim >= 2 else 0, + B_scale.stride(2) if B_scale is not None and B_scale.ndim == 3 else 0, + B_scale.stride(1) if B_scale is not None and B_scale.ndim >= 2 else 0, + 0 if block_shape is None else block_shape[0], + 0 if block_shape is None else block_shape[1], MUL_ROUTED_WEIGHT=mul_routed_weight, top_k=top_k, compute_type=compute_type, @@ -362,6 +410,7 @@ def try_get_optimal_moe_config( dtype: Optional[str], M: int, is_marlin: bool = False, + block_shape: Optional[List[int]] = None, ): from vllm.model_executor.layers.fused_moe import get_config override_config = get_config() @@ -380,6 +429,12 @@ def try_get_optimal_moe_config( # Else use the default config config = get_default_config(M, E, N, w1_shape[2], top_k, dtype, is_marlin) + # NOTE: For block-wise quant, + # BLOCK_K must be divisible by block_shape[1] + # BLOCK_N and BLOCK_M has no requirements + if block_shape is not None: + config["BLOCK_SIZE_N"] = block_shape[0] + config["BLOCK_SIZE_K"] = block_shape[1] return config @@ -479,10 +534,11 @@ def inplace_fused_experts(hidden_states: torch.Tensor, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None) -> None: + a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None) -> None: fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True, use_fp8_w8a8, use_int8_w8a16, w1_scale, w2_scale, - a1_scale, a2_scale) + a1_scale, a2_scale, block_shape) def inplace_fused_experts_fake( @@ -496,7 +552,8 @@ def inplace_fused_experts_fake( w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None) -> None: + a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None) -> None: pass @@ -519,10 +576,11 @@ def outplace_fused_experts( w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None) -> torch.Tensor: + a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None) -> torch.Tensor: return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, False, use_fp8_w8a8, use_int8_w8a16, w1_scale, - w2_scale, a1_scale, a2_scale) + w2_scale, a1_scale, a2_scale, block_shape) def outplace_fused_experts_fake( @@ -536,7 +594,8 @@ def outplace_fused_experts_fake( w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None) -> torch.Tensor: + a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None) -> torch.Tensor: return torch.empty_like(hidden_states) @@ -559,18 +618,22 @@ def fused_experts(hidden_states: torch.Tensor, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None): + a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None): if inplace: torch.ops.vllm.inplace_fused_experts(hidden_states, w1, w2, topk_weights, topk_ids, use_fp8_w8a8, use_int8_w8a16, w1_scale, w2_scale, a1_scale, - a2_scale) + a2_scale, block_shape) return hidden_states else: - return torch.ops.vllm.outplace_fused_experts( - hidden_states, w1, w2, topk_weights, topk_ids, use_fp8_w8a8, - use_int8_w8a16, w1_scale, w2_scale, a1_scale, a2_scale) + return torch.ops.vllm.outplace_fused_experts(hidden_states, w1, w2, + topk_weights, topk_ids, + use_fp8_w8a8, + use_int8_w8a16, w1_scale, + w2_scale, a1_scale, + a2_scale, block_shape) def fused_experts_impl(hidden_states: torch.Tensor, @@ -584,7 +647,8 @@ def fused_experts_impl(hidden_states: torch.Tensor, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None): + a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None): # Check constraints. assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" @@ -611,6 +675,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, w2.shape, topk_ids.shape[1], config_dtype, + block_shape=block_shape, ) config = get_config_func(M) @@ -674,7 +739,8 @@ def fused_experts_impl(hidden_states: torch.Tensor, config, compute_type=compute_type, use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a16=use_int8_w8a16) + use_int8_w8a16=use_int8_w8a16, + block_shape=block_shape) ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) @@ -693,7 +759,8 @@ def fused_experts_impl(hidden_states: torch.Tensor, config, compute_type=compute_type, use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a16=use_int8_w8a16) + use_int8_w8a16=use_int8_w8a16, + block_shape=block_shape) ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.shape), out_hidden_states[begin_chunk_idx:end_chunk_idx]) @@ -718,6 +785,7 @@ def fused_moe( w2_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None, ) -> torch.Tensor: """ This function computes a Mixture of Experts (MoE) layer using two sets of @@ -745,6 +813,12 @@ def fused_moe( w1. - w2_scale (Optional[torch.Tensor]): Optional scale to be used for w2. + - a1_scale (Optional[torch.Tensor]): Optional scale to be used for + a1. + - a2_scale (Optional[torch.Tensor]): Optional scale to be used for + a2. + - block_shape: (Optional[List[int]]): Optional block size for block-wise + quantization. Returns: - torch.Tensor: The output tensor after applying the MoE layer. @@ -775,4 +849,5 @@ def fused_moe( w1_scale=w1_scale, w2_scale=w2_scale, a1_scale=a1_scale, - a2_scale=a2_scale) + a2_scale=a2_scale, + block_shape=block_shape) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 8c6f7c6e06515..55c0a202920ff 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -29,6 +29,7 @@ class FusedMoeWeightScaleSupported(Enum): TENSOR = "tensor" CHANNEL = "channel" GROUP = "group" + BLOCK = "block" class FusedMoEMethodBase(QuantizeMethodBase): @@ -199,6 +200,7 @@ def __init__( get_tensor_model_parallel_world_size()) self.top_k = top_k self.num_experts = num_experts + assert intermediate_size % self.tp_size == 0 self.intermediate_size_per_partition = intermediate_size // self.tp_size self.reduce_results = reduce_results self.renormalize = renormalize @@ -398,7 +400,10 @@ def weight_loader(self, param: torch.nn.Parameter, loaded_weight=loaded_weight, expert_data=expert_data, tp_rank=tp_rank) - elif quant_method == FusedMoeWeightScaleSupported.GROUP.value: + elif quant_method in [ + FusedMoeWeightScaleSupported.GROUP.value, + FusedMoeWeightScaleSupported.BLOCK.value, + ]: self._load_model_weight_or_group_weight_scale( shard_id=shard_id, shard_dim=shard_dim, diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 46ef11e7d02c6..33b221b994b2b 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -14,11 +14,14 @@ from vllm.logger import init_logger from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) +# yapf: disable from vllm.model_executor.parameter import (BasevLLMParameter, + BlockQuantScaleParameter, PackedColumnParameter, PackedvLLMParameter, PerTensorScaleParameter, RowvLLMParameter) +# yapf: enable from vllm.model_executor.utils import set_weight_attrs logger = init_logger(__name__) @@ -623,8 +626,24 @@ def weight_loader_v2(self, assert loaded_shard_id < len(self.output_sizes) tp_size = get_tensor_model_parallel_world_size() - shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size - shard_size = self.output_sizes[loaded_shard_id] // tp_size + + if isinstance(param, BlockQuantScaleParameter): + from vllm.model_executor.layers.quantization.fp8 import ( + Fp8LinearMethod, Fp8MoEMethod) + assert self.quant_method is not None + assert isinstance(self.quant_method, + (Fp8LinearMethod, Fp8MoEMethod)) + weight_block_size = self.quant_method.quant_config.weight_block_size + assert weight_block_size is not None + block_n, _ = weight_block_size[0], weight_block_size[1] + shard_offset = ( + (sum(self.output_sizes[:loaded_shard_id]) + block_n - 1) // + block_n) // tp_size + shard_size = ((self.output_sizes[loaded_shard_id] + block_n - 1) // + block_n // tp_size) + else: + shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size + shard_size = self.output_sizes[loaded_shard_id] // tp_size param.load_merged_column_weight(loaded_weight=loaded_weight, shard_id=loaded_shard_id, diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 978e727bc7cb3..5dfd86727a02a 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -6,6 +6,7 @@ import vllm.envs as envs from vllm import _custom_ops as ops +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) @@ -14,6 +15,8 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + apply_w8a8_block_fp8_linear) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin) from vllm.model_executor.layers.quantization.utils.quant_utils import ( @@ -22,7 +25,8 @@ all_close_1d, apply_fp8_linear, convert_to_channelwise, cutlass_fp8_supported, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize, requantize_with_max_scale) -from vllm.model_executor.parameter import (ModelWeightParameter, +from vllm.model_executor.parameter import (BlockQuantScaleParameter, + ModelWeightParameter, PerTensorScaleParameter) from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform @@ -41,6 +45,7 @@ def __init__( is_checkpoint_fp8_serialized: bool = False, activation_scheme: str = "dynamic", ignored_layers: Optional[List[str]] = None, + weight_block_size: Optional[List[int]] = None, ) -> None: self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized if is_checkpoint_fp8_serialized: @@ -51,6 +56,20 @@ def __init__( f"Unsupported activation scheme {activation_scheme}") self.activation_scheme = activation_scheme self.ignored_layers = ignored_layers or [] + if weight_block_size is not None: + if not is_checkpoint_fp8_serialized: + raise ValueError( + "The block-wise quantization only supports fp8-serialized " + "checkpoint for now.") + if len(weight_block_size) != 2: + raise ValueError( + "The quantization block size of weight must have 2 " + f"dimensions, but got {len(weight_block_size)} dimensions") + if activation_scheme != "dynamic": + raise ValueError("The block-wise quantization only supports " + "dynamic activation scheme for now, but got " + f"{activation_scheme} activation scheme.") + self.weight_block_size = weight_block_size @classmethod def get_name(cls) -> str: @@ -74,9 +93,12 @@ def from_config(cls, config: Dict[str, Any]) -> "Fp8Config": is_checkpoint_fp8_serialized = ("fp8" in quant_method) activation_scheme = cls.get_from_keys(config, ["activation_scheme"]) ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None) + weight_block_size = cls.get_from_keys_or(config, ["weight_block_size"], + None) return cls(is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized, activation_scheme=activation_scheme, - ignored_layers=ignored_layers) + ignored_layers=ignored_layers, + weight_block_size=weight_block_size) def get_quant_method(self, layer: torch.nn.Module, prefix: str) -> Optional["QuantizeMethodBase"]: @@ -123,6 +145,11 @@ def __init__(self, quant_config: Fp8Config): if current_platform.is_rocm(): self.use_marlin = False + self.block_quant = self.quant_config.weight_block_size is not None + if self.block_quant: + # Marlin doesn't support block-wise fp8 + self.use_marlin = False + def create_weights( self, layer: torch.nn.Module, @@ -133,10 +160,34 @@ def create_weights( params_dtype: torch.dtype, **extra_weight_attrs, ): - del input_size, output_size output_size_per_partition = sum(output_partition_sizes) weight_loader = extra_weight_attrs.get("weight_loader") + if self.block_quant: + tp_size = get_tensor_model_parallel_world_size() + assert self.quant_config.weight_block_size is not None + block_n, block_k = ( + self.quant_config.weight_block_size[0], + self.quant_config.weight_block_size[1], + ) + # Required by row parallel + if (tp_size > 1 + and input_size // input_size_per_partition == tp_size + and input_size_per_partition % block_k != 0): + raise ValueError( + f"Weight input_size_per_partition = " + f"{input_size_per_partition} is not divisible by " + f"weight quantization block_k = {block_k}.") + # Required by column parallel or enabling merged weights + if (tp_size > 1 and output_size // output_size_per_partition + == tp_size) or len(output_partition_sizes) > 1: + for output_partition_size in output_partition_sizes: + if output_partition_size % block_n != 0: + raise ValueError( + f"Weight output_partition_size = " + f"{output_partition_size} is not divisible by " + f"weight quantization block_n = {block_n}.") + layer.logical_widths = output_partition_sizes layer.input_size_per_partition = input_size_per_partition @@ -161,12 +212,29 @@ def create_weights( # Otherwise, wait until process_weights_after_loading. if self.quant_config.is_checkpoint_fp8_serialized: # WEIGHT SCALE - scale = PerTensorScaleParameter(data=torch.empty( - len(output_partition_sizes), dtype=torch.float32), - weight_loader=weight_loader) - - scale[:] = torch.finfo(torch.float32).min - layer.register_parameter("weight_scale", scale) + if not self.block_quant: + scale = PerTensorScaleParameter( + data=torch.empty(len(output_partition_sizes), + dtype=torch.float32), + weight_loader=weight_loader, + ) + scale[:] = torch.finfo(torch.float32).min + layer.register_parameter("weight_scale", scale) + else: + assert self.quant_config.activation_scheme == "dynamic" + scale = BlockQuantScaleParameter( + data=torch.empty( + (output_size_per_partition + block_n - 1) // block_n, + (input_size_per_partition + block_k - 1) // block_k, + dtype=torch.float32, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + scale[:] = torch.finfo(torch.float32).min + # The weight_scale_inv name is intentional for deepseekv3 + layer.register_parameter("weight_scale_inv", scale) # INPUT ACTIVATION SCALE if self.quant_config.activation_scheme == "static": @@ -180,6 +248,9 @@ def create_weights( layer.register_parameter("input_scale", None) def process_weights_after_loading(self, layer: Module) -> None: + # Block quant doesn't need to process weights after loading + if self.block_quant: + return layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False) # If checkpoint not serialized fp8, quantize the weights. @@ -266,6 +337,17 @@ def apply(self, size_k=layer.input_size_per_partition, bias=bias) + if self.block_quant: + assert self.quant_config.weight_block_size is not None + return apply_w8a8_block_fp8_linear( + input=x, + weight=layer.weight, + block_size=self.quant_config.weight_block_size, + weight_scale=layer.weight_scale_inv, + input_scale=layer.input_scale, + bias=bias, + ) + return apply_fp8_linear( input=x, weight=layer.weight, @@ -291,6 +373,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): def __init__(self, quant_config: Fp8Config): self.quant_config = quant_config + self.block_quant = self.quant_config.weight_block_size is not None def create_weights(self, layer: Module, num_experts: int, hidden_size: int, intermediate_size: int, params_dtype: torch.dtype, @@ -298,6 +381,27 @@ def create_weights(self, layer: Module, num_experts: int, hidden_size: int, if self.quant_config.is_checkpoint_fp8_serialized: params_dtype = torch.float8_e4m3fn + if self.block_quant: + assert self.quant_config.weight_block_size is not None + tp_size = get_tensor_model_parallel_world_size() + block_n, block_k = ( + self.quant_config.weight_block_size[0], + self.quant_config.weight_block_size[1], + ) + # NOTE: To ensure proper alignment of the block-wise quantization + # scales, the output_size of the weights for both the gate and up + # layers must be divisible by block_n. + # Required by column parallel or enabling merged weights + if intermediate_size % block_n != 0: + raise ValueError( + f"The output_size of gate's and up's weight = " + f"{intermediate_size} is not divisible by " + f"weight quantization block_n = {block_n}.") + if (tp_size > 1 and intermediate_size % block_k != 0): + # Required by row parallel + raise ValueError(f"The input_size of down's weight = " + f"{intermediate_size} is not divisible by " + f"weight quantization block_k = {block_k}.") # WEIGHTS w13_weight = torch.nn.Parameter(torch.empty(num_experts, @@ -317,21 +421,45 @@ def create_weights(self, layer: Module, num_experts: int, hidden_size: int, set_weight_attrs(w2_weight, extra_weight_attrs) # WEIGHT_SCALES - # Allocate 2 scales for w1 and w3 respectively. - # They will be combined to a single scale after weight loading. - w13_weight_scale = torch.nn.Parameter(torch.ones(num_experts, - 2, - dtype=torch.float32), - requires_grad=False) - layer.register_parameter("w13_weight_scale", w13_weight_scale) + if not self.block_quant: + # Allocate 2 scales for w1 and w3 respectively. + # They will be combined to a single scale after weight loading. + w13_weight_scale = torch.nn.Parameter(torch.ones( + num_experts, 2, dtype=torch.float32), + requires_grad=False) + w2_weight_scale = torch.nn.Parameter(torch.ones( + num_experts, dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + else: + w13_weight_scale = torch.nn.Parameter( + torch.ones( + num_experts, + 2 * ((intermediate_size + block_n - 1) // block_n), + (hidden_size + block_k - 1) // block_k, + dtype=torch.float32, + ), + requires_grad=False, + ) + w2_weight_scale = torch.nn.Parameter( + torch.ones( + num_experts, + (hidden_size + block_n - 1) // block_n, + (intermediate_size + block_k - 1) // block_k, + dtype=torch.float32, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight_scale_inv", w13_weight_scale) + layer.register_parameter("w2_weight_scale_inv", w2_weight_scale) + assert self.quant_config.activation_scheme == "dynamic" - w2_weight_scale = torch.nn.Parameter(torch.ones(num_experts, - dtype=torch.float32), - requires_grad=False) - layer.register_parameter("w2_weight_scale", w2_weight_scale) # Add the quantization method used (per tensor/grouped/channel) # to ensure the weight scales are loaded in properly extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.BLOCK. + value} if self.block_quant else {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}) # If loading fp8 checkpoint, pass the weight loaders. # If loading an fp16 checkpoint, do not (we will quantize in @@ -364,7 +492,9 @@ def create_weights(self, layer: Module, num_experts: int, hidden_size: int, layer.w2_input_scale = None def process_weights_after_loading(self, layer: Module) -> None: - + # Block quant doesn't need to process weights after loading + if self.block_quant: + return # If checkpoint is fp16, quantize in place. if not self.quant_config.is_checkpoint_fp8_serialized: # If rocm, use float8_e4m3fnuz as dtype @@ -489,17 +619,22 @@ def apply( num_expert_group=num_expert_group, custom_routing_function=custom_routing_function) - return fused_experts(x, - layer.w13_weight, - layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - inplace=True, - use_fp8_w8a8=True, - w1_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, - a1_scale=layer.w13_input_scale, - a2_scale=layer.w2_input_scale) + return fused_experts( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=True, + use_fp8_w8a8=True, + w1_scale=(layer.w13_weight_scale_inv + if self.block_quant else layer.w13_weight_scale), + w2_scale=(layer.w2_weight_scale_inv + if self.block_quant else layer.w2_weight_scale), + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale, + block_shape=self.quant_config.weight_block_size, + ) class Fp8KVCacheMethod(BaseKVCacheMethod): diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py new file mode 100644 index 0000000000000..f3c3e130e4161 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -0,0 +1,353 @@ +# Adapted from https://github.com/sgl-project/sglang/pull/2575 +from typing import List, Optional, Tuple + +import torch +import triton +import triton.language as tl + + +def apply_w8a8_block_fp8_linear( + input: torch.Tensor, + weight: torch.Tensor, + block_size: List[int], + weight_scale: torch.Tensor, + input_scale: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, +) -> torch.Tensor: + assert input_scale is None + # View input as 2D matrix for fp8 methods + input_2d = input.view(-1, input.shape[-1]) + output_shape = [*input.shape[:-1], weight.shape[0]] + + q_input, x_scale = per_token_group_quant_fp8(input_2d, block_size[1]) + output = w8a8_block_fp8_matmul(q_input, + weight, + x_scale, + weight_scale, + block_size, + output_dtype=input.dtype) + + if bias is not None: + output = output + bias + return output.to(dtype=input.dtype).view(*output_shape) + + +def input_to_float8( + x: torch.Tensor, + dtype: torch.dtype = torch.float8_e4m3fn +) -> Tuple[torch.Tensor, torch.Tensor]: + """This function quantizes input values to float8 values " + "with tensor-wise quantization.""" + finfo = torch.finfo(dtype) + min_val, max_val = x.aminmax() + amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12) + scale = finfo.max / amax + x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max) + return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal() + + +def block_quant_to_tensor_quant( + x_q_block: torch.Tensor, + x_s: torch.Tensor, + block_size: List[int], +) -> Tuple[torch.Tensor, torch.Tensor]: + """This function converts block-wise quantization to tensor-wise + quantization. The inputs are block-wise quantization tensor `x_q_block`, + block-wise quantization scale and the block size. + The outputs are tensor-wise quantization tensor and tensor-wise + quantization scale. Note only float8 is supported for now. + """ + block_n, block_k = block_size[0], block_size[1] + n, k = x_q_block.shape + n_tiles = (n + block_n - 1) // block_n + k_tiles = (k + block_k - 1) // block_k + assert n_tiles == x_s.shape[0] + assert k_tiles == x_s.shape[1] + + x_dq_block = x_q_block.to(torch.float32) + + x_dq_block_tiles = [[ + x_dq_block[j * block_n:min((j + 1) * block_n, n), + i * block_k:min((i + 1) * block_k, k), ] + for i in range(k_tiles) + ] for j in range(n_tiles)] + + for i in range(k_tiles): + for j in range(n_tiles): + x_dq_block_tiles[j][i][:, :] = x_dq_block_tiles[j][i] * x_s[j][i] + + x_q_tensor, scale = input_to_float8(x_dq_block, dtype=x_q_block.dtype) + return x_q_tensor, scale + + +@triton.jit +def _per_token_group_quant_fp8( + # Pointers to inputs and output + y_ptr, + y_q_ptr, + y_s_ptr, + # Stride of input + y_stride, + # Columns of input + N, + # Avoid to divide zero + eps, + # Information for float8 + fp8_min, + fp8_max, + # Meta-parameters + BLOCK: tl.constexpr, +): + """A Triton-accelerated function to perform per-token-group + quantization on a tensor. + This function converts the tensor values into float8 values. + """ + # Map the program id to the row of X and Y it should compute. + g_id = tl.program_id(0) + y_ptr += g_id * y_stride + y_q_ptr += g_id * y_stride + y_s_ptr += g_id + + cols = tl.arange(0, BLOCK) # N <= BLOCK + mask = cols < N + + y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32) + # Quant + _absmax = tl.maximum(tl.max(tl.abs(y)), eps) + y_s = _absmax / fp8_max + y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) + + tl.store(y_q_ptr + cols, y_q, mask=mask) + tl.store(y_s_ptr, y_s) + + +def per_token_group_quant_fp8( + x: torch.Tensor, + group_size: int, + eps: float = 1e-10, + dtype: torch.dtype = torch.float8_e4m3fn, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Function to perform per-token-group quantization on an input tensor `x`. + It converts the tensor values into signed float8 values and returns the + quantized tensor along with the scaling factor used for quantization. + Args: + x: The input tenosr with ndim >= 2. + group_size: The group size used for quantization. + eps: The minimum to avoid dividing zero. + dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn` + is supported for now. + Returns: + Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the + scaling factor for quantization. + """ + assert (x.shape[-1] % group_size == 0), ( + f"the last dimension of `x` {x.shape[-1]} must be divisible " + f"by `group_size` {group_size}") + assert x.is_contiguous(), "`x` must be contiguous" + + finfo = torch.finfo(dtype) + fp8_min = finfo.min + fp8_max = finfo.max + + x_q = torch.empty_like(x, device=x.device, dtype=dtype) + M = x.numel() // group_size + N = group_size + x_s = torch.empty( + x.shape[:-1] + (x.shape[-1] // group_size, ), + device=x.device, + dtype=torch.float32, + ) + + BLOCK = triton.next_power_of_2(N) + # heuristics for number of warps + num_warps = min(max(BLOCK // 256, 1), 8) + num_stages = 1 + _per_token_group_quant_fp8[(M, )]( + x, + x_q, + x_s, + group_size, + N, + eps, + fp8_min=fp8_min, + fp8_max=fp8_max, + BLOCK=BLOCK, + num_warps=num_warps, + num_stages=num_stages, + ) + + return x_q, x_s + + +@triton.jit +def _w8a8_block_fp8_matmul( + # Pointers to inputs and output + A, + B, + C, + As, + Bs, + # Shape for matmul + M, + N, + K, + # Block size for block-wise quantization + group_n, + group_k, + # Stride for inputs and output + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_As_m, + stride_As_k, + stride_Bs_k, + stride_Bs_n, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + """Triton-accelerated function used to perform linear operations (dot + product) on input tensors `A` and `B` with block-wise quantization, and + store the result in output tensor `C`. + """ + + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + As_ptrs = As + offs_am * stride_As_m + offs_bsn = offs_bn // group_n + Bs_ptrs = Bs + offs_bsn * stride_Bs_n + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, + mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, + other=0.0) + b = tl.load(b_ptrs, + mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, + other=0.0) + + k_start = k * BLOCK_SIZE_K + offs_ks = k_start // group_k + a_s = tl.load(As_ptrs + offs_ks * stride_As_k) + b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k) + + accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :] + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + if C.dtype.element_ty == tl.bfloat16: + c = accumulator.to(tl.bfloat16) + elif C.dtype.element_ty == tl.float16: + c = accumulator.to(tl.float16) + else: + c = accumulator.to(tl.float32) + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) + + +def w8a8_block_fp8_matmul( + A: torch.Tensor, + B: torch.Tensor, + As: torch.Tensor, + Bs: torch.Tensor, + block_size: List[int], + output_dtype: torch.dtype = torch.float16, +) -> torch.Tensor: + """This function performs matrix multiplication with block-wise + quantization. + It takes two input tensors `A` and `B` with scales `As` and `Bs`. + The output is returned in the specified `output_dtype`. + Args: + A: The input tensor, e.g., activation. + B: The input tensor, e.g., weight. + As: The per-token-group quantization scale for `A`. + Bs: The per-block quantization scale for `B`. + block_size: The block size for per-block quantization. It should + be 2-dim, e.g., [128, 128]. + output_dytpe: The dtype of the returned tensor. + Returns: + torch.Tensor: The result of matmul. + """ + assert len(block_size) == 2 + block_n, block_k = block_size[0], block_size[1] + + assert A.shape[-1] == B.shape[-1] + assert A.shape[:-1] == As.shape[:-1] and A.is_contiguous() + assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1] + M = A.numel() // A.shape[-1] + + assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2 + N, K = B.shape + assert triton.cdiv(N, block_n) == Bs.shape[0] + assert triton.cdiv(K, block_k) == Bs.shape[1] + + C_shape = A.shape[:-1] + (N, ) + C = A.new_empty(C_shape, dtype=output_dtype) + + # TODO: + # BLOCK_SIZE_M, BLOCK_SIZE_K, BLOCK_SIZE_N can be optimized. + # BLOCK_SIZE_K must be divisible by block_k + # BLOCK_SIZE_N and BLOCK_SIZE_M has no requirements + BLOCK_SIZE_M = 128 + if M < BLOCK_SIZE_M: + BLOCK_SIZE_M = triton.next_power_of_2(M) + BLOCK_SIZE_M = max(BLOCK_SIZE_M, 16) + BLOCK_SIZE_K = block_k + assert block_k % BLOCK_SIZE_K == 0 + BLOCK_SIZE_N = block_n + + def grid(META): + return (triton.cdiv(M, META["BLOCK_SIZE_M"]) * + triton.cdiv(N, META["BLOCK_SIZE_N"]), ) + + _w8a8_block_fp8_matmul[grid]( + A, + B, + C, + As, + Bs, + M, + N, + K, + block_n, + block_k, + A.stride(-2), + A.stride(-1), + B.stride(1), + B.stride(0), + C.stride(-2), + C.stride(-1), + As.stride(-2), + As.stride(-1), + Bs.stride(1), + Bs.stride(0), + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_K=BLOCK_SIZE_K, + GROUP_SIZE_M=8, + ) + + return C diff --git a/vllm/model_executor/parameter.py b/vllm/model_executor/parameter.py index 7a6d7c90f34d5..02d22a5ca62c0 100644 --- a/vllm/model_executor/parameter.py +++ b/vllm/model_executor/parameter.py @@ -328,6 +328,15 @@ def adjust_shard_indexes_for_packing(self, shard_size, shard_offset): marlin_tile_size=self.marlin_tile_size) +class BlockQuantScaleParameter(_ColumnvLLMParameter, RowvLLMParameter): + """ + Parameter class for weight scales loaded for weights with + block-wise quantization. Uses both column and row parallelism. + """ + + pass + + def permute_param_layout_(param: BasevLLMParameter, input_dim: int, output_dim: int, **kwargs) -> BasevLLMParameter: """