Skip to content

Commit

Permalink
fix ds optimizer
Browse files Browse the repository at this point in the history
  • Loading branch information
hiyouga committed Mar 26, 2024
1 parent b29d556 commit 3bcd41b
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 8 deletions.
4 changes: 2 additions & 2 deletions src/llmtuner/train/dpo/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,10 @@ def __init__(
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)

def create_optimizer_and_scheduler(self, num_training_steps: int) -> None:
self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args, num_training_steps)
if self.optimizer is None:
self.create_optimizer()
self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args, num_training_steps)

self.create_optimizer()
self.create_scheduler(num_training_steps=num_training_steps, optimizer=self.optimizer)

def sft_loss(self, chosen_logits: torch.FloatTensor, chosen_labels: torch.LongTensor) -> torch.Tensor:
Expand Down
4 changes: 2 additions & 2 deletions src/llmtuner/train/pt/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ def __init__(self, finetuning_args: "FinetuningArguments", **kwargs) -> None:
self.finetuning_args = finetuning_args

def create_optimizer_and_scheduler(self, num_training_steps: int) -> None:
self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args, num_training_steps)
if self.optimizer is None:
self.create_optimizer()
self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args, num_training_steps)

self.create_optimizer()
self.create_scheduler(num_training_steps=num_training_steps, optimizer=self.optimizer)
4 changes: 2 additions & 2 deletions src/llmtuner/train/rm/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,10 @@ def __init__(self, finetuning_args: "FinetuningArguments", **kwargs) -> None:
self.can_return_loss = True # override property to return eval_loss

def create_optimizer_and_scheduler(self, num_training_steps: int) -> None:
self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args, num_training_steps)
if self.optimizer is None:
self.create_optimizer()
self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args, num_training_steps)

self.create_optimizer()
self.create_scheduler(num_training_steps=num_training_steps, optimizer=self.optimizer)

def compute_loss(
Expand Down
4 changes: 2 additions & 2 deletions src/llmtuner/train/sft/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,10 @@ def __init__(self, finetuning_args: "FinetuningArguments", **kwargs) -> None:
self.finetuning_args = finetuning_args

def create_optimizer_and_scheduler(self, num_training_steps: int) -> None:
self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args, num_training_steps)
if self.optimizer is None:
self.create_optimizer()
self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args, num_training_steps)

self.create_optimizer()
self.create_scheduler(num_training_steps=num_training_steps, optimizer=self.optimizer)

def prediction_step(
Expand Down

2 comments on commit 3bcd41b

@hiyouga
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probability fixes #2983 and #2991

@hiyouga
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

memo: the logic before this fix recreate an optimizer but leave the scheduler as the original one, so the scheduler will not affect the training

Please sign in to comment.