Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FSDP] add no_broadcast_optim_state option #560

Merged
merged 10 commits into from
Apr 4, 2021

Conversation

sshleifer
Copy link
Contributor

@sshleifer sshleifer commented Mar 31, 2021

Problem

  • some parameters are large and their optimizer state cannot be gathered, for example, the experts in mixture of experts.
  • We need a way to allow callers to control which keys in the optimizer state dict (OSD) are collected to avoid OOM.
    • rebuild_full_params already implements this, but the approach is recursive.
  • Callers can now write custom logic to save OSD state of certain submodules without gathering.

Approach

  • We set m.no_broadcast_optim_state=True if a child instance m has world_size=1 and it's own process_group.
  • We remove no_broadcast_optim_state parameters from each rank's OSD before collection.
  • We add back rank 0's expert parameters to OSD after collection.
  • This allows us to produce an OSD with the correct number of entries and correct shapes for expert parameters, just like state_dict
  • We let callers overwrite the state for the expert on rank i with the correct contents.

Determination in _set_is_root

m.no_broadcast_optim_state = m.no_broadcast_optim_state or (
(m.world_size == 1) and (m.world_size < self.world_size) and (m.process_group != self.process_group))

the tests show that we can recognize MOE experts like this and avoid OOM without any change to callers.

TODO

  • Optimization: Do not copy osd['param_groups'] from each worker.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 31, 2021
@sshleifer sshleifer changed the title [FSDP] no_broadcast_optim_state option to avoid extra communication [FSDP/optim state] pass no_broadcast_optim_state=True option to avoid OOM Mar 31, 2021
@sshleifer sshleifer changed the title [FSDP/optim state] pass no_broadcast_optim_state=True option to avoid OOM [FSDP] add no_broadcast_optim_state option Mar 31, 2021
@sshleifer sshleifer requested review from myleott and min-xu-ai March 31, 2021 17:44
@sshleifer sshleifer added the FSDP FullyShardedDataParallel (zero-3) label Mar 31, 2021
@min-xu-ai
Copy link
Contributor

@msbaines Mandeep a failure here on a pipeline test is sporadic?

Copy link
Contributor

@min-xu-ai min-xu-ai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is nice. I think I mostly followed but can't say I checked the logic of every line. Hopefully the test is sufficient to cover the new case this enables and ensure no regression. Some comments that are not blocking. The only really issue is the commented out assertion, seems scary.

@@ -21,22 +21,22 @@ def flatten_optim_state_dict(sd: Dict) -> Dict:
non_tensor_state = {}

# Populate `new_state["state"]`. (Assuming sd is sorted)
for expanded_pid, buffers in sd["state"].items():
consolidated_pid = param_id_map[expanded_pid]
for global_id, buffers in sd["state"].items():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the rename of the variables!

new_sd = {"state": new_state, "param_groups": sd["param_groups"]}
new_state[local_id][buffer_name] = torch.cat(tensors)
new_state[local_id].update(non_tensor_state)
new_sd = {"state": new_state, "param_groups": copy.deepcopy(sd["param_groups"])}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a comment on the deep copy? Also, last time I checked that it seems deepcopy doesn't really copy tensors. you may want to double check (with some asserts and testing) to verify the deep copy here is doing the right thing?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Luckily there are no tensors in param_groups, but that's very useful to know!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 to adding a comment about param groups not having tensors (thus deepcopy being okay)

fairscale/nn/data_parallel/fully_sharded_data_parallel.py Outdated Show resolved Hide resolved
fairscale/nn/data_parallel/fully_sharded_data_parallel.py Outdated Show resolved Hide resolved
fairscale/nn/data_parallel/fully_sharded_data_parallel.py Outdated Show resolved Hide resolved
Copy link
Contributor

@myleott myleott left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, although I will need to spend some time stepping through the code at some point to fully understand the implementation 😄

Should we also test that no_broadcast_optim_state is inferred properly for the MixtureOfExperts model?

fairscale/nn/data_parallel/fsdp_optim_utils.py Outdated Show resolved Hide resolved
fairscale/nn/data_parallel/fully_sharded_data_parallel.py Outdated Show resolved Hide resolved
fairscale/nn/data_parallel/fully_sharded_data_parallel.py Outdated Show resolved Hide resolved
name_func=rename_test,
)
def test_consolidate_optimizer(self, optim_fn, transformer):
config = {"mixed_precision": True, "flatten_parameters": True}
config["compute_dtype"] = torch.float32
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm guessing this and the autocast change where required to get tests to pass?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, I followed the logic of __test_identical_outputs

tests/nn/data_parallel/test_fsdp_optimizer_utils.py Outdated Show resolved Hide resolved
new_sd = {"state": new_state, "param_groups": sd["param_groups"]}
new_state[local_id][buffer_name] = torch.cat(tensors)
new_state[local_id].update(non_tensor_state)
new_sd = {"state": new_state, "param_groups": copy.deepcopy(sd["param_groups"])}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 to adding a comment about param groups not having tensors (thus deepcopy being okay)

@msbaines
Copy link
Contributor

msbaines commented Apr 3, 2021

@msbaines Mandeep a failure here on a pipeline test is sporadic?

Yep. Here is a fix: #575

@sshleifer sshleifer merged commit 1fcbd62 into master Apr 4, 2021
@sshleifer sshleifer deleted the fsdp-optim-uncollectable branch April 4, 2021 19:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. FSDP FullyShardedDataParallel (zero-3)
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants