Skip to content

Commit

Permalink
Add tutorial for trainable tensor subclass (#908)
Browse files Browse the repository at this point in the history
Summary: The new tutorial provides an example of how to implement
a trainable tensor subclass that wraps quantized data. This extends
the existing `MyDTypeTensor` with a few necessary steps to ensure
proper gradient updates, namely:

1. Define a differentiable constructor
2. Define backward pass for ops of interest (e.g. torch.nn.functional.linear)
3. Handle special ops used by the optimizer (e.g. aten.add, aten.add_)

Test Plan:
python tutorials/developer_api_guide/my_trainable_tensor_subclass.py
  • Loading branch information
andrewor14 authored Sep 20, 2024
1 parent 53b6b78 commit 23321fb
Show file tree
Hide file tree
Showing 3 changed files with 244 additions and 36 deletions.
Empty file.
80 changes: 44 additions & 36 deletions tutorials/developer_api_guide/my_dtype_tensor_subclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def __new__(
layout_tensor: MyDTypeLayout,
shape: torch.Size,
dtype: Optional[torch.dtype] = None,
requires_grad: bool = False,
):
kwargs = {}
kwargs["device"] = layout_tensor.device
Expand All @@ -86,14 +87,15 @@ def __new__(
else layout_tensor.layout
)
kwargs["dtype"] = dtype
kwargs["requires_grad"] = False
kwargs["requires_grad"] = requires_grad
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]

def __init__(
self,
layout_tensor: MyDTypeLayout,
shape: torch.Size,
dtype: Optional[torch.dtype] = None,
requires_grad: bool = False,
):
self.layout_tensor = layout_tensor

Expand All @@ -108,7 +110,7 @@ def __tensor_flatten__(self):
The first one contains any tensor fields such as int_data and scale as keys to a dictionary
The second one contains all other non tensor type fields as values of a list
"""
return ["layout_tensor"], [self.shape, self.dtype]
return ["layout_tensor"], [self.shape, self.dtype, self.requires_grad]

@classmethod
def __tensor_unflatten__(
Expand All @@ -120,11 +122,12 @@ def __tensor_unflatten__(
tensor_attributes contains all other non tensor type fields
"""
layout_tensor = tensor_data_dict["layout_tensor"]
shape, dtype = tensor_attributes
shape, dtype, requires_grad = tensor_attributes
return cls(
layout_tensor,
shape if outer_size is None else outer_size,
dtype=dtype,
requires_grad=requires_grad,
)

