diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise_grouped.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise_grouped.cu index 2c0f9c1a4c..437d3d8bf5 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise_grouped.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise_grouped.cu @@ -462,8 +462,6 @@ std::tuple> f8f8bf16_rowwise_grouped_impl( reinterpret_cast(output_ptr), stride_output_ptr}}; - int M = XQ[0].size(0); - int N = WQ[0].size(0); arguments.epilogue.thread = { {reinterpret_cast( x_scale_ptr)}, // x_scale @@ -599,7 +597,13 @@ at::Tensor f8f8bf16_rowwise_grouped_dynamic( at::Tensor output = std::get<0>(dispatch_fp8_grouped_kernel( XQ, WQ, x_scale, w_scale, Y, zero_start_index_M)); // View as proper shape. - output = output.view({-1, XQ[0].size(0), WQ[0].size(0)}); + // When zero_start_index_M is provided, we can view as [G, M, N] + if (zero_start_index_M.has_value()) { + output = output.view({-1, XQ[0].size(0), WQ[0].size(0)}); + // Otherwise we view as {total_M, N}. + } else { + output = output.view({-1, WQ[0].size(0)}); + } return output; } diff --git a/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py b/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py index 1f0f4f81f5..51fbcc6db8 100644 --- a/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py +++ b/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py @@ -726,7 +726,8 @@ def fp8_loopover_bmm( torch.testing.assert_close(y_ref, y_fp8, atol=8.0e-2, rtol=8.0e-2) @unittest.skipIf( - not torch.version.cuda, "Skip on AMD: GMM ops are not yet suported." + not torch.version.cuda and torch.version.hip < "6.2", + "Skip on AMD with < RoCM 6.2", ) @settings(deadline=None) @given( @@ -805,63 +806,39 @@ def test_fp8_grouped_gemm( w_scale_group = torch.unbind(torch.stack(w_scale_group, dim=0).contiguous()) # FP8 grouped gemm kernel + fp8_args = ( + [ + xq_group, + wq_group, + x_scale_group, + w_scale_group, + zero_start_index_M if use_padding_zeros else None, + ] + if use_dynamic + else [xq_group, wq_group, x_scale_group, w_scale_group] + ) + fp8_op = ( + torch.ops.fbgemm.f8f8bf16_rowwise_grouped_dynamic + if use_dynamic + else torch.ops.fbgemm.f8f8bf16_rowwise_grouped + ) if use_cudagraph: - if use_padding_zeros: - # warmup - torch.ops.fbgemm.f8f8bf16_rowwise_grouped_dynamic( - xq_group, - wq_group, - x_scale_group, - w_scale_group, - zero_start_index_M, - ) - # With cudagraph - g = torch.cuda.CUDAGraph() - with torch.cuda.graph(g): - y_fp8_group = torch.ops.fbgemm.f8f8bf16_rowwise_grouped_dynamic( - xq_group, - wq_group, - x_scale_group, - w_scale_group, - zero_start_index_M, - ) - g.replay() - y_fp8_group = y_fp8_group.unbind(dim=0) - else: - # warmup - torch.ops.fbgemm.f8f8bf16_rowwise_grouped( - xq_group, - wq_group, - x_scale_group, - w_scale_group, - ) - # With cudagraph - g = torch.cuda.CUDAGraph() - with torch.cuda.graph(g): - y_fp8_group = torch.ops.fbgemm.f8f8bf16_rowwise_grouped( - xq_group, - wq_group, - x_scale_group, - w_scale_group, - ) - g.replay() + # warmup + fp8_op(*fp8_args) + # With cudagraph + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + y_fp8_group = fp8_op(*fp8_args) + g.replay() else: - if use_padding_zeros: - y_fp8_group = torch.ops.fbgemm.f8f8bf16_rowwise_grouped_dynamic( - xq_group, - wq_group, - x_scale_group, - w_scale_group, - zero_start_index_M, - ) - y_fp8_group = y_fp8_group.unbind(dim=0) + y_fp8_group = fp8_op(*fp8_args) + + # Massage output into proper format. + if not isinstance(y_fp8_group, (tuple, list)): + if y_fp8_group.ndim == 2: + y_fp8_group = torch.split(y_fp8_group, tuple(ms.tolist()), dim=0) else: - y_fp8_group = torch.ops.fbgemm.f8f8bf16_rowwise_grouped( - xq_group, - wq_group, - x_scale_group, - w_scale_group, - ) + y_fp8_group = torch.unbind(y_fp8_group) # BF16 grouped gemm kernel bf16_args = (