diff --git a/fbgemm_gpu/experimental/gemm/test/fp8_gemm_test.py b/fbgemm_gpu/experimental/gemm/test/fp8_gemm_test.py index 29bf06ad80..fa1f6f7660 100644 --- a/fbgemm_gpu/experimental/gemm/test/fp8_gemm_test.py +++ b/fbgemm_gpu/experimental/gemm/test/fp8_gemm_test.py @@ -60,16 +60,23 @@ def _test_quantize_fp8_row( # Apply sparsification if specified. zero_start_index_M = None if use_jagged: + # View input as [G, M, K] where G is the number of groups. + grouped_input = input_a.view( + -1, input_a.shape[-2], input_a.shape[-1] + ) m_vals = torch.randint( - 0, input_a.shape[-1] + 1, (input_a.shape[:-1]) + 0, grouped_input.shape[1] + 1, (grouped_input.shape[0],) ) - mask = torch.arange(input_a.shape[-1]).expand( - input_a.shape[:-1] + (input_a.shape[-1],) + mask = torch.arange(grouped_input.shape[-2]).expand( + (grouped_input.shape[0], grouped_input.shape[1]) ) >= m_vals.unsqueeze(-1) # Set corresponding values to 0. - input_a[mask] = 0.0 + grouped_input[mask] = 0.0 # Generate nonzero tensor in same layout as input. - zero_start_index_M = torch.count_nonzero(input_a, dim=-1) + zero_start_index_M = torch.count_nonzero( + torch.sum(grouped_input, dim=-1), dim=-1 + ) + a_fp8, a_scale = quantize_fp8_row( input_a, scale_ub=scale_ub, diff --git a/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py b/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py index 304840ecee..bf8d52cb1e 100644 --- a/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py +++ b/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py @@ -2316,7 +2316,6 @@ def _kernel_quantize_fp8_row( stride_ok, stride_zb, stride_zm, - stride_zn, TL_FP8_DTYPE: tl.constexpr, MAX_FP8: tl.constexpr, EPS: tl.constexpr, @@ -2354,7 +2353,6 @@ def _kernel_quantize_fp8_row( stride_ok (int): Stride of k dimension of output. stride_zb (int): Stride of b dimension of jagged index. stride_zm (int): Stride of m dimension of jagged index. - stride_zn (int): Stride of n dimension of jagged index. TL_FP8_DTYPE (tl.dtype): Target fp8 datatype. MAX_FP8 (float): Maxmimum expressible value for FP8. EPS (float): Epsilon value for numerical stability. @@ -2380,24 +2378,22 @@ def _kernel_quantize_fp8_row( + (pid % (M * N)) % N * stride_on ) - if JAGGED: - z_offset_base = ( - pid // (M * N) * stride_zb - + (pid % (M * N)) // N * stride_zm - + (pid % (M * N)) % N * stride_zn - ) - row_size = tl.load(zero_start_index_M + z_offset_base) - else: - row_size = K + K_in = K - blocks = tl.cdiv(row_size, BLOCK_SIZE) + if JAGGED: + z_offset_base = pid // (M * N) * stride_zb + (pid % (M * N)) // N * stride_zm + group_rows = tl.load(zero_start_index_M + z_offset_base) + current_row = pid % N + # If this row is empty, dont process any of it. + if current_row >= group_rows: + K_in = 0 # Calculate max. cur_max = 0.0 - for _k in range(0, blocks): + for _k in range(0, tl.cdiv(K_in, BLOCK_SIZE)): a = tl.load( A + a_offset_base + n_offset * stride_ak, - mask=n_offset < row_size, + mask=n_offset < K_in, other=0.0, ) tile_max = tl.max(tl.abs(a)) @@ -2418,15 +2414,14 @@ def _kernel_quantize_fp8_row( for _k in range(0, tl.cdiv(K, BLOCK_SIZE)): a = tl.load( A + a_offset_base + n_offset * stride_ak, - mask=n_offset < row_size, + mask=n_offset < K_in, other=0.0, ) a_fp8 = a * a_scale # Clamp A to fp8 range to make sure there's no overflow. # This is required for AMD. Nvidia's default saturation # handles it, but it's nice to have anyway. - a_fp8 = tl.clamp(a_fp8, -MAX_FP8, MAX_FP8) - a_fp8.to(TL_FP8_DTYPE) + a_fp8 = tl.clamp(a_fp8, -MAX_FP8, MAX_FP8).to(TL_FP8_DTYPE) tl.store( A_fp8 + a_fp8_offset_base + n_offset * stride_ok, a_fp8, @@ -2481,7 +2476,6 @@ def triton_quantize_fp8_row( a_fp8.stride(3), zero_start_index_M.stride(0) if zero_start_index_M is not None else None, zero_start_index_M.stride(1) if zero_start_index_M is not None else None, - zero_start_index_M.stride(2) if zero_start_index_M is not None else None, TL_FP8_DTYPE=tl_dtype, MAX_FP8=max_fp8, EPS=eps, @@ -2527,8 +2521,8 @@ def quantize_fp8_row( while a.dim() < 4: a = a.unsqueeze(0) if zero_start_index_M is not None: - while zero_start_index_M.dim() < 3: - zero_start_index_M = zero_start_index_M.unsqueeze(0) + # There should be one value of zero_start_index_M per NxK matrix. + zero_start_index_M = zero_start_index_M.view(a.shape[0], a.shape[1]) a_fp8, a_scale = triton_quantize_fp8_row(a, scale_ub, zero_start_index_M) return a_fp8.view(a_shape), a_scale.view(a_shape[:-1]) # else use pytorch implementation.