-
Notifications
You must be signed in to change notification settings - Fork 282
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FSDP] add no_broadcast_optim_state option #560
Conversation
@msbaines Mandeep a failure here on a pipeline test is sporadic? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is nice. I think I mostly followed but can't say I checked the logic of every line. Hopefully the test is sufficient to cover the new case this enables and ensure no regression. Some comments that are not blocking. The only really issue is the commented out assertion, seems scary.
@@ -21,22 +21,22 @@ def flatten_optim_state_dict(sd: Dict) -> Dict: | |||
non_tensor_state = {} | |||
|
|||
# Populate `new_state["state"]`. (Assuming sd is sorted) | |||
for expanded_pid, buffers in sd["state"].items(): | |||
consolidated_pid = param_id_map[expanded_pid] | |||
for global_id, buffers in sd["state"].items(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like the rename of the variables!
new_sd = {"state": new_state, "param_groups": sd["param_groups"]} | ||
new_state[local_id][buffer_name] = torch.cat(tensors) | ||
new_state[local_id].update(non_tensor_state) | ||
new_sd = {"state": new_state, "param_groups": copy.deepcopy(sd["param_groups"])} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add a comment on the deep copy? Also, last time I checked that it seems deepcopy doesn't really copy tensors. you may want to double check (with some asserts and testing) to verify the deep copy here is doing the right thing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Luckily there are no tensors in param_groups, but that's very useful to know!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 to adding a comment about param groups not having tensors (thus deepcopy being okay)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, although I will need to spend some time stepping through the code at some point to fully understand the implementation 😄
Should we also test that no_broadcast_optim_state
is inferred properly for the MixtureOfExperts model?
name_func=rename_test, | ||
) | ||
def test_consolidate_optimizer(self, optim_fn, transformer): | ||
config = {"mixed_precision": True, "flatten_parameters": True} | ||
config["compute_dtype"] = torch.float32 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm guessing this and the autocast change where required to get tests to pass?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, I followed the logic of __test_identical_outputs
new_sd = {"state": new_state, "param_groups": sd["param_groups"]} | ||
new_state[local_id][buffer_name] = torch.cat(tensors) | ||
new_state[local_id].update(non_tensor_state) | ||
new_sd = {"state": new_state, "param_groups": copy.deepcopy(sd["param_groups"])} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 to adding a comment about param groups not having tensors (thus deepcopy being okay)
Problem
OSD
) are collected to avoid OOM.rebuild_full_params
already implements this, but the approach is recursive.OSD
state of certain submodules without gathering.Approach
m.no_broadcast_optim_state=True
if a child instancem
has world_size=1 and it's ownprocess_group
.no_broadcast_optim_state
parameters from each rank's OSD before collection.OSD
with the correct number of entries and correct shapes for expert parameters, just likestate_dict
i
with the correct contents.Determination in
_set_is_root
the tests show that we can recognize MOE experts like this and avoid OOM without any change to callers.
TODO
osd['param_groups']
from each worker.