diff --git a/tutorials/developer_api_guide/my_dtype_tensor_subclass.py b/tutorials/developer_api_guide/my_dtype_tensor_subclass.py index 6baa6dfc6..03b0d3159 100644 --- a/tutorials/developer_api_guide/my_dtype_tensor_subclass.py +++ b/tutorials/developer_api_guide/my_dtype_tensor_subclass.py @@ -15,7 +15,12 @@ import torch from torch.utils._python_dispatch import return_and_correct_aliasing -from torchao.quantization.quant_primitives import choose_qparams_affine, MappingType +from torchao.quantization.quant_primitives import ( + choose_qparams_affine, + MappingType, + quantize_affine, + dequantize_affine, +) from torchao.dtypes.utils import ( LayoutType, PlainLayoutType, @@ -24,6 +29,32 @@ aten = torch.ops.aten +# TODO: move to torchao/utils.py +def fill_defaults(args, n, defaults_tail): + """ + __torch_dispatch__ doesn't guarantee the number of arguments you are + passed (e.g., defaulted arguments are not passed); but usually it is + convenient to pad out the arguments list with defaults. This function + helps you do that. + Args: + args: the list of positional arguments passed to __torch_dispatch__ + n: the number of arguments you are expecting to get + defaults_tail: default values for the arguments, starting from the + end of the list + Example: + >>> fill_defaults([1, 2, 3], 5, [3, 4, 5]) + [1, 2, 3, 4, 5] + >>> fill_defaults([1, 2, 3], 5, [None, None, None]) + [1, 2, 3, None, None]] + """ + if n - len(defaults_tail) > len(args): + raise RuntimeError("not enough defaults to fill arguments") + r = list(args) + for i in range(len(args), n): + r.append(defaults_tail[i - n + len(defaults_tail)]) + return r + + ############################### # Base Layout Tensor Subclass # ############################### @@ -140,10 +171,10 @@ def from_float( layout_type: LayoutType = PlainLayoutType(), ): mapping_type = MappingType.SYMMETRIC - block_size = input_float.shape + block_size = (1, input_float.shape[-1]) dtype = torch.int16 - scale, _ = choose_qparams_affine(input_float, mapping_type, block_size, dtype) - int_data = (input_float / scale).to(torch.int8) + scale, zero_point = choose_qparams_affine(input_float, mapping_type, block_size, dtype) + int_data = quantize_affine(input_float, block_size, scale, zero_point, dtype) layout_tensor_ctr = get_layout_tensor_constructor(type(layout_type)) layout_tensor = layout_tensor_ctr(int_data, scale, layout_type) return cls(layout_tensor, input_float.shape) @@ -160,7 +191,14 @@ def dequantize(self, output_dtype=None): if output_dtype is None: output_dtype = torch.get_default_dtype() int_data, scale = self.layout_tensor.get_plain() - return int_data.to(output_dtype) * scale + transposed = False + block_size = (1, int_data.shape[-1]) + if hasattr(self.layout_tensor, "transposed") and self.layout_tensor.transposed: + transposed = True + res = dequantize_affine(int_data, block_size, scale, None, int_data.dtype, output_dtype=output_dtype) + if transposed: + res = res.t() + return res def __repr__(self): return ( @@ -203,6 +241,7 @@ def __new__( cls, int_data: torch.Tensor, scale: torch.Tensor, + transposed: bool, layout_type: LayoutType, ): kwargs = {} @@ -219,22 +258,24 @@ def __init__( self, int_data: torch.Tensor, scale: torch.Tensor, + transposed: bool, layout_type: LayoutType, ): self.int_data = int_data self.scale = scale + self.transposed = transposed self.layout_type = layout_type def __tensor_flatten__(self): - return ["int_data", "scale"], [self.layout_type] + return ["int_data", "scale"], [self.transposed, self.layout_type] @classmethod def __tensor_unflatten__( cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride ): int_data, scale = tensor_data_dict["int_data"], tensor_data_dict["scale"] - layout_type, = tensor_attributes - return cls(int_data, scale, layout_type) + transposed, layout_type, = tensor_attributes + return cls(int_data, scale, transposed, layout_type) @classmethod def from_plain( @@ -247,12 +288,13 @@ def from_plain( extra metadata for packing etc. """ assert isinstance(layout_type, PlainLayoutType) - return cls(int_data, scale, layout_type) + return cls(int_data, scale, False, layout_type) def _apply_fn_to_data(self, fn): return self.__class__( fn(self.int_data), fn(self.scale), + self.transposed, self.layout_type, ) @@ -265,8 +307,36 @@ def __torch_dispatch__(cls, func, types, args, kwargs): func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) ) + # Tensor parallel support START + elif func in [aten._to_copy.default, aten.clone.default]: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) + ) + elif func is aten.split.Tensor: + int_data_list = func(args[0].int_data, *args[1:], **kwargs) + scale_list = func(args[0].scale, *args[1:], **kwargs) + out = [PlainMyDTypeLayout(int_data, scale, args[0].transposed, args[0].layout_type) for int_data, scale in zip(int_data_list, scale_list)] + return out + elif func is aten.empty_like.default: + int_data_empty_like = func(args[0].int_data, *args[1:], **kwargs) + return PlainMyDTypeLayout(int_data_empty_like, args[0].scale, args[0].transposed, args[0].layout_type) + elif func is aten.slice.Tensor: + self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) + if dim == 0: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(lambda x: aten.slice.Tensor(x, dim, start, end, step)) + ) + elif dim == 1: + return PlainMyDTypeLayout(aten.slice.Tensor(self.int_data, dim, start, end, step), self.scale.view(-1, 1), self.transposed, self.layout_type) + else: + raise NotImplementedError(f"PlainMyDTypeLayout dispatch: attempting to run {func}, with dim={dim}, that is not supported") + elif func is aten.t.default: + return return_and_correct_aliasing(func, args, kwargs, PlainMyDTypeLayout(args[0].int_data, args[0].scale, not args[0].transposed, args[0].layout_type)) + + # Tensor parallel support END + raise NotImplementedError( - f"MyDTypeLayout dispatch: attempting to run {func}, this is not supported" + f"PlainMyDTypeLayout dispatch: attempting to run {func}, this is not supported" ) ##################################################### @@ -315,15 +385,6 @@ def _(func, types, args, kwargs): func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) ) - -class M(torch.nn.Module): - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) - self.linear = torch.nn.Linear(1024, 1024) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.linear(x) - ##################### # Factory functions # ##################### @@ -333,42 +394,49 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ######## # Test # ######## - -def test(): +def main(): from torchao.utils import benchmark_model - + + class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = torch.nn.Linear(1024, 128) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.linear(x) + m = M() - example_inputs = (100 * torch.randn(1024, 1024),) + example_inputs = (100 * torch.randn(512, 1024),) NUM_WARMUPS = 10 NUM_RUNS = 100 - + for _ in range(NUM_WARMUPS): m(*example_inputs) print("before quantization:", benchmark_model(m, NUM_RUNS, example_inputs)) - + compiled = torch.compile(m, mode="max-autotune") for _ in range(NUM_WARMUPS): compiled(*example_inputs) print("after compile:", benchmark_model(compiled, NUM_RUNS, example_inputs)) - + # convert weights to quantized weights m.linear.weight = torch.nn.Parameter( to_my_dtype(m.linear.weight), requires_grad=False ) - + for _ in range(NUM_WARMUPS): m(*example_inputs) - + print("after quantization:", benchmark_model(m, NUM_RUNS, example_inputs)) - + m = torch.compile(m, mode="max-autotune") - + for _ in range(NUM_WARMUPS): m(*example_inputs) - + # NOTE: currently there is no speedup because we just dequantize the weight in the _quantized_linear op # we plan to add custom op example in the future and that will help us to get speedup print("after quantization and compile:", benchmark_model(m, NUM_RUNS, example_inputs)) if __name__ == "__main__": - test() + main() diff --git a/tutorials/developer_api_guide/my_trainable_tensor_subclass.py b/tutorials/developer_api_guide/my_trainable_tensor_subclass.py index a3fc0af8d..b702ac4f9 100644 --- a/tutorials/developer_api_guide/my_trainable_tensor_subclass.py +++ b/tutorials/developer_api_guide/my_trainable_tensor_subclass.py @@ -61,7 +61,7 @@ def from_float( return _ToMyTrainableDTypeTensor.apply(input_float, layout_type) class _ToMyTrainableDTypeTensor(torch.autograd.Function): - """ + """ Differentiable constructor for `MyTrainableDTypeTensor`. """ @@ -163,8 +163,8 @@ def _(func, types, args, kwargs): ######## class M(torch.nn.Module): - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) + def __init__(self) -> None: + super().__init__() self.linear = torch.nn.Linear(512, 1024, bias=False) def forward(self, x: torch.Tensor) -> torch.Tensor: diff --git a/tutorials/developer_api_guide/tensor_parallel.py b/tutorials/developer_api_guide/tensor_parallel.py new file mode 100644 index 000000000..a94d84fe0 --- /dev/null +++ b/tutorials/developer_api_guide/tensor_parallel.py @@ -0,0 +1,191 @@ +import os +import torch +import torch.distributed as dist +from torch.distributed import DeviceMesh +from torch.distributed._tensor import DTensor, Replicate, Shard +from torch.utils._python_dispatch import return_and_correct_aliasing +from my_dtype_tensor_subclass import MyDTypeTensor, fill_defaults + +# a tensor subclass that supports tensor parallelism with DTensor +class MyDTypeTensorTP(MyDTypeTensor): + pass + +implements = MyDTypeTensorTP.implements + +aten = torch.ops.aten + +@implements([aten._to_copy.default, aten.clone.default]) +def _(func, types, args, kwargs): + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) + ) + +@implements([aten.split.Tensor]) +def _(func, types, args, kwargs): + layout_tensor_list = func(args[0].layout_tensor, *args[1:], **kwargs) + out = [MyDTypeTensorTP(layout_tensor, layout_tensor.shape) for layout_tensor in layout_tensor_list] + return out + +@implements([aten.empty_like.default]) +def _(func, types, args, kwargs): + empty_like_layout_tensor = func(args[0].layout_tensor, *args[1:], **kwargs) + return MyDTypeTensorTP(empty_like_layout_tensor, empty_like_layout_tensor.shape) + +@implements(aten.slice.Tensor) +def _(func, types, args, kwargs): + self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) + assert step == 1 + if end >= self.shape[dim]: + end = self.shape[dim] + shape = list(self.shape) + shape[dim] = end - start + return self.__class__(aten.slice.Tensor(self.layout_tensor, dim, start, end, step), shape, self.dtype) + +# this is needed for DTensor.from_local() and for flattening tensor +@implements(aten.view.default) +def _(func, types, args, kwargs): + x, shape = args + + if tuple(x.shape) == tuple(shape): + return x.__class__(x.layout_tensor, x.shape, x.dtype) + + if len(shape) == 1 and shape[0] == -1: + return x.__class__(x.layout_tensor, (x.numel(),), x.dtype) + + raise ValueError(f"{x.__class__.__name__} only supports .view() with same shape or shape=[-1]") + +@implements(aten.t.default) +def _(func, types, args, kwargs): + tensor = args[0] + shape = tensor.shape[::-1] + new = tensor.__class__(tensor.layout_tensor.t(), shape, tensor.dtype) + return return_and_correct_aliasing(func, args, kwargs, new) + +@implements(aten.addmm.default) +def _(func, types, args, kwargs): + input_tensor, weight_tensor, bias = ( + args[1], + args[2], + args[0], + ) + weight_tensor = weight_tensor.dequantize() + return aten.addmm(input_tensor, weight_tensor, bias) + +@implements(aten.mm.default) +def _(func, types, args, kwargs): + input_tensor, weight_tensor, bias = ( + args[0], + args[1], + None + ) + weight_tensor = weight_tensor.dequantize() + return aten.mm(input_tensor, weight_tensor) + + +class M(torch.nn.Module): + def __init__(self, in_features, out_features, **kwargs) -> None: + super().__init__(**kwargs) + self.linear = torch.nn.Linear(in_features, out_features, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.linear(x) + +to_my_dtype_tp = MyDTypeTensorTP.from_float + +def quantize(m: torch.nn.Module) -> torch.nn.Module: + """ + Quantize the model + """ + m.linear.weight = torch.nn.Parameter( + to_my_dtype_tp(m.linear.weight), requires_grad=False + ) + return m + +def colwise_shard(m: torch.nn.Module, mesh: DeviceMesh) -> torch.nn.Module: + """ + Shard linear layer of the model in column-wise fashion + """ + # Column-wise is wrt to A^T, so for A it is row-wise. + # Number of rows per rank + orig_weight = m.linear.weight + n_local_rows = orig_weight.size(0) // mesh.size() + rank = mesh.get_local_rank() + local_shard = orig_weight[rank * n_local_rows : (rank + 1) * n_local_rows, :] + # Construct DTensor from local shard + dtensor = DTensor.from_local(local_shard, mesh, [Shard(0)]) + # Replace parameter in module + m.linear.weight = torch.nn.Parameter( + dtensor, requires_grad=False + ) + return m + +def rowwise_shard(m: torch.nn.Module, mesh: DeviceMesh) -> torch.nn.Module: + """ + Shard linear layer of the model in row-wise fashion + """ + # Row-wise is wrt to A^T, so for A it is column-wise. + # Number of rows per rank + orig_weight = m.linear.weight + n_local_cols = orig_weight.size(1) // mesh.size() + rank = mesh.get_local_rank() + local_shard = orig_weight[:, rank * n_local_cols : (rank + 1) * n_local_cols] + # Construct DTensor from local shard + dtensor = DTensor.from_local(local_shard, mesh, [Shard(1)]) + # Replace parameter in module + m.linear.weight = torch.nn.Parameter( + dtensor, requires_grad=False + ) + return m + +######## +# Test # +######## +def main(): + # To make sure different ranks create the same module + torch.manual_seed(5) + + # Get rank and device + world_size = int(os.environ["WORLD_SIZE"]) + rank = int(os.environ["RANK"]) + device = torch.device(f"cuda:{rank % torch.cuda.device_count()}") + + # Original model + proj_up = M(1024, 2048).to(device) + proj_dn = M(2048, 1024).to(device) + example_input = 100 * torch.randn(128, 1024, device=device) + y = proj_dn(proj_up(example_input)) + + # Quantize the model + up_quant = quantize(proj_up) + dn_quant = quantize(proj_dn) + y_q = dn_quant(up_quant(example_input)) + print("Quantization works!") + + # Create a device mesh + dist.init_process_group(backend="nccl") + mesh = dist.init_device_mesh("cuda", (world_size,)) + + # Shard the models + up_dist = colwise_shard(up_quant, mesh) + dn_dist = rowwise_shard(dn_quant, mesh) + + # We need to turn inputs into DTensor form as well -- just a format change + input_dtensor = DTensor.from_local( + example_input, mesh, [Replicate()] + ) + + y_d = dn_dist(up_dist(input_dtensor)) + print("Distributed result:", y_d) + print("Distributed works!") + + up_compiled = torch.compile(up_dist) + y_up = up_compiled(input_dtensor) + dn_compiled = torch.compile(dn_dist) + y_dn = dn_compiled(y_up) + print("compiled result:", y_dn) + print("torch.compile works!") + + dist.destroy_process_group() + +if __name__ == "__main__": + main()