-
Notifications
You must be signed in to change notification settings - Fork 282
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[feat] FSDP: add auto_wrap_bn #531
Conversation
- add an utility function to handle wrapping of BN
return not isinstance(module, tuple(default_auto_wrap_policy.FORCE_LEAF_MODULES)) # type: ignore | ||
else: | ||
return is_bn and not isinstance(module, tuple(default_auto_wrap_policy.EXCLUDE_WRAP_MODULES)) # type: ignore | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice, I find the next couple of lines (config, single group, then guided auto wrap) very elegant
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Quick question: could dist.new_group(ranks=[my_rank])
impacts performance in any ways ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no, should not really, AFAIK the overhead is minimal
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think there should be any perf impact since FSDP has special casing for world_size == 1. But perhaps @myleott can think of something else?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I understood the question being about the perf cost of having many groups in pytorch distributed basically, vs. few, not specific to FSDP. I might be wrong, but that was the reasoning behind my reply
@@ -54,7 +45,16 @@ def forward(self, x): | |||
# TODO (Min): check DDP equivalency. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
having been burnt a little by that, I would recommend not waiting too long for that part
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
definitely. see my plan below.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks very good to me, I guess it's good that the others have a look though, I'm missing some context probably, but seems very clean and reasonable
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 LGTM 😄
return not isinstance(module, tuple(default_auto_wrap_policy.FORCE_LEAF_MODULES)) # type: ignore | ||
else: | ||
return is_bn and not isinstance(module, tuple(default_auto_wrap_policy.EXCLUDE_WRAP_MODULES)) # type: ignore | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Quick question: could dist.new_group(ranks=[my_rank])
impacts performance in any ways ?
Thank you guys @blefaudeux @myleott @tchaton for quick and high quality reviews. To forecast a bit:
|
Before submitting
What does this PR do?
Fixes # (issue).
PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.
Did you have fun?
Make sure you had fun coding 🙃