Skip to content

Commit

Permalink
Improve handling for FP8 grouped gemm without zero_start_index_M (#3615)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #3615

X-link: facebookresearch/FBGEMM#694

When zero_start_index_M isnt provided, inputs can have variable M values across groups. To support this, we need to return a tensor with shape [total_M, N] since it isnt possible to view the tensor as [G, M, N].

Reviewed By: jasonjk-park, bradleyhd

Differential Revision: D68686266

fbshipit-source-id: 4267de288d6f7f2d0dec82b881eba056c11ea737
  • Loading branch information
jwfromm authored and facebook-github-bot committed Jan 27, 2025
1 parent e7c08e3 commit c8da60b
Showing 1 changed file with 34 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -481,15 +481,42 @@ at::Tensor f8f8bf16_rowwise_grouped_dynamic(
TORCH_CHECK(ws.dtype() == at::kFloat, "Scales must be float32.");
}

// Create a single chunk of tensor but view it as a list for compatibility.
at::Tensor Y_full;
std::vector<at::Tensor> Y;
int M = XQ[0].size(0);
int N = WQ[0].size(0);
// Allocate an empty output array. We will set its values to zero as part
// of kernel setup.
at::Tensor Y_full =
at::empty({group_count, M, N}, XQ[0].options().dtype(at::kBFloat16));
// Split the output into groups.
std::vector<at::Tensor> Y = at::unbind(Y_full, 0);
int K = XQ[0].size(1);
// When run with padding, we return an output with shape [G, M, N] since
// M and N are consistent across groups.
if (zero_start_index_M.has_value()) {
// Allocate an empty output array. We will set its values to zero as part
// of kernel setup.
Y_full =
at::empty({group_count, M, N}, XQ[0].options().dtype(at::kBFloat16));
// Split the output into groups.
Y = at::unbind(Y_full, 0);
// When padding isnt provided, we return a tensor with shape [Total_M, N]
// since viewing as groups isnt possible due to variable M.
} else {
int total_M = 0;
std::vector<int> group_sizes = {};
for (int i = 0; i < group_count; i++) {
TORCH_CHECK(
XQ[i].size(1) == K && WQ[i].size(0) == N && WQ[i].size(1) == K,
"Dynamic grouped gemm requires fixed N and K.");
int group_M = XQ[i].size(0);
total_M += group_M;
group_sizes.push_back(group_M);
}
// Allocate continuous array for all groups.
Y_full = at::empty({total_M, N}, XQ[0].options().dtype(at::kBFloat16));
// Split full array into groups for downstream handling.
int offset = 0;
for (int size : group_sizes) {
Y.push_back(Y_full.narrow(0, offset, size));
offset += size;
}
}

// Prepare kernel arguments by copying them to the proper device location.
at::Tensor kernel_args =
Expand Down

0 comments on commit c8da60b

Please sign in to comment.