Skip to content

Commit

Permalink
[Float8] Add static constructor that will be used in Observer workflow (
Browse files Browse the repository at this point in the history
  • Loading branch information
drisspg authored Sep 13, 2024
1 parent 16b40fd commit 3fa38aa
Show file tree
Hide file tree
Showing 6 changed files with 386 additions and 68 deletions.
2 changes: 2 additions & 0 deletions ruff.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,6 @@ include = [
"torchao/quantization/linear_activation_weight_observer.py",
"test/quantization/test_observer.py",
"test/dtypes/test_affine_quantized_float.py",
"torchao/quantization/weight_tensor_linear_activation_quantization.py"

]
143 changes: 86 additions & 57 deletions test/dtypes/test_affine_quantized_float.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@
float8_weight_only,
float8_dynamic_activation_float8_weight,
)
from torchao.quantization.quant_api import (
float8_static_activation_float8_weight,
)
from torchao.quantization.quant_primitives import choose_qparams_affine, MappingType
from torchao.quantization.observer import PerTensor, PerRow
from torchao.float8.float8_utils import compute_error
import torch
Expand Down Expand Up @@ -50,7 +54,7 @@ class TestAffineQuantizedFloat8Compile(InductorTestCase):
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skipIf(not is_cuda_8_9, "Requires GPU with compute capability >= 8.9")
@common_utils.parametrize("dtype", [torch.bfloat16, torch.float32])
@common_utils.parametrize("mode", ["dynamic", "weight-only"])
@common_utils.parametrize("mode", ["dynamic", "weight-only", "static"])
@common_utils.parametrize("compile", [True, False])
@common_utils.parametrize(
"granularity", [PerTensor(), PerRow()] if is_H100 else [PerTensor()]
Expand All @@ -60,45 +64,57 @@ class TestAffineQuantizedFloat8Compile(InductorTestCase):
"sizes",
[
((128,), 256, 128),
((256,), 512, 256),
((64,), 128, 64),
((32, 128), 64, 256),
((64, 256), 512, 128),
],
)
def test_fp8_linear_variants(
self, dtype: torch.dtype, mode: str, compile: bool, sizes: Tuple, granularity
):
raises = (
isinstance(granularity, PerRow)
and mode == "dynamic"
and dtype != torch.bfloat16
)
context = (
nullcontext()
if not raises
else pytest.raises(
AssertionError,
match="PerRow quantization only works for bfloat16 precision",
)
error_message = None
if isinstance(granularity, PerRow):
if mode == "dynamic" and dtype != torch.bfloat16:
error_message = "PerRow quantization only works for bfloat16 precision"
elif mode == "static":
error_message = (
"Static quantization only supports PerTensor granularity"
)

error_context = (
pytest.raises(AssertionError, match=error_message)
if error_message
else nullcontext()
)
with context:

with error_context:
M, N, K = sizes
input_tensor = torch.randn(*M, K, dtype=dtype, device="cuda")

# Get a "reasonable" scale for the input tensor even though
# we use the same scale for multiple activations
scale, _ = choose_qparams_affine(
input_tensor,
MappingType.SYMMETRIC,
input_tensor.shape,
torch.float8_e4m3fn,
scale_dtype=torch.float32,
)
mode_map = {
"dynamic": partial(
float8_dynamic_activation_float8_weight, granularity=granularity
),
"weight-only": float8_weight_only,
"static": partial(
float8_static_activation_float8_weight,
scale=scale,
granularity=granularity,
),
}

# Create a linear layer with bfloat16 dtype
model = ToyLinearModel(K, N).eval().to(dtype).to("cuda")

quantized_model = copy.deepcopy(model)
factory = mode_map[mode]()
quantize_(model, factory)
quantize_(quantized_model, factory)

if compile:
quantized_model = torch.compile(quantized_model, fullgraph=True)
Expand Down Expand Up @@ -145,14 +161,23 @@ def test_per_row_with_float32(self):

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skipIf(not is_cuda_8_9, "Requires GPU with compute capability >= 8.9")
@common_utils.parametrize("mode", ["dynamic", "weight-only"])
@common_utils.parametrize("mode", ["dynamic", "weight-only", "static"])
def test_serialization(self, mode: str):
# Create and quantize the model
model = ToyLinearModel(16, 32).to(device="cuda")
if mode == "dynamic":
factory = float8_dynamic_activation_float8_weight()
else:
factory = float8_weight_only()

mode_map = {
"dynamic": partial(
float8_dynamic_activation_float8_weight, granularity=PerTensor()
),
"weight-only": float8_weight_only,
"static": partial(
float8_static_activation_float8_weight,
scale=torch.tensor(1.0, dtype=torch.float32, device="cuda"),
granularity=PerTensor(),
),
}
factory = mode_map[mode]()
quantize_(model, factory)

# Save the state dict to an in-memory buffer
Expand All @@ -163,46 +188,50 @@ def test_serialization(self, mode: str):
buffer.seek(0)

# Load the state dict from the buffer
loaded_state_dict = torch.load(buffer)
weights_only_load = True
if mode == "dynamic":
# TODO will fix in followup
weights_only_load = False

loaded_state_dict = torch.load(buffer, weights_only=weights_only_load)

# Create a new model and load the state dict
with torch.device("meta"):
new_model = ToyLinearModel(16, 32)
if mode == "static":
quantize_(new_model, factory)
new_model.load_state_dict(loaded_state_dict, assign=True)

# Compare the original and loaded models
if mode == "weight-only":
model_weight_1 = model.linear1.weight.layout_tensor.float8_data.to(
torch.float32
)
new_model_weight_1 = new_model.linear1.weight.layout_tensor.float8_data.to(
torch.float32
)

model_weight_2 = model.linear2.weight.layout_tensor.float8_data.to(
torch.float32
)
new_model_weight_2 = new_model.linear2.weight.layout_tensor.float8_data.to(
torch.float32
)

else:
model_weight_1 = model.linear1.weight.original_weight_tensor.layout_tensor.float8_data.to(
torch.float32
)
new_model_weight_1 = new_model.linear1.weight.original_weight_tensor.layout_tensor.float8_data.to(
torch.float32
)

model_weight_2 = model.linear2.weight.original_weight_tensor.layout_tensor.float8_data.to(
torch.float32
)
new_model_weight_2 = new_model.linear2.weight.original_weight_tensor.layout_tensor.float8_data.to(
torch.float32
)

assert torch.allclose(model_weight_1, new_model_weight_1)
assert torch.allclose(model_weight_2, new_model_weight_2)
for layer_name in ["linear1", "linear2"]:
original_layer = getattr(model, layer_name)
new_layer = getattr(new_model, layer_name)

# Compare weights
if mode == "weight-only":
original_weight = original_layer.weight.layout_tensor.float8_data.to(
torch.float32
)
new_weight = new_layer.weight.layout_tensor.float8_data.to(
torch.float32
)
else:
original_weight = original_layer.weight.original_weight_tensor.layout_tensor.float8_data.to(
torch.float32
)
new_weight = new_layer.weight.original_weight_tensor.layout_tensor.float8_data.to(
torch.float32
)

assert torch.allclose(
original_weight, new_weight
), f"Weights do not match for {layer_name}"

# Compare scales
if hasattr(original_layer.weight, "scale"):
assert torch.allclose(
original_layer.weight.scale, new_layer.weight.scale
), f"Scales do not match for {layer_name}"


common_utils.instantiate_parametrized_tests(TestAffineQuantizedFloat8Compile)
Expand Down
3 changes: 2 additions & 1 deletion torchao/float8/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
LinearMMConfig,
ScaledMMConfig,
)
from torchao.float8.inference import Float8MMConfig
from torchao.float8.fsdp_utils import precompute_float8_dynamic_scale_for_fsdp

