Skip to content

Commit

Permalink
fix #5305
Browse files Browse the repository at this point in the history
  • Loading branch information
hiyouga committed Aug 29, 2024
1 parent 0f5a0f6 commit 364b757
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion src/llamafactory/train/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def fix_valuehead_checkpoint(
path_to_checkpoint = os.path.join(output_dir, WEIGHTS_NAME)
state_dict: Dict[str, torch.Tensor] = torch.load(path_to_checkpoint, map_location="cpu")

os.remove(path_to_checkpoint)
decoder_state_dict = {}
v_head_state_dict = {}
for name, param in state_dict.items():
Expand All @@ -91,7 +92,6 @@ def fix_valuehead_checkpoint(
else:
torch.save(v_head_state_dict, os.path.join(output_dir, V_HEAD_WEIGHTS_NAME))

os.remove(path_to_checkpoint)
logger.info("Value head model saved at: {}".format(output_dir))


Expand Down

0 comments on commit 364b757

Please sign in to comment.