Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhancement] Support ZeRO-3 when using BAdam #4352

Merged
merged 7 commits into from
Jun 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -163,3 +163,5 @@ cython_debug/
user.config
saves/
cache/
wandb
ds_badam_exp
40 changes: 40 additions & 0 deletions examples/extras/badam/llama3_badam_sft.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
### model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct

### method
stage: sft
do_train: true
finetuning_type: full
use_badam: true
badam_switch_mode: ascending
badam_switch_interval: 50
badam_verbose: 2

### dataset
dataset: identity,alpaca_en_demo
template: llama3
cutoff_len: 1024
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16

### output
output_dir: saves/llama3-8b/full/sft
logging_steps: 10
save_steps: 500
plot_loss: true
overwrite_output_dir: true

### train
per_device_train_batch_size: 1
gradient_accumulation_steps: 8
learning_rate: 1.0e-6
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_ratio: 0.1

### eval
val_size: 0.1
per_device_eval_batch_size: 1
eval_strategy: steps
eval_steps: 500
37 changes: 37 additions & 0 deletions examples/extras/badam/train_single_gpu.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
#!/bin/bash
export CUDA_VISIBLE_DEVICES=0

cd ../../..

llamafactory-cli train \
--stage sft \
--do_train True \
--model_name_or_path meta-llama/Llama-2-13b-hf \
--preprocessing_num_workers 16 \
--finetuning_type full \
--template default \
--flash_attn auto \
--dataset_dir data \
--dataset alpaca_en_demo \
--cutoff_len 1024 \
--learning_rate 1e-6 \
--num_train_epochs 3.0 \
--max_samples 100000 \
--per_device_train_batch_size 1 \
--gradient_accumulation_steps 8 \
--lr_scheduler_type cosine \
--max_grad_norm 1.0 \
--logging_steps 5 \
--save_steps 100 \
--warmup_steps 0 \
--optim adamw_torch \
--packing False \
--report_to none \
--use_badam True \
--output_dir saves/LLaMA2-13B/full/BAdam \
--plot_loss True \
--ddp_timeout 180000000 \
--include_num_input_tokens_seen True \
--badam_mode layer \
--badam_switch_mode ascending \
--badam_switch_interval 50
39 changes: 39 additions & 0 deletions examples/extras/badam/train_zero3.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#!/bin/bash
export CUDA_VISIBLE_DEVICES=0,1,2,3

cd ../../..

llamafactory-cli train \
--stage sft \
--do_train True \
--model_name_or_path meta-llama/Llama-2-13b-hf \
--preprocessing_num_workers 16 \
--finetuning_type full \
--template default \
--flash_attn auto \
--dataset_dir data \
--dataset alpaca_en_demo \
--cutoff_len 1024 \
--learning_rate 1e-6 \
--num_train_epochs 3.0 \
--max_samples 100000 \
--per_device_train_batch_size 8 \
--gradient_accumulation_steps 2 \
--lr_scheduler_type cosine \
--max_grad_norm 1.0 \
--logging_steps 5 \
--save_steps 100 \
--warmup_steps 0 \
--optim adamw_torch \
--packing False \
--report_to none \
--use_badam True \
--output_dir saves/LLaMA2-13B/full/BAdam \
--fp16 True \
--plot_loss True \
--ddp_timeout 180000000 \
--include_num_input_tokens_seen True \
--badam_mode layer \
--badam_switch_mode ascending \
--badam_switch_interval 50 \
--deepspeed cache/ds_z3_config.json
12 changes: 7 additions & 5 deletions src/llamafactory/hparams/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,13 +211,15 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:

if (
finetuning_args.use_badam
and finetuning_args.badam_mode == "layer"
and training_args.parallel_mode == ParallelMode.DISTRIBUTED
and training_args.parallel_mode.value == "distributed"
):
raise ValueError("Layer-wise BAdam does not yet support distributed training, use ratio-wise BAdam.")
if finetuning_args.badam_mode == "ratio":
raise ValueError("Ratio-wise BAdam does not yet support distributed training, use layer-wise BAdam: --badam_mode layer")
if finetuning_args.badam_mode == "layer" and (not is_deepspeed_zero3_enabled()):
raise ValueError(f"Layer-wise BAdam only supports DeepSpeed ZeRO 3 stage.")

if (finetuning_args.use_galore or finetuning_args.use_badam) and training_args.deepspeed is not None:
raise ValueError("GaLore and BAdam are incompatible with DeepSpeed yet.")
if (finetuning_args.use_galore) and training_args.deepspeed is not None:
raise ValueError("GaLore are incompatible with DeepSpeed yet.")

