From 0ac25af2e6feaffa83fa336032750b04058d2acc Mon Sep 17 00:00:00 2001 From: James Reed Date: Fri, 19 Nov 2021 00:53:47 +0000 Subject: [PATCH 1/9] RFC-0020/0021 RFCs for Pipeline Parallelism --- RFC-0020-Distributed-Pipeline-Parallelism.md | 209 ++++++ ...Distributed-Pipeline-Parallel-Technical.md | 620 ++++++++++++++++++ ...el-Partitioning-in-Pipeline-Parallelism.md | 255 +++++++ 3 files changed, 1084 insertions(+) create mode 100644 RFC-0020-Distributed-Pipeline-Parallelism.md create mode 100644 RFC-0021-Distributed-Pipeline-Parallel-Technical.md create mode 100644 RFC-0022-Model-Partitioning-in-Pipeline-Parallelism.md diff --git a/RFC-0020-Distributed-Pipeline-Parallelism.md b/RFC-0020-Distributed-Pipeline-Parallelism.md new file mode 100644 index 0000000..9f3dade --- /dev/null +++ b/RFC-0020-Distributed-Pipeline-Parallelism.md @@ -0,0 +1,209 @@ +# [RFC] Ceci n'est pas pipeline parallelism (Pipeline Parallelism 2021Q4/2022 Plan) + +This is an RFC for the strategic plan for further developing pipeline parallelism in PyTorch. **We invite our users and partners to comment on this plan and the corresponding technical plan** to help us develop the best APIs for PP in PyTorch. + +Goal: *Provide a flexible, composable, and reconfigurable interface for pipeline parallelism in PyTorch that allows scaling a wide variety of PyTorch models on a wide variety of hardware configurations*. + +## Motivation + +* Pipeline parallelism (PP) is used as a lower-communication-volume technique for model parallelism. It is especially applicable when data must be transmitted across comparatively slower interconnects. +* Several research-oriented frameworks exist that implement PP (e.g. fairscale, megatron, deepspeed), but we would like to provide a production-quality implementation and support contract for PP. +* The existing PP implementation in PyTorch (`torch.distributed.pipeline.sync`) only supports intra-host pipeline parallelism across GPUs and does not support techniques like 1F1B scheduling. We can deliver inter-host pipelining and other features. +* Ultimately, we want to use this body of work as a driving force for research in delivering both performance AND usability of parallelism paradigms. We invite developers and researchers to participate in the design and development of this project. + +## Stage 1: Requirements Gathering (2021Q4) + +We have spent a good amount of time this calendar quarter researching the user requirements and systems research directions of pipeline parallelism and will continue to do so going forward. **We invite additional comments to fill in details we have not captured here**, if any. + +### Prior Work + +The research literature[1-12] has a rich body of work. This includes: + +* Synchronous vs. Asynchronous pipeline parallelism. + * Synchronous pipeline parallelism where a mini-batch is split into micro batches and the pipeline is filled and drained, blocking until the mini-batch is completed. This is the typical use case we are designing for + * Asynchronous pipeline parallelism that keeps the pipeline continually occupied. Various clever techniques such as weight stashing and weight prediction have been proposed to address the consistency issues from the "locking" nature of SGD in these cases. These techniques may introduce additional design concerns in a pipeline parallelism API. +* Pipeline scheduling, where the execution order of `forward` or `backward` micro-batches follows a specified policy. The infrastructure for implementing these schedules can be an important consideration for the design on a PP API. + * Fill-drain schedule, where all forward micro-batches are run to completion before all backward micro-batches are run and parameter updates are applied. + * 1F1B schedule, where `backward` micro-batches are triggered by the last pipeline stage and stages ensure that they alternate between running `forward` and `backward` micro-batches at steady-state. This helps to reduce the amount of state stored on each pipeline stage. + * More, including interleaved 1F1B and new research schedules. + +### Key Stakeholders + +This section is meant to capture key users/researchers who would benefit from such a pipeline parallelism API. **We invite additional comments to fill in users/researchers who would benefit from this API and would like to see their requirements satisfied**. + +#### P0: HF Transformers + +HF transformers [wants to](https://github.com/huggingface/transformers/issues/13690) incorporate 3d parallelism including Pipeline Parallelism, however the [current PyTorch implementation](https://github.com/pytorch/pytorch/blob/9f4e004abd8c5d11fc23f4ab705328cb9b4050bb/torch/distributed/pipeline/sync/pipe.py#L220) has limitations that we should address (Px is a priority, with lower x being higher priority. We assigned these priorities based on a) user need and b) implementation time/complexity, but we can adjust them based on user feedback): + +* Frontend limitations: + * **P0**: Cannot pass arbitrary data types between pipeline stages + * **P0**: Unclear composability in 3d parallelism scheme (data, pipeline, model parallel) + * **P1**: User needs to rewrite their model as an `nn.Sequential` instance +* Backend Limitations: + * **P(-1)**: No cross-host support for PT pipeline parallelism API + * **P0**: No support off-the-shelf schedules (1F1B or interleaving) + * **P1**: No support arbitrary programmable schedules +* Non-requirements: + * Composability with ZeRO-2/3 is not required. Theoretically possible, but reportedly will not give any perf gain. +* Success Criteria: + * **to be determined**: Feedback on this would be appreciated + +### Prior Implementations and Proposed Approach + +An analysis of prior implementations and a proposed technical approach for pipeline parallelism can be seen in [[RFC] Distributed Pipeline Parallel Training Technical Approach](https://github.com/pytorch/rfcs/blob/master/RFC-0021-Distributed-Pipeline-Parallel-Technical.md). In this document, we further split execution into stages and correlate those to the PyTorch external release schedule. + +## Stage 2: Ship prototype synchronous multi-node pipeline parallelism (torchgpipe-style) (1.11 Prototype Release) + +### P(-1): Implement cross-host support for pipeline parallelism + +Existing approaches that support this (in no particular order): + +* Fairscale [experimental distributed pipeline parallelism](https://github.com/facebookresearch/fairscale/tree/main/fairscale/experimental/nn/distributed_pipeline) +* Sagemaker [model parallelism](https://arxiv.org/abs/2111.05972) +* [DeepSpeed pipeline parallelism](https://www.deepspeed.ai/tutorials/pipeline/) +* [OneFlow](https://github.com/Oneflow-Inc/oneflow) + +Proposed approach short-list: (all approaches can be seen in [[RFC] Distributed Pipeline Parallel Training Technical Approach](https://github.com/pytorch/rfcs/blob/master/RFC-0021-Distributed-Pipeline-Parallel-Technical.md) + +1. Selected approach: "Approach 3 with Modifications" + * Inherit RemoteModule + torchpipe-based implementation from fairscale [experimental distributed pipeline parallelism](https://github.com/facebookresearch/fairscale/tree/main/fairscale/experimental/nn/distributed_pipeline). + * Switch autograd off of distributed autograd and onto manual handling of autograd in the pipeline, to facilitate implementing schedules (e.g. 1F1B) + * Abstract the runtime for each RemoteModule to allow for programming in execution schedules + * Switch from using DistributedLoss to having a loss callback to facilitate the last pipeline stage calling the loss locally rather than relying on the training loop to calculate the loss via RPC and call distributed autograd. This will be necessary with arbitrary schedules. + +### P0: Implement support for passing arbitrary data-types between pipeline stages + +Existing approaches that support this (in no particular order): + +* Some amount of [support](https://github.com/pytorch/pytorch/issues/53952) in existing PT implementation + +Proposed approach short-list: + +1. Hopefully should just work out of the box with the RPC API, but need to keep it in mind. + +### P0: 1.11 Prototype Release and out-of-tree demo on HF Transformers + +* Release API as prototype in the 1.11 release to facilitate gathering feedback +* Validation: Out-of-tree demo on HF transformers repo - hack it together to get it to work and pull out work items to improve the API to remove places where code edits are needed +* 1.11 Release Dates + * Feature submission: 11/30 EOD + * Branch cut 1/31/2022 + + + +### **P0**: Support off-the-shelf schedules (1F1B or interleaving) + +Existing approaches that support this (in no particular order): + +* Megatron hardcoded schedules: [1f1b](https://github.com/NVIDIA/Megatron-LM/blob/5ac5571ba0265af4c491ee0af1508ca7589450c6/megatron/schedules.py#L517), [interleaved](https://github.com/NVIDIA/Megatron-LM/blob/5ac5571ba0265af4c491ee0af1508ca7589450c6/megatron/schedules.py#L187) + +Proposed approach short-list: + +1. "Approach 3 with Modifications" + * Once manual autograd is handled, and we abstract the workers, we can implement 1F1B or interleaved 1F1B using that infrastructure. + +### P0: Composability of PP with TP and DP (3d Parallelism) + +Existing approaches that support this (in no particular order): + +* `torch.distributed` APIs via module wrapper composition +* [DeepSpeed](https://www.deepspeed.ai/tutorials/pipeline/) + +Proposed approach short-list: + +1. DistributedDataParallel wraps PipelineParallel API which operates on upcoming ShardedTensors +2. Unified programming model (Stage 5) + +## Stage 3: Figure out how to reconcile local pipeline vs. distributed pipeline (2022H1) + +The existing approaches live in different corners of a 2-dimensional space with axes on **single-driver vs. actors** and **local vs. distributed**. + + +| |single-driver |actors | +|--- |--- |--- | +|local |torchgpipe/fairscale Pipe/distributed.sync.Pipe | | +|distributed | |Fairscale distributed_pipeline, DeepSpeed, Megatron-LM | + +### Design Speculation + +We can interpolate the missing spaces: + +* **single-driver, distributed**: “macro SIMD” style distributed execution. I believe this is actually what was envisioned in @pritamdamania87’s [RFC](https://github.com/pytorch/pytorch/issues/44827) with the `torch.distributed.pipeline_sync` API. The current `distributed.sync.Pipe` API is a fork of the `torchgpipe` implementation (transitively forked in `fairscale`), which is hard-coded for single-node execution issuing commands via CUDA streams (or a fake CPU stream stand-in they implemented) +* **actors, local**: We can take the event-driven approach taken in fairscale’s `distributed_pipeline` and extend that to having worker processes/threads that both a) feed a corresponding CUDA device and b) feed data through to the successor in the pipeline. This is sort-of already done by the `torchgpipe` lineage of implementations which use [worker threads](https://github.com/pytorch/pytorch/blob/master/torch/distributed/pipeline/sync/worker.py) that run the actual forward computation but still have a central coordinating thread issuing each of those workers commands nonetheless. Potentially if done in a multi-process setting, this could lead to higher performance (need to measure). + +I believe the way to go in the future may be to consolidate on actors for both local and distributed. This may represent lower complexity than the torchgpipe-style execution (at least when I think about it) and can avoid issues with a single driver process being a bottleneck (as evidenced by the fact that `torchgpipe` already uses threads for speed). + + +## Stage 4: Generalize pipeline parallelism interface to allow for more coverage of different techniques in the literature (e.g. async, scheduling, auto-partitioning, composition with tensor parallelism) (2022, OSS releases 1.11-1.15) + +### P1: Pipeline parallelism without `nn.Sequential` rewrite + +Existing approaches/proposals that support this (in no particular order): + +* Sagemaker [model parallelism](https://drive.google.com/file/d/1N2eo5Yr_QOw0EtKv-MYBDWKvyRYxKv2o/view) +* @zdevito's [sequential-free splitting approach](https://colab.research.google.com/drive/1lGg2NqlvDwVmvBqejzni2yTmYE9rxfdr?usp=sharing) +* [OneFlow](https://github.com/Oneflow-Inc/oneflow) +* [[RFC] Model Partitioning in Pipeline Parallelism](https://github.com/pytorch/rfcs/blob/master/RFC-0022-Model-Partitioning-in-Pipeline-Parallelism.md) + +Proposed approach short-list: + +1. [[RFC] Model Partitioning in Pipeline Parallelism](https://github.com/pytorch/rfcs/blob/master/RFC-0022-Model-Partitioning-in-Pipeline-Parallelism.md) +2. @zdevito's [sequential-free splitting approach](https://colab.research.google.com/drive/1lGg2NqlvDwVmvBqejzni2yTmYE9rxfdr?usp=sharing) +3. Construct a pipeline parallelism API that uses a different approach, such as the one used in SageMaker model parallelism. This introduces trade-offs elsewhere, such as in support for schedules/the requirement for an optimization pass to be applied to implement "true" pipeline parallelism. + +These approaches can be composed on top of an existing API that takes an `nn.Sequential`. We may consider in the future to develop a "v2" API that is centered more natively around non-`nn.Sequential` models using technologies from Sagemaker, OneFlow, or other research developments. + +### P1: Support arbitrary programmable schedules (e.g. fill-drain, 1F1B, interleaved 1F1B) + +Existing approaches that support this (in no particular order): + +* DeepSpeed [PipeSchedule](https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/runtime/pipe/schedule.py) is an instruction format that allows customizing the order in which forward/backward jobs on different stages should be executed. + +Proposed approach short-list: + +1. Programmable instruction stream + interpreter à la [PipeSchedule](https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/runtime/pipe/schedule.py). This should be enabled by the infrastructural work done in Stage 2. + +### P2: Asynchronous Pipeline Parallelism - Mechanics of Asynchronous Training Loop + +**Call for Stakeholders**: Do you have a project that would benefit from Asynchronous Pipeline Parallelism in PyTorch? Please comment on the RFC and we can incorporate your requirements. + +* async training is like a self-perpetuating engine v.s. a synchronous procedure call as is typical in Python. How do we bridge these two? What would the Pythonic experience for async look like? + +### P2: Asynchronous Pipeline Parallelism - Weight stashing + +**Call for Stakeholders**: Do you have a project that would benefit from Asynchronous Pipeline Parallelism in PyTorch? Please comment on the RFC and we can incorporate your requirements. + +* [Parametrization](https://pytorch.org/tutorials/intermediate/parametrizations.html) as an approach? + +### P2: Asynchronous Pipeline Parallelism - Double-Buffered Weight Stashing + +**Call for Stakeholders**: Do you have a project that would benefit from Asynchronous Pipeline Parallelism in PyTorch? Please comment on the RFC and we can incorporate your requirements. + +* [Parametrization](https://pytorch.org/tutorials/intermediate/parametrizations.html) as an approach? + +### P2: Asynchronous Pipeline Parallelism - Weight Prediction + +**Call for Stakeholders**: Do you have a project that would benefit from Asynchronous Pipeline Parallelism in PyTorch? Please comment on the RFC and we can incorporate your requirements. + +* [Parametrization](https://pytorch.org/tutorials/intermediate/parametrizations.html) as an approach? + +## Stage 5: Integrate into Unified Programming Models Research (2022?) + +Going into the future, we would like to develop theory and implementation for a unified distributed, parallel programming model that brings together all of data parallel, model parallel, pipeline parallel, expert parallel, and more. Various ideas are floating around, including building on top of the Actor model (as in Ray, OneFlow, etc) or extending the MPI-style SPMD model to support spatial parallelism like pipeline parallelism and predicated expert parallelism. Hopefully, this pipeline parallelism project will help to inform us on the correct model here and we can publish our findings in the future. + + +## References + +1. Efficient and Robust Parallel DNN Training through Model Parallelism on Multi-GPU Platform https://arxiv.org/abs/1809.02839 +2. ElasticPipe: An Efficient and Dynamic Model-Parallel Solution to DNN Training https://dl.acm.org/doi/10.1145/3322795.3331463 +3. XPipe: Efficient Pipeline Model Parallelism for Multi-GPU DNN Training https://arxiv.org/abs/1911.04610 +4. PipeDream: Fast and Efficient Pipeline Parallel DNN Training https://arxiv.org/abs/1806.03377 +5. GPipe: Efficient Training of Giant Neural Networks using Pipeline Parallelism https://arxiv.org/abs/1811.06965 +6. torchgpipe: On-the-fly Pipeline Parallelism for Training Giant Models https://arxiv.org/abs/2004.09910 +7. Pipelined Backpropagation at Scale: Training Large Models without Batches https://arxiv.org/abs/2003.11666 +8. Memory-Efficient Pipeline-Parallel DNN Training https://arxiv.org/abs/2006.09503 +9. Efficient Large-Scale Language Model Training on GPU Clusters https://arxiv.org/abs/2104.04473 +10. Performance analysis of a pipelined backpropagation parallel algorithm https://ieeexplore.ieee.org/document/286892 +11. PipeMare: Asynchronous Pipeline Parallel DNN Training https://arxiv.org/abs/1910.05124 +12. Scaling Language Model Training to a Trillion Parameters Using Megatron + https://developer.nvidia.com/blog/scaling-language-model-training-to-a-trillion-parameters-using-megatron/ \ No newline at end of file diff --git a/RFC-0021-Distributed-Pipeline-Parallel-Technical.md b/RFC-0021-Distributed-Pipeline-Parallel-Technical.md new file mode 100644 index 0000000..c4523a0 --- /dev/null +++ b/RFC-0021-Distributed-Pipeline-Parallel-Technical.md @@ -0,0 +1,620 @@ +# [RFC] Distributed Pipeline Parallel Training Technical Approach + +## Background - PyTorch Training Loop + +PyTorch does not vend a standard training loop abstraction. As a result, the training process for a PyTorch model consists of free-form Python code. An example of a standard training loop in PyTorch might look something like this (borrowed from the PyTorch transfer learning [tutorial](https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html)): + +``` +def train_model(model, criterion, optimizer, scheduler, num_epochs=25): + since = time.time() + + best_model_wts = copy.deepcopy(model.state_dict()) + best_acc = 0.0 + + for epoch in range(num_epochs): + print('Epoch {}/{}'.format(epoch, num_epochs - 1)) + print('-' * 10) + + # Each epoch has a training and validation phase + for phase in ['train', 'val']: + if phase == 'train': + model.train() # Set model to training mode + else: + model.eval() # Set model to evaluate mode + + running_loss = 0.0 + running_corrects = 0 + + # Iterate over data. + for inputs, labels in dataloaders[phase]: + inputs = inputs.to(device) + labels = labels.to(device) + + # zero the parameter gradients + optimizer.zero_grad() + + # forward + # track history if only in train + with torch.set_grad_enabled(phase == 'train'): + outputs = model(inputs) + _, preds = torch.max(outputs, 1) + loss = criterion(outputs, labels) + + # backward + optimize only if in training phase + if phase == 'train': + loss.backward() + optimizer.step() + + # statistics + running_loss += loss.item() * inputs.size(0) + running_corrects += torch.sum(preds == labels.data) + if phase == 'train': + scheduler.step() + + epoch_loss = running_loss / dataset_sizes[phase] + epoch_acc = running_corrects.double() / dataset_sizes[phase] + + print('{} Loss: {:.4f} Acc: {:.4f}'.format( + phase, epoch_loss, epoch_acc)) + + # deep copy the model + if phase == 'val' and epoch_acc > best_acc: + best_acc = epoch_acc + best_model_wts = copy.deepcopy(model.state_dict()) + + print() + + time_elapsed = time.time() - since + print('Training complete in {:.0f}m {:.0f}s'.format( + time_elapsed // 60, time_elapsed % 60)) + print('Best val Acc: {:4f}'.format(best_acc)) + + # load best model weights + model.load_state_dict(best_model_wts) + return model + +model_ft = models.resnet18(pretrained=True) +num_ftrs = model_ft.fc.in_features +# Here the size of each output sample is set to 2. +# Alternatively, it can be generalized to nn.Linear(num_ftrs, len(class_names)). +model_ft.fc = nn.Linear(num_ftrs, 2) + +model_ft = model_ft.to(device) + +criterion = nn.CrossEntropyLoss() + +# Observe that all parameters are being optimized +optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9) + +# Decay LR by a factor of 0.1 every 7 epochs +exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1) + +model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler, + num_epochs=25) +``` + +As you can see, the code is written in a way that is very free-form and configurable. There is some similarity between training loops, such as the common appearance of various constructs: + + +* A DataLoader object that yields input and target data for the training process, potentially after being subject to data augmentation techniques +* An Optimizer object that manages the behavior of updating parameter values given gradients subject to a specific update policy +* A model to be run in forward propagation as well as back-propagation +* A Loss function that takes the model output, targets, and yields a differentiable expression that reduces the divergence between the two down to a scalar value to be minimized +* Calls to `backward()` for backpropagation and `optimizer.step()` to apply parameter updates after gradient computation +* A learning rate scheduler that decays the learning rate according to a policy + +However, these things may not all be present during a training process and may be used in unconventional ways within the free-form Python code. + +## Background - Pipeline Parallel Training + +Pipeline parallel deep learning training is a technique to distribute the process of training (or inferencing) a deep learning model over a series of machines. As a technique for splitting up a large model, pipeline parallel training often requires less communication bandwidth between machines than other techniques such as tensor splitting. Pipeline parallel training can also result in high overlap between numerical computation and cross-stage communication operations. A graphical example of pipeline parallelism from [1] can be seen below: + +![An example pipeline-parallel assignment](https://i.imgur.com/ODy3ws4.png) + +However, model placement is not the whole story. The training process encompasses not just the forward propagation as shown in the figure, but also loss calculation, backward propagation, and parameter update. A graphical representation of this whole process from [2] can be seen in the figure below. + +![Synchronous Pipeline Parallelism in Deep Learning Training](https://i.imgur.com/IHbuIm0.png) + +If we refer back to the previous section, we can see that pipeline parallelism **encompasses a large part of the training loop**. A key consideration for a programming interface that enables pipeline parallelism is: how does it interact with the code within the training loop (as opposed to the code in the model?) + +## Motivation - Pipeline Parallel Training with as Few Edits as Possible + +We would like to deliver a pipeline parallelism solution that is as unintrusive as possible to the developer. However, given that 1) PyTorch does not vend a standard training loop abstraction and 2) pipeline parallel training encompasses a large part of the training loop, we must design a novel solution to run model training in a pipelined way. + +To set a goal, we would like to run the training loop from the beginning of the document in a pipeline parallel fashion with as few code changes as possible. In this document, we explore multiple approaches for implementing Pipeline Parallel training and examine how it affects the training loop authoring process. + +## Desiderata + +We would like to consider the following (lifted from [[RFC] Ceci n'est pas pipeline parallelism (Pipeline Parallelism 2021Q4/2022 Plan))](https://github.com/pytorch/rfcs/blob/master/RFC-0021-Distributed-Pipeline-Parallelism.md)) when comparing alternatives (D* is an identifier for later reference, P* is a priority based on the roadmap): + + +* **D0 (P-1)** Cross-host pipeline parallel support (table stakes) +* **D1 (P0)** Support for passing arbitrary data types between stages (table stakes) +* **D2 (P0)** Support for pipeline parallel schedules (e.g. GPipe fill-drain, 1F1B, or interleaved 1F1B) + * **P1** Support for arbitrary programmable schedules +* **D3 (P0)** Composability with other parallelism schemes (Tensor Parallelism, Data Parallelism) in a 3D parallelism scheme +* **D4 (P1)** Composability with other parallelism schemes in an *arbitrary scheme* +* **D5 (P1)** Off-the-shelf support for pipelining without manual conversion of a model to `nn.Sequential` +* **D6 (P2)** Support for asynchronous pipeline parallelism + * Continuous data loading + * Weight stashing/Weight prediction +* **D7 (P2)** Research: Fits into a unified, highly configurable programming model encompassing all parallelism schemes +* **D8 (P1)** The user can use an idiomatic Python-based training loop with no or minimal modifications from their “normal” training loop + +## Approach 1: SPMD with Predicated Training Loop and Message Passing + +**NOTE**: This approach is only an abstract proposal and the ideas are still in development. + +Suppose I have a model consisting of 5 layers and I have 5 processors I want to run that model on. The model pseudocode might look something like: + +``` +def model(x): + x = layer1(x) # Assigned to processor 1 + x = layer2(x) # Assigned to processor 2 + x = layer3(x) # Assigned to processor 3 + x = layer4(x) # Assigned to processor 4 + x = layer5(x) # Assigned to processor 5 + return x +``` + +One way to implement this is to convert the code into a form that is programmatically manipulable (such as a `torch.fx` Graph), partition that IR such that the code for each stage resides in a separate program, and distribute that program to each of the processors. A runtime on each of those processors would then handle loading data (from a DataLoader or from a previous stage), running the partitioned code, calculating the loss (on the last stage), running backpropagation, and applying gradient updates. This is the [approach](https://github.com/microsoft/DeepSpeed/blob/af443f63f483f6ea6769b78b4b0f2407023e9aed/deepspeed/runtime/pipe/engine.py#L46) DeepSpeed takes in pipeline parallel execution (for example, [see](https://www.deepspeed.ai/tutorials/pipeline/) how you must pass your DataLoader and use an opaque `engine.train_batch` API in DeepSpeed, relinquishing control to the opaque runtime). However, this runtime essentially replaces the PyTorch training loop, and converting your training process to this scheme may be a burden for the end user (see Approach 4 for a comprehensive treatment of the design considerations for this approach). + +@zdevito proposed transparently [splitting a model into stages](https://colab.research.google.com/drive/1lGg2NqlvDwVmvBqejzni2yTmYE9rxfdr?usp=sharing) by employing a type of predicated execution[3] For each stage, the Python code of the whole model is run, however operations outside of the ones relevant to each specific pipeline stage are “no-op”ed. + +![Predicated execution of a model in pipeline parallelism](https://i.imgur.com/opv5Wic.png) + +One thing not represented in this diagram is a 3rd dimension of time, i.e. there will be a “fill” stage where `forward` micro-batches fill in from left-to-right, and a drain stage where `backward` micro-batches drain out of the pipeline from right-to-left. The diagram above represents the pipeline in a “steady-state” condition. + + +### Expanding Predication to Include the Training Loop + +The next logical step is to expand this proposal to encompass the training loop, allowing the user to continue to write their arbitrary Python training loop, but overlaying pipeline parallelism semantics on this using predicated execution. + +We can pull out the parts of the canonical training loop from the beginning and investigate how they should be executed under a predicated, pipeline parallel training loop: + +**Data Loader** + +The data loader should only load input data on rank 0. We can view this as predicating `true` on rank 0 and false on all other ranks. We can also commingle this with `recv` for stages != 0. i.e. under the hood the dataloader object will return an input micro-batch on rank 0, but will return the intermediate value received from `rank - 1` on all ranks != 0. + +**Optimizer - Zero Grad** + +In synchronous PP, the optimizer should zero out gradients at the beginning of the entire mini-batch (i.e. during micro-batch 0) and *not* zero the gradients for subsequent micro-batches. Gradients for each of the micro-batches should be accumulated, but not applied, preserving the mathematical integrity of SGD optimization. The GPipe diagram from above is reproduced to demonstrate this: + +![Synchronous Pipeline Parallelism in Deep Learning Training](https://i.imgur.com/IHbuIm0.png) + +This scheme of zeroing grads on the first micro-batch can be trivially implemented by predicating `zero_grads` as `True` for the first micro-batch on each pipeline stage, and predicating it as `False` for each subsequent micro-batch. Applying the update with accumulated gradients is covered later. + + +**Forward Propagation** + +Predication of forward-propagation can be done as in Zach’s [proposal](https://colab.research.google.com/drive/1lGg2NqlvDwVmvBqejzni2yTmYE9rxfdr?usp=sharing). + +Note that we may extend the predicated training loop scheme to include schedules such as 1F1B or interleaved 1F1B, discussed later. + +**Loss Calculation** + +The loss calculation is only valid for the last pipeline stage, i.e. `rank == world_size - 1`. We can predicate loss calculation as `False` for all ranks except the last one. There may be some complication here, as (as in the transfer learning example), there may be additional computation between the model forward pass and the loss calculation, such as a `max()` operation for extracting a single top prediction. Potentially we could extend the predicated tensor data-type from Zach’s proposal and simply use that scheme to predicate out the loss and any interstitial computation. + +**Backward Propagation** + +I think back-prop could happen similarly to Zach’s proposal. Similar extensions for 1F1B schedules etc apply. + +**Optimizer - Step** + +As described in the Zero Grad section, the optimizer step should only be applied after the backwards pass for every micro-batch has gone through the pipeline. The optimizer step call should predicate `False` for every invocation except the very last one on each stage. + +**LR Scheduler Step** + +This occurs outside the micro-batch region, so I think this can execute as-is + + +**Graphical Representation** + +![Predicated execution of the whole training loop](https://i.imgur.com/9RlorjQ.png) + +Note that forward and backward stages do not necessarily always run in a given rank, for example in the “fill” or "drain" phases of pipelined execution or according to specific pipeline schedules. Time should be considered as an axis into and out of the page where different states of the pipeline (fill/drain) can be represented. + +### Pros and Cons of the Approach + +**Pro** + +* Similarly to other SPMD schemes, there is no possibility for this approach to be “front-end bound”. There is no remote coordinating process and there is no network latency delaying commands to each of the workers. The program text resides locally on each processor. +* (**D8**) It is not necessary to use specialized DistributedLoss or DistributedOptimizer structures. There are no RPCs occurring during loss calculation or optimization. + * Note: to implement this as a clean API, maybe need to create PredicatedLoss or PredicatedOptimizer, so this may be a wash, but at least we save the RPC cost. +* (**D2**) Pipeline parallel schedules can be programmed by controlling the number of iterations each stage executes and predicating `true` the parts that should execute on that iteration. Arbitrary schedules can be programmed by the user via Python code. +* (**D3, D4, D7**) Composes with SPMD execution model; can likely readily interoperate well with SPMD Tensor Parallelism. Is potentially the basis for converging parallelism on the SPMD model (other alternative is converging parallelism on the Actor model)(needs more research - likely can be the basis for (a) paper(s)) +* (**D5**) Does not require the user to manually partition their model into an `nn.Sequential` +* (**D6**) There is no concept of a synchronous “call” or “dispatch” into the training loop; this scheme can likely readily support asynchronous pipeline parallelism with continuous data loading and training. + +**Con** + +* Value predication imposes certain restrictions on the classes of programs that can be used in this scheme + * For a predicated tensor without metadata, restrictions would be essentially equivalent to those in FX tracing + * For a predicated tensor with metadata (i.e. carrying forward a MetaTensor), restrictions would be equivalent to those for MetaTensor capability. + * Potentially, we could short-circuit evaluate Module dispatches where all Tensor inputs are predicated `false`. This would require knowing the output type of the Module, however, to return a value with the correct structure. +* Program text predication (e.g. as in CUDA) would be ideal, but I’m not aware of any global “code identifier” structure we could use here. (**TODO**: investigate Python runtime structures that might be useful here?) +* (**D8**) The practice of having the training loop tiled over multiple machines may be confusing to the user. Outside of the components covered here (e.g. DataLoader, optimizer, etc), the user may need to reason about what is and isn’t valid in their training loop under an SPMD execution model. + +## Approach 2 - RPC with RemoteModule and torchgpipe-style single coordinator (@pritamdamania87 RFC) + +One proposal for an API for pipeline parallelism is the `pipeline_sync` API proposed in @pritamdamania87’s [RFC](https://github.com/pytorch/pytorch/issues/44827) (Certain lines are called out with end-of-line comments containing an alphabetical identifier): + +``` +# Note: This API is very similar to torchgpipe and inspired from it. +# torchgpipe API for reference: https://torchgpipe.readthedocs.io/en/stable/api.html + +torch.distributed.pipeline_sync( + pipeline: nn.Sequential, + checkpoint: CheckpointEnum = EXCEPT_LAST, # ALWAYS, EXCEPT_LAST, NEVER + chunks: int = 1) -> PipelineSyncModel + +Arguments: + +pipeline: nn.Sequential (https://pytorch.org/docs/stable/generated/torch.nn.Sequential.html) where each nn.Module in the list is placed on the + appropriate device(CPU or GPU)/machine by the user. Note that + nn.Sequential could also consist of RemoteModule (https://github.com/pytorch/pytorch/blob/master/torch/distributed/nn/api/remote_module.py#L147) for cross host + pipelining. +checkpoint: Enum that determines which checkpointing mode to use. +chunks: Number of micro-batches. + +Returns: + An instance of PipelineSyncModel + +Forward Method + +PipelineSyncModel.forward(self, *input, **kwargs) -> RRef + +Returns: + RRef to output corresponding to the result of the minibatch. + Since we plan to support cross host pipelining, the RRef could be on a + device on a different host. + +Example: + +# This is an example of a pipeline across two machines each using one GPU. +# On worker 0 +layer1 = nn.Linear(10, 5).cuda(0) +# Need to enhance RemoteModule to include device for this purposes. +layer2 = RemoteModule("worker1", device="cuda:0", nn.Linear, 5, 1) + +pipeline = nn.Sequential(layer1, layer2) +model = torch.distributed.pipeline_sync(pipeline, chunks = 4) # A + +rref_params = [RRef(param) for param in layer1.parameters()] +# Need to enhance RemoteModule for `get_rref_parameters` +rref_params.append(layer2.get_rref_parameters()) + +dist_optim = DistributedOptimizer(optim.SGD, rref_params, lr=0.05) # B + +# Helper functions +def compute_loss(output_rref, target_rref): + return F.cross_entropy*(*output_rref.local_value()*,* target_rref.local_value()*)* + +def identity_fn(inp): + return inp + +for epoch in range(epochs): + for minibatch, target in data: + # Use dist autograd context for distributed autograd. + with dist_autograd.context() as context_id: + target_rref = rpc.remote("worker1", identity_fn, target) # C + output_rref = model(minibatch) # D + loss_rref = rpc.remote("worker1", compute_loss, output_rref, target_rref) # E + # Can enhance RRef to ensure this calls "dist_autograd.backward" on the last + # node in the pipeline. + loss_rref.backward(context_id) # F + dist_optim****.step() # G +``` + +This proposal has the training loop running on a single machine and makes copious use of the `torch.distributed.rpc` APIs available in PyTorch. We can examine different parts of the loop, highlighted with alphabetical labels: + +**A - Model Pipelining** + +The `pipeline_sync` function wraps the given `Sequential` of layers in a runtime that will do micro-batch splitting and execution in a pipelined fashion. I believe by the comment, the implication is that this `pipeline_sync` API functions similarly to `torchgpipe`, where a single driver schedules commands for each pipeline stage onto some “stream” abstraction, issuing subsequent micro-batches on the same stream on the forward pass and scheduling “virtual” dependencies in the autograd graph to serialize execution of gradient computation on micro-batches on each stage (see section 3.2.2 of the torchgpipe [paper](https://arxiv.org/abs/2004.09910)). + +`pipeline_sync` returns an `RRef` referring to the output of the pipeline *for the whole mini-batch*. + +**B - Distributed Optimizer** + +The training script instantiates a [DistributedOptimizer](https://pytorch.org/docs/master/distributed.optim.html) to wrap the vanilla SGD optimizer. This DistributedOptimizer takes RRefs to the parameters distributed among the pipeline stages. During the `step()` call in stage (G), the DistributedOptimizer will make async RPC calls to all of the remotes the run an optimizer step. + +**D - Model Execution** + +As mentioned in (A), this is likely going to follow the `torchgpipe` logic, which issues commands for each stage according to a schedule (e.g. GPipe fill/drain schedule, see section 3.2.1 of torchgpipe [paper](https://arxiv.org/abs/2004.09910)). In this case, `forward()` calls would be issued as calls to the RemoteModule instances in sequence. Calls to `forward()` would be recorded in the distributed autograd context for later backpropagation in stage (F). + +**C/E - Loss execution** + +The loss calculation is something that would usually happen directly in the training loop. However, in the case of Pipeline Parallel execution, the output of forward propagation resides on the last pipeline stage and backprop should begin from that stage, moving in reverse order through the pipeline. Thus, the loss computation is formulated as an RPC onto the last stage. The training loop calls `rpc.remote`, feeding the loss calculation as target and the returned minibatch output RRef as argument (along with target values moved to the remote). This then returns an RRef referring to the loss value calculated on the remote that is backpropagated through in stage (F). + +**F - Backprop** + +This proposal uses the distributed autograd engine to backpropagate through the forward passes computed in pipeline parallel execution. As mentioned earlier, the per-stage execution order is likely mediated by “fork” and “join” virtual dependencies in the autograd graph, due to section 3.2.2 in the [paper](https://arxiv.org/abs/2004.09910). + +**NOTE**: I don’t believe that forward and backward jobs are serialized; they may run concurrently. Is this true? + +**G - Optimizer** + +As mentioned in stage (B), the DistributedOptimizer will make async RPC calls to all stages to apply the selected optimizer to the parameters contained within that stage. + +### Pros and Cons of the Approach + +**Pro** + +* (**D8**) Training loop looks pretty close to the original PyTorch code. Training loop runs on a single machine, so user does not need to reason about correctness of their training loop under SPMD, as in Approach 1. + +**Con** + +* (**D5**) In its current conception, requires manual splitting of the model into an `nn.Sequential` instance +* High possibility of being “front-end bound”. Every forward computation, loss calculation, autograd, and optimizer are all mediated by RPCs from a central coordinator. + * Speed of issuing these commands on the host may be an issue. torchgpipe addresses this via scheduling jobs in execution order (section 3.2.1 of the [paper](https://arxiv.org/abs/2004.09910.pdf)). However, `torchgpipe` still finds that this is not fast enough, so uses worker threads to actually run the forward computations (likely to parallelize CPU-bound tasks such as CUDA allocator invocation). Depending on the overhead of issuing these commands over RPC, the central coordinator overhead may become an issue + * When expanding the torchgpipe single-coordinator scheme to cross-host execution, network latency, jitter, and instability may contribute to front-end-boundedness issues +* (**D2**) In its current conception, it’s not clear to me if schedules are representable in this scheme due to reliance on distributed autograd execution. GPipe’s fill-drain schedule is implemented via careful data dependency programming in the autograd graph. It’s not clear to me if things like 1F1B, interleaved 1F1B, Varuna scheduling, or other research schedules are (easily) implementable in this scheme. +* (**D6**) This approach has a strong concept of “synchronous dispatch” into the pipeline. The single coordinator calls into the pipeline with a mini-batch, the execution is scheduled internally to that, and the pipeline returns an RRef to the result value. It’s not clear how continuous, asynchronous training would fit into this without retrofitting an event-driven handler for the training loop to feed another mini-batch in. + +## Approach 3 - RPC with RemoteModule and message passing (fairscale experimental) + +This describes the approach used in the fairscale experimental [distributed pipeline](https://github.com/facebookresearch/fairscale/tree/main/fairscale/experimental/nn/distributed_pipeline). A sample training loop implementation can be seen starting [here](https://github.com/wayi1/pipeline_experiments/blob/7e0fe6f884edfab026379cce1b5ae03b5c2489cd/BERT/main.py#L200). The syntax in that example is not particularly clean, but it is similar to the loop in approach (2). We can distill its essence in the following: + +``` +layers = nn.Sequential(...) +graph = make_graph(layers) +model = DistributedPipline(graph,chunks=chunks) # A + +optimizer = DistributedOptimizer(torch.optim.SGD, model.parameter_rrefs(), lr=lr) # B + +criterion = nn.CrossEntropyLoss() +class Loss(nn.Module): + def __init__(self, criterion, ntokens): + super().__init__() + self.ntokens = ntokens + self.criterion = criterion + #self.criterion = nn.CrossEntropyLoss() + + def forward(self, input, target): + return self.criterion(input.view(-1, self.ntokens), target.to(input.device)) + +for epoch in range(epochs): + loss_module = DistributedLoss(Loss, criterion, ntokens) + + for minibatch, targets in dataloader: + with dist_autograd.context() as context_id: + minibatch = minibatch.transpose(0, 1) + output = model(minibatch) # C + loss = loss_module(output, rpc.RRef(targets)).to_here() # D + dist_autograd.backward(context_id, [loss]) # E + optimizer.step(context_id) # F +``` + +This proposal has the training loop running on a single machine and makes copious use of the `torch.distributed.rpc` APIs available in PyTorch. We can examine different parts of the loop, highlighted with alphabetical labels: + +**A - Model Pipelining** + +As opposed to the torchgpipe-based Approach 2, this approach instantiates actors (specifically [PartitionHandler](https://github.com/facebookresearch/fairscale/blob/6f3931a4d3464056231f1bbb92a67fbd14f30d96/fairscale/experimental/nn/distributed_pipeline/partition_handler.py#L140) instances) that execute the pipeline in an event-driven manner. PartitionHandler instances own a [DistributedPipelineRecord](https://github.com/facebookresearch/fairscale/blob/6f3931a4d3464056231f1bbb92a67fbd14f30d96/fairscale/experimental/nn/distributed_pipeline/partition_handler.py#L27) instance, which has a “feed” method to be called via RPC to add a data item for processing. + +**B - Distributed Optimizer** + +DistributedOptimizer is used in the same way as Approach 2. The training script instantiates a [DistributedOptimizer](https://pytorch.org/docs/master/distributed.optim.html) to wrap the vanilla SGD optimizer. This DistributedOptimizer takes RRefs to the parameters distributed among the pipeline stages. During the `step()` call in stage (G), the DistributedOptimizer will make async RPC calls to all of the remotes the run an optimizer step. + +**C - Model Execution** + +`PartitionHandler` has a worker thread that runs the pipeline stage given the input data in series. Then, it forwards the result to the successor by calling its `feed()` method via RPC. + +**D - Loss calculation** + +Loss calculation happens similarly to in Approach 2, the single driver calls into [DistributedLoss](https://github.com/facebookresearch/fairscale/blob/6f3931a4d3464056231f1bbb92a67fbd14f30d96/fairscale/experimental/nn/distributed_pipeline/loss.py#L16), which under the hood makes an async RPC to the last pipeline stage to execute the loss calculation. + +**E - Backprop** + +Backpropagation through the pipeline is similarly implemented via distributed autograd, as in Approach 2. Note that the same fork/join barrier approach is used to [serialize](https://github.com/facebookresearch/fairscale/blob/6f3931a4d3464056231f1bbb92a67fbd14f30d96/fairscale/experimental/nn/distributed_pipeline/partition_handler.py#L103) execution of micro-batches on the backward pass. + +**NOTE**: I don’t believe that forward and backward jobs are serialized; they may run concurrently. Is this true? + +**F - Optimizer Step** + +The optimizer step uses DistributedOptimizer in the same was as Approach 2. DistributedOptimizer will make async RPC calls to all stages to apply the selected optimizer to the parameters contained within that stage. + +### Pros and Cons of the Approach + +**Pro** + + +* (**D8**) Training loop looks pretty close to the original PyTorch code. Training loop runs on a single machine, so user does not need to reason about correctness of their training loop under SPMD, as in Approach 1. + * OTOH, some of the set-up in the [example](https://github.com/wayi1/pipeline_experiments/blob/7e0fe6f884edfab026379cce1b5ae03b5c2489cd/BERT/main.py#L200) is pretty hairy and could probably be improved +* Compared to Approach 2, much less risk of being “front-end bound”. The burden of issuing commands is distributed throughout the ranks, i.e. a rank receives micro-batches and dispatches completed micro-batches to its successor. + +**Con** + + +* (**D5**) In its current conception, requires manual splitting of the model into an `nn.Sequential` instance +* The system may still be “front-end bound” for loss calculation, distributed autograd, and DistributedOptimizer step. +* (**D2**) In its current conception, it’s not clear to me if schedules are representable in this scheme due to reliance on distributed autograd execution. GPipe’s fill-drain schedule is implemented via careful data dependency programming in the autograd graph. It’s not clear to me if things like 1F1B, interleaved 1F1B, Varuna scheduling, or other research schedules are (easily) implementable in this scheme. +* (**D6**) This approach has a strong concept of “synchronous dispatch” into the pipeline. The single coordinator calls into the pipeline with a mini-batch, the execution is scheduled internally to that, and the pipeline returns an RRef to the result value. It’s not clear how continuous, asynchronous training would fit into this without retrofitting an event-driven handler for the training loop to feed another mini-batch in. + +## Approach 4 - MPMD with a custom interpreter/instruction format and message passing (DeepSpeed) + +This is the approached used in [DeepSpeed pipeline parallelism](https://www.deepspeed.ai/tutorials/pipeline/). From that tutorial, we can see that the training loop ends up looking like this: + +``` +class AlexNetPipe(AlexNet): + def to_layers(self): + layers = [ + *self.features, + self.avgpool, + lambda x: torch.flatten(x, 1), + *self.classifier + ] + return layers + +from deepspeed.pipe import PipelineModule +net = AlexNetPipe() +net = PipelineModule(layers=net.to_layers(), num_stages=2) + +engine, _, _, _ = deepspeed.initialize( + args=args, + model=net, + model_parameters=[p for p in net.parameters() if p.requires_grad], + training_data=cifar_trainset()) + +for step in range(args.steps): + loss = engine.train_batch() +``` + +The DeepSpeed `engine` here encapsulates the runtime semantics of the training loop, including data loading, pipeline parallel execution, scheduling, and optimization. The implementation of [train_batch](https://github.com/microsoft/DeepSpeed/blob/488105ebd200bbd1f6d7cbe863412e41d9ab4221/deepspeed/runtime/pipe/engine.py#L278) shows that DeepSpeed uses a type of programmable interpreter to run the instructions constituting pipeline parallel execution. The method constructs a [TrainSchedule](https://github.com/microsoft/DeepSpeed/blob/488105ebd200bbd1f6d7cbe863412e41d9ab4221/deepspeed/runtime/pipe/schedule.py#L182), which yields [commands](https://github.com/microsoft/DeepSpeed/blob/488105ebd200bbd1f6d7cbe863412e41d9ab4221/deepspeed/runtime/pipe/schedule.py#L189) for the processor to run to implement the proper sequencing of events for pipeline parallel execution. The instructions available for this interpreter are the following (with self-explanatory names): + + +* OptimizerStep +* ReduceGrads +* [ReduceTiedGrads](https://www.deepspeed.ai/tutorials/pipeline/#tied-layers) +* LoadMicroBatch +* ForwardPass +* BackwardPass +* SendActivation +* RecvActivation +* SendGrad +* RecvGrad + + +The implementations for each of these instructions can be referenced from this [lookup table](https://github.com/microsoft/DeepSpeed/blob/488105ebd200bbd1f6d7cbe863412e41d9ab4221/deepspeed/runtime/pipe/engine.py#L1307). + +### Pros and Cons of the Approach + +**Pro** + +* (**D2**) (hypothetically) supports arbitrary schedules through the [PipeSchedule](https://github.com/microsoft/DeepSpeed/blob/488105ebd200bbd1f6d7cbe863412e41d9ab4221/deepspeed/runtime/pipe/schedule.py#L6) abstraction. However, there don’t seem to be any schedules implemented beyond the default +* (**D3, D4?**) Usable in 3d parallelism, as detailed by the [blog post](https://www.deepspeed.ai/tutorials/pipeline/). +* (**D6**) Since data is pulled from the data loader rather than being pushed by a synchronous call in the training loop, this approach could *hypothetically* support async PP. +* (**D7**) The approach seems to account for many different types of parallelism. + +**Con** + +* (**D1**) Does not support passing arbitrary data between stages, only supports Tensor and tuple of Tensor (because of `nn.Sequential` front-end) +* (**D5**) Only supports models fit into an `nn.Sequential` +* (**D8**) This approach takes control away from the user. The training loop is now implemented by the DeepSpeed engine abstraction, rather than being free-form Python code. + +## Approach 5: RPC with remote modules and generalized Module-server architecture (SageMaker) + +The [SageMaker model parallelism](https://arxiv.org/abs/2111.05972) design uses a single Python-native training loop with a “module-server” architecture. The system divides the model based on the Module hierarchy and assigns each module onto a specific pipeline parallel rank (PP_RANK). During execution, when there is a dispatch to a `Module` that resides on another PP_RANK, a remote request-response RPC is made to run the appropriate forward/backward pass for the Module on the remote PP_RANK. + +[Image: req_resp.png] + +PP_RANK 0 drives the process by scheduling instances of the training loop function (a UDF annotated by `@smp.step`): two for each micro-batch (one for forward, one for backward). PP_RANK 0 can implement different “schedules” by dispatching these `(micro-batch, phase)` tuples in a given order. The orders that they present are: + + +* Simple pipeline (aka GPipe fill-drain). This is implemented by having PP_RANK 0 dispatch the `phase=forward` tuples for each micro-batch in sequence. Then, dispatching the `phase=backward` tuples for each micro-batch in sequence. +* “interleaved” pipeline (**NB**: this is not the same as the *interleaved 1F1B* from Narayanan, 2021). PP_RANK 0 will schedule `phase=forward` jobs and opportunistically schedule `phase=backward` jobs *as soon as the forward pass for that micro-batch is done*. + +![Module-server request-response execution in SageMaker pipeline parallelism](https://i.imgur.com/y9MZJ3b.png) + +**NOTE**: The schedules here do not necessarily run stage in a given order on each stage. Network latency and other affects may change the order of when micro-batches are executed. + +### Pros and Cons of the Approach + +**Pro** + +* (**D1**) I *believe* passing arbitrary types works, assuming their P2P communication backend supports it. There is no fundamental limitation precluding this from working with this design. +* (**D2**) (split pro/con) Support for *some* kind of schedules, but not strictly implemented as those described in the literature +* (**D3/D4**) Composes with other parallelism schemes +* (**D5**) Does not require model to be an `nn.Sequential` +* (**D8**) User’s original training loop is preserved with only slight modifications (`@smp.step` annotation and other things) + +**Con** + +* (**D2**) (split pro/con) Support for *some* kind of schedules, but not strictly implemented as those described in the literature +* (**D6**) (not sure about this one) Seems to still rely on synchronous training loop. May need modifications to support async (but not sure if there are any fundamental limitations?) + +## Approach 6: SPMD with Program capture/JIT compilation and message passing (OneFlow) + +[OneFlow](https://arxiv.org/abs/2110.15032) is a deep learning framework that purports to redesign the deep learning programming model to better support distributed computation. The framework uses a “consistent view” abstraction to represent distributed memory. The framework sports a compiler that can run global optimization of device placement. It also uses a unified actor model abstraction for its distributed runtime. + +An example of using OneFlow for pipeline parallelism can be seen in this [tutorial](https://docs.oneflow.org/en/master/parallelism/06_pipeline.html). The example instantiates a sequential two-stage model architecture and uses the `to_consistent` API to move the submodules to the appropriate CUDA devices. It also sets `config.stage_id` on each of the submodules to give a monotonically increasing stage number. Finally, on the `nn.Graph` it uses `config.set_gradient_accumulation_steps` to delay optimization for 2 micro-batches, and calls `add_optimizer` to add the optimizer class. + +`oneflow.distributed.launch` will launch the processes for each “actor”. Eager mode and graph mode reportedly go down the same (or similar) code path for distributed graph processing, with eager mode using a LazyTensor-like model to capture the program. The compiler knows which rank the compilation is for and will emit code specifically for that rank. Micro-batch splitting and `1f1b` seem to be hard-coded, if they are implemented. + +### Pros and Cons of the Approach + +**Pro** + +* (split pro/con) (**D2**) Not clear if `1f1b` or other schedules are implemented +* (**D3/D4**) Composable with other parallelism schemes via their “Consistent View” definition +* (**D5**) `nn.Sequential` not needed, but potentially an `nn.Graph` instance may be needed in some cases +* (**D6**) Async is probably supportable but not clear. From their presentation, the actor/register model with backpressure can implement on-demand data loading, but I’m not 100% sure what that API looks like +* (**D7**) Unified programming model that already exists + +**Con** + +* (split pro/con) (**D2**) Not clear if `1f1b` or other schedules are implemented/implementable? +* (**D8**) Not clear what the training loop abstraction looks like. The optimizer is installed via an `nn.Graph` API. Loss calculation is created in the `nn.Graph.build()` method. + +## Final Analysis + +### General Design Axes + +Looking through the various approaches above, we can pull out some general design axes we should consider: + + +* (**DA1**) Single-coordinator vs. distributed coordination + * Sub-decisions: model execution, autograd, optimizers +* (**DA2**) Python training loop vs. custom encapsulated training loop (e.g. as in DeepSpeed) +* (**DA3**) `nn.Sequential` vs. more free-form partitioning +* (**DA4**) Local vs. distributed loss +* (**DA5**) Distributed autograd framework vs. manual autograd + * As it relates to pipeline schedules +* (**DA6**) Local vs. distributed optimizer +* (**DA7**) Synchronous pipeline dispatch vs. asynchronous +* (**DA8**) Predication scheme (Approach 1 only) +* (**DA9**) Instruction format (Approach 4 only) + + +We can start analyzing the approaches by these design axes + +* Approach 1: SPMD with Predicated Training Loop and message passing +* Approach 2: RPC with RemoteModule and torchgpipe-style single coordinator (@pritamdamania87 RFC) + * Single-coordinator + * CUDA command buffer analogy + * Continuation passing +* Approach 3: MPMD with RemoteModule and message passing (fairscale experimental) +* Approach 4: MPMD with a custom interpreter/instruction format and message passing (DeepSpeed) +* Approach 5: RPC with remote modules and generalized Module-server architecture (SageMaker) +* Approach 6: SPMD with Program capture/JIT compilation and message passing (OneFlow) + +| |DA1 |DA2 |DA3 |DA4 |DA5 |DA6 |DA7 |DA8 |DA9 |Notes | +|--- |--- |--- |--- |--- |--- |--- |--- |--- |--- |--- | +|Approach 1 |multi |py |FF |local |manual |local |async |? |X | | +|Approach 2 |single |py |seq |dist |dist |dist |sync |X |X | | +|Approach 3 |single |py |seq |dist |dist |dist |sync |X |X | | +|Approach 4 |multi |interp |seq |local? |manual |local |async* |X |? | | +|Approach 5 |single* |py |FF |local? |manual |local? |sync? |X |X |Schedules? | +|Approach 6 |multi |interp |FF |local? |manual? (graph?) |local? |async? |X |X | | + +### Decision - Approach 3 with Modifications + +After deliberation, we want to build the API with the least complexity, at least initially. We will modify/build the API in FairScale experimental with a few modifications: + + +* (**DA5**) Rather than using torchgpipe-style virtual dependencies in the distributed autograd graph, we want each stage to manually handle running `forward` and `backward` stages (the latter by explicitly calling `torch.autograd.backward()`). This will give easier and finer-grained control over the execution schedule of the pipeline +* (**DA9**) We want to abstract the runtime for each actor (similar to DeepSpeed) so that the actor can run forward/backward phases in a prescribed order. This will allow us to program schedules like `1f1b` or `interleaved 1f1b` . Further, we can define an instruction format similar to DeepSpeed to make these schedules arbitrarily programmable by the end user. +* (**DA4**) In the current implementation, the loss is implemented via DistributedLoss, which is issued in the training loop over the whole mini-batch. This will not be compatible with arbitrary pipeline schedules, which need to compute the loss and launch backward micro-batches asynchronously. So, the loss will need to be implemented as a callback that the pipeline can schedule on its own, rather than something called in the training loop + + +Approach 3 with modifications then looks like: + + +| |DA1 |DA2 |DA3 |DA4 |DA5 |DA6 |DA7 |DA8 |DA9 |Notes | +|--- |--- |--- |--- |--- |--- |--- |--- |--- |--- |--- | +|Approach 3 with Modifications |single |py |seq |local |manual |dist |sync |X |? | | + + +**Future Extensibility** + +This approach leaves some design improvements on the table. In particular, we will have to keep an eye on these extensibility points: + + +* (**DA3**) Expanding this API to work on programs that are not already `nn.Sequential`. We can explore using a predicated scheme such as in Zach’s [notebook](https://colab.research.google.com/drive/1lGg2NqlvDwVmvBqejzni2yTmYE9rxfdr?usp=sharing) or a `torch.fx`-based scheme as in Wanchao’s proposal. Further, we could create a "V2" API down the road using technology from Approaches 4-6 or if new research finds a better programming model. +* (**DA4**) In the current implementation, optimization is implemented via a `DistributedOptimizer` step that is called explicitly from the training loop. This is likely okay for synchronous PP. We should keep an eye on (a) the overhead from having a single coordinator issue RPCs to do the optimization step and (b) how this design might change with asynchronous pipeline parallelism, considering the pipeline itself will need to trigger all these events (async PP will likely be a different API, but it would be nice if it didn’t need to be) + +## References + + +1. https://arxiv.org/abs/1806.03377 +2. https://arxiv.org/abs/1811.06965 +3. https://en.wikipedia.org/wiki/Predication_(computer_architecture) diff --git a/RFC-0022-Model-Partitioning-in-Pipeline-Parallelism.md b/RFC-0022-Model-Partitioning-in-Pipeline-Parallelism.md new file mode 100644 index 0000000..cc51a2e --- /dev/null +++ b/RFC-0022-Model-Partitioning-in-Pipeline-Parallelism.md @@ -0,0 +1,255 @@ +# [RFC] Model Partitioning in Pipeline Parallelism + +Credit for writing this RFC goes to @WanchaoL. + +## Background + +We introduced Pipeline Parallelism API in PyTorch 1.8. The current Pipeline parallelism provides an intuitive API to use. But currently we are passing the obligation of model partitioning to the user, where user need to explicitly convert different parts to different devices before passing to `sync.Pipe`. This might be OK for simple models, but as models scale in complexity, it becomes very hard to estimate and partition the model properly by hand. The user might need to do a lot of experiments before actually running the model in Pipeline Parallelism for efficiency purposes. Therefore, it’s crucial that we provide a way to automatically partition the model so that we could reduce the burden for users, this might also give some insights when we do other model parallelism work. This doc did some explorations on the existing works from different libraries, and proposing a way for PyTorch’s pipeline parallelism to do automatic partitioning. + + +## Relevant works + +How is the industry dealing with model partitioning for pipeline parallelism so far? There’re plenty exploration works going on from different pipeline parallelism implementations, and some of them produces pretty nice performance improvements. + +***FairScale/torchgpipe:*** +FairScale and torchgpipe use similar/same approaches when partitioning the model, their partition APIs: + +* `Pipe(..., balance: List[int],...)` use `balance` to partition the model during construction +* two utils: `balance_by_time` and `balance_by_sizes` to generate the `balance` list automatically based on execution time or memory sizes (parameter + optim states) + +Partition Algorithm: use “Block Partition of Sequences” algorithm to find a balanced partition base on the costs. +Note they only allow to partition on the top-level modules of the Sequential list based on the costs (memory/time) + +***SageMaker*** +Major logic is under `model_partition.py` and the basic ideas: + +* `ModuleParitioner` to drive the model partition progress base on trace results +* `ModulePartitioner.partition` create a tree of `ModuleNodes`, populate the costs (execution time + memory cost) and run the partition algorithm, return a `Dict[Module, device_id]` + +Partition Algorithm: SageMaker provided an automated model parallelism by their own proposed model partition algorithm based on BFS + DP-based method for device allocations, plus reallocation using d'Hondt method + +***DeepSpeed:*** +Provide several mechanisms for partition the model across GPUs with `partition_method`: + +* `partition_method="parameters"` (default): balances the number of trainable parameters on each pipeline stage. +* `partition_method="type:[regex]"`: balances layers whose class names match [regex]. +* `partition_method="uniform"`: balances the number of layers per stage. + +Partition Algorithm: For partition_method = “parameters”, DeepSpeed use the mechanism of counting the layer’s parameters size, and do a binary search (find the smallest weight of the heaviest partition) to find a balanced partition. + +***PipeDream*** +More on asynchronous training pipeline planning, like parameter server approaches, the algorithm itself also uses Dynamic Programming, but the states is highly correlated with the communication and synchronization costs, which is not a good choice for us (unless we consider async training in the future). + +***DAPPLE*** + +DAPPLE combines DDP with pipeline parallelism to form a bigger search space and use a more efficient schedule if possible. Their partition approach is a DP-based algorithm, it first tries to find the “pivotal” stage, then optimize the overall latency, the “overall latency” optimization here tries to reduce the bubbles in the pipeline as small as possible. + + + +## Proposing: Automated Model Partitioning in PyTorch Pipeline Parallelism + +Given the existing work and their limitations, we introduce a set of APIs that’s flexible enough for future improvements and intuitive to use. We expose a single API `create_balanced_partition`, to take the model and do the partition under the hood. + +``` +def create_balanced_partition(model: nn.Module, + devices: List[RemoteDevice], + *sample_input) -> List[RemoteModule]: + #device = torch.device("cuda")): + # if model itself if larger than a single cuda device memory, + # should we allow the user to profile the model on cpu? + # Decision: probably no, as the characteristics are different + + # Step 1 + # fx profiler to collect statistics based on a single run of + # the sample_input + interp = ProfilingInterpreter(model) + interp.run(sample_input) + + # Step2 + # Partition algorithm to calculate the least + # cost partition assignment + num_partitions = len(devices) + partitions = _partition_by_cost(interp, num_partitions) + + # Step3 + # Use the results from Step 2 to assign each part of submodule to the device + # returned a fx-based model with partition results applied (assigned to devices already) + return partitioned_model +``` + + +Note that this will be the only API we expose to the user, it returns a partitioned model which already been transferred to the corresponding devices. This allow us to iterate on the underlying implementation and experiment more efficient partition algorithms in the future. + +## FX compatibility + +Pipeline parallelism auto partition capability needs more advanced knowledge in order to accurately divide a model into a balanced partition set. In order to achieve auto partitioning with the most balanced approach, we need to get the model execution graph and try to split the model base on the graph. Using torch.fx can give us a helpful graph representation in order for us to do more advanced partition with extensive analysis with tracing and profiling. + +The model that passed in should be fx compatible in order to generate the partitions with the most accurate estimation. What models does not have fx compatibility currently: + +1. models that contain input dependent control flow (no way to fix as far as I know +2. models that have tensor constructors (this could potentially be fixed) +3. models that contains builtin-op with non tensor inputs https://github.com/pytorch/pytorch/issues/53937 (this could potentially be fixed) + +For 2 and 3, I think it could be fixed, for 1, a fundamental limitation is there. +Should we make this assumption? + +We should try to symbolically trace the model first, and if that fails with tracing exceptions (i.e. detecting data dependent control flows, we could detect that during tracing), we should fall back to a legacy partition algorithm with python available only (i.e. without extensive graph analysis, we can simply try to balance the model with the top-level submodules like nn.Sequentials) + + +@jamesr66a: We could also try doing unintrusive tracing, similar to SageMaker or the upcoming define-by-run quantization API. `torch.fx` symbolic tracing supposes you want to extract a freestanding representation of the program, so is rather strict in the error conditions for which it fails. OTOH, we can do a "best effort" program recording using `__torch_function__`/`__torch_dispatch__` and record information about the structure of the program, but not necessarily require that it fully represent the whole program. + +## Profiling using torch.fx + +We will use torch.fx to do a profiling run to collect the statistics (i.e. execution_time, parameter_size, execution_order, etc.). The profiler still needs user to pass in a sample input, and we do a full run on the model base on this sample input. We can adjust the statistics to collect base on the partition algorithm we choose (i.e. module-level statistics or op-level statistics) + +``` +from torch.fx import Interpreter + +class ProfilingInterpreter(Interpreter): + + def __init__(self, mod : torch.nn.Module): + gm = torch.fx.symbolic_trace(mod) + super().__init__(gm) + + # We are going to store away three things here: + # + # 1. execution time of each module/node + # 2. parameter_sizes of each module/node + # 3. activation_sizes of each module + self.execution_times_sec : Dict[str, float] = [] + self.parameter_sizes : Dict[str, float] = {} + self.activation_sizes: Dict[str, float] = {} + + def run_node(self, n : torch.fx.Node) -> Any: + # Record the time we started running the op + t_start = time.time() + # Run the op + return_val = super().run_node(n) + # Record the time we finished running the op + t_end = time.time() + self.execution_times_sec[module_qual_name] = t_end - t_start + + # also update the parameter_size and activation_size for + # module nodes + + return return_val +``` + + + + + +## Partition Algorithm + +There’re several partition algorithms as we mentioned in the relevant works section, each of them have its own pros/cons. For example, torchgpipe/fairscale only allows partitioning the model on the top level, SageMaker only partition on the module level instead of operation level. So there might be some unbalanced cases for those approaches. DAPPLE explores a bigger search space by controlling the stages/schedules, but the algorithm complexity is too high. + +Our approach here starts with an approach similar to fairscale, which only do top-level partition, but it could be improved further by exploring more partition approaches, i.e. with operation level partitioning by torch.fx, we could partition the model in a more even manner, which further improves the efficiency. Since we only expose a simple API and return the partitioned model with stages assigned to the devices, the internal implementation can be improved as we explore more ideas. The plan for partition algorithm exploration: + +* Use fairscale/gpipe's [Block Partition of Sequences](https://arxiv.org/pdf/1308.2452.pdf) as our first step +* See if we can apply [Fiduccia-Mattheyses Heuristic](https://en.wikipedia.org/wiki/Fiduccia%E2%80%93Mattheyses_algorithm) to partition the graph and how the performance compare with the default one +* See if we can do operation-level tracing, and partition the model into several balanced `torch.fx.GraphModule` instead of using the original module architecture + +### Approach that enables more possibility of balanced partitioning + +* fx trace the graph, collect the statistics of each node +* partition the graph using a bottom up approach: starting from treating each node as a partition, and merge them if it lowers the cost, use BFS to scan the merge order +* The result will be an `fx.GraphModule` for each partition, the original module architecture is not preserved + * qual names of param might change + +Partition: +Dict[fx.graphmodule, remotedevice] + +in order to apply the algorithm, we need to assign each top-level submodule with a proper cost, we calculate the cost based on memory and execution costs. + + +### Cost calculation + +The cost of a submodule is defined as the following: + +``` +cost(m) = time_weight * execution_time(m) + + memory_weight * memory_cost(m) + +memory_cost(m) = parameter cost * optimizer_multiplier + + activation_cost +``` + +where `memory_weight + time_weight = 1`. How to weigh between time and memory cost? Undecided, a heuristic number `time_weight=0.6` + +* @mrshenli: regarding balancing time between memory, I kind of feel we should prioritize time over memory. Because people usually use Pipeline parallel to accelerate training, and the training speed the ultimate goal. Balanced execution time has a directly impact on total pipeline makespan (suppose one phase is slower than others, then all other devices will need to wait for that phase). If the above assumption is correct, it looks like we should first try to balance time, and only try to balance memory when a shard of the optimally time-balanced model cannot fit in some devices. + +Time > memory (memory constraint only, need to count in some buffer) + +How do we decide the `optimizer_multiplier` since different optimizers have different states maintained? + +We can use the param_scale mechanism which specify the scale for each parameters like in fairscale or torchgpipe. One disadvantage is that this `param_scale` is required from the user as a parameter, and user need to know the scale base on our documentation? Still seems a bad UX experience, but if we provide a default one, it’s likely not accurate. + + +``` + ========= ============= ========================================= + Optimizer `param_scale` Internal State + ========= ============= ========================================= + SGD 2--3 (momentum_buffer) + Adam 4--5 exp_avg, exp_avg_sq, (max_exp_avg_sq) + Adadelta 4 square_avg, acc_delta + Adagrad 3 sum + RMSprop 3--5 square_avg, (momentum_buffer), (grad_avg) + ========= ============= ========================================= +``` + + +Since we assigned the cost to each top-level module, we can do the partition, basic partition function will be like the below: + +``` +def _partition_by_cost(interp: torch.fx.Interpreter, + num_partitions): + submodule_costs = interp.get_costs() + # Dict[module_qual_name, cost], execution order + + # use Block Partition of Sequences to minimize the variance + # of the submodules' cost list + partitions = + block_partition.solve(submodule_costs, num_partitions) + return partitions +``` + + +Q: How do we collect the communication costs? + +* activation_size and gradient size, could potentially contribute to part of the communication cost +* What about the device info? how do we know each pairs are NVLink or PCIe? + + + +### Further exploration of partition algorithm + +Since we hide the underlying partition algorithm, we can do further partition algorithm explorations, some potential explorations we can do: + +* exploration like in DAPPLE when using non-simple schedule (i.e. we reduce the bubble by do the backward as early as possible). +* Using torch.fx, we can trace the module architecture in module level, and possibly split the entire traced graph into several submodules (this might not fully resemble the existing module architecture). +* DP-based algorithm like SageMaker, but could be on operation-level with more granularity + + + +## Potential issues + +**What if a model with a big memory requirement that couldn’t fit into a single GPU?** +A: we construct the model on CPU and do fx tracing on CPU, after the partition we move it to the corresponding device + +**What if a “module” (submodule) with a big memory requirement that couldn’t fit into a single GPU?** +Can we use ShardedTensor to shard the module parameters? or we could do operation level tracing, partition this module into 2 submodules, then assign to different devices. + +**What if constructing the model on CPU itself is hard?** +User created the model first on meta device, we use the model that’s not materializing to do symbolic tracing, but we couldn’t do profiling since it’s not materialized yet, after we do symbolic tracing, we should use a simple partition algorithm (i.e. only count the param sizes) to do the partition, then materialize afterwards. + + + +## Reference + + +1. SageMaker https://arxiv.org/abs/2111.05972 +2. Deepspeed https://www.deepspeed.ai/tutorials/pipeline/#load-balancing-pipeline-modules +3. Torchgpipe [https://github.com/kakaobrain/torchgpipe](https://github.com/kakaobrain/torchgpipe/tree/master/torchgpipe) +4. Fairscale https://github.com/facebookresearch/fairscale +5. FX_IR partitioner for accelerators +https://fb.workplace.com/notes/wang-xu/fx-ir-partitioner/165617378604863 From 7e5ad3310e2a5b97eb60bd170c16921c177ace66 Mon Sep 17 00:00:00 2001 From: James Reed Date: Wed, 1 Dec 2021 13:13:37 -0800 Subject: [PATCH 2/9] Update RFC-0020-Distributed-Pipeline-Parallelism.md Co-authored-by: Stas Bekman --- RFC-0020-Distributed-Pipeline-Parallelism.md | 1 + 1 file changed, 1 insertion(+) diff --git a/RFC-0020-Distributed-Pipeline-Parallelism.md b/RFC-0020-Distributed-Pipeline-Parallelism.md index 9f3dade..f548fe9 100644 --- a/RFC-0020-Distributed-Pipeline-Parallelism.md +++ b/RFC-0020-Distributed-Pipeline-Parallelism.md @@ -9,6 +9,7 @@ Goal: *Provide a flexible, composable, and reconfigurable interface for pipeline * Pipeline parallelism (PP) is used as a lower-communication-volume technique for model parallelism. It is especially applicable when data must be transmitted across comparatively slower interconnects. * Several research-oriented frameworks exist that implement PP (e.g. fairscale, megatron, deepspeed), but we would like to provide a production-quality implementation and support contract for PP. * The existing PP implementation in PyTorch (`torch.distributed.pipeline.sync`) only supports intra-host pipeline parallelism across GPUs and does not support techniques like 1F1B scheduling. We can deliver inter-host pipelining and other features. +* `nn.Sequential` requirement creates a huge barrier to users who have models that don't lend easily to be converted to `nn.Sequential`. In particular with models that have dynamic control flow for some large segments (e.g. conditional encoder). * Ultimately, we want to use this body of work as a driving force for research in delivering both performance AND usability of parallelism paradigms. We invite developers and researchers to participate in the design and development of this project. ## Stage 1: Requirements Gathering (2021Q4) From ecfd1f82ffff9a7dc786a7a3015a35f877bd4fe2 Mon Sep 17 00:00:00 2001 From: James Reed Date: Wed, 1 Dec 2021 13:24:00 -0800 Subject: [PATCH 3/9] Update RFC-0020-Distributed-Pipeline-Parallelism.md Co-authored-by: Stas Bekman --- RFC-0020-Distributed-Pipeline-Parallelism.md | 1 + 1 file changed, 1 insertion(+) diff --git a/RFC-0020-Distributed-Pipeline-Parallelism.md b/RFC-0020-Distributed-Pipeline-Parallelism.md index f548fe9..4452609 100644 --- a/RFC-0020-Distributed-Pipeline-Parallelism.md +++ b/RFC-0020-Distributed-Pipeline-Parallelism.md @@ -144,6 +144,7 @@ Existing approaches/proposals that support this (in no particular order): * Sagemaker [model parallelism](https://drive.google.com/file/d/1N2eo5Yr_QOw0EtKv-MYBDWKvyRYxKv2o/view) * @zdevito's [sequential-free splitting approach](https://colab.research.google.com/drive/1lGg2NqlvDwVmvBqejzni2yTmYE9rxfdr?usp=sharing) * [OneFlow](https://github.com/Oneflow-Inc/oneflow) +* [Varuna](gttps://github.com/microsoft/varuna) / [paper](https://arxiv.org/abs/2111.04007) * [[RFC] Model Partitioning in Pipeline Parallelism](https://github.com/pytorch/rfcs/blob/master/RFC-0022-Model-Partitioning-in-Pipeline-Parallelism.md) Proposed approach short-list: From 29126ef77c42c12f43df926c827c7e6d3a3a46ff Mon Sep 17 00:00:00 2001 From: James Reed Date: Mon, 6 Dec 2021 11:57:49 -0800 Subject: [PATCH 4/9] Add Varuna details to the RFC --- RFC-0020-Distributed-Pipeline-Parallelism.md | 6 +- ...Distributed-Pipeline-Parallel-Technical.md | 63 ++++++++++++++----- 2 files changed, 51 insertions(+), 18 deletions(-) diff --git a/RFC-0020-Distributed-Pipeline-Parallelism.md b/RFC-0020-Distributed-Pipeline-Parallelism.md index 4452609..1f6a2bc 100644 --- a/RFC-0020-Distributed-Pipeline-Parallelism.md +++ b/RFC-0020-Distributed-Pipeline-Parallelism.md @@ -63,6 +63,7 @@ Existing approaches that support this (in no particular order): * Sagemaker [model parallelism](https://arxiv.org/abs/2111.05972) * [DeepSpeed pipeline parallelism](https://www.deepspeed.ai/tutorials/pipeline/) * [OneFlow](https://github.com/Oneflow-Inc/oneflow) +* [Varuna](https://github.com/microsoft/varuna)[13] Proposed approach short-list: (all approaches can be seen in [[RFC] Distributed Pipeline Parallel Training Technical Approach](https://github.com/pytorch/rfcs/blob/master/RFC-0021-Distributed-Pipeline-Parallel-Technical.md) @@ -155,7 +156,7 @@ Proposed approach short-list: These approaches can be composed on top of an existing API that takes an `nn.Sequential`. We may consider in the future to develop a "v2" API that is centered more natively around non-`nn.Sequential` models using technologies from Sagemaker, OneFlow, or other research developments. -### P1: Support arbitrary programmable schedules (e.g. fill-drain, 1F1B, interleaved 1F1B) +### P1: Support arbitrary programmable schedules (e.g. fill-drain, 1F1B, interleaved 1F1B) Existing approaches that support this (in no particular order): @@ -208,4 +209,5 @@ Going into the future, we would like to develop theory and implementation for a 10. Performance analysis of a pipelined backpropagation parallel algorithm https://ieeexplore.ieee.org/document/286892 11. PipeMare: Asynchronous Pipeline Parallel DNN Training https://arxiv.org/abs/1910.05124 12. Scaling Language Model Training to a Trillion Parameters Using Megatron - https://developer.nvidia.com/blog/scaling-language-model-training-to-a-trillion-parameters-using-megatron/ \ No newline at end of file + https://developer.nvidia.com/blog/scaling-language-model-training-to-a-trillion-parameters-using-megatron/ +13. Varuna: Scalable, Low-cost Training of Massive Deep Learning Models https://arxiv.org/abs/2111.04007 diff --git a/RFC-0021-Distributed-Pipeline-Parallel-Technical.md b/RFC-0021-Distributed-Pipeline-Parallel-Technical.md index c4523a0..1a9b221 100644 --- a/RFC-0021-Distributed-Pipeline-Parallel-Technical.md +++ b/RFC-0021-Distributed-Pipeline-Parallel-Technical.md @@ -72,7 +72,7 @@ def train_model(model, criterion, optimizer, scheduler, num_epochs=25): # load best model weights model.load_state_dict(best_model_wts) return model - + model_ft = models.resnet18(pretrained=True) num_ftrs = model_ft.fc.in_features # Here the size of each output sample is set to 2. @@ -189,7 +189,7 @@ This scheme of zeroing grads on the first micro-batch can be trivially implement Predication of forward-propagation can be done as in Zach’s [proposal](https://colab.research.google.com/drive/1lGg2NqlvDwVmvBqejzni2yTmYE9rxfdr?usp=sharing). -Note that we may extend the predicated training loop scheme to include schedules such as 1F1B or interleaved 1F1B, discussed later. +Note that we may extend the predicated training loop scheme to include schedules such as 1F1B or interleaved 1F1B, discussed later. **Loss Calculation** @@ -237,7 +237,7 @@ Note that forward and backward stages do not necessarily always run in a given r ## Approach 2 - RPC with RemoteModule and torchgpipe-style single coordinator (@pritamdamania87 RFC) -One proposal for an API for pipeline parallelism is the `pipeline_sync` API proposed in @pritamdamania87’s [RFC](https://github.com/pytorch/pytorch/issues/44827) (Certain lines are called out with end-of-line comments containing an alphabetical identifier): +One proposal for an API for pipeline parallelism is the `pipeline_sync` API proposed in @pritamdamania87’s [RFC](https://github.com/pytorch/pytorch/issues/44827) (Certain lines are called out with end-of-line comments containing an alphabetical identifier): ``` # Note: This API is very similar to torchgpipe and inspired from it. @@ -245,14 +245,14 @@ One proposal for an API for pipeline parallelism is the `pipeline_sync` API prop torch.distributed.pipeline_sync( pipeline: nn.Sequential, - checkpoint: CheckpointEnum = EXCEPT_LAST, # ALWAYS, EXCEPT_LAST, NEVER + checkpoint: CheckpointEnum = EXCEPT_LAST, # ALWAYS, EXCEPT_LAST, NEVER chunks: int = 1) -> PipelineSyncModel Arguments: -pipeline: nn.Sequential (https://pytorch.org/docs/stable/generated/torch.nn.Sequential.html) where each nn.Module in the list is placed on the - appropriate device(CPU or GPU)/machine by the user. Note that - nn.Sequential could also consist of RemoteModule (https://github.com/pytorch/pytorch/blob/master/torch/distributed/nn/api/remote_module.py#L147) for cross host +pipeline: nn.Sequential (https://pytorch.org/docs/stable/generated/torch.nn.Sequential.html) where each nn.Module in the list is placed on the + appropriate device(CPU or GPU)/machine by the user. Note that + nn.Sequential could also consist of RemoteModule (https://github.com/pytorch/pytorch/blob/master/torch/distributed/nn/api/remote_module.py#L147) for cross host pipelining. checkpoint: Enum that determines which checkpointing mode to use. chunks: Number of micro-batches. @@ -265,17 +265,17 @@ Forward Method PipelineSyncModel.forward(self, *input, **kwargs) -> RRef Returns: - RRef to output corresponding to the result of the minibatch. - Since we plan to support cross host pipelining, the RRef could be on a + RRef to output corresponding to the result of the minibatch. + Since we plan to support cross host pipelining, the RRef could be on a device on a different host. - + Example: # This is an example of a pipeline across two machines each using one GPU. # On worker 0 layer1 = nn.Linear(10, 5).cuda(0) # Need to enhance RemoteModule to include device for this purposes. -layer2 = RemoteModule("worker1", device="cuda:0", nn.Linear, 5, 1) +layer2 = RemoteModule("worker1", device="cuda:0", nn.Linear, 5, 1) pipeline = nn.Sequential(layer1, layer2) model = torch.distributed.pipeline_sync(pipeline, chunks = 4) # A @@ -300,7 +300,7 @@ for epoch in range(epochs): target_rref = rpc.remote("worker1", identity_fn, target) # C output_rref = model(minibatch) # D loss_rref = rpc.remote("worker1", compute_loss, output_rref, target_rref) # E - # Can enhance RRef to ensure this calls "dist_autograd.backward" on the last + # Can enhance RRef to ensure this calls "dist_autograd.backward" on the last # node in the pipeline. loss_rref.backward(context_id) # F dist_optim****.step() # G @@ -375,7 +375,7 @@ class Loss(nn.Module): for epoch in range(epochs): loss_module = DistributedLoss(Loss, criterion, ntokens) - + for minibatch, targets in dataloader: with dist_autograd.context() as context_id: minibatch = minibatch.transpose(0, 1) @@ -389,7 +389,7 @@ This proposal has the training loop running on a single machine and makes copiou **A - Model Pipelining** -As opposed to the torchgpipe-based Approach 2, this approach instantiates actors (specifically [PartitionHandler](https://github.com/facebookresearch/fairscale/blob/6f3931a4d3464056231f1bbb92a67fbd14f30d96/fairscale/experimental/nn/distributed_pipeline/partition_handler.py#L140) instances) that execute the pipeline in an event-driven manner. PartitionHandler instances own a [DistributedPipelineRecord](https://github.com/facebookresearch/fairscale/blob/6f3931a4d3464056231f1bbb92a67fbd14f30d96/fairscale/experimental/nn/distributed_pipeline/partition_handler.py#L27) instance, which has a “feed” method to be called via RPC to add a data item for processing. +As opposed to the torchgpipe-based Approach 2, this approach instantiates actors (specifically [PartitionHandler](https://github.com/facebookresearch/fairscale/blob/6f3931a4d3464056231f1bbb92a67fbd14f30d96/fairscale/experimental/nn/distributed_pipeline/partition_handler.py#L140) instances) that execute the pipeline in an event-driven manner. PartitionHandler instances own a [DistributedPipelineRecord](https://github.com/facebookresearch/fairscale/blob/6f3931a4d3464056231f1bbb92a67fbd14f30d96/fairscale/experimental/nn/distributed_pipeline/partition_handler.py#L27) instance, which has a “feed” method to be called via RPC to add a data item for processing. **B - Distributed Optimizer** @@ -405,7 +405,7 @@ Loss calculation happens similarly to in Approach 2, the single driver calls int **E - Backprop** -Backpropagation through the pipeline is similarly implemented via distributed autograd, as in Approach 2. Note that the same fork/join barrier approach is used to [serialize](https://github.com/facebookresearch/fairscale/blob/6f3931a4d3464056231f1bbb92a67fbd14f30d96/fairscale/experimental/nn/distributed_pipeline/partition_handler.py#L103) execution of micro-batches on the backward pass. +Backpropagation through the pipeline is similarly implemented via distributed autograd, as in Approach 2. Note that the same fork/join barrier approach is used to [serialize](https://github.com/facebookresearch/fairscale/blob/6f3931a4d3464056231f1bbb92a67fbd14f30d96/fairscale/experimental/nn/distributed_pipeline/partition_handler.py#L103) execution of micro-batches on the backward pass. **NOTE**: I don’t believe that forward and backward jobs are serialized; they may run concurrently. Is this true? @@ -482,7 +482,7 @@ The implementations for each of these instructions can be referenced from this [ * (**D2**) (hypothetically) supports arbitrary schedules through the [PipeSchedule](https://github.com/microsoft/DeepSpeed/blob/488105ebd200bbd1f6d7cbe863412e41d9ab4221/deepspeed/runtime/pipe/schedule.py#L6) abstraction. However, there don’t seem to be any schedules implemented beyond the default * (**D3, D4?**) Usable in 3d parallelism, as detailed by the [blog post](https://www.deepspeed.ai/tutorials/pipeline/). -* (**D6**) Since data is pulled from the data loader rather than being pushed by a synchronous call in the training loop, this approach could *hypothetically* support async PP. +* (**D6**) Since data is pulled from the data loader rather than being pushed by a synchronous call in the training loop, this approach could *hypothetically* support async PP. * (**D7**) The approach seems to account for many different types of parallelism. **Con** @@ -545,6 +545,37 @@ An example of using OneFlow for pipeline parallelism can be seen in this [tutori * (split pro/con) (**D2**) Not clear if `1f1b` or other schedules are implemented/implementable? * (**D8**) Not clear what the training loop abstraction looks like. The optimizer is installed via an `nn.Graph` API. Loss calculation is created in the `nn.Graph.build()` method. +## Approach 7: Varuna + +Varuna (https://arxiv.org/abs/2111.04007) proposes a system for large-scale training that focuses on training on commodity hardware. In particular, Varuna focuses on pipeline parallelism (to work on commodity interconnects), tuning the pipeline to optimally trade-off pipeline bubble size vs. allreduce bandwidth, dynamic scheduling of pipeline stages to account for network latency and jitter, and elastic scheduling. + +The workflow of Varuna looks like the following: +* The user manually annotates their model with [CutPoints](https://github.com/microsoft/varuna/blob/ea13ce6a8934bfe829b662a56af4dc29044fcb34/docs/cutpoint.rst). These are points in the program where the system _may_ place a pipeline stage +* The user wraps their model in the [Varuna](https://github.com/microsoft/varuna/blob/ea13ce6a8934bfe829b662a56af4dc29044fcb34/docs/varuna.rst#id9) class and configures pipeline parallelism via this interface. This includes parameters like chunk size, installing the optimizer, and informing the system of the rank of the current running instance. + * Internally, the system is going to do a roudn of profiling to determine the optimal pipeline balance and choose a subset of the cut-points together to represent the code for different pipeline stages. +* User calls the `Varuna.step()` method to run the program in pipeline parallel execution (forward, loss, backward) + * Varuna uses an opportunistic scheduling policy (described in section 3.2 of the paper), which will run ahead and run `forward()` micro-batches if `backward()` micro-batches are not available +* User applies the optimizer's `step()` method to finally update the parameters give accumulated gradients. +* Outside of the Python script, the user uses the `run_varuna` launcher script to orchestrate (elastic) job scheduling + +### Pros and Cons of the Approach + +**Pro** + +* (**D2**) Varuna implements scheduling, particularly their [opportunistic scheduling](https://github.com/microsoft/varuna/blob/79aaf45995a2b06bf5186e825b7baf81b9145837/varuna/pipeline.py#L280) policy. See in the "con", I'm not super convinced by this scheme, but the system supports scheduling (and can probably be extended or hacked to support more traditional schedules) +* (**D5**) The system nominally supports pipeline partitioning without wrapping into a `Sequential`, but see "con" for commentary about the soundness of this approach. + + +**Con** + +* (**D5**) Varuna's approach to partitioning models does not seem sound. + * It assumes that the order of invocation of modules matches the order of the modules as [enumerated](https://github.com/microsoft/varuna/blob/ea13ce6a8934bfe829b662a56af4dc29044fcb34/varuna/partitioned_model.py#L396) in `nn.Module`. This is not necessarily the case, as there can be any arbitrary set of `use-def` relationships between modules and call-sites in a PyTorch module + * It [nulls out](https://github.com/microsoft/varuna/blob/ea13ce6a8934bfe829b662a56af4dc29044fcb34/varuna/partitioned_model.py#L434) modules that are not relevant to the current rank's computation by replacing them with [PassThroughModule](https://github.com/microsoft/varuna/blob/ea13ce6a8934bfe829b662a56af4dc29044fcb34/varuna/partitioned_model.py#L635), which is implemented to simply return `None` from `forward()`. This works when the model is composed purely of module calls and the data dependencies between them, but if at any point there is a non-trivial construct in the code (e.g. any operations on the output of the Module), this will break. +* (**D2**) I'm not convinced by [opportunistic scheduling](https://github.com/microsoft/varuna/blob/79aaf45995a2b06bf5186e825b7baf81b9145837/varuna/pipeline.py#L280). The concept is sound, and reminds me of Tomasulo algorithm-style out-of-order execution for dealing with stochastic latencies (e.g. from memory latencies in a processor), but an extra constraint in pipeline parallel execution for deep learning is: value lifetimes. Activations from forward jobs must be saved for use in the backward job, meaning the memory high-watermark of the pipeline stage is increased for every forward() job that is admitted without a corresponding backward() job to release those values. The literature addresses this with static execution schedules such as the [1F1B strategy](https://arxiv.org/abs/1806.03377) and OneFlow solves this by implementing [Registers and back pressure](https://oneflow2020.medium.com/runtime-of-oneflow-based-on-boxing-and-actor-model-part-3-f2b786dc14a0) in the pipeline. As far as I can tell, Varuna will run ahead indiscriminately until OOM +* (**D8**) The extent to which the training loop must be modified ([BERT](https://github.com/microsoft/varuna/blob/ea13ce6a8934bfe829b662a56af4dc29044fcb34/examples/BERT/bert.patch), [megatron](https://github.com/microsoft/varuna/blob/ea13ce6a8934bfe829b662a56af4dc29044fcb34/examples/Megatron-LM/megatron.patch)) to work with Varuna is pretty extreme +* (**D3/D4/D7**) The system does not seem to compose with tensor parallelism, at least that is not described in the paper or the README. +* (**D6**) Does not support async + ## Final Analysis ### General Design Axes From 2a8701bfdb1f846bb54c2d86578a099cecd0f2bb Mon Sep 17 00:00:00 2001 From: James Reed Date: Mon, 6 Dec 2021 12:58:20 -0800 Subject: [PATCH 5/9] update release timeline to 1.12 since we missed the cutoff for 1.11 --- RFC-0020-Distributed-Pipeline-Parallelism.md | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/RFC-0020-Distributed-Pipeline-Parallelism.md b/RFC-0020-Distributed-Pipeline-Parallelism.md index 1f6a2bc..6566e0f 100644 --- a/RFC-0020-Distributed-Pipeline-Parallelism.md +++ b/RFC-0020-Distributed-Pipeline-Parallelism.md @@ -53,7 +53,7 @@ HF transformers [wants to](https://github.com/huggingface/transformers/issues/13 An analysis of prior implementations and a proposed technical approach for pipeline parallelism can be seen in [[RFC] Distributed Pipeline Parallel Training Technical Approach](https://github.com/pytorch/rfcs/blob/master/RFC-0021-Distributed-Pipeline-Parallel-Technical.md). In this document, we further split execution into stages and correlate those to the PyTorch external release schedule. -## Stage 2: Ship prototype synchronous multi-node pipeline parallelism (torchgpipe-style) (1.11 Prototype Release) +## Stage 2: Ship prototype synchronous multi-node pipeline parallelism (torchgpipe-style) (1.12 Prototype Release) ### P(-1): Implement cross-host support for pipeline parallelism @@ -83,13 +83,11 @@ Proposed approach short-list: 1. Hopefully should just work out of the box with the RPC API, but need to keep it in mind. -### P0: 1.11 Prototype Release and out-of-tree demo on HF Transformers +### P0: 1.12 Prototype Release and out-of-tree demo on HF Transformers -* Release API as prototype in the 1.11 release to facilitate gathering feedback +* Release API as prototype in the 1.12 release to facilitate gathering feedback * Validation: Out-of-tree demo on HF transformers repo - hack it together to get it to work and pull out work items to improve the API to remove places where code edits are needed -* 1.11 Release Dates - * Feature submission: 11/30 EOD - * Branch cut 1/31/2022 +* 1.12 release date sometime in April 2022 @@ -136,7 +134,7 @@ We can interpolate the missing spaces: I believe the way to go in the future may be to consolidate on actors for both local and distributed. This may represent lower complexity than the torchgpipe-style execution (at least when I think about it) and can avoid issues with a single driver process being a bottleneck (as evidenced by the fact that `torchgpipe` already uses threads for speed). -## Stage 4: Generalize pipeline parallelism interface to allow for more coverage of different techniques in the literature (e.g. async, scheduling, auto-partitioning, composition with tensor parallelism) (2022, OSS releases 1.11-1.15) +## Stage 4: Generalize pipeline parallelism interface to allow for more coverage of different techniques in the literature (e.g. async, scheduling, auto-partitioning, composition with tensor parallelism) (2022, OSS releases 1.12-1.15) ### P1: Pipeline parallelism without `nn.Sequential` rewrite From 9e6522af2ad2ea52ae50c067851055cd0acbf1b3 Mon Sep 17 00:00:00 2001 From: James Reed Date: Mon, 6 Dec 2021 13:33:56 -0800 Subject: [PATCH 6/9] Add more desiderata and add note about rethinking the design --- ...Distributed-Pipeline-Parallel-Technical.md | 23 +++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/RFC-0021-Distributed-Pipeline-Parallel-Technical.md b/RFC-0021-Distributed-Pipeline-Parallel-Technical.md index 1a9b221..c987d54 100644 --- a/RFC-0021-Distributed-Pipeline-Parallel-Technical.md +++ b/RFC-0021-Distributed-Pipeline-Parallel-Technical.md @@ -140,6 +140,8 @@ We would like to consider the following (lifted from [[RFC] Ceci n'est pas pipel * Weight stashing/Weight prediction * **D7 (P2)** Research: Fits into a unified, highly configurable programming model encompassing all parallelism schemes * **D8 (P1)** The user can use an idiomatic Python-based training loop with no or minimal modifications from their “normal” training loop +* **D9 (P?)** Compatibility with non-monadic topology (e.g. skip connections) +* **D10 (P?)** Support for dynamic programs, e.g. those with control flow that changes across runs (e.g. conditional encoder) ## Approach 1: SPMD with Predicated Training Loop and Message Passing @@ -225,6 +227,8 @@ Note that forward and backward stages do not necessarily always run in a given r * (**D3, D4, D7**) Composes with SPMD execution model; can likely readily interoperate well with SPMD Tensor Parallelism. Is potentially the basis for converging parallelism on the SPMD model (other alternative is converging parallelism on the Actor model)(needs more research - likely can be the basis for (a) paper(s)) * (**D5**) Does not require the user to manually partition their model into an `nn.Sequential` * (**D6**) There is no concept of a synchronous “call” or “dispatch” into the training loop; this scheme can likely readily support asynchronous pipeline parallelism with continuous data loading and training. +* (**D9**) Supports arbitrary value connectivity +* (**D10**) Hypothetically supports dynamic programs **Con** @@ -341,6 +345,7 @@ As mentioned in stage (B), the DistributedOptimizer will make async RPC calls to **Pro** * (**D8**) Training loop looks pretty close to the original PyTorch code. Training loop runs on a single machine, so user does not need to reason about correctness of their training loop under SPMD, as in Approach 1. +* (**D10**) Supports dynamic programs within each stage **Con** @@ -350,6 +355,8 @@ As mentioned in stage (B), the DistributedOptimizer will make async RPC calls to * When expanding the torchgpipe single-coordinator scheme to cross-host execution, network latency, jitter, and instability may contribute to front-end-boundedness issues * (**D2**) In its current conception, it’s not clear to me if schedules are representable in this scheme due to reliance on distributed autograd execution. GPipe’s fill-drain schedule is implemented via careful data dependency programming in the autograd graph. It’s not clear to me if things like 1F1B, interleaved 1F1B, Varuna scheduling, or other research schedules are (easily) implementable in this scheme. * (**D6**) This approach has a strong concept of “synchronous dispatch” into the pipeline. The single coordinator calls into the pipeline with a mini-batch, the execution is scheduled internally to that, and the pipeline returns an RRef to the result value. It’s not clear how continuous, asynchronous training would fit into this without retrofitting an event-driven handler for the training loop to feed another mini-batch in. +* (**D9**) Will need special handling for things like skip connections +* (**D10**) Does not support dynamic programs across the stages ## Approach 3 - RPC with RemoteModule and message passing (fairscale experimental) @@ -421,6 +428,7 @@ The optimizer step uses DistributedOptimizer in the same was as Approach 2. Dist * (**D8**) Training loop looks pretty close to the original PyTorch code. Training loop runs on a single machine, so user does not need to reason about correctness of their training loop under SPMD, as in Approach 1. * OTOH, some of the set-up in the [example](https://github.com/wayi1/pipeline_experiments/blob/7e0fe6f884edfab026379cce1b5ae03b5c2489cd/BERT/main.py#L200) is pretty hairy and could probably be improved * Compared to Approach 2, much less risk of being “front-end bound”. The burden of issuing commands is distributed throughout the ranks, i.e. a rank receives micro-batches and dispatches completed micro-batches to its successor. +* (**D10**) Supports dynamic programs within each stage **Con** @@ -429,6 +437,8 @@ The optimizer step uses DistributedOptimizer in the same was as Approach 2. Dist * The system may still be “front-end bound” for loss calculation, distributed autograd, and DistributedOptimizer step. * (**D2**) In its current conception, it’s not clear to me if schedules are representable in this scheme due to reliance on distributed autograd execution. GPipe’s fill-drain schedule is implemented via careful data dependency programming in the autograd graph. It’s not clear to me if things like 1F1B, interleaved 1F1B, Varuna scheduling, or other research schedules are (easily) implementable in this scheme. * (**D6**) This approach has a strong concept of “synchronous dispatch” into the pipeline. The single coordinator calls into the pipeline with a mini-batch, the execution is scheduled internally to that, and the pipeline returns an RRef to the result value. It’s not clear how continuous, asynchronous training would fit into this without retrofitting an event-driven handler for the training loop to feed another mini-batch in. +* (**D9**) Will need special handling for things like skip connections +* (**D10**) Does not support dynamic programs across the stages ## Approach 4 - MPMD with a custom interpreter/instruction format and message passing (DeepSpeed) @@ -484,12 +494,16 @@ The implementations for each of these instructions can be referenced from this [ * (**D3, D4?**) Usable in 3d parallelism, as detailed by the [blog post](https://www.deepspeed.ai/tutorials/pipeline/). * (**D6**) Since data is pulled from the data loader rather than being pushed by a synchronous call in the training loop, this approach could *hypothetically* support async PP. * (**D7**) The approach seems to account for many different types of parallelism. +* (**D10**) Supports dynamic programs within each stage + **Con** * (**D1**) Does not support passing arbitrary data between stages, only supports Tensor and tuple of Tensor (because of `nn.Sequential` front-end) * (**D5**) Only supports models fit into an `nn.Sequential` * (**D8**) This approach takes control away from the user. The training loop is now implemented by the DeepSpeed engine abstraction, rather than being free-form Python code. +* (**D9**) I don't think it supports skip connections (it supports tied layers for accumulating gradients but not skip connections, I think) +* (**D10**) Does not support dynamic programs across the stages ## Approach 5: RPC with remote modules and generalized Module-server architecture (SageMaker) @@ -516,6 +530,8 @@ PP_RANK 0 drives the process by scheduling instances of the training loop functi * (**D3/D4**) Composes with other parallelism schemes * (**D5**) Does not require model to be an `nn.Sequential` * (**D8**) User’s original training loop is preserved with only slight modifications (`@smp.step` annotation and other things) +* (**D9**) Supports arbitrary topology +* (**D10**) Supports fully dynamic programs **Con** @@ -539,6 +555,8 @@ An example of using OneFlow for pipeline parallelism can be seen in this [tutori * (**D5**) `nn.Sequential` not needed, but potentially an `nn.Graph` instance may be needed in some cases * (**D6**) Async is probably supportable but not clear. From their presentation, the actor/register model with backpressure can implement on-demand data loading, but I’m not 100% sure what that API looks like * (**D7**) Unified programming model that already exists +* (**D9**) Supports arbitrary topology +* (**D10**) Supports dynamic programs (via just-in-time program capture) **Con** @@ -564,6 +582,7 @@ The workflow of Varuna looks like the following: * (**D2**) Varuna implements scheduling, particularly their [opportunistic scheduling](https://github.com/microsoft/varuna/blob/79aaf45995a2b06bf5186e825b7baf81b9145837/varuna/pipeline.py#L280) policy. See in the "con", I'm not super convinced by this scheme, but the system supports scheduling (and can probably be extended or hacked to support more traditional schedules) * (**D5**) The system nominally supports pipeline partitioning without wrapping into a `Sequential`, but see "con" for commentary about the soundness of this approach. +* (**D10**) Supports dynamic control within stages **Con** @@ -575,6 +594,8 @@ The workflow of Varuna looks like the following: * (**D8**) The extent to which the training loop must be modified ([BERT](https://github.com/microsoft/varuna/blob/ea13ce6a8934bfe829b662a56af4dc29044fcb34/examples/BERT/bert.patch), [megatron](https://github.com/microsoft/varuna/blob/ea13ce6a8934bfe829b662a56af4dc29044fcb34/examples/Megatron-LM/megatron.patch)) to work with Varuna is pretty extreme * (**D3/D4/D7**) The system does not seem to compose with tensor parallelism, at least that is not described in the paper or the README. * (**D6**) Does not support async +* (**D9**) Don't believe it supports arbitrary topology for activations (only for weights) +* (**D10**) Does not support dynamic control between stages ## Final Analysis @@ -619,6 +640,8 @@ We can start analyzing the approaches by these design axes ### Decision - Approach 3 with Modifications +**NOTE**: Based on feedback on this RFC, we are currently reconsidering this design and will update this section accordingly as we hash out the details. + After deliberation, we want to build the API with the least complexity, at least initially. We will modify/build the API in FairScale experimental with a few modifications: From afea9ec4b32f56bd44233a0af65cf523286d8cb1 Mon Sep 17 00:00:00 2001 From: James Reed Date: Mon, 6 Dec 2021 15:04:30 -0800 Subject: [PATCH 7/9] format tables --- ...Distributed-Pipeline-Parallel-Technical.md | 30 +++++++++---------- 1 file changed, 14 insertions(+), 16 deletions(-) diff --git a/RFC-0021-Distributed-Pipeline-Parallel-Technical.md b/RFC-0021-Distributed-Pipeline-Parallel-Technical.md index c987d54..4a7aa23 100644 --- a/RFC-0021-Distributed-Pipeline-Parallel-Technical.md +++ b/RFC-0021-Distributed-Pipeline-Parallel-Technical.md @@ -130,11 +130,11 @@ We would like to consider the following (lifted from [[RFC] Ceci n'est pas pipel * **D0 (P-1)** Cross-host pipeline parallel support (table stakes) * **D1 (P0)** Support for passing arbitrary data types between stages (table stakes) -* **D2 (P0)** Support for pipeline parallel schedules (e.g. GPipe fill-drain, 1F1B, or interleaved 1F1B) +* **D2 (P1)** Support for pipeline parallel schedules (e.g. GPipe fill-drain, 1F1B, or interleaved 1F1B) * **P1** Support for arbitrary programmable schedules * **D3 (P0)** Composability with other parallelism schemes (Tensor Parallelism, Data Parallelism) in a 3D parallelism scheme * **D4 (P1)** Composability with other parallelism schemes in an *arbitrary scheme* -* **D5 (P1)** Off-the-shelf support for pipelining without manual conversion of a model to `nn.Sequential` +* **D5 (P0)** Off-the-shelf support for pipelining without manual conversion of a model to `nn.Sequential` * **D6 (P2)** Support for asynchronous pipeline parallelism * Continuous data loading * Weight stashing/Weight prediction @@ -509,7 +509,7 @@ The implementations for each of these instructions can be referenced from this [ The [SageMaker model parallelism](https://arxiv.org/abs/2111.05972) design uses a single Python-native training loop with a “module-server” architecture. The system divides the model based on the Module hierarchy and assigns each module onto a specific pipeline parallel rank (PP_RANK). During execution, when there is a dispatch to a `Module` that resides on another PP_RANK, a remote request-response RPC is made to run the appropriate forward/backward pass for the Module on the remote PP_RANK. -[Image: req_resp.png] +![Module-server request-response execution in SageMaker pipeline parallelism](https://i.imgur.com/y9MZJ3b.png) PP_RANK 0 drives the process by scheduling instances of the training loop function (a UDF annotated by `@smp.step`): two for each micro-batch (one for forward, one for backward). PP_RANK 0 can implement different “schedules” by dispatching these `(micro-batch, phase)` tuples in a given order. The orders that they present are: @@ -517,8 +517,6 @@ PP_RANK 0 drives the process by scheduling instances of the training loop functi * Simple pipeline (aka GPipe fill-drain). This is implemented by having PP_RANK 0 dispatch the `phase=forward` tuples for each micro-batch in sequence. Then, dispatching the `phase=backward` tuples for each micro-batch in sequence. * “interleaved” pipeline (**NB**: this is not the same as the *interleaved 1F1B* from Narayanan, 2021). PP_RANK 0 will schedule `phase=forward` jobs and opportunistically schedule `phase=backward` jobs *as soon as the forward pass for that micro-batch is done*. -![Module-server request-response execution in SageMaker pipeline parallelism](https://i.imgur.com/y9MZJ3b.png) - **NOTE**: The schedules here do not necessarily run stage in a given order on each stage. Network latency and other affects may change the order of when micro-batches are executed. ### Pros and Cons of the Approach @@ -629,14 +627,14 @@ We can start analyzing the approaches by these design axes * Approach 5: RPC with remote modules and generalized Module-server architecture (SageMaker) * Approach 6: SPMD with Program capture/JIT compilation and message passing (OneFlow) -| |DA1 |DA2 |DA3 |DA4 |DA5 |DA6 |DA7 |DA8 |DA9 |Notes | -|--- |--- |--- |--- |--- |--- |--- |--- |--- |--- |--- | -|Approach 1 |multi |py |FF |local |manual |local |async |? |X | | -|Approach 2 |single |py |seq |dist |dist |dist |sync |X |X | | -|Approach 3 |single |py |seq |dist |dist |dist |sync |X |X | | -|Approach 4 |multi |interp |seq |local? |manual |local |async* |X |? | | -|Approach 5 |single* |py |FF |local? |manual |local? |sync? |X |X |Schedules? | -|Approach 6 |multi |interp |FF |local? |manual? (graph?) |local? |async? |X |X | | +| |DA1 |DA2 |DA3|DA4 |DA5 |DA6 |DA7 |DA8|DA9|Notes | +|--- |--- |--- |---|--- |--- |--- |--- |---|---|--- | +|Approach 1|multi |py |FF |local |manual |local |async |? |X | | +|Approach 2|single |py |seq|dist |dist |dist |sync |X |X | | +|Approach 3|single |py |seq|dist |dist |dist |sync |X |X | | +|Approach 4|multi |interp|seq|local?|manual |local |async*|X |? | | +|Approach 5|single*|py |FF |local?|manual |local?|sync? |X |X |Schedules?| +|Approach 6|multi |interp|FF |local?|manual? (graph?)|local?|async?|X |X | | ### Decision - Approach 3 with Modifications @@ -653,9 +651,9 @@ After deliberation, we want to build the API with the least complexity, at least Approach 3 with modifications then looks like: -| |DA1 |DA2 |DA3 |DA4 |DA5 |DA6 |DA7 |DA8 |DA9 |Notes | -|--- |--- |--- |--- |--- |--- |--- |--- |--- |--- |--- | -|Approach 3 with Modifications |single |py |seq |local |manual |dist |sync |X |? | | +| |DA1 |DA2|DA3|DA4 |DA5 |DA6 |DA7 |DA8|DA9|Notes| +|--- |--- |---|---|--- |--- |--- |--- |---|---|--- | +|Approach 3 with Modifications|single|py |seq|local|manual|dist|sync|X |? | | **Future Extensibility** From 75ffa08eeb493f53da3c79f86faff41a6356079b Mon Sep 17 00:00:00 2001 From: James Reed Date: Tue, 7 Dec 2021 14:07:38 -0800 Subject: [PATCH 8/9] Update RFC-0021-Distributed-Pipeline-Parallel-Technical.md Co-authored-by: Stas Bekman --- RFC-0021-Distributed-Pipeline-Parallel-Technical.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/RFC-0021-Distributed-Pipeline-Parallel-Technical.md b/RFC-0021-Distributed-Pipeline-Parallel-Technical.md index 4a7aa23..9c540c5 100644 --- a/RFC-0021-Distributed-Pipeline-Parallel-Technical.md +++ b/RFC-0021-Distributed-Pipeline-Parallel-Technical.md @@ -568,7 +568,7 @@ Varuna (https://arxiv.org/abs/2111.04007) proposes a system for large-scale trai The workflow of Varuna looks like the following: * The user manually annotates their model with [CutPoints](https://github.com/microsoft/varuna/blob/ea13ce6a8934bfe829b662a56af4dc29044fcb34/docs/cutpoint.rst). These are points in the program where the system _may_ place a pipeline stage * The user wraps their model in the [Varuna](https://github.com/microsoft/varuna/blob/ea13ce6a8934bfe829b662a56af4dc29044fcb34/docs/varuna.rst#id9) class and configures pipeline parallelism via this interface. This includes parameters like chunk size, installing the optimizer, and informing the system of the rank of the current running instance. - * Internally, the system is going to do a roudn of profiling to determine the optimal pipeline balance and choose a subset of the cut-points together to represent the code for different pipeline stages. + * Internally, the system is going to do a round of profiling to determine the optimal pipeline balance and choose a subset of the cut-points together to represent the code for different pipeline stages. * User calls the `Varuna.step()` method to run the program in pipeline parallel execution (forward, loss, backward) * Varuna uses an opportunistic scheduling policy (described in section 3.2 of the paper), which will run ahead and run `forward()` micro-batches if `backward()` micro-batches are not available * User applies the optimizer's `step()` method to finally update the parameters give accumulated gradients. From 589ff2fc5294d96f2f86d49126f3286be0b620e2 Mon Sep 17 00:00:00 2001 From: James Reed Date: Tue, 7 Dec 2021 14:14:35 -0800 Subject: [PATCH 9/9] more fixups for nn.sequential --- RFC-0020-Distributed-Pipeline-Parallelism.md | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/RFC-0020-Distributed-Pipeline-Parallelism.md b/RFC-0020-Distributed-Pipeline-Parallelism.md index 6566e0f..fcd095a 100644 --- a/RFC-0020-Distributed-Pipeline-Parallelism.md +++ b/RFC-0020-Distributed-Pipeline-Parallelism.md @@ -39,7 +39,7 @@ HF transformers [wants to](https://github.com/huggingface/transformers/issues/13 * Frontend limitations: * **P0**: Cannot pass arbitrary data types between pipeline stages * **P0**: Unclear composability in 3d parallelism scheme (data, pipeline, model parallel) - * **P1**: User needs to rewrite their model as an `nn.Sequential` instance + * **P0**: User needs to rewrite their model as an `nn.Sequential` instance * Backend Limitations: * **P(-1)**: No cross-host support for PT pipeline parallelism API * **P0**: No support off-the-shelf schedules (1F1B or interleaving) @@ -136,7 +136,7 @@ I believe the way to go in the future may be to consolidate on actors for both l ## Stage 4: Generalize pipeline parallelism interface to allow for more coverage of different techniques in the literature (e.g. async, scheduling, auto-partitioning, composition with tensor parallelism) (2022, OSS releases 1.12-1.15) -### P1: Pipeline parallelism without `nn.Sequential` rewrite +### P0: Pipeline parallelism without `nn.Sequential` rewrite Existing approaches/proposals that support this (in no particular order): @@ -152,7 +152,6 @@ Proposed approach short-list: 2. @zdevito's [sequential-free splitting approach](https://colab.research.google.com/drive/1lGg2NqlvDwVmvBqejzni2yTmYE9rxfdr?usp=sharing) 3. Construct a pipeline parallelism API that uses a different approach, such as the one used in SageMaker model parallelism. This introduces trade-offs elsewhere, such as in support for schedules/the requirement for an optimization pass to be applied to implement "true" pipeline parallelism. -These approaches can be composed on top of an existing API that takes an `nn.Sequential`. We may consider in the future to develop a "v2" API that is centered more natively around non-`nn.Sequential` models using technologies from Sagemaker, OneFlow, or other research developments. ### P1: Support arbitrary programmable schedules (e.g. fill-drain, 1F1B, interleaved 1F1B)