Skip to content

Commit

Permalink
Fix loading universal checkpoint for BF16_Optimizer
Browse files Browse the repository at this point in the history
PR#5104 (Remove optimizer step on initialization) breaks loading universal
checkpoint for BF16_Optimizer.
This is since universal checkpoint attempts to load the optimizer states into
lp._hp_mapping.optim_state dictionary before they are initialized (by step).

As a workaround for loading universal checkpoint, perform step and init hp
params optimizer's states before loading from universal checkpoint files.

Signed-off-by: Moshe Island <[email protected]>
  • Loading branch information
mosheisland committed Feb 29, 2024
1 parent aed599b commit 9bc33ea
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 1 deletion.
2 changes: 2 additions & 0 deletions deepspeed/runtime/bf16_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,8 @@ def _load_legacy_checkpoint(self, state_dict_list, load_optimizer_states=True, l
self._link_all_hp_params()

def _load_universal_checkpoint(self, checkpoint_folder, load_optimizer_states, load_from_fp32_weights):
self.optimizer.step()
self._lazy_init_hp_params_optimizer_state()
self._load_hp_checkpoint_state(checkpoint_folder)

@property
Expand Down
2 changes: 1 addition & 1 deletion deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2785,7 +2785,7 @@ def load_checkpoint(self,
if self.load_universal_checkpoint():
self.optimizer.update_lp_params()
if load_zero_checkpoint:
self.update_optimizer_step(step=client_states['iteration'] + 1)
self.update_optimizer_step(step=client_states['iteration'])

return load_path, client_states

Expand Down

0 comments on commit 9bc33ea

Please sign in to comment.