Skip to content

Commit

Permalink
reorganize sections
Browse files Browse the repository at this point in the history
  • Loading branch information
wanchaol committed Aug 17, 2022
1 parent 244c1e3 commit 34e7f85
Showing 1 changed file with 26 additions and 27 deletions.
53 changes: 26 additions & 27 deletions RFC-0027-PyTorch-Distributed-Tensor.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ Today there are mainly three ways to scale up distributed training: Data Paralle

An ideal scenario is that users could just build their models like in a single node/device, without worrying about how to do distributed training in a cluster, and our solutions could help them run distributed training in an efficient manner. For example, researchers just need to build their big transformer model, and PyTorch Distributed automatically figures out how to split the model and run pipeline parallel across different nodes, how to run data parallel and tensor parallel within each node. In order to achieve this, we need to translate a single device model into a distributed version and train/serve it with our runtime. To represent the distributed version of the model and facilitate translations, we need some common abstractions to represent distribution and computation.

Inspired by [GSPMD](https://arxiv.org/pdf/2105.04663.pdf), [Oneflow](https://arxiv.org/pdf/2110.15032.pdf) and [TF’s DTensor](https://www.tensorflow.org/guide/dtensor_overview), we introduce a DistributedTensor concept to represent generic data distributions across hosts. DistributedTensor is the next evolution of ShardedTensor and provides basic abstractions to distribute storage and compute. It serves as one of the basic building blocks for distributed program translations and describes the layout of a distributed training program. With the DistributedTensor abstraction, we can build different parallelism strategies in a easy way, including generic tensor parallelism, or DDP/FSDP parallelism patterns.
There're many recent works that working on tensor level parallelism to provide common abstractions, see the `Related Works` in the last section for more details. Inspired by [GSPMD](https://arxiv.org/pdf/2105.04663.pdf), [Oneflow](https://arxiv.org/pdf/2110.15032.pdf) and [TF’s DTensor](https://www.tensorflow.org/guide/dtensor_overview), we introduce a DistributedTensor concept to represent generic data distributions across hosts. DistributedTensor is the next evolution of ShardedTensor and provides basic abstractions to distribute storage and compute. It serves as one of the basic building blocks for distributed program translations and describes the layout of a distributed training program. With the DistributedTensor abstraction, we can build different parallelism strategies in a easy way, including generic tensor parallelism, or DDP/FSDP parallelism patterns.

## Value Propsition

Expand All @@ -29,31 +29,6 @@ The primary value of DistributedTensor includes:
- DistributedTensor could be used as a basic building block of a compiler based solution to do distributed training
- DistributedTensor could act as a SPMD programming model entry point for ML System Engineers, providing good UX to mix up different types of parallelism.


## Related Works

This work is mainly inspired by [GSPMD](https://arxiv.org/pdf/2105.04663.pdf), [Oneflow](https://arxiv.org/pdf/2110.15032.pdf) and [TF’s DTensor](https://www.tensorflow.org/guide/dtensor_overview). All of these three works use a single “distributed tensor” concept for both replication and sharding, and the solutions could enable users to build up their distributed training program in a uniform SPMD programming model. Specifically:

GSPMD: 
- GSPMD is now the fundamental component of JAX/TensorFlow distributed training and enables various optimizations with the XLA compiler to allow users to train their models efficiently in a large scale setting. 
- Fundamentally, GSPMD have three types of sharding strategies within a tensor: “tiled”, “replicated”, “partially tiled” to represent sharding and replication.
- At the core of GSPMD Partitioner, it utilizes the XLA compiler to do advanced optimizations, i.e. sharding propagation and compiler based fusion. 
- XLA mark_sharding API: PyTorch XLA’s [mark_sharding](https://github.com/pytorch/xla/pull/3476) API uses [XLAShardedTensor](https://github.com/pytorch/xla/issues/3871) abstraction (i.e. sharding specs) in PyTorch/XLA. Under the hood XLAShardedTensor is utilizing the GPSMD partitioner to enable SPMD style training on TPU.

OneFlow GlobalTensor: 

- OneFlow is building up their own solution of the “GlobalTensor” concept, which is a variant form of GSPMD sharding, allowing users to explore different parallel strategies with GlobalTensor. 
- OneFlow also has three types of tensor, but they are slightly different from GSPMD: “split”, “broadcast”, and “partial sum”. They don’t use partially tiled and instead have a concept of partial sum to partition the values.

TensorFlow DTensor:
- [DTensor Concepts](https://www.tensorflow.org/guide/dtensor_overview) is an extension of TensorFlow synchronous distributed training. its sharding API, supported features and its compilation passes with MLIR.
- DTensor also allows sharding and replication on an n-d mesh like device network.
- DTensor implements MLIR passes to do propagation and operator implementations.

There are also several cutting edge research fields that embeds tensor sharding as part of the system, i.e. [Megatron-LM](https://arxiv.org/pdf/1909.08053.pdf) for tensor parallelism on Transformer based models. [DeepSpeed](https://github.com/microsoft/DeepSpeed) for training large scale models with different optimization techniques on top of tensor sharding.

In PyTorch, we have existing [ShardedTensor](https://docs.google.com/document/d/1WEjwKYv022rc1lSrYcNWh3Xjx9fJC7zsrTKTg0wbPj0/edit?usp=sharing) work in the prototype stage, which introduces basic PyTorch sharding primitives as our Tensor Parallelism solution. But ShardedTensor only has tensor sharding support, which makes it hard to be used by users to describe other data distributions strategies like replication or replication + sharding. As a distributed system developer who wants to explore more parallelism patterns, it’s crucial to have a basic building block that describes the data distribution in a uniform way.

## PyTorch Distributed Tensor

### DistributedTensor API
Expand Down Expand Up @@ -125,7 +100,7 @@ def distribute_module(
    '''
```

##### High level API examples:
#### High level API examples:

```python
def MyModule(nn.Module):
Expand Down Expand Up @@ -154,3 +129,27 @@ sharded_module = distribute_module(model, device_mesh, partition_fn=shard_fc)
DistributedTensor provides efficient solutions for cases like Tensor Parallelism. But when using the replication more often with DistributedTensor, it might become observably slow compared to our existing solutions like DDP/FSDP. This is mainly because replication in eager mode is part of data parallel, and existing solutions like DDP/FSDP could have the global view of entire model architecture, thus could specifically optimize for data parallel, i.e. using collective fusion and computation overlap, etc. DistributedTensor itself is only a Tensor-like object and only knows its local computation operation, it does not know the subsequent operations that happened afterwards.

In order to recover the performance when using DistributedTensor directly to do training (i.e. Users might want to use DistributedTensor to do DDP-like training), DistributedTensor also needs the global view to do things like communication optimization. We are exploring a compiler based solution accompanied with DistributedTensor so that we could run optimizations on top of it, which will be shared later.

## Related Works

This work is mainly inspired by [GSPMD](https://arxiv.org/pdf/2105.04663.pdf), [Oneflow](https://arxiv.org/pdf/2110.15032.pdf) and [TF’s DTensor](https://www.tensorflow.org/guide/dtensor_overview). All of these three works use a single “distributed tensor” concept for both replication and sharding, and the solutions could enable users to build up their distributed training program in a uniform SPMD programming model. Specifically:

GSPMD: 
- GSPMD is now the fundamental component of JAX/TensorFlow distributed training and enables various optimizations with the XLA compiler to allow users to train their models efficiently in a large scale setting. 
- Fundamentally, GSPMD have three types of sharding strategies within a tensor: “tiled”, “replicated”, “partially tiled” to represent sharding and replication.
- At the core of GSPMD Partitioner, it utilizes the XLA compiler to do advanced optimizations, i.e. sharding propagation and compiler based fusion. 
- XLA mark_sharding API: PyTorch XLA’s [mark_sharding](https://github.com/pytorch/xla/pull/3476) API uses [XLAShardedTensor](https://github.com/pytorch/xla/issues/3871) abstraction (i.e. sharding specs) in PyTorch/XLA. Under the hood XLAShardedTensor is utilizing the GPSMD partitioner to enable SPMD style training on TPU.

OneFlow GlobalTensor: 

- OneFlow is building up their own solution of the “GlobalTensor” concept, which is a variant form of GSPMD sharding, allowing users to explore different parallel strategies with GlobalTensor. 
- OneFlow also has three types of tensor, but they are slightly different from GSPMD: “split”, “broadcast”, and “partial sum”. They don’t use partially tiled and instead have a concept of partial sum to partition the values.

TensorFlow DTensor:
- [DTensor Concepts](https://www.tensorflow.org/guide/dtensor_overview) is an extension of TensorFlow synchronous distributed training. its sharding API, supported features and its compilation passes with MLIR.
- DTensor also allows sharding and replication on an n-d mesh like device network.
- DTensor implements MLIR passes to do propagation and operator implementations.

There are also several cutting edge research fields that embeds tensor sharding as part of the system, i.e. [Megatron-LM](https://arxiv.org/pdf/1909.08053.pdf) for tensor parallelism on Transformer based models. [DeepSpeed](https://github.com/microsoft/DeepSpeed) for training large scale models with different optimization techniques on top of tensor sharding.

In PyTorch, we have existing [ShardedTensor](https://docs.google.com/document/d/1WEjwKYv022rc1lSrYcNWh3Xjx9fJC7zsrTKTg0wbPj0/edit?usp=sharing) work in the prototype stage, which introduces basic PyTorch sharding primitives as our Tensor Parallelism solution. But ShardedTensor only has tensor sharding support, which makes it hard to be used by users to describe other data distributions strategies like replication or replication + sharding. As a distributed system developer who wants to explore more parallelism patterns, it’s crucial to have a basic building block that describes the data distribution in a uniform way.

0 comments on commit 34e7f85

Please sign in to comment.