from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
Expand All @@ -31,7 +32,7 @@
if TORCH_VERSION_AT_LEAST_2_5:
# Needed to load Float8Tensor with weights_only = True
from torch.serialization import add_safe_globals
add_safe_globals([Float8Tensor, ScaledMMConfig, GemmInputRole, LinearMMConfig])
add_safe_globals([Float8Tensor, ScaledMMConfig, GemmInputRole, LinearMMConfig, Float8MMConfig])

__all__ = [
# configuration
Expand Down
5 changes: 5 additions & 0 deletions torchao/quantization/observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
MappingType,
ZeroPointDomain,
)
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5

from abc import ABCMeta, abstractmethod
from dataclasses import dataclass
Expand Down Expand Up @@ -222,3 +223,7 @@ def calculate_qparams(self) -> Tuple[torch.Tensor, torch.Tensor]:
self.preserve_zero,
self.zero_point_domain,
)

if TORCH_VERSION_AT_LEAST_2_5:
# Allow a model with LinearActivationQuantizedTensor weights to be loaded with `weights_only=True`
torch.serialization.add_safe_globals([PerRow, PerTensor])
100 changes: 90 additions & 10 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from torchao.dtypes import (
to_affine_quantized_intx,
to_affine_quantized_floatx,
to_affine_quantized_floatx_static,
TensorCoreTiledLayoutType,
PlainLayoutType,
AffineQuantizedTensor,
Expand All @@ -47,6 +48,9 @@
LinearActivationQuantizedTensor,
to_linear_activation_quantized,
)
from torchao.quantization.weight_tensor_linear_activation_quantization import (
to_weight_tensor_with_linear_activation_quantization_metadata,
)

