From 7ced89cff3fc9f7df62ad6b11e7a5c32568aae7f Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Sun, 2 Feb 2025 17:12:54 -0800 Subject: [PATCH] Fix zero_start_index_M argument for triton rowwise quantize (#3639) Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/714 D68797978 implemented a new feature that allowed partial rowwise quantization for jagged tensors in the hopes of improving MOE performance. However, it operated on the wrong dimension (oops). This update shifts the dimension to the proper per-group non zero row. Reviewed By: jasonjk-park, jiawenliu64 Differential Revision: D68872138 --- .../experimental/gemm/test/fp8_gemm_test.py | 17 +++++++--- .../experimental/gemm/triton_gemm/fp8_gemm.py | 34 ++++++++----------- 2 files changed, 26 insertions(+), 25 deletions(-) diff --git a/fbgemm_gpu/experimental/gemm/test/fp8_gemm_test.py b/fbgemm_gpu/experimental/gemm/test/fp8_gemm_test.py index 29bf06ad80..fa1f6f7660 100644 --- a/fbgemm_gpu/experimental/gemm/test/fp8_gemm_test.py +++ b/fbgemm_gpu/experimental/gemm/test/fp8_gemm_test.py @@ -60,16 +60,23 @@ def _test_quantize_fp8_row( # Apply sparsification if specified. zero_start_index_M = None if use_jagged: + # View input as [G, M, K] where G is the number of groups. + grouped_input = input_a.view( + -1, input_a.shape[-2], input_a.shape[-1] + ) m_vals = torch.randint( - 0, input_a.shape[-1] + 1, (input_a.shape[:-1]) + 0, grouped_input.shape[1] + 1, (grouped_input.shape[0],) ) - mask = torch.arange(input_a.shape[-1]).expand( - input_a.shape[:-1] + (input_a.shape[-1],) + mask = torch.arange(grouped_input.shape[-2]).expand( + (grouped_input.shape[0], grouped_input.shape[1]) ) >= m_vals.unsqueeze(-1) # Set corresponding values to 0. - input_a[mask] = 0.0 + grouped_input[mask] = 0.0 # Generate nonzero tensor in same layout as input. - zero_start_index_M = torch.count_nonzero(input_a, dim=-1) + zero_start_index_M = torch.count_nonzero( + torch.sum(grouped_input, dim=-1), dim=-1 + ) + a_fp8, a_scale = quantize_fp8_row( input_a, scale_ub=scale_ub, diff --git a/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py b/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py index 304840ecee..bf8d52cb1e 100644 --- a/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py +++ b/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py @@ -2316,7 +2316,6 @@ def _kernel_quantize_fp8_row( stride_ok, stride_zb, stride_zm, - stride_zn, TL_FP8_DTYPE: tl.constexpr, MAX_FP8: tl.constexpr, EPS: tl.constexpr, @@ -2354,7 +2353,6 @@ def _kernel_quantize_fp8_row( stride_ok (int): Stride of k dimension of output. stride_zb (int): Stride of b dimension of jagged index. stride_zm (int): Stride of m dimension of jagged index. - stride_zn (int): Stride of n dimension of jagged index. TL_FP8_DTYPE (tl.dtype): Target fp8 datatype. MAX_FP8 (float): Maxmimum expressible value for FP8. EPS (float): Epsilon value for numerical stability. @@ -2380,24 +2378,22 @@ def _kernel_quantize_fp8_row( + (pid % (M * N)) % N * stride_on ) - if JAGGED: - z_offset_base = ( - pid // (M * N) * stride_zb - + (pid % (M * N)) // N * stride_zm - + (pid % (M * N)) % N * stride_zn - ) - row_size = tl.load(zero_start_index_M + z_offset_base) - else: - row_size = K + K_in = K - blocks = tl.cdiv(row_size, BLOCK_SIZE) + if JAGGED: + z_offset_base = pid // (M * N) * stride_zb + (pid % (M * N)) // N * stride_zm + group_rows = tl.load(zero_start_index_M + z_offset_base) + current_row = pid % N + # If this row is empty, dont process any of it. + if current_row >= group_rows: + K_in = 0 # Calculate max. cur_max = 0.0 - for _k in range(0, blocks): + for _k in range(0, tl.cdiv(K_in, BLOCK_SIZE)): a = tl.load( A + a_offset_base + n_offset * stride_ak, - mask=n_offset < row_size, + mask=n_offset < K_in, other=0.0, ) tile_max = tl.max(tl.abs(a)) @@ -2418,15 +2414,14 @@ def _kernel_quantize_fp8_row( for _k in range(0, tl.cdiv(K, BLOCK_SIZE)): a = tl.load( A + a_offset_base + n_offset * stride_ak, - mask=n_offset < row_size, + mask=n_offset < K_in, other=0.0, ) a_fp8 = a * a_scale # Clamp A to fp8 range to make sure there's no overflow. # This is required for AMD. Nvidia's default saturation # handles it, but it's nice to have anyway. - a_fp8 = tl.clamp(a_fp8, -MAX_FP8, MAX_FP8) - a_fp8.to(TL_FP8_DTYPE) + a_fp8 = tl.clamp(a_fp8, -MAX_FP8, MAX_FP8).to(TL_FP8_DTYPE) tl.store( A_fp8 + a_fp8_offset_base + n_offset * stride_ok, a_fp8, @@ -2481,7 +2476,6 @@ def triton_quantize_fp8_row( a_fp8.stride(3), zero_start_index_M.stride(0) if zero_start_index_M is not None else None, zero_start_index_M.stride(1) if zero_start_index_M is not None else None, - zero_start_index_M.stride(2) if zero_start_index_M is not None else None, TL_FP8_DTYPE=tl_dtype, MAX_FP8=max_fp8, EPS=eps, @@ -2527,8 +2521,8 @@ def quantize_fp8_row( while a.dim() < 4: a = a.unsqueeze(0) if zero_start_index_M is not None: - while zero_start_index_M.dim() < 3: - zero_start_index_M = zero_start_index_M.unsqueeze(0) + # There should be one value of zero_start_index_M per NxK matrix. + zero_start_index_M = zero_start_index_M.view(a.shape[0], a.shape[1]) a_fp8, a_scale = triton_quantize_fp8_row(a, scale_ub, zero_start_index_M) return a_fp8.view(a_shape), a_scale.view(a_shape[:-1]) # else use pytorch implementation.