Skip to content

Commit

Permalink
Add preprocess stage to quantize bench operators (pytorch#3648)
Browse files Browse the repository at this point in the history
Summary:

X-link: facebookresearch/FBGEMM#724

When benchmarking quantize functions, we'd like the overhead to mimic e2e behavior as closely as possible. For example, weights should be quantized ahead of time. The current design of quantize_bench does not allow this.

To accomodate it, I've added a new optional preprocess phase that allows some transformations to be applied independently from benchmarking. Here we use it to prepare data for grouped gemm benchmarks to more accurately capture the e2e behavior.

Reviewed By: jiawenliu64

Differential Revision: D68964950
  • Loading branch information
jwfromm authored and facebook-github-bot committed Feb 3, 2025
1 parent 1203558 commit 146442e
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 94 deletions.
13 changes: 7 additions & 6 deletions fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,8 @@ def benchmark_grouped(
# Also check if the operator is supported.
if kernel_requested and quantize_op.supported:
# Get the quantized tensors for this operator.
quantized_vals = quantize_op.quantize(A, B)
preprocessed_args = quantize_op.preprocess(A, B)
quantized_vals = quantize_op.quantize(*preprocessed_args)
# Compute the output given quantized values.
output = quantize_op.compute(*quantized_vals)
# Some kernels may pad output, just take the first m values of each row.
Expand All @@ -143,8 +144,7 @@ def benchmark_grouped(
if bench_quantize:
# Benchmark both quantize and compute.
ms_runtime = quantize_op.benchmark(
A,
B,
*preprocessed_args,
bench_quantize=True,
use_rotating_buffer_bench=use_rotating_buffer_bench,
use_cuda_graph=use_cuda_graph,
Expand Down Expand Up @@ -218,8 +218,10 @@ def benchmark(
)
# Also check if the operator is supported.
if kernel_requested and quantize_op.supported:
# Preprocess data if needed.
preprocessed_args = quantize_op.preprocess(A, B)
# Get the quantized tensors for this operator.
quantized_vals = quantize_op.quantize(A, B)
quantized_vals = quantize_op.quantize(*preprocessed_args)
# Compute the output given quantized values.
output = quantize_op.compute(*quantized_vals)
# Compare the quantize op output to reference as a sanity check.
Expand All @@ -229,8 +231,7 @@ def benchmark(
if bench_quantize:
# Benchmark both quantize and compute.
ms_runtime = quantize_op.benchmark(
A,
B,
*preprocessed_args,
bench_quantize=True,
use_rotating_buffer_bench=use_rotating_buffer_bench,
use_cuda_graph=use_cuda_graph,
Expand Down
179 changes: 91 additions & 88 deletions fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@ def quantize_and_compute(self, *args, **kwargs):
"""Function which quantizes inputs and performs main compute operation."""
pass

def preprocess(self, *args):
"""Preprocess inputs before benchmarking. These outputs will be passed to quantize."""
return args

def bench_with_rotating_buffer(self, fn, args, use_cuda_graph: bool = True):
import copy
import pickle
Expand Down Expand Up @@ -113,8 +117,13 @@ def benchmark(
) -> float:
"""Benchmark runtime of this operator."""
if bench_quantize:
with torch.cuda.stream(torch.cuda.Stream()):
t = triton.testing.do_bench_cudagraph(
if use_cuda_graph:
with torch.cuda.stream(torch.cuda.Stream()):
t = triton.testing.do_bench_cudagraph(
lambda: self.quantize_and_compute(*args, **kwargs)
)
else:
t = triton.testing.do_bench(
lambda: self.quantize_and_compute(*args, **kwargs)
)
else:
Expand Down Expand Up @@ -468,57 +477,52 @@ class FP8RowwiseGroupedGemm(QuantizeOpBase):
FP8 grouped matmul with rowwise scaling.
"""

def quantize_fixed_nk(self, x, w):
group_size = len(x)
m_values = [i.shape[0] for i in x]
# Inputs for fixed nk mode must be contiguous, however in the benchmark
# script they typically are not. Do a little special processing to make them
# work. In practice this wont be needed.
# Start by padding along m dimension with zeros.
max_m = max(m_values)
xq = [
torch.nn.functional.pad(i, (0, 0, 0, max_m - i.shape[0]), value=0)
for i in x
]
# Stack inputs into groups.
xq = torch.stack(xq).contiguous()
wq = torch.stack(w).contiguous()
# Apply quantization.
xq, x_scale = quantize_fp8_row(xq)
wq, w_scale = quantize_fp8_row(wq)
# View these unified tensors as lists of tensors.
xq = [x.squeeze() for x in xq.split(1, dim=0)]
wq = [w.squeeze() for w in wq.split(1, dim=0)]
x_scale = [xs.squeeze() for xs in x_scale.view(group_size, -1).split(1, dim=0)]
w_scale = [ws.squeeze() for ws in w_scale.view(group_size, -1).split(1, dim=0)]

# Return processed tensors.
return (
xq,
wq,
x_scale,
w_scale,
torch.tensor(m_values).to(dtype=torch.int64, device=xq[0].device),
)

def quantize(self, x, w):
assert isinstance(
x, (list, tuple)
), "Inputs to group gemm must be a list of tensors."

def preprocess(self, x, w):
# Apply sparsity to inputs if appropriate.
# First check if N and K are fixed.
m_values = [i.shape[0] for i in x]
n_values = [i.shape[0] for i in w]
k_values = [i.shape[1] for i in w]
# if so, do specialized version of initialization.
# If so, do specialized version of initialization.
if len(np.unique(n_values)) == 1 and len(np.unique(k_values)) == 1:
return self.quantize_fixed_nk(x, w)

# Otherwise handle in eager mode.
xq, x_scale = zip(*[quantize_fp8_row(i) for i in x])
m_values = [i.shape[0] for i in x]
# Inputs for fixed nk mode must be contiguous, however in the benchmark
# script they typically are not. Do a little special processing to make them
# work. In practice this wont be needed.
# Start by padding along m dimension with zeros.
max_m = max(m_values)
x = [
torch.nn.functional.pad(i, (0, 0, 0, max_m - i.shape[0]), value=0)
for i in x
]
# Stack inputs into groups.
x = torch.stack(x).contiguous()
w = torch.stack(w).contiguous()

# Preapply weight quantization.
wq, w_scale = quantize_fp8_row(w)
# Return processed tensors.
return (
x,
wq,
w_scale,
torch.tensor(m_values).to(dtype=torch.int64, device=x[0].device),
)
# Otherwise run without sparsity.
wq, w_scale = zip(*[quantize_fp8_row(i) for i in w])
m_values = None
return xq, wq, x_scale, w_scale, m_values
return x, wq, w_scale, None

def quantize(self, x, wq, w_scale, m_values=None):
# Handle case where inputs are explicitly grouped and non-sparse.
if isinstance(x, (tuple, list)):
xq, x_scale = zip(*[quantize_fp8_row(i) for i in x])
return xq, wq, x_scale, w_scale, m_values
# Otherwise inputs are unified tensors and sparse.
else:
B = x.shape[0]
xq, x_scale = quantize_fp8_row(x, zero_start_index_M=m_values)
x_scale = x_scale.view(B, -1)
return xq, wq, x_scale, w_scale, m_values

def compute(self, xq, wq, x_scale, w_scale, m_values, kernel_name=None):
if m_values is None:
Expand All @@ -530,17 +534,23 @@ def compute(self, xq, wq, x_scale, w_scale, m_values, kernel_name=None):
kernel_name=kernel_name,
)
else:
# Break tensor into groups, simulates what is done e2e.
B = xq.shape[0]
xq_group = [xq[i, :, :] for i in range(B)]
x_scale_group = [x_scale[i, :] for i in range(B)]
wq_group = [wq[i, :, :] for i in range(B)]
w_scale_group = [w_scale[i, :] for i in range(B)]
return torch.ops.fbgemm.f8f8bf16_rowwise_grouped_dynamic(
xq,
wq,
x_scale,
w_scale,
xq_group,
wq_group,
x_scale_group,
w_scale_group,
zero_start_index_M=m_values,
kernel_name=kernel_name,
)

def quantize_and_compute(self, x, w):
xq, wq, x_scale, w_scale, m_values = self.quantize(x, w)
def quantize_and_compute(self, x, wq, w_scale, m_values=None):
xq, wq, x_scale, w_scale, m_values = self.quantize(x, wq, w_scale, m_values)
return self.compute(xq, wq, x_scale, w_scale, m_values)

@property
Expand All @@ -565,55 +575,48 @@ class BF16GroupedGemm(QuantizeOpBase):
BF16 grouped matmul implemented with CK or Cutlass.
"""

def quantize_fixed_nk(self, x, w):
m_values = [i.shape[0] for i in x]
# Inputs for fixed nk mode must be contiguous, however in the benchmark
# script they typically are not. Do a little special processing to make them
# work. In practice this wont be needed.
# Start by padding along m dimension with zeros.
max_m = max(m_values)
xp = [
torch.nn.functional.pad(i, (0, 0, 0, max_m - i.shape[0]), value=0)
for i in x
]
# Stack inputs into groups.
x = torch.stack(xp).contiguous()
w = torch.stack(w).contiguous()
# View these unified tensors as lists of tensors.
x = [xi.squeeze() for xi in x.split(1, dim=0)]
w = [wi.squeeze() for wi in w.split(1, dim=0)]

# Return processed tensors.
return (
x,
w,
torch.tensor(m_values).to(dtype=torch.int64, device=x[0].device),
)

def quantize(self, x, w):
assert isinstance(
x, (list, tuple)
), "Inputs to group gemm must be a list of tensors."

def preprocess(self, x, w):
# Apply sparsity to inputs if appropriate.
# First check if N and K are fixed.
m_values = [i.shape[0] for i in x]
n_values = [i.shape[0] for i in w]
k_values = [i.shape[1] for i in w]
# if so, do specialized version of initialization.
# If so, do specialized version of initialization.
if len(np.unique(n_values)) == 1 and len(np.unique(k_values)) == 1:
return self.quantize_fixed_nk(x, w)
m_values = [i.shape[0] for i in x]
# Inputs for fixed nk mode must be contiguous, however in the benchmark
# script they typically are not. Do a little special processing to make them
# work. In practice this wont be needed.
# Start by padding along m dimension with zeros.
max_m = max(m_values)
x = [
torch.nn.functional.pad(i, (0, 0, 0, max_m - i.shape[0]), value=0)
for i in x
]
# Stack inputs into groups.
x = torch.stack(x).contiguous()
w = torch.stack(w).contiguous()
return (
x,
w,
torch.tensor(m_values).to(dtype=torch.int64, device=x[0].device),
)
return x, w, None

m_values = None
def quantize(self, x, w, m_values=None):
# No action required.
return x, w, m_values

def compute(self, x, w, m_values):
if m_values is None:
return torch.ops.fbgemm.bf16bf16bf16_grouped(x, w)
else:
B = x.shape[0]
x = [x[i, :, :] for i in range(B)]
w = [w[i, :, :] for i in range(B)]
return torch.ops.fbgemm.bf16bf16bf16_grouped_dynamic(x, w, m_values)

def quantize_and_compute(self, x, w):
x, w, m_values = self.quantize(x, w)
def quantize_and_compute(self, x, w, m_values):
return self.compute(x, w, m_values)

@property
Expand Down

0 comments on commit 146442e

Please sign in to comment.