Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add BAdam algorithm #3287

Merged
merged 14 commits into from
Apr 16, 2024
35 changes: 35 additions & 0 deletions examples/extras/badam/sft.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#!/bin/bash

CUDA_VISIBLE_DEVICES=0 python ../../../src/train_bash.py \
--stage sft \
--do_train \
--model_name_or_path meta-llama/Llama-2-7b-hf \
--dataset alpaca_gpt4_en,glaive_toolcall \
--dataset_dir ../../../data \
--template default \
--finetuning_type full \
--use_badam \
--badam_switch_mode descending \
--badam_switch_block_every 50 \
--badam_verbose 2 \
--output_dir ../../../saves/LLaMA2-7B/badam/sft \
--overwrite_cache \
--overwrite_output_dir \
--cutoff_len 1024 \
--preprocessing_num_workers 16 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 8 \
--lr_scheduler_type cosine \
--logging_steps 10 \
--warmup_steps 20 \
--save_steps 100 \
--eval_steps 100 \
--evaluation_strategy steps \
--load_best_model_at_end \
--learning_rate 5e-5 \
--num_train_epochs 3.0 \
--max_samples 3000 \
--val_size 0.1 \
--plot_loss \
--pure_bf16
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def get_requires():
"metrics": ["nltk", "jieba", "rouge-chinese"],
"unsloth": ["torch==2.2.0", "unsloth[cu121-ampere-torch220]"],
"galore": ["galore-torch"],
"badam": ["badam"],
"vllm": ["vllm>=0.3.3"],
"bitsandbytes": ["bitsandbytes>=0.39.0"],
"gptq": ["optimum>=1.16.0", "auto-gptq>=0.5.0"],
Expand Down
51 changes: 49 additions & 2 deletions src/llmtuner/hparams/finetuning_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ class GaloreArguments:

use_galore: bool = field(
default=False,
metadata={"help": "Whether or not to use gradient low-Rank projection."},
metadata={"help": "Whether or not to use the gradient low-Rank projection (GaLore)."},
)
galore_target: str = field(
default="all",
Expand Down Expand Up @@ -204,7 +204,54 @@ class GaloreArguments:


@dataclass
class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreArguments):
class BAdamArgument:
r"""
Arguments pertaining to the BAdam optimizer.
"""

use_badam: bool = field(
default=False,
metadata={"help": "Whether or not to use the BAdam optimizer."},
)
badam_mode: Literal["layer", "ratio"] = field(
default="layer",
metadata={"help": "Whether to use layer-wise or ratio-wise BAdam optimizer."},
)
badam_start_block: Optional[int] = field(
default=None,
metadata={"help": "The starting block index for layer-wise BAdam."},
)
badam_switch_block_every: Optional[int] = field(
default=50,
metadata={"help": "How often to switch model's block update. Set to -1 to disable the block update."},
)
badam_switch_mode: Optional[Literal["ascending", "descending", "random", "fixed"]] = field(
default="ascending",
metadata={"help": "the strategy of picking block to update for layer-wise BAdam."},
)
badam_update_ratio: float = field(
default=0.0,
metadata={"help": "The ratio of the update for ratio-wise BAdam."},
)
badam_mask_mode: Literal["adjacent", "scatter"] = field(
default="adjacent",
metadata={
"help": """The mode of the mask for BAdam optimizer. \
`adjacent` means that the trainable parameters are adjacent to each other, \
`scatter` means that trainable parameters are randomly choosed from the weight."""
},
)
badam_verbose: int = field(
default=0,
metadata={
"help": """The verbosity level of BAdam optimizer. \
0 for no print, 1 for print the block prefix, 2 for print trainable parameters"""
},
)


