-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add DDP Communication Hooks #2841
Add DDP Communication Hooks #2841
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Overall this looks great to me, a few requests though:
- Can we add an example in
examples/by_feature
specifically showcasing this usage - Can we add some documentation to the official docs on it. Ideally a
Usage Guide
if nothing else (though aConcept Guide
too would be ideal!)
src/accelerate/utils/dataclasses.py
Outdated
- **BATCHED_POWER_SGD** -- DDP communication hook to use batched PowerSGD | ||
""" | ||
|
||
# Subclassing str as well as Enum allows the `DistributedType` to be JSON-serializable out of the box. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Copy/paste leftover :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for fixing this :)
Thank you for the review and suggestions :) In response to your feedback, I have made the following updates:
These additions aim to provide clear guidance on how to utilize DDP communication hooks with the 🤗 Accelerate library, enhancing the usability and performance of distributed training. Please let me know if there are any further adjustments or additions required. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Co-authored-by: Zach Mueller <[email protected]>
Co-authored-by: Zach Mueller <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the clean PR @yhna940 ! LGTM ! Just a nit, could you add some details about comm_wrapper
or comm_state_option
? From the tests, it looks like this is only useful for POWER_SGD.
Thank you for the feedback and the suggestion @SunMarc. I have added more details about However, there are other state-using options such as import torch
import torch.distributed as dist
import torch.distributed.algorithms.model_averaging.averagers as averagers
import torch.nn as nn
from torch.distributed.optim import PostLocalSGDOptimizer
from torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook import (
PostLocalSGDState,
post_localSGD_hook,
)
model = nn.parallel.DistributedDataParallel(
module, device_ids=[rank], output_device=rank
)
# Register a post-localSGD communication hook.
state = PostLocalSGDState(process_group=None, subgroup=None, start_localSGD_iter=100)
model.register_comm_hook(state, post_localSGD_hook)
# Create a post-localSGD optimizer that wraps a local optimizer.
# Note that `warmup_steps` used in `PostLocalSGDOptimizer` must be the same as
# `start_localSGD_iter` used in `PostLocalSGDState`.
local_optim = torch.optim.SGD(params=model.parameters(), lr=0.01)
opt = PostLocalSGDOptimizer(
optim=local_optim,
averager=averagers.PeriodicModelAverager(period=4, warmup_steps=100)
)
# In the first 100 steps, DDP runs global gradient averaging at every step.
# After 100 steps, DDP runs gradient averaging within each subgroup (intra-node by default),
# and post-localSGD optimizer runs global model averaging every 4 steps after applying the local optimizer.
for step in range(0, 200):
opt.zero_grad()
loss = loss_fn(output, labels)
loss.backward()
opt.step() These advanced hooks were not included in this PR because the PyTorch official documentation primarily highlights PowerSGD, FP16, and BF16 hooks. You can find more information about additional hooks in the PyTorch DDP Communication Hooks documentation and the PyTorch GitHub repository. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome @yhna940 ! Thanks for iterating ! Could you also add to the docs/source/_toctree.yml file the link to the new usage guide you wrote ? I think that you can put in the Training section !
Thanks for the quick review @SunMarc , I've done it :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice job! 👏
docs/source/_toctree.yml
Outdated
@@ -58,6 +58,8 @@ | |||
title: Apple M1 GPUs | |||
- local: usage_guides/ipex | |||
title: IPEX training with CPU | |||
- local: usage_guides/ddp_comm_hook |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would consider putting it after or before Fully Sharded Data Parallelism since they're more closely related.
## Converting it to 🤗 Accelerate | ||
|
||
Now, let's see how to use the same hooks with the 🤗 Accelerate library. | ||
|
||
### Using FP16 Compression Hook with 🤗 Accelerate |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You could use the <hfoption>
tags to create separate tabs the user can click on to view the native version vs using it in Accelerate. It'll look like the tabs in this section here. For example, for FP16 compression you can do:
<hfoptions id="fp16">
<hfoption id="PyTorch">
code
</hfoption>
<hfoption id="Accelerate">
code
</hfoption>
</hfoptions>
|
||
For more advanced usage and additional hooks, refer to the [PyTorch DDP Communication Hooks documentation](https://pytorch.org/docs/stable/ddp_comm_hooks.html). | ||
|
||
This demonstrates how to use DDP communication hooks to optimize gradient communication in distributed training with the 🤗 Accelerate library. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This sentence would be more impactful in the beginning where you describe what the tutorial will do.
|
||
### comm_wrapper | ||
|
||
`comm_wrapper` is an option to wrap a communication hook with additional functionality. For example, it can be used to combine FP16 compression with other communication strategies. Currently supported wrappers are `no`, `fp16`, and `bf16`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This would be awesome if you can provide a code example!
Co-authored-by: Steven Liu <[email protected]>
Co-authored-by: Steven Liu <[email protected]>
Co-authored-by: Steven Liu <[email protected]>
Co-authored-by: Steven Liu <[email protected]>
Co-authored-by: Steven Liu <[email protected]>
Co-authored-by: Steven Liu <[email protected]>
Co-authored-by: Steven Liu <[email protected]>
Thank you for your detailed review @stevhliu :) I've changed the guide as you suggested, I used |
Nicely done @yhna940! |
What does this PR do?
This PR adds support for DDP communication hooks to the
accelerate
library. Similar to frameworks like PyTorch Lightning and Detectron, these hooks provide an interface to control how gradients are communicated across workers, overriding the standard allreduce in DistributedDataParallel. This feature enables the use of performance-improving communication hooks when using multiple nodes.Motivation and Context
DDP communication hooks allow users to customize and optimize gradient communication, potentially improving training performance in distributed settings.
Based on the official PyTorch documentation here, I've implemented three default hooks: PowerSGD, FP16, and BF16. These hooks provide performance improvements in distributed training scenarios.
The implementation for registering these hooks was inspired by the PyTorch Lightning implementation, which can be found here.
Fixes # (issue)
N/A
Before submitting
Pull Request section?
to it if that's the case: here
documentation guidelines, and
here are tips on formatting docstrings.