Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix handling of dynamic FP8 grouped gemm on Nvidia #3616

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -462,8 +462,6 @@ std::tuple<at::Tensor, std::vector<at::Tensor>> f8f8bf16_rowwise_grouped_impl(
reinterpret_cast<GroupedGemmArgs::ElementOutput**>(output_ptr),
stride_output_ptr}};

int M = XQ[0].size(0);
int N = WQ[0].size(0);
arguments.epilogue.thread = {
{reinterpret_cast<const GroupedGemmArgs::ElementComputeEpilogue**>(
x_scale_ptr)}, // x_scale
Expand Down Expand Up @@ -599,7 +597,13 @@ at::Tensor f8f8bf16_rowwise_grouped_dynamic(
at::Tensor output = std::get<0>(dispatch_fp8_grouped_kernel(
XQ, WQ, x_scale, w_scale, Y, zero_start_index_M));
// View as proper shape.
output = output.view({-1, XQ[0].size(0), WQ[0].size(0)});
// When zero_start_index_M is provided, we can view as [G, M, N]
if (zero_start_index_M.has_value()) {
output = output.view({-1, XQ[0].size(0), WQ[0].size(0)});
// Otherwise we view as {total_M, N}.
} else {
output = output.view({-1, WQ[0].size(0)});
}
return output;
}

Expand Down
87 changes: 32 additions & 55 deletions fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -726,7 +726,8 @@ def fp8_loopover_bmm(
torch.testing.assert_close(y_ref, y_fp8, atol=8.0e-2, rtol=8.0e-2)

@unittest.skipIf(
not torch.version.cuda, "Skip on AMD: GMM ops are not yet suported."
not torch.version.cuda and torch.version.hip < "6.2",
"Skip on AMD with < RoCM 6.2",
)
@settings(deadline=None)
@given(
Expand Down Expand Up @@ -805,63 +806,39 @@ def test_fp8_grouped_gemm(
w_scale_group = torch.unbind(torch.stack(w_scale_group, dim=0).contiguous())

# FP8 grouped gemm kernel
fp8_args = (
[
xq_group,
wq_group,
x_scale_group,
w_scale_group,
zero_start_index_M if use_padding_zeros else None,
]
if use_dynamic
else [xq_group, wq_group, x_scale_group, w_scale_group]
)
fp8_op = (
torch.ops.fbgemm.f8f8bf16_rowwise_grouped_dynamic
if use_dynamic
else torch.ops.fbgemm.f8f8bf16_rowwise_grouped
)
if use_cudagraph:
if use_padding_zeros:
# warmup
torch.ops.fbgemm.f8f8bf16_rowwise_grouped_dynamic(
xq_group,
wq_group,
x_scale_group,
w_scale_group,
zero_start_index_M,
)
# With cudagraph
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
y_fp8_group = torch.ops.fbgemm.f8f8bf16_rowwise_grouped_dynamic(
xq_group,
wq_group,
x_scale_group,
w_scale_group,
zero_start_index_M,
)
g.replay()
y_fp8_group = y_fp8_group.unbind(dim=0)
else:
# warmup
torch.ops.fbgemm.f8f8bf16_rowwise_grouped(
xq_group,
wq_group,
x_scale_group,
w_scale_group,
)
# With cudagraph
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
y_fp8_group = torch.ops.fbgemm.f8f8bf16_rowwise_grouped(
xq_group,
wq_group,
x_scale_group,
w_scale_group,
)
g.replay()
# warmup
fp8_op(*fp8_args)
# With cudagraph
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
y_fp8_group = fp8_op(*fp8_args)
g.replay()
else:
if use_padding_zeros:
y_fp8_group = torch.ops.fbgemm.f8f8bf16_rowwise_grouped_dynamic(
xq_group,
wq_group,
x_scale_group,
w_scale_group,
zero_start_index_M,
)
y_fp8_group = y_fp8_group.unbind(dim=0)
y_fp8_group = fp8_op(*fp8_args)

# Massage output into proper format.
if not isinstance(y_fp8_group, (tuple, list)):
if y_fp8_group.ndim == 2:
y_fp8_group = torch.split(y_fp8_group, tuple(ms.tolist()), dim=0)
else:
y_fp8_group = torch.ops.fbgemm.f8f8bf16_rowwise_grouped(
xq_group,
wq_group,
x_scale_group,
w_scale_group,
)
y_fp8_group = torch.unbind(y_fp8_group)

# BF16 grouped gemm kernel
bf16_args = (
Expand Down
Loading