from .quant_primitives import (
MappingType,
Expand Down Expand Up @@ -678,24 +682,40 @@ def _normalize_granularity(
raise ValueError(f"Invalid granularity specification: {granularity}, only PerTensor or PerRow are supported.")


def _input_quant_func_dyanmic_fp8(
def _input_activation_quant_func_fp8(
x: torch.Tensor,
activation_granularity: _fp8_granularities,
activation_dtype: torch.dtype,
scale: Optional[torch.Tensor] = None,
zero_point: Optional[torch.Tensor] = None,
):
"""This function is used to quantize the input activation tensor for an aqt_float variant. If scale
is not provided it will be dynamically calculate the scales otherwise it will use the provided scale.
"""
assert zero_point is None, "Zero point is not supported for dynamic FP8 quantization"
if isinstance(activation_granularity, PerRow):
assert (
x.dtype == torch.bfloat16
), "PerRow quantization only works for bfloat16 precision input activation"

block_size = get_block_size(x.shape, activation_granularity)
activation = to_affine_quantized_floatx(
input_float=x,
block_size=block_size,
target_dtype=activation_dtype,
scale_dtype=torch.float32,
layout_type=Float8LayoutType(mm_config=None), # Config is stored on weight
)
if scale is None:
activation = to_affine_quantized_floatx(
input_float=x,
block_size=block_size,
target_dtype=activation_dtype,
scale_dtype=torch.float32,
layout_type=Float8LayoutType(mm_config=None), # Config is stored on weight
)
else:
assert isinstance(activation_granularity, PerTensor), "Static quantization only supports PerTensor granularity"
activation = to_affine_quantized_floatx_static(
input_float=x,
block_size=block_size,
scale=scale,
target_dtype=activation_dtype,
layout_type=Float8LayoutType(mm_config=None), # Config is stored on weight
)
return activation


Expand Down Expand Up @@ -742,7 +762,7 @@ def apply_float8_dynamic_activation_quant(weight: torch.Tensor):
)

input_quant_func = partial(
_input_quant_func_dyanmic_fp8,
_input_activation_quant_func_fp8,
activation_granularity=activation_granularity,
activation_dtype=activation_dtype,
)
Expand All @@ -755,6 +775,60 @@ def apply_float8_dynamic_activation_quant(weight: torch.Tensor):
return _get_linear_subclass_inserter(apply_float8_dynamic_activation_quant)


def float8_static_activation_float8_weight(
scale: torch.Tensor,
activation_dtype: torch.dtype = torch.float8_e4m3fn,
weight_dtype: torch.dtype = torch.float8_e4m3fn,
granularity: Optional[
Union[_fp8_granularities, Tuple[_fp8_granularities, _fp8_granularities]]
] = None,
mm_config: Optional[Float8MMConfig] = None,
):
"""
Applies float8 static symmetric quantization to
Args:
scale (torch.Tensor): The scale tensor for activation quantization.
activation_dtype (torch.dtype): The target data type for activation quantization. Default is torch.float8_e4m
weight_dtype (torch.dtype): The target data type for weight quantization. Default is torch.float8_e4m
mm_config (Float8MMConfig): Configuration for the matrix multiplication. Default uses fast accumulation.
"""
if mm_config is None:
mm_config = Float8MMConfig(use_fast_accum=True)

activation_granularity, weight_granularity = _normalize_granularity(granularity)
assert isinstance(
activation_granularity, PerTensor
), "Static quantization only supports PerTensor granularity"

def apply_float8_static_activation_quant(weight: torch.Tensor):
block_size = get_block_size(weight.shape, weight_granularity)
quantized_weight = to_affine_quantized_floatx(
input_float=weight,
block_size=block_size,
target_dtype=weight_dtype,
scale_dtype=torch.float32,
layout_type=Float8LayoutType(mm_config=mm_config),
)

input_quant_func = _input_activation_quant_func_fp8
input_quant_kwargs = {
"activation_granularity": activation_granularity,
"activation_dtype": activation_dtype,
}

quantized_weight = to_weight_tensor_with_linear_activation_quantization_metadata(
quantized_weight,
input_quant_func,
scale=scale,
zero_point=None,
quant_kwargs=input_quant_kwargs
)
return quantized_weight

return _get_linear_subclass_inserter(apply_float8_static_activation_quant)


def uintx_weight_only(dtype, group_size=64, pack_dim=-1, use_hqq=False):
"""
Applies uintx weight-only asymmetric per-group quantization to linear layers, using uintx quantization where
Expand Down Expand Up @@ -836,4 +910,10 @@ def apply_quant_llm(weight: torch.Tensor) -> torch.Tensor:


if TORCH_VERSION_AT_LEAST_2_5:
torch.serialization.add_safe_globals([_int8_asymm_per_token_quant, _int8_symm_per_token_reduced_range_quant])
torch.serialization.add_safe_globals(
[
_int8_asymm_per_token_quant,
_int8_symm_per_token_reduced_range_quant,
_input_activation_quant_func_fp8,
]
)
Loading

0 comments on commit 3fa38aa

Please sign in to comment.