Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] [Quantization] Support deepseek_v3 w8a8 fp8 block-wise quantization #11523

Merged
merged 9 commits into from
Dec 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
265 changes: 265 additions & 0 deletions tests/kernels/test_block_fp8.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,265 @@
# Adapted from https://github.com/sgl-project/sglang/pull/2575
import itertools

import pytest
import torch

from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8, w8a8_block_fp8_matmul)
from vllm.platforms import current_platform

if current_platform.get_device_capability() < (9, 0):
pytest.skip("FP8 Triton requires CUDA 9.0 or higher",
allow_module_level=True)

# Test configurations
DTYPES = [torch.bfloat16] # [torch.half, torch.bfloat16, torch.float32]
NUM_TOKENS = [7, 83, 2048]
D = [512, 4096, 5120, 13824]
GROUP_SIZE = [64, 128, 256, 512]
M = [1, 7, 83, 512, 2048]
N = [128, 512, 1024, 4096, 7748, 13824]
K = [256, 4096, 5120, 3884, 13824]
# Deepseek-V3's intermediate size 18432, so N is 18432*2/8=4608 at TP8
# and its hidden size is 7168.
M_moe = [1, 7, 83, 512, 2048]
N_moe = [4608] # [128, 4608, 13824]
K_moe = [7168] # [256, 7168, 13824]
BLOCK_SIZE = [[128, 128]]
E = [256] # [8, 24, 128, 256]
TOP_KS = [1] # [1, 2, 6]
OUT_DTYPES = [torch.bfloat16] # [torch.float32, torch.half, torch.bfloat16]
SEEDS = [0]


def native_per_token_group_quant_fp8(x,
group_size,
eps=1e-10,
dtype=torch.float8_e4m3fn):
"""Function to perform per-token-group quantization on an input tensor
`x` using native torch."""
assert x.shape[-1] % group_size == 0, ("the last dimension of `x` cannot "
"be divisible by `group_size`")
assert x.is_contiguous(), "`x` is not contiguous"

finfo = torch.finfo(dtype)
fp8_min = finfo.min
fp8_max = finfo.max

x_ = x.reshape(x.numel() // group_size, group_size)
amax = x_.abs().max(dim=-1,
keepdim=True)[0].clamp(min=eps).to(torch.float32)
x_s = amax / fp8_max
x_q = (x_ / x_s).clamp(min=fp8_min, max=fp8_max).to(dtype)
x_q = x_q.reshape(x.shape)
x_s = x_s.reshape(x.shape[:-1] + (x.shape[-1] // group_size, ))

return x_q, x_s


def native_w8a8_block_fp8_matmul(A,
B,
As,
Bs,
block_size,
output_dtype=torch.float16):
"""Matrix multiplication with block-wise quantization using native torch."""
A = A.to(torch.float32)
B = B.to(torch.float32)
assert A.shape[-1] == B.shape[-1]
assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2
assert len(block_size) == 2
block_n, block_k = block_size[0], block_size[1]
assert (A.shape[-1] + block_k - 1) // block_k == As.shape[-1]
assert A.shape[:-1] == As.shape[:-1]

M = A.numel() // A.shape[-1]
N, K = B.shape
origin_C_shape = A.shape[:-1] + (N, )
A = A.reshape(M, A.shape[-1])
As = As.reshape(M, As.shape[-1])
n_tiles = (N + block_n - 1) // block_n
k_tiles = (K + block_k - 1) // block_k
assert n_tiles == Bs.shape[0]
assert k_tiles == Bs.shape[1]

C_shape = (M, N)
C = torch.zeros(C_shape, dtype=torch.float32, device=A.device)

A_tiles = [
A[:, i * block_k:min((i + 1) * block_k, K)] for i in range(k_tiles)
]
B_tiles = [[
B[j * block_n:min((j + 1) * block_n, N),
i * block_k:min((i + 1) * block_k, K), ] for i in range(k_tiles)
] for j in range(n_tiles)]
C_tiles = [
C[:, j * block_n:min((j + 1) * block_n, N)] for j in range(n_tiles)
]
As_tiles = [As[:, i:i + 1] for i in range(k_tiles)]

for i in range(k_tiles):
for j in range(n_tiles):
a = A_tiles[i]
b = B_tiles[j][i]
c = C_tiles[j]
s = As_tiles[i] * Bs[j][i]
c[:, :] += torch.matmul(a, b.t()) * s

C = C.reshape(origin_C_shape).to(output_dtype)
return C


def torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_shape):
"""Fused moe with block-wise quantization using native torch."""
B, D = a.shape
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
score = torch.softmax(score, dim=-1, dtype=torch.float32)
topk_weight, topk_ids = torch.topk(score, topk)
topk_weight = topk_weight.view(-1)
topk_ids = topk_ids.view(-1)

_, block_k = block_shape[0], block_shape[1]
a_q, a_s = native_per_token_group_quant_fp8(a, block_k)
a_q = a_q.to(torch.float32)
for i in range(w1.shape[0]):
mask = topk_ids == i
if mask.sum():
inter_out = native_w8a8_block_fp8_matmul(a_q[mask],
w1[i],
a_s[mask],
w1_s[i],
block_shape,
output_dtype=a.dtype)
act_out = SiluAndMul().forward_native(inter_out)
act_out_q, act_out_s = native_per_token_group_quant_fp8(
act_out, block_k)
act_out = act_out.to(torch.float32)
out[mask] = native_w8a8_block_fp8_matmul(act_out_q,
w2[i],
act_out_s,
w2_s[i],
block_shape,
output_dtype=a.dtype)
return (out.view(B, -1, w2.shape[1]) *
topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1)


# Skip all tests if CUDA is not available
pytest.importorskip("torch.cuda")


@pytest.fixture(autouse=True)
def setup_cuda():
torch.set_default_device("cuda")


@pytest.mark.parametrize("num_tokens,d,dtype,group_size,seed",
itertools.product(NUM_TOKENS, D, DTYPES, GROUP_SIZE,
SEEDS))
@torch.inference_mode()
def test_per_token_group_quant_fp8(num_tokens, d, dtype, group_size, seed):
torch.manual_seed(seed)
x = torch.rand(num_tokens, d, dtype=dtype)

ref_out, ref_scale = native_per_token_group_quant_fp8(x, group_size)
out, scale = per_token_group_quant_fp8(x, group_size)

assert torch.allclose(out.to(torch.float32),
ref_out.to(torch.float32),
rtol=0.15)
assert torch.allclose(scale, ref_scale)


@pytest.mark.parametrize("M,N,K,block_size,out_dtype,seed",
itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES,
SEEDS))
@torch.inference_mode()
def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed):
torch.manual_seed(seed)
factor_for_scale = 1e-2
fp8_info = torch.finfo(torch.float8_e4m3fn)
fp8_max, fp8_min = fp8_info.max, fp8_info.min

