diff --git a/benchmarks/benchmark_fp6.py b/benchmarks/benchmark_fp6.py index e9f9d2139..9b8dcf338 100644 --- a/benchmarks/benchmark_fp6.py +++ b/benchmarks/benchmark_fp6.py @@ -1,7 +1,7 @@ import torch import pandas as pd import torch.nn.functional as F -from torchao.dtypes import to_affine_quantized_floatx +from torchao.dtypes import to_affine_quantized_fpx from torchao.dtypes.floatx import FloatxTensorCoreAQTLayout, FloatxTensorCoreLayoutType from torchao.utils import benchmark_torch_function_in_microseconds from tqdm import tqdm @@ -9,7 +9,7 @@ def benchmark(m: int, k: int, n: int): float_data = torch.randn(n, k, dtype=torch.half, device="cuda") - fp6_weight = to_affine_quantized_floatx(float_data, FloatxTensorCoreLayoutType(3, 2)) + fp6_weight = to_affine_quantized_fpx(float_data, FloatxTensorCoreLayoutType(3, 2)) fp16_weight = fp6_weight.dequantize(torch.half) fp16_act = torch.randn(m, k, dtype=torch.half, device="cuda")