"""classmethod that converts from a floating point Tensor (fp32/fp16/bf16) to the current dtype
Expand Down Expand Up @@ -330,37 +333,42 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
########
# Test #
########
from torchao.utils import benchmark_model

m = M()
example_inputs = (100 * torch.randn(1024, 1024),)
NUM_WARMUPS = 10
NUM_RUNS = 100

for _ in range(NUM_WARMUPS):
m(*example_inputs)
print("before quantization:", benchmark_model(m, NUM_RUNS, example_inputs))

compiled = torch.compile(m, mode="max-autotune")
for _ in range(NUM_WARMUPS):
compiled(*example_inputs)
print("after compile:", benchmark_model(compiled, NUM_RUNS, example_inputs))

# convert weights to quantized weights
m.linear.weight = torch.nn.Parameter(
to_my_dtype(m.linear.weight), requires_grad=False
)

for _ in range(NUM_WARMUPS):
m(*example_inputs)

print("after quantization:", benchmark_model(m, NUM_RUNS, example_inputs))

m = torch.compile(m, mode="max-autotune")

for _ in range(NUM_WARMUPS):
m(*example_inputs)

# NOTE: currently there is no speedup because we just dequantize the weight in the _quantized_linear op
# we plan to add custom op example in the future and that will help us to get speedup
print("after quantization and compile:", benchmark_model(m, NUM_RUNS, example_inputs))
def test():
from torchao.utils import benchmark_model

m = M()
example_inputs = (100 * torch.randn(1024, 1024),)
NUM_WARMUPS = 10
NUM_RUNS = 100

for _ in range(NUM_WARMUPS):
m(*example_inputs)
print("before quantization:", benchmark_model(m, NUM_RUNS, example_inputs))

compiled = torch.compile(m, mode="max-autotune")
for _ in range(NUM_WARMUPS):
compiled(*example_inputs)
print("after compile:", benchmark_model(compiled, NUM_RUNS, example_inputs))

# convert weights to quantized weights
m.linear.weight = torch.nn.Parameter(
to_my_dtype(m.linear.weight), requires_grad=False
)

for _ in range(NUM_WARMUPS):
m(*example_inputs)

print("after quantization:", benchmark_model(m, NUM_RUNS, example_inputs))

m = torch.compile(m, mode="max-autotune")

for _ in range(NUM_WARMUPS):
m(*example_inputs)

# NOTE: currently there is no speedup because we just dequantize the weight in the _quantized_linear op
# we plan to add custom op example in the future and that will help us to get speedup
print("after quantization and compile:", benchmark_model(m, NUM_RUNS, example_inputs))

if __name__ == "__main__":
test()
200 changes: 200 additions & 0 deletions tutorials/developer_api_guide/my_trainable_tensor_subclass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
"""
This is an example for a tensor subclass representing a simple dtype
that can be used in training.
We extend our previous example of `MyDTypeTensor` with a few extra steps
needed to ensure proper gradient updates during training:
1. Define a differentiable constructor
2. Define backward pass for ops of interest (e.g. torch.nn.functional.linear)
3. Handle special ops used by the optimizer (e.g. aten.add, aten.add_)
"""

import torch

from torch.utils._python_dispatch import return_and_correct_aliasing
from torchao.quantization.quant_primitives import choose_qparams_affine, MappingType
from torchao.dtypes.utils import LayoutType, PlainLayoutType
from my_dtype_tensor_subclass import MyDTypeLayout, MyDTypeTensor

aten = torch.ops.aten


##############################
# Tensor Subclass Definition #
##############################

class MyTrainableDTypeTensor(MyDTypeTensor):
"""
Example tensor subclass that extends `MyDTypeTensor` to support training.
"""

@classmethod
def _quantize(
cls,
input_float: torch.Tensor,
layout_type: LayoutType,
) -> MyDTypeLayout:
"""
Convert from a floating point tensor (fp32/fp16/bf16) to the desired dtype.
"""
mapping_type = MappingType.SYMMETRIC
block_size = input_float.shape
dtype = torch.int16
scale, _ = choose_qparams_affine(input_float, mapping_type, block_size, dtype)
int_data = (input_float / scale).to(torch.int8)
layout_tensor_ctr = cls.get_layout_tensor_constructor(type(layout_type))
return layout_tensor_ctr(int_data, scale, layout_type)

@classmethod
def from_float(
cls,
input_float: torch.Tensor,
layout_type: LayoutType = PlainLayoutType(),
) -> "MyTrainableDTypeTensor":
"""
Main entry point for creating a `MyTrainableDTypeTensor`.
This instantiates the tensor subclass in a differentiable constructor
to ensure gradients are passed to the tensor subclass properly during training.
"""
return _ToMyTrainableDTypeTensor.apply(input_float, layout_type)

class _ToMyTrainableDTypeTensor(torch.autograd.Function):
"""
Differentiable constructor for `MyTrainableDTypeTensor`.
"""

@staticmethod
def forward(
ctx: torch.autograd.function.FunctionCtx,
input_float: torch.Tensor,
layout_type: LayoutType,
) -> "MyTrainableDTypeTensor":
layout_tensor = MyTrainableDTypeTensor._quantize(input_float, layout_type)
return MyTrainableDTypeTensor(
layout_tensor,
input_float.shape,
requires_grad=True,
)

@staticmethod
def backward(ctx, gy):
return gy, None

to_my_trainable_dtype = MyTrainableDTypeTensor.from_float


#####################################################
# torch functional and aten operator implementation #
#####################################################

implements = MyTrainableDTypeTensor.implements

class _QuantizedLinearOp(torch.autograd.Function):
"""
Forward and backward definition for linear with quantized weights.
Weights are dequantized during both the forward and the backward passes.
"""

@staticmethod
def forward(
ctx: torch.autograd.function.FunctionCtx,
input_tensor: torch.Tensor,
weight_tensor: torch.Tensor,
) -> torch.Tensor:
assert isinstance(weight_tensor, MyTrainableDTypeTensor)
ctx.save_for_backward(input_tensor, weight_tensor)
weight_tensor = weight_tensor.dequantize()
return torch.nn.functional.linear(input_tensor, weight_tensor)

@staticmethod
def backward(ctx, grad_output):
input_tensor, weight_tensor = ctx.saved_tensors
grad_input = torch.matmul(grad_output, weight_tensor.dequantize())
grad_weight = torch.matmul(
grad_output.view(-1, weight_tensor.shape[0]).T,
input_tensor.view(-1, weight_tensor.shape[1]),
)
return grad_input, grad_weight

@implements(torch.nn.functional.linear)
def _(func, types, args, kwargs):
"""
Handle the linear op with quantized weights.
For simplicity, we run both the forward and backward passes entirely in float.
"""
assert isinstance(args[1], MyTrainableDTypeTensor)
if len(args) > 2 and args[2] is not None:
raise NotImplementedError("linear bias not yet supported")
return _QuantizedLinearOp.apply(args[0], args[1])

@implements(aten.add_.Tensor)
def _(func, types, args, kwargs):
"""
Handle the in-place add op, called by the optimizer to update
the quantized weight during training.
"""
assert len(args) == 2
assert isinstance(args[0], MyTrainableDTypeTensor)
assert args[0].layout_tensor.int_data.dtype == torch.int8
float0 = args[0].dequantize()
float1 = args[1].dequantize() if isinstance(args[1], MyTrainableDTypeTensor) else args[1]
new_value = torch.add(float0, float1, **kwargs)
new_layout_tensor = MyTrainableDTypeTensor._quantize(
new_value,
args[0].layout_tensor.get_layout_type(),
)
args[0].layout_tensor = new_layout_tensor
return return_and_correct_aliasing(func, args, kwargs, args[0])

@implements(aten.add.Tensor)
def _(func, types, args, kwargs):
"""Handle the add op, called by the optimizer during training."""
assert len(args) == 2
assert not isinstance(args[0], MyTrainableDTypeTensor)
assert isinstance(args[1], MyTrainableDTypeTensor)
out = torch.add(args[0], args[1].dequantize(), **kwargs)
return return_and_correct_aliasing(func, args, kwargs, out)


########
# Test #
########

class M(torch.nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.linear = torch.nn.Linear(512, 1024, bias=False)

def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.linear(x)

def main():
m = M().cuda()
NUM_TRAIN_STEPS = 10
VERBOSE = True

# Convert weights to quantized weights
m.linear.weight = torch.nn.Parameter(
to_my_trainable_dtype(m.linear.weight), requires_grad=True,
)

# Dummy training loop
optimizer = torch.optim.SGD(m.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-5)
loss_fn = torch.nn.CrossEntropyLoss()
for i in range(NUM_TRAIN_STEPS):
example_inputs = (torch.randn(512).cuda(),)
target = torch.randn(1024).cuda()
output = m(*example_inputs)
loss = loss_fn(output, target)
loss.backward()
if VERBOSE:
weight = m.linear.weight.layout_tensor.int_data.flatten()[:3]
weight_grad = m.linear.weight.grad.flatten()[:3]
print(" * step %s: weight grad = %s, weight value = %s" % (i, weight_grad, weight))
optimizer.step()
optimizer.zero_grad()

if __name__ == "__main__":
main()

0 comments on commit 23321fb

Please sign in to comment.