Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding example for quantized tensor + tensor parallelism #785

Merged
merged 18 commits into from
Sep 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 100 additions & 32 deletions tutorials/developer_api_guide/my_dtype_tensor_subclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,12 @@
import torch

from torch.utils._python_dispatch import return_and_correct_aliasing
from torchao.quantization.quant_primitives import choose_qparams_affine, MappingType
from torchao.quantization.quant_primitives import (
choose_qparams_affine,
MappingType,
quantize_affine,
dequantize_affine,
)
from torchao.dtypes.utils import (
LayoutType,
PlainLayoutType,
Expand All @@ -24,6 +29,32 @@

aten = torch.ops.aten

# TODO: move to torchao/utils.py
def fill_defaults(args, n, defaults_tail):
"""
__torch_dispatch__ doesn't guarantee the number of arguments you are
passed (e.g., defaulted arguments are not passed); but usually it is
convenient to pad out the arguments list with defaults. This function
helps you do that.
Args:
args: the list of positional arguments passed to __torch_dispatch__
n: the number of arguments you are expecting to get
defaults_tail: default values for the arguments, starting from the
end of the list
Example:
>>> fill_defaults([1, 2, 3], 5, [3, 4, 5])
[1, 2, 3, 4, 5]
>>> fill_defaults([1, 2, 3], 5, [None, None, None])
[1, 2, 3, None, None]]
"""
if n - len(defaults_tail) > len(args):
raise RuntimeError("not enough defaults to fill arguments")
r = list(args)
for i in range(len(args), n):
r.append(defaults_tail[i - n + len(defaults_tail)])
return r


###############################
# Base Layout Tensor Subclass #
###############################
Expand Down Expand Up @@ -140,10 +171,10 @@ def from_float(
layout_type: LayoutType = PlainLayoutType(),
):
mapping_type = MappingType.SYMMETRIC
block_size = input_float.shape
block_size = (1, input_float.shape[-1])
dtype = torch.int16
scale, _ = choose_qparams_affine(input_float, mapping_type, block_size, dtype)
int_data = (input_float / scale).to(torch.int8)
scale, zero_point = choose_qparams_affine(input_float, mapping_type, block_size, dtype)
int_data = quantize_affine(input_float, block_size, scale, zero_point, dtype)
layout_tensor_ctr = get_layout_tensor_constructor(type(layout_type))
layout_tensor = layout_tensor_ctr(int_data, scale, layout_type)
return cls(layout_tensor, input_float.shape)
Expand All @@ -160,7 +191,14 @@ def dequantize(self, output_dtype=None):
if output_dtype is None:
output_dtype = torch.get_default_dtype()
int_data, scale = self.layout_tensor.get_plain()
return int_data.to(output_dtype) * scale
transposed = False
block_size = (1, int_data.shape[-1])
if hasattr(self.layout_tensor, "transposed") and self.layout_tensor.transposed:
transposed = True
res = dequantize_affine(int_data, block_size, scale, None, int_data.dtype, output_dtype=output_dtype)
if transposed:
res = res.t()
return res

def __repr__(self):
return (
Expand Down Expand Up @@ -203,6 +241,7 @@ def __new__(
cls,
int_data: torch.Tensor,
scale: torch.Tensor,
transposed: bool,
layout_type: LayoutType,
):
kwargs = {}
Expand All @@ -219,22 +258,24 @@ def __init__(
self,
int_data: torch.Tensor,
scale: torch.Tensor,
transposed: bool,
layout_type: LayoutType,
):
self.int_data = int_data
self.scale = scale
self.transposed = transposed
self.layout_type = layout_type

def __tensor_flatten__(self):
return ["int_data", "scale"], [self.layout_type]
return ["int_data", "scale"], [self.transposed, self.layout_type]

@classmethod
def __tensor_unflatten__(
cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride
):
int_data, scale = tensor_data_dict["int_data"], tensor_data_dict["scale"]
layout_type, = tensor_attributes
return cls(int_data, scale, layout_type)
transposed, layout_type, = tensor_attributes
return cls(int_data, scale, transposed, layout_type)

@classmethod
def from_plain(
Expand All @@ -247,12 +288,13 @@ def from_plain(
extra metadata for packing etc.
"""
assert isinstance(layout_type, PlainLayoutType)
return cls(int_data, scale, layout_type)
return cls(int_data, scale, False, layout_type)

def _apply_fn_to_data(self, fn):
return self.__class__(
fn(self.int_data),
fn(self.scale),
self.transposed,
self.layout_type,
)

Expand All @@ -265,8 +307,36 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)
)

# Tensor parallel support START
elif func in [aten._to_copy.default, aten.clone.default]:
return return_and_correct_aliasing(
func, args, kwargs, args[0]._apply_fn_to_data(torch.clone)
)
elif func is aten.split.Tensor:
int_data_list = func(args[0].int_data, *args[1:], **kwargs)
scale_list = func(args[0].scale, *args[1:], **kwargs)
out = [PlainMyDTypeLayout(int_data, scale, args[0].transposed, args[0].layout_type) for int_data, scale in zip(int_data_list, scale_list)]
return out
elif func is aten.empty_like.default:
int_data_empty_like = func(args[0].int_data, *args[1:], **kwargs)
return PlainMyDTypeLayout(int_data_empty_like, args[0].scale, args[0].transposed, args[0].layout_type)
elif func is aten.slice.Tensor:
self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1])
if dim == 0:
return return_and_correct_aliasing(
func, args, kwargs, args[0]._apply_fn_to_data(lambda x: aten.slice.Tensor(x, dim, start, end, step))
)
elif dim == 1:
return PlainMyDTypeLayout(aten.slice.Tensor(self.int_data, dim, start, end, step), self.scale.view(-1, 1), self.transposed, self.layout_type)
else:
raise NotImplementedError(f"PlainMyDTypeLayout dispatch: attempting to run {func}, with dim={dim}, that is not supported")
elif func is aten.t.default:
return return_and_correct_aliasing(func, args, kwargs, PlainMyDTypeLayout(args[0].int_data, args[0].scale, not args[0].transposed, args[0].layout_type))

# Tensor parallel support END

raise NotImplementedError(
f"MyDTypeLayout dispatch: attempting to run {func}, this is not supported"
f"PlainMyDTypeLayout dispatch: attempting to run {func}, this is not supported"
)

#####################################################
Expand Down Expand Up @@ -315,15 +385,6 @@ def _(func, types, args, kwargs):
func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)
)


class M(torch.nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.linear = torch.nn.Linear(1024, 1024)

def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.linear(x)

#####################
# Factory functions #
#####################
Expand All @@ -333,42 +394,49 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
########
# Test #
########

def test():
def main():
from torchao.utils import benchmark_model


class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear = torch.nn.Linear(1024, 128)

def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.linear(x)

m = M()
example_inputs = (100 * torch.randn(1024, 1024),)
example_inputs = (100 * torch.randn(512, 1024),)
NUM_WARMUPS = 10
NUM_RUNS = 100

for _ in range(NUM_WARMUPS):
m(*example_inputs)
print("before quantization:", benchmark_model(m, NUM_RUNS, example_inputs))

compiled = torch.compile(m, mode="max-autotune")
for _ in range(NUM_WARMUPS):
compiled(*example_inputs)
print("after compile:", benchmark_model(compiled, NUM_RUNS, example_inputs))

# convert weights to quantized weights
m.linear.weight = torch.nn.Parameter(
to_my_dtype(m.linear.weight), requires_grad=False
)

for _ in range(NUM_WARMUPS):
m(*example_inputs)

print("after quantization:", benchmark_model(m, NUM_RUNS, example_inputs))

m = torch.compile(m, mode="max-autotune")

for _ in range(NUM_WARMUPS):
m(*example_inputs)

# NOTE: currently there is no speedup because we just dequantize the weight in the _quantized_linear op
# we plan to add custom op example in the future and that will help us to get speedup
print("after quantization and compile:", benchmark_model(m, NUM_RUNS, example_inputs))

if __name__ == "__main__":
test()
main()
6 changes: 3 additions & 3 deletions tutorials/developer_api_guide/my_trainable_tensor_subclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def from_float(
return _ToMyTrainableDTypeTensor.apply(input_float, layout_type)

class _ToMyTrainableDTypeTensor(torch.autograd.Function):
"""
"""
Differentiable constructor for `MyTrainableDTypeTensor`.
"""

Expand Down Expand Up @@ -163,8 +163,8 @@ def _(func, types, args, kwargs):
########

class M(torch.nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
def __init__(self) -> None:
super().__init__()
self.linear = torch.nn.Linear(512, 1024, bias=False)

def forward(self, x: torch.Tensor) -> torch.Tensor:
Expand Down
Loading
Loading