Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tutorial for trainable tensor subclass #908

Merged
merged 1 commit into from
Sep 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
80 changes: 44 additions & 36 deletions tutorials/developer_api_guide/my_dtype_tensor_subclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def __new__(
layout_tensor: MyDTypeLayout,
shape: torch.Size,
dtype: Optional[torch.dtype] = None,
requires_grad: bool = False,
):
kwargs = {}
kwargs["device"] = layout_tensor.device
Expand All @@ -86,14 +87,15 @@ def __new__(
else layout_tensor.layout
)
kwargs["dtype"] = dtype
kwargs["requires_grad"] = False
kwargs["requires_grad"] = requires_grad
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]

def __init__(
self,
layout_tensor: MyDTypeLayout,
shape: torch.Size,
dtype: Optional[torch.dtype] = None,
requires_grad: bool = False,
):
self.layout_tensor = layout_tensor

Expand All @@ -108,7 +110,7 @@ def __tensor_flatten__(self):
The first one contains any tensor fields such as int_data and scale as keys to a dictionary
The second one contains all other non tensor type fields as values of a list
"""
return ["layout_tensor"], [self.shape, self.dtype]
return ["layout_tensor"], [self.shape, self.dtype, self.requires_grad]

@classmethod
def __tensor_unflatten__(
Expand All @@ -120,11 +122,12 @@ def __tensor_unflatten__(
tensor_attributes contains all other non tensor type fields
"""
layout_tensor = tensor_data_dict["layout_tensor"]
shape, dtype = tensor_attributes
shape, dtype, requires_grad = tensor_attributes
return cls(
layout_tensor,
shape if outer_size is None else outer_size,
dtype=dtype,
requires_grad=requires_grad,
)