A_fp32 = (torch.rand(M, K, dtype=torch.float32) - 0.5) * 2 * fp8_max
A_fp8 = A_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)

B_fp32 = (torch.rand(N, K, dtype=torch.float32) - 0.5) * 2 * fp8_max
B_fp8 = B_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)

block_n, block_k = block_size[0], block_size[1]
n_tiles = (N + block_n - 1) // block_n
k_tiles = (K + block_k - 1) // block_k

As = torch.rand(M, k_tiles, dtype=torch.float32) * factor_for_scale
Bs = torch.rand(n_tiles, k_tiles, dtype=torch.float32) * factor_for_scale

ref_out = native_w8a8_block_fp8_matmul(A_fp8, B_fp8, As, Bs, block_size,
out_dtype)
out = w8a8_block_fp8_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype)

rel_diff = (torch.mean(
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) /
torch.mean(torch.abs(ref_out.to(torch.float32))))
assert rel_diff < 0.001


@pytest.mark.parametrize("M,N,K,E,topk,block_size,dtype,seed",
itertools.product(M_moe, N_moe, K_moe, E, TOP_KS,
BLOCK_SIZE, DTYPES, SEEDS))
@torch.inference_mode()
def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed):
torch.manual_seed(seed)
factor_for_scale = 1e-2
fp8_info = torch.finfo(torch.float8_e4m3fn)
fp8_max, fp8_min = fp8_info.max, fp8_info.min

a = torch.randn((M, K), dtype=dtype) / 10

w1_bf16 = (torch.rand(
(E, 2 * N, K), dtype=torch.bfloat16) - 0.5) * 2 * fp8_max
w1 = w1_bf16.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
del w1_bf16

w2_bf16 = (torch.rand((E, K, N), dtype=torch.bfloat16) - 0.5) * 2 * fp8_max
w2 = w2_bf16.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
del w2_bf16

block_n, block_k = block_size[0], block_size[1]
n_tiles_w1 = (2 * N + block_n - 1) // block_n
n_tiles_w2 = (K + block_n - 1) // block_n
k_tiles_w1 = (K + block_k - 1) // block_k
k_tiles_w2 = (N + block_k - 1) // block_k

w1_s = torch.rand(
(E, n_tiles_w1, k_tiles_w1), dtype=torch.float32) * factor_for_scale
w2_s = torch.rand(
(E, n_tiles_w2, k_tiles_w2), dtype=torch.float32) * factor_for_scale

score = torch.randn((M, E), dtype=dtype)

out = fused_moe(
a,
w1,
w2,
score,
topk,
renormalize=False,
use_fp8_w8a8=True,
w1_scale=w1_s,
w2_scale=w2_s,
block_shape=block_size,
)
ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk,
block_size)

print(f"{out.sum()=}")
print(f"{ref_out.sum()=}")

rel_diff = (torch.mean(
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) /
torch.mean(torch.abs(ref_out.to(torch.float32))))
assert rel_diff < 0.03
14 changes: 7 additions & 7 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ class ModelConfig:
override default pooling config for the pooling model.
logits_processor_pattern: Optional regex pattern specifying valid
logits processor qualified names that can be passed with the
`logits_processors` extra completion argument. Defaults to None,
`logits_processors` extra completion argument. Defaults to None,
which allows no processors.
generation_config: Configuration parameter file for generation.
"""
Expand Down Expand Up @@ -363,7 +363,7 @@ def __init__(self,
def maybe_pull_model_tokenizer_for_s3(self, model: str,
tokenizer: str) -> None:
"""
Pull the model config or tokenizer to a temporary
Pull the model config or tokenizer to a temporary
directory in case of S3.

Args:
Expand Down Expand Up @@ -874,14 +874,14 @@ def try_get_generation_config(self) -> Dict[str, Any]:

def get_diff_sampling_param(self) -> Dict[str, Any]:
"""
This method returns a dictionary containing the parameters
that differ from the default sampling parameters, but only
if `generation_config` is set. If `generation_config` is not
This method returns a dictionary containing the parameters
that differ from the default sampling parameters, but only
if `generation_config` is set. If `generation_config` is not
set, an empty dictionary is returned.

Returns:
Dict[str, Any]: A dictionary with the differing sampling
parameters if `generation_config` is set, otherwise an
Dict[str, Any]: A dictionary with the differing sampling
parameters if `generation_config` is set, otherwise an
empty dictionary.
"""
if self.generation_config is None:
Expand Down
Loading
Loading