Skip to content

Commit

Permalink
Fix zero_start_index_M argument for triton rowwise quantize (#3639)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #3639

X-link: facebookresearch/FBGEMM#714

D68797978 implemented a new feature that allowed partial rowwise quantization for jagged tensors in the hopes of improving MOE performance. However, it operated on the wrong dimension (oops). This update shifts the dimension to the proper per-group non zero row.

Reviewed By: jasonjk-park, jiawenliu64

Differential Revision: D68872138

fbshipit-source-id: 92afb61da24cd9d85603b30e0baaa9685ba51c8d
  • Loading branch information
jwfromm authored and facebook-github-bot committed Feb 3, 2025
1 parent 8cb1476 commit 26eeef0
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 25 deletions.
17 changes: 12 additions & 5 deletions fbgemm_gpu/experimental/gemm/test/fp8_gemm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,16 +60,23 @@ def _test_quantize_fp8_row(
# Apply sparsification if specified.
zero_start_index_M = None
if use_jagged:
# View input as [G, M, K] where G is the number of groups.
grouped_input = input_a.view(
-1, input_a.shape[-2], input_a.shape[-1]
)
m_vals = torch.randint(
0, input_a.shape[-1] + 1, (input_a.shape[:-1])
0, grouped_input.shape[1] + 1, (grouped_input.shape[0],)
)
mask = torch.arange(input_a.shape[-1]).expand(
input_a.shape[:-1] + (input_a.shape[-1],)
mask = torch.arange(grouped_input.shape[-2]).expand(
(grouped_input.shape[0], grouped_input.shape[1])
) >= m_vals.unsqueeze(-1)
# Set corresponding values to 0.
input_a[mask] = 0.0
grouped_input[mask] = 0.0
# Generate nonzero tensor in same layout as input.
zero_start_index_M = torch.count_nonzero(input_a, dim=-1)
zero_start_index_M = torch.count_nonzero(
torch.sum(grouped_input, dim=-1), dim=-1
)

a_fp8, a_scale = quantize_fp8_row(
input_a,
scale_ub=scale_ub,
Expand Down
34 changes: 14 additions & 20 deletions fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2316,7 +2316,6 @@ def _kernel_quantize_fp8_row(
stride_ok,
stride_zb,
stride_zm,
stride_zn,
TL_FP8_DTYPE: tl.constexpr,
MAX_FP8: tl.constexpr,
EPS: tl.constexpr,
Expand Down Expand Up @@ -2354,7 +2353,6 @@ def _kernel_quantize_fp8_row(
stride_ok (int): Stride of k dimension of output.
stride_zb (int): Stride of b dimension of jagged index.
stride_zm (int): Stride of m dimension of jagged index.
stride_zn (int): Stride of n dimension of jagged index.
TL_FP8_DTYPE (tl.dtype): Target fp8 datatype.
MAX_FP8 (float): Maxmimum expressible value for FP8.
EPS (float): Epsilon value for numerical stability.
Expand All @@ -2380,24 +2378,22 @@ def _kernel_quantize_fp8_row(
+ (pid % (M * N)) % N * stride_on
)

if JAGGED:
z_offset_base = (
pid // (M * N) * stride_zb
+ (pid % (M * N)) // N * stride_zm
+ (pid % (M * N)) % N * stride_zn
)
row_size = tl.load(zero_start_index_M + z_offset_base)
else:
row_size = K
K_in = K

blocks = tl.cdiv(row_size, BLOCK_SIZE)
if JAGGED:
z_offset_base = pid // (M * N) * stride_zb + (pid % (M * N)) // N * stride_zm
group_rows = tl.load(zero_start_index_M + z_offset_base)
current_row = pid % N
# If this row is empty, dont process any of it.
if current_row >= group_rows:
K_in = 0

# Calculate max.
cur_max = 0.0
for _k in range(0, blocks):
for _k in range(0, tl.cdiv(K_in, BLOCK_SIZE)):
a = tl.load(
A + a_offset_base + n_offset * stride_ak,
mask=n_offset < row_size,
mask=n_offset < K_in,
other=0.0,
)
tile_max = tl.max(tl.abs(a))
Expand All @@ -2418,15 +2414,14 @@ def _kernel_quantize_fp8_row(
for _k in range(0, tl.cdiv(K, BLOCK_SIZE)):
a = tl.load(
A + a_offset_base + n_offset * stride_ak,
mask=n_offset < row_size,
mask=n_offset < K_in,
other=0.0,
)
a_fp8 = a * a_scale
# Clamp A to fp8 range to make sure there's no overflow.
# This is required for AMD. Nvidia's default saturation
# handles it, but it's nice to have anyway.
a_fp8 = tl.clamp(a_fp8, -MAX_FP8, MAX_FP8)
a_fp8.to(TL_FP8_DTYPE)
a_fp8 = tl.clamp(a_fp8, -MAX_FP8, MAX_FP8).to(TL_FP8_DTYPE)
tl.store(
A_fp8 + a_fp8_offset_base + n_offset * stride_ok,
a_fp8,
Expand Down Expand Up @@ -2481,7 +2476,6 @@ def triton_quantize_fp8_row(
a_fp8.stride(3),
zero_start_index_M.stride(0) if zero_start_index_M is not None else None,
zero_start_index_M.stride(1) if zero_start_index_M is not None else None,
zero_start_index_M.stride(2) if zero_start_index_M is not None else None,
TL_FP8_DTYPE=tl_dtype,
MAX_FP8=max_fp8,
EPS=eps,
Expand Down Expand Up @@ -2527,8 +2521,8 @@ def quantize_fp8_row(
while a.dim() < 4:
a = a.unsqueeze(0)
if zero_start_index_M is not None:
while zero_start_index_M.dim() < 3:
zero_start_index_M = zero_start_index_M.unsqueeze(0)
# There should be one value of zero_start_index_M per NxK matrix.
zero_start_index_M = zero_start_index_M.view(a.shape[0], a.shape[1])
a_fp8, a_scale = triton_quantize_fp8_row(a, scale_ub, zero_start_index_M)
return a_fp8.view(a_shape), a_scale.view(a_shape[:-1])
# else use pytorch implementation.
Expand Down

0 comments on commit 26eeef0

Please sign in to comment.