Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Layerwise GaLore optimizer cannot convergence with warmup scheduler #30371

Closed
2 of 4 tasks
hiyouga opened this issue Apr 21, 2024 · 0 comments · Fixed by #30372
Closed
2 of 4 tasks

Layerwise GaLore optimizer cannot convergence with warmup scheduler #30371

hiyouga opened this issue Apr 21, 2024 · 0 comments · Fixed by #30372

Comments

@hiyouga
Copy link
Contributor

hiyouga commented Apr 21, 2024

System Info

  • transformers version: 4.40.0
  • Platform: Linux-5.15.0-100-generic-x86_64-with-glibc2.35
  • Python version: 3.11.8
  • Huggingface_hub version: 0.21.4
  • Safetensors version: 0.4.2
  • Accelerate version: 0.28.0
  • PyTorch version (GPU?): 2.2.0+cu121 (True)

Who can help?

@muellerzr and @younesbelkada

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Experiments are conducted using the following scripts:

import matplotlib.pyplot as plt
import torch
from datasets import load_dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    DataCollatorForLanguageModeling,
    Trainer,
    TrainingArguments,
)

model_id = "Qwen/Qwen1.5-0.5B-Chat"
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16)

dataset = load_dataset("mlabonne/guanaco-llama2-1k", split="train[:200]").map(
    lambda x: tokenizer(x["text"], max_length=1024, truncation=True, add_special_tokens=False),
    batched=True,
)
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

training_args = TrainingArguments(
    output_dir="test_galore",
    learning_rate=1e-4,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=1,
    num_train_epochs=6.0,
    logging_steps=10,
    warmup_steps=10, # warmup_steps=0
    optim="galore_adamw_layerwise",
    optim_args="scale=2.0,update_proj_gap=400",
    optim_target_modules="all-linear",
    gradient_checkpointing=True,
    report_to="none",
)
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
    tokenizer=tokenizer,
    data_collator=data_collator,
)
trainer.train()

steps, losses = [], []
for i in range(len(trainer.state.log_history)):
    if "loss" in trainer.state.log_history[i]:
        steps.append(trainer.state.log_history[i]["step"])
        losses.append(trainer.state.log_history[i]["loss"])

plt.figure()
plt.plot(steps, losses)
plt.savefig("loss.png", format="png", dpi=100)

With warmup_steps=0, the loss converges normally:

loss_a

With warmup_steps=10, the loss cannot converge:

loss_b

Expected behavior

The implementation of layerwise GaLore optimizer depends on the hooks, the trainer first attaches a hook to each parameter for the optimizers:

def optimizer_hook(param):
if param.grad is not None:
optimizer_dict[param].step()
optimizer_dict[param].zero_grad()
for param in model.parameters():
if param.requires_grad:
param.register_post_accumulate_grad_hook(optimizer_hook)

and then attaches another hook to each parameter for the schedulers:

def scheduler_hook(param):
# Since the optimizer hook has been already attached we only need to
# attach the scheduler hook
if param.grad is not None:
scheduler_dict[param].step()
for param in optimizer_dict.keys():
if param.requires_grad:
param.register_post_accumulate_grad_hook(scheduler_hook)

However, since the scheduler hook was attached after the optimizer hook, the parameter gradient had been already cleared by the optimizer hook. Therefore the if condition param.grad is not None could NOT hold inside the scheduler hook, and the scheduler_dict[param].step() was actually not even called during the training.

Apart from the above experiment, we can alternatively validate it by adding a print statement in the scheduler hook:

def scheduler_hook(param):
    # Since the optimizer hook has been already attached we only need to
    # attach the scheduler hook
    if param.grad is not None:
        print("scheduler step")
        scheduler_dict[param].step()

As we can see, this string will never be printed during training. Consequently, the scheduler unexpectedly has no effect to the training if the layerwise optimizer is used.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants