From c8da60bde754fce4bfa860f196e4bfe30dfeb077 Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Mon, 27 Jan 2025 14:06:27 -0800 Subject: [PATCH] Improve handling for FP8 grouped gemm without zero_start_index_M (#3615) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3615 X-link: https://github.com/facebookresearch/FBGEMM/pull/694 When zero_start_index_M isnt provided, inputs can have variable M values across groups. To support this, we need to return a tensor with shape [total_M, N] since it isnt possible to view the tensor as [G, M, N]. Reviewed By: jasonjk-park, bradleyhd Differential Revision: D68686266 fbshipit-source-id: 4267de288d6f7f2d0dec82b881eba056c11ea737 --- .../fp8_rowwise_grouped_gemm.hip | 41 +++++++++++++++---- 1 file changed, 34 insertions(+), 7 deletions(-) diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/fp8_rowwise_grouped_gemm.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/fp8_rowwise_grouped_gemm.hip index 87de8a0c52..c82cd46410 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/fp8_rowwise_grouped_gemm.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/fp8_rowwise_grouped_gemm.hip @@ -481,15 +481,42 @@ at::Tensor f8f8bf16_rowwise_grouped_dynamic( TORCH_CHECK(ws.dtype() == at::kFloat, "Scales must be float32."); } - // Create a single chunk of tensor but view it as a list for compatibility. + at::Tensor Y_full; + std::vector Y; int M = XQ[0].size(0); int N = WQ[0].size(0); - // Allocate an empty output array. We will set its values to zero as part - // of kernel setup. - at::Tensor Y_full = - at::empty({group_count, M, N}, XQ[0].options().dtype(at::kBFloat16)); - // Split the output into groups. - std::vector Y = at::unbind(Y_full, 0); + int K = XQ[0].size(1); + // When run with padding, we return an output with shape [G, M, N] since + // M and N are consistent across groups. + if (zero_start_index_M.has_value()) { + // Allocate an empty output array. We will set its values to zero as part + // of kernel setup. + Y_full = + at::empty({group_count, M, N}, XQ[0].options().dtype(at::kBFloat16)); + // Split the output into groups. + Y = at::unbind(Y_full, 0); + // When padding isnt provided, we return a tensor with shape [Total_M, N] + // since viewing as groups isnt possible due to variable M. + } else { + int total_M = 0; + std::vector group_sizes = {}; + for (int i = 0; i < group_count; i++) { + TORCH_CHECK( + XQ[i].size(1) == K && WQ[i].size(0) == N && WQ[i].size(1) == K, + "Dynamic grouped gemm requires fixed N and K."); + int group_M = XQ[i].size(0); + total_M += group_M; + group_sizes.push_back(group_M); + } + // Allocate continuous array for all groups. + Y_full = at::empty({total_M, N}, XQ[0].options().dtype(at::kBFloat16)); + // Split full array into groups for downstream handling. + int offset = 0; + for (int size : group_sizes) { + Y.push_back(Y_full.narrow(0, offset, size)); + offset += size; + } + } // Prepare kernel arguments by copying them to the proper device location. at::Tensor kernel_args =