Skip to content

Commit

Permalink
Renaming fpx to floatx (#877)
Browse files Browse the repository at this point in the history
* Renaming fpx to floatx

Summary:
att, to allow float8 code to be moved to floatx folder
fpx_weight_only is not yet renamed to floatx_weight_only yet, we'll do that
in the future after we have more clarity on what specific dtypes we want to support (e.g. maybe we'll
just support fp4, fp6)

Test Plan:
python test/dtypes/test_floatx.py

Reviewers:

Subscribers:

Tasks:

Tags:

* fix test_ops
  • Loading branch information
jerryzh168 authored Sep 12, 2024
1 parent f82071d commit 8236a87
Show file tree
Hide file tree
Showing 17 changed files with 188 additions and 188 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ The best example we have combining the composability of lower bit dtype with com

We've added support for authoring and releasing [custom ops](./torchao/csrc/) that do not graph break with `torch.compile()` so if you love writing kernels but hate packaging them so they work all operating systems and cuda versions, we'd love to accept contributions for your custom ops. We have a few examples you can follow

1. [fp6](torchao/prototype/quant_llm/) for 2x faster inference over fp16 with an easy to use API `quantize_(model, fp6_llm_weight_only())`
1. [fp6](torchao/dtypes/floatx) for 2x faster inference over fp16 with an easy to use API `quantize_(model, fpx_weight_only(3, 2))`
2. [2:4 Sparse Marlin GEMM](https://github.com/pytorch/ao/pull/733) 2x speedups for FP16xINT4 kernels even at batch sizes up to 256
3. [int4 tinygemm unpacker](https://github.com/pytorch/ao/pull/415) which makes it easier to switch quantized backends for inference

Expand Down
6 changes: 3 additions & 3 deletions benchmarks/benchmark_fp6.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import torch
import pandas as pd
import torch.nn.functional as F
from torchao.dtypes import to_affine_quantized_fpx
from torchao.dtypes.fpx import FpxTensorCoreAQTLayout, FpxTensorCoreLayoutType
from torchao.dtypes import to_affine_quantized_floatx
from torchao.dtypes.floatx import FloatxTensorCoreAQTLayout, FloatxTensorCoreLayoutType
from torchao.utils import benchmark_torch_function_in_microseconds
from tqdm import tqdm


def benchmark(m: int, k: int, n: int):
float_data = torch.randn(n, k, dtype=torch.half, device="cuda")
fp6_weight = to_affine_quantized_fpx(float_data, FpxTensorCoreLayoutType(3, 2))
fp6_weight = to_affine_quantized_floatx(float_data, FloatxTensorCoreLayoutType(3, 2))
fp16_weight = fp6_weight.dequantize(torch.half)

fp16_act = torch.randn(m, k, dtype=torch.half, device="cuda")
Expand Down
72 changes: 36 additions & 36 deletions test/dtypes/test_fpx.py → test/dtypes/test_floatx.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@
parametrize,
run_tests,
)
from torchao.dtypes.fpx import (
FpxTensorCoreAQTLayout,
FpxTensorCoreLayoutType,
to_scaled_tc_fpx,
from_scaled_tc_fpx,
from torchao.dtypes.floatx import (
FloatxTensorCoreAQTLayout,
FloatxTensorCoreLayoutType,
to_scaled_tc_floatx,
from_scaled_tc_floatx,
)
from torchao.dtypes.fpx.fpx import _pack_tc_fpx, _pack_tc_fp6
from torchao.prototype.custom_fp_utils import _f32_to_fpx_unpacked, _fpx_unpacked_to_f32
from torchao.dtypes.floatx.floatx import _pack_tc_floatx, _pack_tc_fp6
from torchao.prototype.custom_fp_utils import _f32_to_floatx_unpacked, _floatx_unpacked_to_f32
from torchao.quantization import (
quantize_,
fpx_weight_only,
Expand All @@ -25,71 +25,71 @@


_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else [])
_FPx_DTYPES = [(3, 2), (2, 2)]
_Floatx_DTYPES = [(3, 2), (2, 2)]


class TestFpxTensorCoreAQTLayout(TestCase):
class TestFloatxTensorCoreAQTLayout(TestCase):
@parametrize("device", _DEVICES)
def test_pack_tc_fp6_correctness(self, device):
x = torch.randint(256, size=(256, 64), dtype=torch.uint8, device=device)

expected = _pack_tc_fpx(x, 6)
expected = _pack_tc_floatx(x, 6)
actual = _pack_tc_fp6(x)
torch.testing.assert_close(actual, expected)

@parametrize("ebits,mbits", _FPx_DTYPES)
@parametrize("ebits,mbits", _Floatx_DTYPES)
@parametrize("device", _DEVICES)
def test_to_scaled_tc_fpx_compile(self, ebits, mbits, device):
def test_to_scaled_tc_floatx_compile(self, ebits, mbits, device):
x = torch.randn(256, 64, device=device)

expected = to_scaled_tc_fpx(x, ebits, mbits)
actual = torch.compile(to_scaled_tc_fpx, fullgraph=True)(x, ebits, mbits)
expected = to_scaled_tc_floatx(x, ebits, mbits)
actual = torch.compile(to_scaled_tc_floatx, fullgraph=True)(x, ebits, mbits)
torch.testing.assert_close(actual, expected)

@parametrize("ebits,mbits", _FPx_DTYPES)
@parametrize("ebits,mbits", _Floatx_DTYPES)
@parametrize("device", _DEVICES)
def test_from_tc_fpx_correctness(self, ebits, mbits, device):
def test_from_tc_floatx_correctness(self, ebits, mbits, device):
x = torch.randn(256, 64, device=device) * 100

# quantize and dequantize so that the values are exactly representable in FPx
x = _fpx_unpacked_to_f32(_f32_to_fpx_unpacked(x, ebits, mbits), ebits, mbits)
# quantize and dequantize so that the values are exactly representable in Floatx
x = _floatx_unpacked_to_f32(_f32_to_floatx_unpacked(x, ebits, mbits), ebits, mbits)

tc_fpx, scale = to_scaled_tc_fpx(x, ebits, mbits)
actual = from_scaled_tc_fpx(tc_fpx, ebits, mbits, scale=scale)
tc_floatx, scale = to_scaled_tc_floatx(x, ebits, mbits)
actual = from_scaled_tc_floatx(tc_floatx, ebits, mbits, scale=scale)
torch.testing.assert_close(actual, x)

@parametrize("ebits,mbits", _FPx_DTYPES)
@parametrize("ebits,mbits", _Floatx_DTYPES)
@parametrize("device", _DEVICES)
def test_from_scaled_tc_fpx_compile(self, ebits, mbits, device):
def test_from_scaled_tc_floatx_compile(self, ebits, mbits, device):
M, N = 256, 64
nbits = 1 + ebits + mbits
x = torch.randint(256, size=(M, N // 8 * nbits), dtype=torch.uint8, device=device)
scale = torch.randn(M, device=device)

expected = from_scaled_tc_fpx(x, ebits, mbits, scale)
actual = torch.compile(from_scaled_tc_fpx, fullgraph=True)(x, ebits, mbits, scale)
expected = from_scaled_tc_floatx(x, ebits, mbits, scale)
actual = torch.compile(from_scaled_tc_floatx, fullgraph=True)(x, ebits, mbits, scale)
torch.testing.assert_close(actual, expected)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@parametrize("ebits,mbits", _FPx_DTYPES)
@parametrize("ebits,mbits", _Floatx_DTYPES)
def test_to_copy_device(self, ebits, mbits):
from torchao.quantization.quant_primitives import (
choose_qparams_affine_fpx,
quantize_affine_fpx,
choose_qparams_affine_floatx,
quantize_affine_floatx,
)

x = torch.randn(256, 64)
scale = choose_qparams_affine_fpx(x, ebits, mbits)
x = quantize_affine_fpx(x, scale, ebits, mbits)
layout_type = FpxTensorCoreLayoutType(ebits, mbits)
fpx_layout_tensor = FpxTensorCoreAQTLayout.from_plain(x, scale, None, layout_type).cuda()
assert fpx_layout_tensor.device.type == "cuda"
fpx_layout_tensor = fpx_layout_tensor.cpu()
assert fpx_layout_tensor.device.type == "cpu"
scale = choose_qparams_affine_floatx(x, ebits, mbits)
x = quantize_affine_floatx(x, scale, ebits, mbits)
layout_type = FloatxTensorCoreLayoutType(ebits, mbits)
floatx_layout_tensor = FloatxTensorCoreAQTLayout.from_plain(x, scale, None, layout_type).cuda()
assert floatx_layout_tensor.device.type == "cuda"
floatx_layout_tensor = floatx_layout_tensor.cpu()
assert floatx_layout_tensor.device.type == "cpu"

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="quantization only works with torch.compile for 2.5+")
@parametrize("ebits,mbits", _FPx_DTYPES)
@parametrize("ebits,mbits", _Floatx_DTYPES)
@parametrize("bias", [False, True])
def test_fpx_weight_only(self, ebits, mbits, bias):
N, OC, IC = 4, 256, 64
Expand All @@ -106,7 +106,7 @@ def test_fpx_weight_only(self, ebits, mbits, bias):
torch.testing.assert_close(actual, expected)


instantiate_parametrized_tests(TestFpxTensorCoreAQTLayout)
instantiate_parametrized_tests(TestFloatxTensorCoreAQTLayout)


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion test/dtypes/test_uintx.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import torch

from torchao.dtypes.uintx.Uintx import to_uintx
from torchao.dtypes.uintx.uintx import to_uintx
from torchao.quantization.quant_api import quantize_, uintx_weight_only
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_3,
Expand Down
26 changes: 13 additions & 13 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
)
from torch.testing._internal.optests import opcheck
from torchao.utils import is_fbcode, TORCH_VERSION_AT_LEAST_2_5, compute_max_diff
from torchao.dtypes.fpx import from_scaled_tc_fpx
from torchao.dtypes.floatx import from_scaled_tc_floatx
from torchao.sparsity.marlin import marlin_24_workspace, pack_to_marlin_24, inject_24
import pytest

Expand All @@ -33,13 +33,13 @@


class TestOps(TestCase):
def _create_fpx_inputs(self, ebits: int, mbits: int, BS: int, OC: int, IC: int, device):
def _create_floatx_inputs(self, ebits: int, mbits: int, BS: int, OC: int, IC: int, device):
# Randomly initialize each byte
nbits = 1 + ebits + mbits
fpx_weight = torch.randint(256, (OC, IC // 8 * nbits), dtype=torch.uint8)
floatx_weight = torch.randint(256, (OC, IC // 8 * nbits), dtype=torch.uint8)
scale = torch.rand(OC).half() + 0.5
fp16_act = torch.rand(BS, IC).half() + 0.5
return fpx_weight.to(device), scale.to(device), fp16_act.to(device)
return floatx_weight.to(device), scale.to(device), fp16_act.to(device)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@parametrize("ebits,mbits", [(3, 2), (2, 2)])
Expand All @@ -48,28 +48,28 @@ def test_quant_llm_linear(self, ebits, mbits):
OC = 256
IC = 256
splitK = 1
fpx_weight, scale, fp16_act = self._create_fpx_inputs(ebits, mbits, BS, OC, IC, "cuda")
floatx_weight, scale, fp16_act = self._create_floatx_inputs(ebits, mbits, BS, OC, IC, "cuda")

# smoke test
torchao.ops.quant_llm_linear(ebits, mbits, fp16_act, fpx_weight, scale, splitK)
torchao.ops.quant_llm_linear(ebits, mbits, fp16_act, floatx_weight, scale, splitK)

# comprehensive testing
test_utils = ["test_schema", "test_autograd_registration", "test_faketensor", "test_aot_dispatch_dynamic"]
opcheck(torch.ops.torchao.quant_llm_linear, (ebits, mbits, fp16_act, fpx_weight, scale, splitK), test_utils=test_utils)
opcheck(torch.ops.torchao.quant_llm_linear, (ebits, mbits, fp16_act, floatx_weight, scale, splitK), test_utils=test_utils)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@parametrize("BS,OC,IC,splitK", [(1, 2048, 4096, 5), (2, 8192, 8192, 6)])
@parametrize("ebits,mbits", [(3, 2), (2, 2)])
def test_quant_llm_linear_correctness(self, ebits, mbits, BS, OC, IC, splitK):
# adapted from https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/tests/python/kernel_test_fpx.py
fpx_weight, scale, fp16_act = self._create_fpx_inputs(ebits, mbits, BS, OC, IC, "cuda")
floatx_weight, scale, fp16_act = self._create_floatx_inputs(ebits, mbits, BS, OC, IC, "cuda")

results_fpx = torchao.ops.quant_llm_linear(ebits, mbits, fp16_act, fpx_weight, scale, splitK)
results_floatx = torchao.ops.quant_llm_linear(ebits, mbits, fp16_act, floatx_weight, scale, splitK)

fp16_weight = from_scaled_tc_fpx(fpx_weight, ebits, mbits, scale).half()
fp16_weight = from_scaled_tc_floatx(floatx_weight, ebits, mbits, scale).half()
results_fp16 = fp16_act @ fp16_weight.T

error = (results_fpx - results_fp16).abs().mean()
error = (results_floatx - results_fp16).abs().mean()
gt = results_fp16.abs().mean()
relative_error = error / gt
assert relative_error < 1e-3
Expand Down Expand Up @@ -319,7 +319,7 @@ def test_dequantize_tensor_core_tiled_layout_op(shape, inner_k_tiles, group_size
MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128]

MARLIN_TEST_PARAMS = list(itertools.product(
MARLIN_24_BATCH_SIZE, MARLIN_24_K_CHUNKS, MARLIN_24_N_CHUNKS,
MARLIN_24_BATCH_SIZE, MARLIN_24_K_CHUNKS, MARLIN_24_N_CHUNKS,
MARLIN_24_SUPPORTED_NUM_BITS, MARLIN_24_SUPPORTED_GROUP_SIZES, MNK_FACTORS
))

Expand Down Expand Up @@ -405,7 +405,7 @@ def test_marlin_24(batch_size, k_chunk, n_chunk, num_bits, group_size, mnk_facto
workspace_24 = marlin_24_workspace(size_n)

fn_inputs = (
input_2d, marlin_24_q_w_comp, meta, marlin_24_scale, workspace_24,
input_2d, marlin_24_q_w_comp, meta, marlin_24_scale, workspace_24,
num_bits, a_input_in, marlin_24_scale.shape[1], a_input_out,
)
output = torchao.ops.marlin_24_gemm(*fn_inputs)
Expand Down
Loading

0 comments on commit 8236a87

Please sign in to comment.