You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
However, since the scheduler hook was attached after the optimizer hook, the parameter gradient had been already cleared by the optimizer hook. Therefore the if condition param.grad is not Nonecould NOT hold inside the scheduler hook, and the scheduler_dict[param].step() was actually not even called during the training.
Apart from the above experiment, we can alternatively validate it by adding a print statement in the scheduler hook:
defscheduler_hook(param):
# Since the optimizer hook has been already attached we only need to# attach the scheduler hookifparam.gradisnotNone:
print("scheduler step")
scheduler_dict[param].step()
As we can see, this string will never be printed during training. Consequently, the scheduler unexpectedly has no effect to the training if the layerwise optimizer is used.
The text was updated successfully, but these errors were encountered:
System Info
transformers
version: 4.40.0Who can help?
@muellerzr and @younesbelkada
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
Experiments are conducted using the following scripts:
With
warmup_steps=0
, the loss converges normally:With
warmup_steps=10
, the loss cannot converge:Expected behavior
The implementation of layerwise GaLore optimizer depends on the hooks, the trainer first attaches a hook to each parameter for the optimizers:
transformers/src/transformers/trainer.py
Lines 1351 to 1358 in 8c12690
and then attaches another hook to each parameter for the schedulers:
transformers/src/transformers/optimization.py
Lines 445 to 453 in 8c12690
However, since the scheduler hook was attached after the optimizer hook, the parameter gradient had been already cleared by the optimizer hook. Therefore the if condition
param.grad is not None
could NOT hold inside the scheduler hook, and thescheduler_dict[param].step()
was actually not even called during the training.Apart from the above experiment, we can alternatively validate it by adding a print statement in the scheduler hook:
As we can see, this string will never be printed during training. Consequently, the scheduler unexpectedly has no effect to the training if the layerwise optimizer is used.
The text was updated successfully, but these errors were encountered: