diff --git a/python/sglang/srt/model_parallel.py b/python/sglang/srt/model_parallel.py index 778347b8ef3..53c8b622e4d 100644 --- a/python/sglang/srt/model_parallel.py +++ b/python/sglang/srt/model_parallel.py @@ -2,18 +2,18 @@ Common utilities for torch model parallelism. """ -from typing import Optional +from typing import Optional, Sequence import torch +import torch.nn as nn from torch.distributed.device_mesh import DeviceMesh try: - from torch.distributed.tensor import DTensor, Shard + import torch.distributed.tensor as dt except ImportError: # torch 2.4 or older - from torch.distributed._tensor import DTensor, Shard + import torch.distributed._tensor as dt -from torch.distributed._functional_collectives import AsyncCollectiveTensor from torch.distributed.tensor.parallel import ( ColwiseParallel, RowwiseParallel, @@ -21,6 +21,50 @@ ) +def _shard_tensor( + full_tensor: torch.Tensor, + device_mesh: DeviceMesh, + placements: Sequence[dt.Shard], +) -> "dt.DTensor": + """ + Locally shards a full tensor based on indicated sharding arrangement, and + returns a DTensor containing the local shard. + + .. warning:: This is a private API that is subject to change. It skips the + communication otherwise required by `distribute_tensor`. It is only + applicable to cases where all ranks have the same `full_tensor`. For + example, in distributed inference all ranks load from the same + checkpoint. This API will not check for data equality between ranks, it + is thus user's responsibility to ensure the `full_tensor` is the same + across ranks. + + Args: + full_tensor (torch.Tensor): the full tensor to be sharded. + device_mesh (:class:`DeviceMesh`): DeviceMesh to place the + DTensor. Must have same dimension as the number of placements. + placements (Sequence[:class:`Shard`]): the placements that + describes how to place the local tensor on DeviceMesh. + + Returns: + A :class:`DTensor` object with the shard as its local tensor. + + Examples: + >>> # xdoctest: +SKIP("need world_size and rank") + >>> device_mesh = dist.init_device_mesh("cuda", (world_size,)) + >>> full_tensor = torch.arange(world_size, device=f"cuda:{rank}") + >>> dtensor = _shard_tensor(full_tensor, device_mesh, [Shard(1)]) + """ + shape, offset = dt._utils.compute_local_shape_and_global_offset( + full_tensor.shape, device_mesh, placements + ) + slices = [ + slice(cur_offset, cur_offset + cur_shape) + for cur_shape, cur_offset in zip(shape, offset) + ] + local_tensor = full_tensor[slices] + return dt.DTensor.from_local(local_tensor, device_mesh, placements) + + class ColwiseParallelSharded(ColwiseParallel): """ A version of ColwiseParallel where the local weight has been already @@ -34,7 +78,7 @@ def _partition_linear_fn(self, name, module, device_mesh): # means Colwise as Linear is input * weight^T + bias, where # weight would become Shard(1) for name, param in module.named_parameters(): - dtensor = DTensor.from_local(param, device_mesh, [Shard(0)]) + dtensor = dt.DTensor.from_local(param, device_mesh, [dt.Shard(0)]) dist_param = torch.nn.Parameter(dtensor, requires_grad=False) module.register_parameter(name, dist_param) @@ -47,6 +91,23 @@ class RowwiseParallelMaybeWait(RowwiseParallel): AsyncCollectiveTensor and custom ops, such as `class RMSNorm(CustomOp)`. """ + def _partition_linear_fn(self, name, module, device_mesh): + # Rowwise shard weight to Shard(1), bias to Replicate(), weight be Shard(1) + # means Rowwise as nn.Linear is input * weight^T + bias, where + # weight would become Shard(0) + module.register_parameter( + "weight", + nn.Parameter(_shard_tensor(module.weight, device_mesh, [dt.Shard(1)])), + ) + if getattr(module, "bias", None) is not None: + # The Linear module has bias + module.register_parameter( + "bias", + nn.Parameter( + dt.distribute_tensor(module.bias, device_mesh, [dt.Replicate()]) + ), + ) + @staticmethod def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh): outputs = super(