Skip to content

Commit

Permalink
Make torch TP composable with torchao (#2436)
Browse files Browse the repository at this point in the history
  • Loading branch information
kwen2501 authored Dec 11, 2024
1 parent 0fb88aa commit ece7249
Showing 1 changed file with 66 additions and 5 deletions.
71 changes: 66 additions & 5 deletions python/sglang/srt/model_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,69 @@
Common utilities for torch model parallelism.
"""

from typing import Optional
from typing import Optional, Sequence

import torch
import torch.nn as nn
from torch.distributed.device_mesh import DeviceMesh

try:
from torch.distributed.tensor import DTensor, Shard
import torch.distributed.tensor as dt
except ImportError:
# torch 2.4 or older
from torch.distributed._tensor import DTensor, Shard
import torch.distributed._tensor as dt

from torch.distributed._functional_collectives import AsyncCollectiveTensor
from torch.distributed.tensor.parallel import (
ColwiseParallel,
RowwiseParallel,
parallelize_module,
)


def _shard_tensor(
full_tensor: torch.Tensor,
device_mesh: DeviceMesh,
placements: Sequence[dt.Shard],
) -> "dt.DTensor":
"""
Locally shards a full tensor based on indicated sharding arrangement, and
returns a DTensor containing the local shard.
.. warning:: This is a private API that is subject to change. It skips the
communication otherwise required by `distribute_tensor`. It is only
applicable to cases where all ranks have the same `full_tensor`. For
example, in distributed inference all ranks load from the same
checkpoint. This API will not check for data equality between ranks, it
is thus user's responsibility to ensure the `full_tensor` is the same
across ranks.
Args:
full_tensor (torch.Tensor): the full tensor to be sharded.
device_mesh (:class:`DeviceMesh`): DeviceMesh to place the
DTensor. Must have same dimension as the number of placements.
placements (Sequence[:class:`Shard`]): the placements that
describes how to place the local tensor on DeviceMesh.
Returns:
A :class:`DTensor` object with the shard as its local tensor.
Examples:
>>> # xdoctest: +SKIP("need world_size and rank")
>>> device_mesh = dist.init_device_mesh("cuda", (world_size,))
>>> full_tensor = torch.arange(world_size, device=f"cuda:{rank}")
>>> dtensor = _shard_tensor(full_tensor, device_mesh, [Shard(1)])
"""
shape, offset = dt._utils.compute_local_shape_and_global_offset(
full_tensor.shape, device_mesh, placements
)
slices = [
slice(cur_offset, cur_offset + cur_shape)
for cur_shape, cur_offset in zip(shape, offset)
]
local_tensor = full_tensor[slices]
return dt.DTensor.from_local(local_tensor, device_mesh, placements)


class ColwiseParallelSharded(ColwiseParallel):
"""
A version of ColwiseParallel where the local weight has been already
Expand All @@ -34,7 +78,7 @@ def _partition_linear_fn(self, name, module, device_mesh):
# means Colwise as Linear is input * weight^T + bias, where
# weight would become Shard(1)
for name, param in module.named_parameters():
dtensor = DTensor.from_local(param, device_mesh, [Shard(0)])
dtensor = dt.DTensor.from_local(param, device_mesh, [dt.Shard(0)])
dist_param = torch.nn.Parameter(dtensor, requires_grad=False)
module.register_parameter(name, dist_param)

Expand All @@ -47,6 +91,23 @@ class RowwiseParallelMaybeWait(RowwiseParallel):
AsyncCollectiveTensor and custom ops, such as `class RMSNorm(CustomOp)`.
"""

def _partition_linear_fn(self, name, module, device_mesh):
# Rowwise shard weight to Shard(1), bias to Replicate(), weight be Shard(1)
# means Rowwise as nn.Linear is input * weight^T + bias, where
# weight would become Shard(0)
module.register_parameter(
"weight",
nn.Parameter(_shard_tensor(module.weight, device_mesh, [dt.Shard(1)])),
)
if getattr(module, "bias", None) is not None:
# The Linear module has bias
module.register_parameter(
"bias",
nn.Parameter(
dt.distribute_tensor(module.bias, device_mesh, [dt.Replicate()])
),
)

@staticmethod
def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
outputs = super(
Expand Down

0 comments on commit ece7249

Please sign in to comment.