"""classmethod that converts from a floating point Tensor (fp32/fp16/bf16) to the current dtype
Expand Down Expand Up @@ -330,37 +333,42 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
########
# Test #
########
from torchao.utils import benchmark_model

m = M()
example_inputs = (100 * torch.randn(1024, 1024),)
NUM_WARMUPS = 10
NUM_RUNS = 100

for _ in range(NUM_WARMUPS):
m(*example_inputs)
print("before quantization:", benchmark_model(m, NUM_RUNS, example_inputs))

compiled = torch.compile(m, mode="max-autotune")
for _ in range(NUM_WARMUPS):
compiled(*example_inputs)
print("after compile:", benchmark_model(compiled, NUM_RUNS, example_inputs))

# convert weights to quantized weights
m.linear.weight = torch.nn.Parameter(
to_my_dtype(m.linear.weight), requires_grad=False
)

for _ in range(NUM_WARMUPS):
m(*example_inputs)

print("after quantization:", benchmark_model(m, NUM_RUNS, example_inputs))

m = torch.compile(m, mode="max-autotune")

for _ in range(NUM_WARMUPS):
m(*example_inputs)

# NOTE: currently there is no speedup because we just dequantize the weight in the _quantized_linear op
# we plan to add custom op example in the future and that will help us to get speedup
print("after quantization and compile:", benchmark_model(m, NUM_RUNS, example_inputs))
def test():
from torchao.utils import benchmark_model

m = M()
example_inputs = (100 * torch.randn(1024, 1024),)
NUM_WARMUPS = 10
NUM_RUNS = 100

for _ in range(NUM_WARMUPS):
m(*example_inputs)
print("before quantization:", benchmark_model(m, NUM_RUNS, example_inputs))

compiled = torch.compile(m, mode="max-autotune")
for _ in range(NUM_WARMUPS):
compiled(*example_inputs)
print("after compile:", benchmark_model(compiled, NUM_RUNS, example_inputs))

# convert weights to quantized weights
m.linear.weight = torch.nn.Parameter(
to_my_dtype(m.linear.weight), requires_grad=False
)

for _ in range(NUM_WARMUPS):
m(*example_inputs)

print("after quantization:", benchmark_model(m, NUM_RUNS, example_inputs))

m = torch.compile(m, mode="max-autotune")

for _ in range(NUM_WARMUPS):
m(*example_inputs)

# NOTE: currently there is no speedup because we just dequantize the weight in the _quantized_linear op
# we plan to add custom op example in the future and that will help us to get speedup
print("after quantization and compile:", benchmark_model(m, NUM_RUNS, example_inputs))

if __name__ == "__main__":
test()
200 changes: 200 additions & 0 deletions tutorials/developer_api_guide/my_trainable_tensor_subclass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
"""
This is an example for a tensor subclass representing a simple dtype
that can be used in training.
We extend our previous example of `MyDTypeTensor` with a few extra steps
needed to ensure proper gradient updates during training:
1. Define a differentiable constructor
2. Define backward pass for ops of interest (e.g. torch.nn.functional.linear)
3. Handle special ops used by the optimizer (e.g. aten.add, aten.add_)
"""

import torch

from torch.utils._python_dispatch import return_and_correct_aliasing
from torchao.quantization.quant_primitives import choose_qparams_affine, MappingType
from torchao.dtypes.utils import LayoutType, PlainLayoutType
from my_dtype_tensor_subclass import MyDTypeLayout, MyDTypeTensor

aten = torch.ops.aten


##############################
# Tensor Subclass Definition #
##############################

class MyTrainableDTypeTensor(MyDTypeTensor):
"""
Example tensor subclass that extends `MyDTypeTensor` to support training.
"""

@classmethod
def _quantize(
cls,
input_float: torch.Tensor,
layout_type: LayoutType,
) -> MyDTypeLayout:
"""
Convert from a floating point tensor (fp32/fp16/bf16) to the desired dtype.
"""
mapping_type = MappingType.SYMMETRIC
block_size = input_float.shape
dtype = torch.int16
scale, _ = choose_qparams_affine(input_float, mapping_type, block_size, dtype)
int_data = (input_float / scale).to(torch.int8)
layout_tensor_ctr = cls.get_layout_tensor_constructor(type(layout_type))
return layout_tensor_ctr(int_data, scale, layout_type)

@classmethod
def from_float(
cls,
input_float: torch.Tensor,
layout_type: LayoutType = PlainLayoutType(),
) -> "MyTrainableDTypeTensor":
"""
Main entry point for creating a `MyTrainableDTypeTensor`.
This instantiates the tensor subclass in a differentiable constructor
to ensure gradients are passed to the tensor subclass properly during training.
"""
return _ToMyTrainableDTypeTensor.apply(input_float, layout_type)

class _ToMyTrainableDTypeTensor(torch.autograd.Function):
"""
Differentiable constructor for `MyTrainableDTypeTensor`.
"""

@staticmethod
def forward(
ctx: torch.autograd.function.FunctionCtx,
input_float: torch.Tensor,
layout_type: LayoutType,
) -> "MyTrainableDTypeTensor":
layout_tensor = MyTrainableDTypeTensor._quantize(input_float, layout_type)
return MyTrainableDTypeTensor(
layout_tensor,
input_float.shape,
requires_grad=True,
)

@staticmethod
def backward(ctx, gy):
return gy, None

to_my_trainable_dtype = MyTrainableDTypeTensor.from_float


#####################################################
# torch functional and aten operator implementation #
#####################################################

implements = MyTrainableDTypeTensor.implements

class _QuantizedLinearOp(torch.autograd.Function):
"""
Forward and backward definition for linear with quantized weights.
Weights are dequantized during both the forward and the backward passes.
"""

@staticmethod
def forward(
ctx: torch.autograd.function.FunctionCtx,
input_tensor: torch.Tensor,
weight_tensor: torch.Tensor,
) -> torch.Tensor:
assert isinstance(weight_tensor, MyTrainableDTypeTensor)
ctx.save_for_backward(input_tensor, weight_tensor)
weight_tensor = weight_tensor.dequantize()
return torch.nn.functional.linear(input_tensor, weight_tensor)

@staticmethod
def backward(ctx, grad_output):
input_tensor, weight_tensor = ctx.saved_tensors
grad_input = torch.matmul(grad_output, weight_tensor.dequantize())
Copy link
Contributor

@jerryzh168 jerryzh168 Sep 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry I guess I didn't mean to dequantize this, but was asking if this call into something like F.linear(grad_output, weight_tensor) and be dispatched to the quantized linear impl of weight_tensor.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh actually the old code was also doing dequantize math. I just rewrote it so it's cleaner. Not sure if I understand how we can dispatch to the quantized linear impl, since this is called from F.linear(..., weight_subclass_tensor) already?

grad_weight = torch.matmul(
grad_output.view(-1, weight_tensor.shape[0]).T,
input_tensor.view(-1, weight_tensor.shape[1]),
)
return grad_input, grad_weight

@implements(torch.nn.functional.linear)
def _(func, types, args, kwargs):
"""
Handle the linear op with quantized weights.
For simplicity, we run both the forward and backward passes entirely in float.
"""
assert isinstance(args[1], MyTrainableDTypeTensor)
if len(args) > 2 and args[2] is not None:
raise NotImplementedError("linear bias not yet supported")
return _QuantizedLinearOp.apply(args[0], args[1])

@implements(aten.add_.Tensor)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are add_/add both going to be differentiable as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think they don't need to be. These are just used by the optimizer

def _(func, types, args, kwargs):
"""
Handle the in-place add op, called by the optimizer to update
the quantized weight during training.
"""
assert len(args) == 2
assert isinstance(args[0], MyTrainableDTypeTensor)
assert args[0].layout_tensor.int_data.dtype == torch.int8
float0 = args[0].dequantize()
float1 = args[1].dequantize() if isinstance(args[1], MyTrainableDTypeTensor) else args[1]
new_value = torch.add(float0, float1, **kwargs)
new_layout_tensor = MyTrainableDTypeTensor._quantize(
new_value,
args[0].layout_tensor.get_layout_type(),
)
args[0].layout_tensor = new_layout_tensor
return return_and_correct_aliasing(func, args, kwargs, args[0])

@implements(aten.add.Tensor)
def _(func, types, args, kwargs):
"""Handle the add op, called by the optimizer during training."""
assert len(args) == 2
assert not isinstance(args[0], MyTrainableDTypeTensor)
assert isinstance(args[1], MyTrainableDTypeTensor)
out = torch.add(args[0], args[1].dequantize(), **kwargs)
return return_and_correct_aliasing(func, args, kwargs, out)


########
# Test #
########

class M(torch.nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.linear = torch.nn.Linear(512, 1024, bias=False)

def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.linear(x)

def main():
m = M().cuda()
NUM_TRAIN_STEPS = 10
VERBOSE = True

# Convert weights to quantized weights
m.linear.weight = torch.nn.Parameter(
to_my_trainable_dtype(m.linear.weight), requires_grad=True,
)

# Dummy training loop
optimizer = torch.optim.SGD(m.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-5)
loss_fn = torch.nn.CrossEntropyLoss()
for i in range(NUM_TRAIN_STEPS):
example_inputs = (torch.randn(512).cuda(),)
target = torch.randn(1024).cuda()
output = m(*example_inputs)
loss = loss_fn(output, target)
loss.backward()
if VERBOSE:
weight = m.linear.weight.layout_tensor.int_data.flatten()[:3]
weight_grad = m.linear.weight.grad.flatten()[:3]
print(" * step %s: weight grad = %s, weight value = %s" % (i, weight_grad, weight))
optimizer.step()
optimizer.zero_grad()

if __name__ == "__main__":
main()
Loading