Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make torch TP composable with torchao #2436

Merged
merged 1 commit into from
Dec 11, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 66 additions & 5 deletions python/sglang/srt/model_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,69 @@
Common utilities for torch model parallelism.
"""

from typing import Optional
from typing import Optional, Sequence

import torch
import torch.nn as nn
from torch.distributed.device_mesh import DeviceMesh

try:
from torch.distributed.tensor import DTensor, Shard
import torch.distributed.tensor as dt
except ImportError:
# torch 2.4 or older
from torch.distributed._tensor import DTensor, Shard
import torch.distributed._tensor as dt

from torch.distributed._functional_collectives import AsyncCollectiveTensor
from torch.distributed.tensor.parallel import (
ColwiseParallel,
RowwiseParallel,
parallelize_module,
)


def _shard_tensor(
full_tensor: torch.Tensor,
device_mesh: DeviceMesh,
placements: Sequence[dt.Shard],
) -> "dt.DTensor":
"""
Locally shards a full tensor based on indicated sharding arrangement, and
returns a DTensor containing the local shard.

.. warning:: This is a private API that is subject to change. It skips the
communication otherwise required by `distribute_tensor`. It is only
applicable to cases where all ranks have the same `full_tensor`. For
example, in distributed inference all ranks load from the same
checkpoint. This API will not check for data equality between ranks, it
is thus user's responsibility to ensure the `full_tensor` is the same
across ranks.

Args:
full_tensor (torch.Tensor): the full tensor to be sharded.
device_mesh (:class:`DeviceMesh`): DeviceMesh to place the
DTensor. Must have same dimension as the number of placements.
placements (Sequence[:class:`Shard`]): the placements that
describes how to place the local tensor on DeviceMesh.

Returns:
A :class:`DTensor` object with the shard as its local tensor.

Examples:
>>> # xdoctest: +SKIP("need world_size and rank")
>>> device_mesh = dist.init_device_mesh("cuda", (world_size,))
>>> full_tensor = torch.arange(world_size, device=f"cuda:{rank}")
>>> dtensor = _shard_tensor(full_tensor, device_mesh, [Shard(1)])
"""
shape, offset = dt._utils.compute_local_shape_and_global_offset(
full_tensor.shape, device_mesh, placements
)
slices = [
slice(cur_offset, cur_offset + cur_shape)
for cur_shape, cur_offset in zip(shape, offset)
]
local_tensor = full_tensor[slices]
return dt.DTensor.from_local(local_tensor, device_mesh, placements)


class ColwiseParallelSharded(ColwiseParallel):
"""
A version of ColwiseParallel where the local weight has been already
Expand All @@ -34,7 +78,7 @@ def _partition_linear_fn(self, name, module, device_mesh):
# means Colwise as Linear is input * weight^T + bias, where
# weight would become Shard(1)
for name, param in module.named_parameters():
dtensor = DTensor.from_local(param, device_mesh, [Shard(0)])
dtensor = dt.DTensor.from_local(param, device_mesh, [dt.Shard(0)])
dist_param = torch.nn.Parameter(dtensor, requires_grad=False)
module.register_parameter(name, dist_param)

Expand All @@ -47,6 +91,23 @@ class RowwiseParallelMaybeWait(RowwiseParallel):
AsyncCollectiveTensor and custom ops, such as `class RMSNorm(CustomOp)`.
"""

def _partition_linear_fn(self, name, module, device_mesh):
# Rowwise shard weight to Shard(1), bias to Replicate(), weight be Shard(1)
# means Rowwise as nn.Linear is input * weight^T + bias, where
# weight would become Shard(0)
module.register_parameter(
"weight",
nn.Parameter(_shard_tensor(module.weight, device_mesh, [dt.Shard(1)])),
)
if getattr(module, "bias", None) is not None:
# The Linear module has bias
module.register_parameter(
"bias",
nn.Parameter(
dt.distribute_tensor(module.bias, device_mesh, [dt.Replicate()])
),
)

@staticmethod
def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
outputs = super(
Expand Down
Loading