Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add BAdam algorithm #3287

Merged
merged 14 commits into from
Apr 16, 2024
36 changes: 36 additions & 0 deletions examples/extras/badam/sft.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# BAdam layer-wise
export CUDA_VISIBLE_DEVICES=0
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
python ../../../src/train_bash.py \
--stage sft \
--do_train \
--model_name_or_path meta-llama/Llama-2-7b-hf \
--dataset alpaca_gpt4_en,glaive_toolcall \
--dataset_dir ../../../data \
--template default \
--finetuning_type full \
--output_dir ../../../saves/LLaMA2-7B/badam \
--overwrite_cache \
--overwrite_output_dir \
--cutoff_len 1024 \
--preprocessing_num_workers 32 \
--per_device_train_batch_size 8 \
--per_device_eval_batch_size 5 \
--gradient_accumulation_steps 2 \
--lr_scheduler_type cosine \
--logging_steps 10 \
--warmup_steps 20 \
--save_steps 100 \
--eval_steps 100 \
--evaluation_strategy steps \
--load_best_model_at_end \
--learning_rate 5e-5 \
--num_train_epochs 3.0 \
--val_size 0.1 \
--plot_loss \
--use_badam \
--switch_mode descending \
--badam_verbose 2 \
--switch_block_every 50 \
--pure_bf16 \
Ledzy marked this conversation as resolved.
Show resolved Hide resolved

1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@ fastapi
sse-starlette
matplotlib
fire
badam
hiyouga marked this conversation as resolved.
Show resolved Hide resolved
43 changes: 42 additions & 1 deletion src/llmtuner/hparams/finetuning_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,47 @@ class RLHFArguments:
metadata={"help": "The type of the reward model in PPO training. Lora model only supports lora training."},
)

@dataclass
class BAdamArgument:
r"""
Arguments for BAdam optimizer.
"""
use_badam: bool = field(
default=False,
metadata={"help": "Whether or not to use BAdam optimizer."},
)
badam_mode: Literal["layer", "ratio"] = field(
default="layer",
metadata={"help": "The mode of BAdam optimizer. 'layer' for layer-wise, 'ratio' for ratio-wise."},
)

# ======== Arguments for layer-wise update ========
start_block: Optional[int] = field(
default=None,
metadata={"help": "The starting block index for block-wise fine-tuning."}
)
switch_block_every: Optional[int] = field(
default=50,
metadata={"help": "how often to switch model's block update. Set to -1 to disable the block update."}
)
switch_mode: Optional[Literal["ascending", "descending", "random", "fixed"]] = field(
default="ascending",
metadata={"help": "the strategy of picking block to update."}
)

# ======== Arguments for ratio-wise update ========
badam_update_ratio: float = field(
default=0.,
metadata={"help": "The ratio of the update for the BAdam optimizer."}
)
badam_mask_mode: Literal["adjacent", "scatter"] = field(
default="adjacent",
metadata={"help": "The mode of the mask for BAdam optimizer. `adjacent` means that the trainable parameters are adjacent to each other; `scatter` means that trainable parameters are randomly choosed from the weight."}
)
badam_verbose: int = field(
default=0,
metadata={"help": "The verbosity level of BAdam optimizer. 0 for no print, 1 for print the block prefix, 2 for print trainable parameters"}
)

@dataclass
class GaloreArguments:
Expand Down Expand Up @@ -204,7 +245,7 @@ class GaloreArguments:


@dataclass
class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreArguments):
class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreArguments, BAdamArgument):
r"""
Arguments pertaining to which techniques we are going to fine-tuning with.
"""
Expand Down
6 changes: 6 additions & 0 deletions src/llmtuner/hparams/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,12 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
if finetuning_args.use_galore and training_args.deepspeed is not None:
raise ValueError("GaLore is incompatible with DeepSpeed.")

if (finetuning_args.use_badam
and finetuning_args.badam_mode == "layer"
and training_args.parallel_mode.value == "distributed"
):
raise ValueError("BAdam with layer-wise mode is not supported in distributed training by now, use ratio mode instead.")

if model_args.infer_backend == "vllm":
raise ValueError("vLLM backend is only available for API, CLI and Web.")

Expand Down
6 changes: 3 additions & 3 deletions src/llmtuner/model/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def init_adapter(

if finetuning_args.finetuning_type == "full" and is_trainable:
logger.info("Fine-tuning method: Full")
if not finetuning_args.pure_bf16:
if (not finetuning_args.pure_bf16) and (not finetuning_args.use_badam):
model = model.float()

if finetuning_args.finetuning_type == "freeze" and is_trainable:
Expand Down Expand Up @@ -82,7 +82,7 @@ def init_adapter(

for name, param in model.named_parameters():
if any(trainable_layer in name for trainable_layer in trainable_layers):
if not finetuning_args.pure_bf16:
if (not finetuning_args.pure_bf16) and (not finetuning_args.use_badam):
param.data = param.data.to(torch.float32)
else:
param.requires_grad_(False)
Expand Down Expand Up @@ -162,7 +162,7 @@ def init_adapter(
)
model = get_peft_model(model, lora_config)

if not finetuning_args.pure_bf16:
if (not finetuning_args.pure_bf16) and (not finetuning_args.use_badam):
for param in filter(lambda p: p.requires_grad, model.parameters()):
param.data = param.data.to(torch.float32)

Expand Down
5 changes: 3 additions & 2 deletions src/llmtuner/model/patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from ..extras.misc import get_current_device, infer_optim_dtype
from ..extras.packages import is_flash_attn2_available
from ..extras.patches.llama_patch import apply_llama_patch
from .utils import QuantizationMethod, add_z3_leaf_module
from .utils import QuantizationMethod, add_z3_leaf_module, gradient_checkpointing_enable


if TYPE_CHECKING:
Expand Down Expand Up @@ -266,8 +266,9 @@ def _prepare_model_for_training(
else:
# use_reentrant=False might increase VRAM usage (have not been empirically verified yet)
# According to: https://github.com/huggingface/transformers/issues/28339
model.gradient_checkpointing_enable = MethodType(gradient_checkpointing_enable, model)
hiyouga marked this conversation as resolved.
Show resolved Hide resolved
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": True})
model.enable_input_require_grads()
# model.enable_input_require_grads()
setattr(model.config, "use_cache", False) # turn off when gradient checkpointing is enabled
logger.info("Gradient checkpointing enabled.")

Expand Down
42 changes: 42 additions & 0 deletions src/llmtuner/model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,3 +135,45 @@ def register_autoclass(config: "PretrainedConfig", model: "PreTrainedModel", tok
model.__class__.register_for_auto_class()
if "AutoTokenizer" in tokenizer.init_kwargs.get("auto_map", {}):
tokenizer.__class__.register_for_auto_class()

def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
"""
Modification of the original method to enable gradient checkpointing for block-wise optimizer.

Activates gradient checkpointing for the current model.

We pass the `__call__` method of the modules instead of `forward` because `__call__` attaches all the hooks of
the module. https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2

Args:
gradient_checkpointing_kwargs (dict, *optional*):
Additional keyword arguments passed along to the `torch.utils.checkpoint.checkpoint` function.
"""
from torch.utils.checkpoint import checkpoint

if not self.supports_gradient_checkpointing:
raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")

if gradient_checkpointing_kwargs is None:
gradient_checkpointing_kwargs = {}
hiyouga marked this conversation as resolved.
Show resolved Hide resolved

# gradient_checkpointing_func = functools.partial(checkpoint, **gradient_checkpointing_kwargs)

def gradient_checkpointing_func(func, *args, **kwargs):
module = func.__self__

if any([p.requires_grad for p in module.parameters()]):
for arg in args:
if torch.is_tensor(arg) and torch.is_floating_point(arg):
arg.requires_grad_(True)

return checkpoint(func, *args, **kwargs)
hiyouga marked this conversation as resolved.
Show resolved Hide resolved

self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=gradient_checkpointing_func)

if getattr(self, "_hf_peft_config_loaded", False):
hiyouga marked this conversation as resolved.
Show resolved Hide resolved
# When using PEFT + gradient checkpointing + Trainer we need to make sure the input has requires_grad=True
# we do it also on PEFT: https://github.com/huggingface/peft/blob/85013987aa82aa1af3da1236b6902556ce3e483e/src/peft/peft_model.py#L334
# When training with PEFT, only LoRA layers will have requires grad set to True, but the output of frozen layers need to propagate
# the gradients to make sure the gradient flows.
self.enable_input_require_grads()
6 changes: 5 additions & 1 deletion src/llmtuner/train/sft/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
from ...extras.constants import IGNORE_INDEX
from ...extras.logging import get_logger
from ..utils import create_custom_optimzer, create_custom_scheduler

from types import MethodType
from packaging import version

if TYPE_CHECKING:
from transformers.trainer import PredictionOutput
Expand All @@ -28,6 +29,9 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
def __init__(self, finetuning_args: "FinetuningArguments", **kwargs) -> None:
super().__init__(**kwargs)
self.finetuning_args = finetuning_args
if version.parse(torch.__version__) >= version.parse("1.13"):
hiyouga marked this conversation as resolved.
Show resolved Hide resolved
from badam import clip_grad_norm_for_sparse_tensor
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator)

def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None:
Expand Down
57 changes: 57 additions & 0 deletions src/llmtuner/train/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,12 +287,69 @@ def _create_loraplus_optimizer(
logger.info("Using LoRA+ optimizer with loraplus lr ratio {:.2f}.".format(finetuning_args.loraplus_lr_ratio))
return optimizer

def _create_badam_optimizer(
model: "PreTrainedModel",
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
) -> "torch.optim.Optimizer":

from transformers.trainer_pt_utils import get_parameter_names
decay_parameters = list(filter(lambda n: "bias" not in n, get_parameter_names(model, ALL_LAYERNORM_LAYERS)))
# filter out the embedding layers when using badam ratio mode
if finetuning_args.badam_mode == "ratio":
decay_parameters = list(filter(lambda n: "embed" not in n, decay_parameters)) # TODO: make it more general
optimizer_grouped_parameters = [
{
"params": [p for n, p in model.named_parameters() if n in decay_parameters],
"weight_decay": training_args.weight_decay,
},
{
"params": [p for n, p in model.named_parameters() if n not in decay_parameters],
"weight_decay": 0.0,
},
]

optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(training_args)

# create BlockOptimizer
if finetuning_args.badam_mode == "layer":
from badam import BlockOptimizer
base_optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
optimizer = BlockOptimizer(base_optimizer=base_optimizer,
named_parameters_list=list(model.named_parameters()),
block_prefix_list=None,
switch_block_every=finetuning_args.switch_block_every,
start_block=finetuning_args.start_block,
switch_mode=finetuning_args.switch_mode,
verbose=finetuning_args.badam_verbose)

logger.info(f"Using BAdam optimizer with layer-wise update, switch mode is {finetuning_args.switch_mode}, "
f"switch block every {finetuning_args.switch_block_every} steps, "
f"default start block is {finetuning_args.start_block}")

elif finetuning_args.badam_mode == "ratio":
assert finetuning_args.badam_update_ratio > 0.
from badam import BlockOptimizerRatio
optimizer = BlockOptimizerRatio(param_groups=optimizer_grouped_parameters,
named_parameters_list=list(model.named_parameters()),
update_ratio=finetuning_args.badam_update_ratio,
mask_mode=finetuning_args.badam_mask_mode,
verbose=finetuning_args.badam_verbose,
**optimizer_kwargs)

logger.info(f"Using BAdam optimizer with ratio update, update ratio is {finetuning_args.badam_update_ratio}, "
f"mask mode is {finetuning_args.badam_mask_mode}")

return optimizer

def create_custom_optimzer(
model: "PreTrainedModel",
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
) -> Optional["torch.optim.Optimizer"]:
if finetuning_args.use_badam:
return _create_badam_optimizer(model, training_args, finetuning_args)

if finetuning_args.use_galore:
return _create_galore_optimizer(model, training_args, finetuning_args)

Expand Down
Loading