if model_args.infer_backend == "vllm":
raise ValueError("vLLM backend is only available for API, CLI and Web.")
Expand Down
6 changes: 3 additions & 3 deletions src/llamafactory/train/dpo/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,9 @@ def __init__(
self.save_model(os.path.join(self.args.output_dir, "pissa_init"))

if finetuning_args.use_badam:
from badam import clip_grad_norm_for_sparse_tensor

self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator)
from badam import clip_grad_norm_old_version, BAdamCallback
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.callback_handler.add_callback(BAdamCallback)

def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None:
Expand Down
6 changes: 3 additions & 3 deletions src/llamafactory/train/kto/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,9 @@ def __init__(
self.ref_model.eval()

if finetuning_args.use_badam:
from badam import clip_grad_norm_for_sparse_tensor

self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator)
from badam import clip_grad_norm_old_version, BAdamCallback
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.callback_handler.add_callback(BAdamCallback)

def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None:
Expand Down
6 changes: 3 additions & 3 deletions src/llamafactory/train/ppo/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,9 +166,9 @@ def __init__(
self.reward_model = self.accelerator.prepare_model(self.reward_model, evaluation_mode=True)

if finetuning_args.use_badam:
from badam import clip_grad_norm_for_sparse_tensor

self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator)
from badam import clip_grad_norm_old_version, BAdamCallback
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.callback_handler.add_callback(BAdamCallback)

def ppo_train(self, resume_from_checkpoint: Optional[str] = None) -> None:
r"""
Expand Down
6 changes: 3 additions & 3 deletions src/llamafactory/train/pt/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,9 @@ def __init__(
self.save_model(os.path.join(self.args.output_dir, "pissa_init"))

if finetuning_args.use_badam:
from badam import clip_grad_norm_for_sparse_tensor

self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator)
from badam import clip_grad_norm_old_version, BAdamCallback
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.callback_handler.add_callback(BAdamCallback)

def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None:
Expand Down
6 changes: 3 additions & 3 deletions src/llamafactory/train/rm/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,9 @@ def __init__(
self.processor = processor
self.can_return_loss = True # override property to return eval_loss
if finetuning_args.use_badam:
from badam import clip_grad_norm_for_sparse_tensor

self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator)
from badam import clip_grad_norm_old_version, BAdamCallback
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.callback_handler.add_callback(BAdamCallback)

def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None:
Expand Down
6 changes: 3 additions & 3 deletions src/llamafactory/train/sft/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,9 @@ def __init__(
self.save_model(os.path.join(self.args.output_dir, "pissa_init"))

if finetuning_args.use_badam:
from badam import clip_grad_norm_for_sparse_tensor

self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator)
from badam import clip_grad_norm_old_version, BAdamCallback
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.callback_handler.add_callback(BAdamCallback)

def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None:
Expand Down
7 changes: 6 additions & 1 deletion src/llamafactory/train/trainer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,9 @@ def _create_badam_optimizer(
dict(params=decay_params, weight_decay=training_args.weight_decay),
]

from transformers.integrations import is_deepspeed_zero3_enabled
ds_zero3_enabled = is_deepspeed_zero3_enabled()

if finetuning_args.badam_mode == "layer":
from badam import BlockOptimizer

Expand All @@ -383,6 +386,7 @@ def _create_badam_optimizer(
start_block=finetuning_args.badam_start_block,
switch_mode=finetuning_args.badam_switch_mode,
verbose=finetuning_args.badam_verbose,
ds_zero3_enabled=ds_zero3_enabled
)
logger.info(
f"Using BAdam optimizer with layer-wise update, switch mode is {finetuning_args.badam_switch_mode}, "
Expand All @@ -393,6 +397,7 @@ def _create_badam_optimizer(
elif finetuning_args.badam_mode == "ratio":
from badam import BlockOptimizerRatio

assert not ds_zero3_enabled, "BAdam with ratio-based update does not support Deepspeed ZeRO-3 yet, use layer-wise update instead: --badam_mode layer."
assert finetuning_args.badam_update_ratio > 1e-6
optimizer = BlockOptimizerRatio(
param_groups=param_groups,
Expand All @@ -404,7 +409,7 @@ def _create_badam_optimizer(
**optim_kwargs,
)
logger.info(
f"Using BAdam optimizer with ratio-wise update, update ratio is {finetuning_args.badam_update_ratio}, "
f"Using BAdam optimizer with ratio-based update, update ratio is {finetuning_args.badam_update_ratio}, "
f"mask mode is {finetuning_args.badam_mask_mode}"
)

Expand Down