-
Notifications
You must be signed in to change notification settings - Fork 282
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FSDP] use all_gather for 10X OSD consolidation speedup #595
Conversation
|
||
return non_tensor_state, tensor_state | ||
|
||
def _gather_optim_state(self, sd_state: Dict[int, Dict[str, Any]]) -> Dict[int, Dict[str, List]]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the new _all_gather
logic
@@ -627,15 +627,18 @@ def __init__(self, group, wrapper_config, checkpoint_act=False, delay_before_fre | |||
|
|||
# "expert" params are different on each rank | |||
torch.manual_seed(42 + group.rank()) | |||
d_expert = 16 | |||
expert = nn.Linear(d_expert, 4) | |||
d_expert = 23 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
make sure we unpad expert params correctly.
Planning to merge 10am PT tomorrow, barring further comments. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't have time to fully review this, but this looks good at a high level.
"""Return the last known global optimizer state. The returned state is compatible with Pytorch, in that the | ||
sharded properties are not exposed. Multiple parameter groups are not yet supported. | ||
|
||
This should be called only on the root FSDP instance. | ||
Nested FSDP instances are supported as long as they have the same world_size as the parent or world_size=1. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice
cc @QuentinDuval @prigoyal FYI |
sorry for the delay, I was out for most of the last week. |
new_sd = {"state": new_state, "param_groups": copy.deepcopy(sd["param_groups"])} | ||
for k in sd.keys(): # if there are extra keys, like loss_scale, don't delete them | ||
if k not in {"state", "param_groups", "uncollected_local_ids", "param_id_map"}: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm slightly uneasy about this falling out of sync with line 160. Thoughts on some way to enforce parity between them?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
module level constant + comment + assert
# These return keys are used by fairseq. To change, add @sshleifer as a reviewer.
UNFLAT_RETURN_KEYS = {"state", "param_groups", "uncollected_local_ids", "param_id_map"}
...
assert set(unflat_optim_state_dict.keys()) == UNFLAT_RETURN_KEYS
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
TLDR: Using
all_gather
instead ofbroadcast
for optimizer state consolidation appears to be a speed win without a memory cost.Approach:
all_gather
loss_scale
,param_groups
,num_padded
) we usebroadcast
.recipient_rank=0
, since there are no other callers.Evidence of Win: