Skip to content

Commit

Permalink
Merge pull request #4352 from Ledzy/main
Browse files Browse the repository at this point in the history
[Enhancement] Support ZeRO-3 when using BAdam
  • Loading branch information
hiyouga authored Jun 24, 2024
2 parents 4108605 + 5c2ff1b commit d0f953b
Show file tree
Hide file tree
Showing 12 changed files with 149 additions and 24 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -163,3 +163,5 @@ cython_debug/
user.config
saves/
cache/
wandb
ds_badam_exp
40 changes: 40 additions & 0 deletions examples/extras/badam/llama3_badam_sft.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
### model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct

### method
stage: sft
do_train: true
finetuning_type: full
use_badam: true
badam_switch_mode: ascending
badam_switch_interval: 50
badam_verbose: 2

### dataset
dataset: identity,alpaca_en_demo
template: llama3
cutoff_len: 1024
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16

### output
output_dir: saves/llama3-8b/full/sft
logging_steps: 10
save_steps: 500
plot_loss: true
overwrite_output_dir: true

### train
per_device_train_batch_size: 1
gradient_accumulation_steps: 8
learning_rate: 1.0e-6
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_ratio: 0.1

### eval
val_size: 0.1
per_device_eval_batch_size: 1
eval_strategy: steps
eval_steps: 500
37 changes: 37 additions & 0 deletions examples/extras/badam/train_single_gpu.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
#!/bin/bash
export CUDA_VISIBLE_DEVICES=0

cd ../../..

llamafactory-cli train \
--stage sft \
--do_train True \
--model_name_or_path meta-llama/Llama-2-13b-hf \
--preprocessing_num_workers 16 \
--finetuning_type full \
--template default \
--flash_attn auto \
--dataset_dir data \
--dataset alpaca_en_demo \
--cutoff_len 1024 \
--learning_rate 1e-6 \
--num_train_epochs 3.0 \
--max_samples 100000 \
--per_device_train_batch_size 1 \
--gradient_accumulation_steps 8 \
--lr_scheduler_type cosine \
--max_grad_norm 1.0 \
--logging_steps 5 \
--save_steps 100 \
--warmup_steps 0 \
--optim adamw_torch \
--packing False \
--report_to none \
--use_badam True \
--output_dir saves/LLaMA2-13B/full/BAdam \
--plot_loss True \
--ddp_timeout 180000000 \
--include_num_input_tokens_seen True \
--badam_mode layer \
--badam_switch_mode ascending \
--badam_switch_interval 50
39 changes: 39 additions & 0 deletions examples/extras/badam/train_zero3.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#!/bin/bash
export CUDA_VISIBLE_DEVICES=0,1,2,3

cd ../../..

llamafactory-cli train \
--stage sft \
--do_train True \
--model_name_or_path meta-llama/Llama-2-13b-hf \
--preprocessing_num_workers 16 \
--finetuning_type full \
--template default \
--flash_attn auto \
--dataset_dir data \
--dataset alpaca_en_demo \
--cutoff_len 1024 \
--learning_rate 1e-6 \
--num_train_epochs 3.0 \
--max_samples 100000 \
--per_device_train_batch_size 8 \
--gradient_accumulation_steps 2 \
--lr_scheduler_type cosine \
--max_grad_norm 1.0 \
--logging_steps 5 \
--save_steps 100 \
--warmup_steps 0 \
--optim adamw_torch \
--packing False \
--report_to none \
--use_badam True \
--output_dir saves/LLaMA2-13B/full/BAdam \
--fp16 True \
--plot_loss True \
--ddp_timeout 180000000 \
--include_num_input_tokens_seen True \
--badam_mode layer \
--badam_switch_mode ascending \
--badam_switch_interval 50 \
--deepspeed cache/ds_z3_config.json
12 changes: 7 additions & 5 deletions src/llamafactory/hparams/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,13 +214,15 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:

if (
finetuning_args.use_badam
and finetuning_args.badam_mode == "layer"
and training_args.parallel_mode == ParallelMode.DISTRIBUTED
and training_args.parallel_mode.value == "distributed"
):
raise ValueError("Layer-wise BAdam does not yet support distributed training, use ratio-wise BAdam.")
if finetuning_args.badam_mode == "ratio":
raise ValueError("Ratio-wise BAdam does not yet support distributed training, use layer-wise BAdam: --badam_mode layer")
if finetuning_args.badam_mode == "layer" and (not is_deepspeed_zero3_enabled()):
raise ValueError(f"Layer-wise BAdam only supports DeepSpeed ZeRO 3 stage.")

if (finetuning_args.use_galore or finetuning_args.use_badam) and training_args.deepspeed is not None:
raise ValueError("GaLore and BAdam are incompatible with DeepSpeed yet.")
if (finetuning_args.use_galore) and training_args.deepspeed is not None:
raise ValueError("GaLore are incompatible with DeepSpeed yet.")

if model_args.infer_backend == "vllm":
raise ValueError("vLLM backend is only available for API, CLI and Web.")
Expand Down
6 changes: 3 additions & 3 deletions src/llamafactory/train/dpo/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,9 @@ def __init__(
self.save_model(os.path.join(self.args.output_dir, "pissa_init"))

if finetuning_args.use_badam:
from badam import clip_grad_norm_for_sparse_tensor

self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator)
from badam import clip_grad_norm_old_version, BAdamCallback
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.callback_handler.add_callback(BAdamCallback)

def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None:
Expand Down
6 changes: 3 additions & 3 deletions src/llamafactory/train/kto/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,9 @@ def __init__(
self.ref_model.eval()

if finetuning_args.use_badam:
from badam import clip_grad_norm_for_sparse_tensor

self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator)
from badam import clip_grad_norm_old_version, BAdamCallback
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.callback_handler.add_callback(BAdamCallback)

def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None:
Expand Down
6 changes: 3 additions & 3 deletions src/llamafactory/train/ppo/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,9 +166,9 @@ def __init__(
self.reward_model = self.accelerator.prepare_model(self.reward_model, evaluation_mode=True)

if finetuning_args.use_badam:
from badam import clip_grad_norm_for_sparse_tensor

self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator)
from badam import clip_grad_norm_old_version, BAdamCallback
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.callback_handler.add_callback(BAdamCallback)

def ppo_train(self, resume_from_checkpoint: Optional[str] = None) -> None:
r"""
Expand Down
6 changes: 3 additions & 3 deletions src/llamafactory/train/pt/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,9 @@ def __init__(
self.save_model(os.path.join(self.args.output_dir, "pissa_init"))

if finetuning_args.use_badam:
from badam import clip_grad_norm_for_sparse_tensor

self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator)
from badam import clip_grad_norm_old_version, BAdamCallback
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.callback_handler.add_callback(BAdamCallback)

def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None:
Expand Down
6 changes: 3 additions & 3 deletions src/llamafactory/train/rm/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,9 @@ def __init__(
self.processor = processor
self.can_return_loss = True # override property to return eval_loss
if finetuning_args.use_badam:
from badam import clip_grad_norm_for_sparse_tensor

self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator)
from badam import clip_grad_norm_old_version, BAdamCallback
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.callback_handler.add_callback(BAdamCallback)

def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None:
Expand Down
6 changes: 3 additions & 3 deletions src/llamafactory/train/sft/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,9 @@ def __init__(
self.save_model(os.path.join(self.args.output_dir, "pissa_init"))

if finetuning_args.use_badam:
from badam import clip_grad_norm_for_sparse_tensor

self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator)
from badam import clip_grad_norm_old_version, BAdamCallback
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.callback_handler.add_callback(BAdamCallback)

def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None:
Expand Down
7 changes: 6 additions & 1 deletion src/llamafactory/train/trainer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,9 @@ def _create_badam_optimizer(
dict(params=decay_params, weight_decay=training_args.weight_decay),
]

from transformers.integrations import is_deepspeed_zero3_enabled
ds_zero3_enabled = is_deepspeed_zero3_enabled()

if finetuning_args.badam_mode == "layer":
from badam import BlockOptimizer

Expand All @@ -384,6 +387,7 @@ def _create_badam_optimizer(
start_block=finetuning_args.badam_start_block,
switch_mode=finetuning_args.badam_switch_mode,
verbose=finetuning_args.badam_verbose,
ds_zero3_enabled=ds_zero3_enabled
)
logger.info(
f"Using BAdam optimizer with layer-wise update, switch mode is {finetuning_args.badam_switch_mode}, "
Expand All @@ -394,6 +398,7 @@ def _create_badam_optimizer(
elif finetuning_args.badam_mode == "ratio":
from badam import BlockOptimizerRatio

assert not ds_zero3_enabled, "BAdam with ratio-based update does not support Deepspeed ZeRO-3 yet, use layer-wise update instead: --badam_mode layer."
assert finetuning_args.badam_update_ratio > 1e-6
optimizer = BlockOptimizerRatio(
param_groups=param_groups,
Expand All @@ -405,7 +410,7 @@ def _create_badam_optimizer(
**optim_kwargs,
)
logger.info(
f"Using BAdam optimizer with ratio-wise update, update ratio is {finetuning_args.badam_update_ratio}, "
f"Using BAdam optimizer with ratio-based update, update ratio is {finetuning_args.badam_update_ratio}, "
f"mask mode is {finetuning_args.badam_mask_mode}"
)

Expand Down

0 comments on commit d0f953b

Please sign in to comment.