Skip to content

Commit

Permalink
Fix lora single device fine tune checkpoint saving & nan loss when us…
Browse files Browse the repository at this point in the history
…e_dora=True (#1909)
  • Loading branch information
mirceamironenco authored Oct 31, 2024
1 parent 2fa6a54 commit eab21f0
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 19 deletions.
4 changes: 3 additions & 1 deletion recipes/lora_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,6 @@ class LoRAFinetuneRecipeSingleDevice(FTRecipeInterface):
"""

def __init__(self, cfg: DictConfig) -> None:

self._device = utils.get_device(device=cfg.device)
# Reduced precision logic
self._dtype = training.get_dtype(cfg.dtype, device=self._device)
Expand Down Expand Up @@ -438,6 +437,9 @@ def _setup_model(
# This is for any adapters that need to be initialized after base weights
# have been loaded (e.g. DoRA).
if self._is_dora:
for m in model.modules():
if hasattr(m, "initialize_dora_magnitude"):
m.initialize_dora_magnitude()
load_dora_magnitudes(model)
if lora_weights_state_dict:
lora_missing, lora_unexpected = model.load_state_dict(
Expand Down
8 changes: 6 additions & 2 deletions tests/recipes/test_lora_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,8 +259,9 @@ def test_training_state_on_resume(
loss_values, expected_loss_values, rtol=1e-5, atol=1e-5
)

@pytest.mark.parametrize("use_dora", [False, True])
@pytest.mark.integration_test
def test_save_and_load_merged_weights(self, tmpdir, monkeypatch):
def test_save_and_load_merged_weights(self, tmpdir, monkeypatch, use_dora):
ckpt = "llama2_tune"
ckpt_path = Path(CKPT_MODEL_PATHS[ckpt])
ckpt_dir = ckpt_path.parent
Expand All @@ -280,7 +281,10 @@ def test_save_and_load_merged_weights(self, tmpdir, monkeypatch):
enable_activation_offloading=False \
""".split()

model_config = MODEL_TEST_CONFIGS["llama2_lora"]
if use_dora:
model_config = MODEL_TEST_CONFIGS["llama2_dora"]
else:
model_config = MODEL_TEST_CONFIGS["llama2_lora"]

cmd = cmd + self._get_test_config_overrides() + model_config
monkeypatch.setattr(sys, "argv", cmd)
Expand Down
10 changes: 10 additions & 0 deletions tests/recipes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ def lora_llama2_test_config(
lora_rank: int = 8,
lora_alpha: float = 16,
quantize_base: bool = False,
use_dora: bool = False,
) -> List[str]:
return [
# Note: we explicitly use _component_ so that we can also call
Expand All @@ -154,6 +155,7 @@ def lora_llama2_test_config(
f"model.lora_alpha={lora_alpha}",
"model.lora_dropout=0.0",
f"model.quantize_base={quantize_base}",
f"model.use_dora={use_dora}",
]


Expand Down Expand Up @@ -207,6 +209,14 @@ def write_hf_ckpt_config(ckpt_dir: str):
lora_rank=8,
lora_alpha=16,
),
"llama2_dora": lora_llama2_test_config(
lora_attn_modules=["q_proj", "k_proj", "v_proj", "output_proj"],
apply_lora_to_mlp=False,
apply_lora_to_output=False,
lora_rank=8,
lora_alpha=16,
use_dora=True,
),
"llama2_qlora": lora_llama2_test_config(
lora_attn_modules=["q_proj", "k_proj", "v_proj", "output_proj"],
apply_lora_to_mlp=True,
Expand Down
35 changes: 20 additions & 15 deletions torchtune/models/convert_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import re

from typing import Any, Dict
from typing import Any, Dict, Optional

import torch

Expand Down Expand Up @@ -252,23 +252,28 @@ def tune_to_peft_adapter_weights(
num_heads: int = 32,
num_kv_heads: int = 32,
dim: int = 4096,
head_dim: int = None,
head_dim: Optional[int] = None,
):
converted_state_dict = {}
full_mapping = {}
# Rather than recreate a separate mapping for LoRA adapter weights, we just
# re-use the _FROM_HF mapping for base model weights. We iterate over it twice:
# once to add mappings for LoRA A matrices and once to add mappings for LoRA B matrices.
for k, v in _TO_PEFT_KEYS.items():
full_mapping.update(
{
vv.replace(".weight", f".{k}.weight"): kk.replace(
".weight", f".{v}.weight"
)
for kk, vv in _FROM_HF.items()
if vv is not None
}
)
# Rather than recreate a separate mapping for LoRA adapter weights, we re-use the
# _FROM_HF mapping for base model weights. The mapping is adapted to account for:
# LoRA A matrices, LoRA B matrices and the dora magnitude parameter.
for peft_key, peft_val in _TO_PEFT_KEYS.items():
for hf_key, hf_val in _FROM_HF.items():
if hf_val is None:
continue

if peft_key == "magnitude":
# e.g. attn.q_proj.magnitude -> attn.q_proj.lora_magnitude_vector
adapter_key = hf_val.replace(".weight", f".{peft_key}")
adapter_val = hf_key.replace(".weight", f".{peft_val}")
else:
# e.g. attn.q_proj.lora_a.weight -> attn.q_proj.lora_A.weight
adapter_key = hf_val.replace(".weight", f".{peft_key}.weight")
adapter_val = hf_key.replace(".weight", f".{peft_val}.weight")

full_mapping.update({adapter_key: adapter_val})

if head_dim is None:
head_dim = dim // num_heads
Expand Down
3 changes: 2 additions & 1 deletion torchtune/modules/peft/dora.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def initialize_parameters(self):
_lora_a_init_params(self.lora_a)
_lora_b_init_params(self.lora_b)

@torch.no_grad()
def initialize_dora_magnitude(self):
"""
DoRA initializes the magnitude vector such that its outputs are initially
Expand All @@ -87,7 +88,7 @@ def initialize_dora_magnitude(self):
base_weight = self.weight.to(self.lora_a.weight.dtype)
lora_weight = self.lora_b.weight @ self.lora_a.weight
weight_norm = self._get_weight_norm(base_weight, lora_weight)
self.magnitude = nn.Parameter(weight_norm, requires_grad=True)
self.magnitude.copy_(weight_norm)

def _create_weight_and_bias(self):
"""
Expand Down

0 comments on commit eab21f0

Please sign in to comment.