@dataclass
class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreArguments, BAdamArgument):
r"""
Arguments pertaining to which techniques we are going to fine-tuning with.
"""
Expand Down
15 changes: 14 additions & 1 deletion src/llmtuner/hparams/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,9 @@ def _check_extra_dependencies(
if finetuning_args.use_galore:
require_version("galore_torch", "To fix: pip install galore_torch")

if finetuning_args.use_badam:
require_version("badam", "To fix: pip install badam")

if training_args is not None and training_args.predict_with_generate:
require_version("jieba", "To fix: pip install jieba")
require_version("nltk", "To fix: pip install nltk")
Expand Down Expand Up @@ -151,6 +154,9 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
if training_args.do_train and training_args.predict_with_generate:
raise ValueError("`predict_with_generate` cannot be set as True while training.")

if training_args.do_train and model_args.quantization_device_map == "auto":
raise ValueError("Cannot use device map for quantized models in training.")

if finetuning_args.use_dora and model_args.use_unsloth:
raise ValueError("Unsloth does not support DoRA.")

Expand All @@ -169,7 +175,14 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
raise ValueError("Distributed training does not support layer-wise GaLore.")

if finetuning_args.use_galore and training_args.deepspeed is not None:
raise ValueError("GaLore is incompatible with DeepSpeed.")
raise ValueError("GaLore is incompatible with DeepSpeed yet.")

if (
finetuning_args.use_badam
and finetuning_args.badam_mode == "layer"
and training_args.parallel_mode.value == "distributed"
):
raise ValueError("Layer-wise BAdam does not yet support distributed training, use ratio-wise BAdam.")

if model_args.infer_backend == "vllm":
raise ValueError("vLLM backend is only available for API, CLI and Web.")
Expand Down
14 changes: 9 additions & 5 deletions src/llmtuner/model/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def init_adapter(

if finetuning_args.finetuning_type == "full" and is_trainable:
logger.info("Fine-tuning method: Full")
if not finetuning_args.pure_bf16:
if (not finetuning_args.pure_bf16) and (not finetuning_args.use_badam):
model = model.float()

if finetuning_args.finetuning_type == "freeze" and is_trainable:
Expand Down Expand Up @@ -82,7 +82,7 @@ def init_adapter(

for name, param in model.named_parameters():
if any(trainable_layer in name for trainable_layer in trainable_layers):
if not finetuning_args.pure_bf16:
if (not finetuning_args.pure_bf16) and (not finetuning_args.use_badam):
param.data = param.data.to(torch.float32)
else:
param.requires_grad_(False)
Expand Down Expand Up @@ -145,24 +145,28 @@ def init_adapter(
"lora_alpha": finetuning_args.lora_alpha,
"lora_dropout": finetuning_args.lora_dropout,
"use_rslora": finetuning_args.use_rslora,
"modules_to_save": finetuning_args.additional_target,
}

if model_args.use_unsloth:
from unsloth import FastLanguageModel # type: ignore

unsloth_peft_kwargs = {"model": model, "max_seq_length": model_args.model_max_length}
unsloth_peft_kwargs = {
"model": model,
"max_seq_length": model_args.model_max_length,
"use_gradient_checkpointing": "unsloth",
}
model = FastLanguageModel.get_peft_model(**peft_kwargs, **unsloth_peft_kwargs)
else:
lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
modules_to_save=finetuning_args.additional_target,
use_dora=finetuning_args.use_dora,
**peft_kwargs,
)
model = get_peft_model(model, lora_config)

if not finetuning_args.pure_bf16:
if (not finetuning_args.pure_bf16) and (not finetuning_args.use_badam):
for param in filter(lambda p: p.requires_grad, model.parameters()):
param.data = param.data.to(torch.float32)

Expand Down
8 changes: 5 additions & 3 deletions src/llmtuner/model/patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from ..extras.misc import get_current_device, infer_optim_dtype
from ..extras.packages import is_flash_attn2_available
from ..extras.patches.llama_patch import apply_llama_patch
from .utils import QuantizationMethod, add_z3_leaf_module
from .utils import QuantizationMethod, add_z3_leaf_module, gradient_checkpointing_enable


if TYPE_CHECKING:
Expand Down Expand Up @@ -133,7 +133,9 @@ def _configure_quantization(
if is_deepspeed_zero3_enabled():
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantized models.")

init_kwargs["device_map"] = {"": get_current_device()}
if model_args.quantization_device_map != "auto":
init_kwargs["device_map"] = {"": get_current_device()}

quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None)
quant_method = quantization_config.get("quant_method", "")

Expand Down Expand Up @@ -266,8 +268,8 @@ def _prepare_model_for_training(
else:
# use_reentrant=False might increase VRAM usage (have not been empirically verified yet)
# According to: https://github.com/huggingface/transformers/issues/28339
model.gradient_checkpointing_enable = MethodType(gradient_checkpointing_enable, model)
hiyouga marked this conversation as resolved.
Show resolved Hide resolved
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": True})
model.enable_input_require_grads()
setattr(model.config, "use_cache", False) # turn off when gradient checkpointing is enabled
logger.info("Gradient checkpointing enabled.")

Expand Down
34 changes: 33 additions & 1 deletion src/llmtuner/model/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from enum import Enum, unique
from typing import TYPE_CHECKING, Dict, List
from functools import partial
from typing import TYPE_CHECKING, Any, Dict, List, Optional

import torch
from transformers import PreTrainedModel
Expand Down Expand Up @@ -100,6 +101,37 @@ def find_expanded_modules(model: "PreTrainedModel", target_modules: List[str], n
return module_names


def gradient_checkpointing_enable(
self: "PreTrainedModel", gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None
) -> None:
r"""
Activates gradient checkpointing for the current model.

Modification of the original method to enable gradient checkpointing for block-wise optimizer.
"""
from torch.utils.checkpoint import checkpoint

if not self.supports_gradient_checkpointing:
raise ValueError("{} does not support gradient checkpointing.".format(self.__class__.__name__))

if gradient_checkpointing_kwargs is None:
gradient_checkpointing_kwargs = {"use_reentrant": True}

gradient_checkpointing_func = partial(checkpoint, **gradient_checkpointing_kwargs)

def custom_gradient_checkpointing_func(func, *args, **kwargs):
module: "torch.nn.Module" = func.__self__

if any(param.requires_grad for param in module.parameters()):
for arg in args:
if torch.is_tensor(arg) and torch.is_floating_point(arg):
arg.requires_grad_(True)

return gradient_checkpointing_func(func, *args, **kwargs)

self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=custom_gradient_checkpointing_func)


def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") -> Dict[str, torch.Tensor]:
r"""
Loads value head parameters from Hugging Face Hub or local disk.
Expand Down
5 changes: 5 additions & 0 deletions src/llmtuner/train/sft/trainer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
import os
from types import MethodType
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union

import numpy as np
Expand Down Expand Up @@ -28,6 +29,10 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
def __init__(self, finetuning_args: "FinetuningArguments", **kwargs) -> None:
super().__init__(**kwargs)
self.finetuning_args = finetuning_args
if finetuning_args.use_badam:
from badam import clip_grad_norm_for_sparse_tensor

self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator)

def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None:
Expand Down
Loading