Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding Mixture of Depth #3338

Merged
merged 2 commits into from
Apr 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ Choose your path:
- **Various models**: LLaMA, Mistral, Mixtral-MoE, Qwen, Yi, Gemma, Baichuan, ChatGLM, Phi, etc.
- **Integrated methods**: (Continuous) pre-training, supervised fine-tuning, reward modeling, PPO, DPO and ORPO.
- **Scalable resources**: 32-bit full-tuning, 16-bit freeze-tuning, 16-bit LoRA and 2/4/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8.
- **Advanced algorithms**: GaLore, BAdam, DoRA, LongLoRA, LLaMA Pro, LoRA+, LoftQ and Agent tuning.
- **Advanced algorithms**: GaLore, Mixture of Depths, BAdam, DoRA, LongLoRA, LLaMA Pro, LoRA+, LoftQ and Agent tuning.
- **Practical tricks**: FlashAttention-2, Unsloth, RoPE scaling, NEFTune and rsLoRA.
- **Experiment monitors**: LlamaBoard, TensorBoard, Wandb, MLflow, etc.
- **Faster inference**: OpenAI-style API, Gradio UI and CLI with vLLM worker.
Expand All @@ -68,14 +68,16 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/

## Changelog

[24/04/19] We integrated **[Mixture of Depths](https://github.com/astramind-ai/Mixture-of-depths)**. see `examples/extras/MoD` for usage.

[24/04/19] We supported **Meta Llama 3** model series.

[24/04/16] We supported **[BAdam](https://arxiv.org/abs/2404.02827)**. See `examples/extras/badam` for usage.

[24/04/16] We supported **[unsloth](https://github.com/unslothai/unsloth)**'s long-sequence training (Llama-2-7B-56k within 24GB). It achieves **117%** speed and **50%** memory compared with FlashAttention-2, more benchmarks can be found in [this page](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison).

<details><summary>Full Changelog</summary>

[24/04/16] We supported **[unsloth](https://github.com/unslothai/unsloth)**'s long-sequence training (Llama-2-7B-56k within 24GB). It achieves **117%** speed and **50%** memory compared with FlashAttention-2, more benchmarks can be found in [this page](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison).

[24/03/31] We supported **[ORPO](https://arxiv.org/abs/2403.07691)**. See `examples/lora_single_gpu` for usage.

[24/03/21] Our paper "[LlamaFactory: Unified Efficient Fine-Tuning of 100+ Language Models](https://arxiv.org/abs/2403.13372)" is available at arXiv!
Expand Down
8 changes: 5 additions & 3 deletions README_zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
- **多种模型**:LLaMA、Mistral、Mixtral-MoE、Qwen、Yi、Gemma、Baichuan、ChatGLM、Phi 等等。
- **集成方法**:(增量)预训练、指令监督微调、奖励模型训练、PPO 训练、DPO 训练和 ORPO 训练。
- **多种精度**:32 比特全参数微调、16 比特冻结微调、16 比特 LoRA 微调和基于 AQLM/AWQ/GPTQ/LLM.int8 的 2/4/8 比特 QLoRA 微调。
- **先进算法**:GaLore、BAdam、DoRA、LongLoRA、LLaMA Pro、LoRA+、LoftQ 和 Agent 微调。
- **先进算法**:GaLore、Mixture of Depths、BAdam、DoRA、LongLoRA、LLaMA Pro、LoRA+、LoftQ 和 Agent 微调。
- **实用技巧**:FlashAttention-2、Unsloth、RoPE scaling、NEFTune 和 rsLoRA。
- **实验监控**:LlamaBoard、TensorBoard、Wandb、MLflow 等等。
- **极速推理**:基于 vLLM 的 OpenAI 风格 API、浏览器界面和命令行接口。
Expand All @@ -68,14 +68,16 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd

## 更新日志

[24/04/19] 我们整合了 **[深度混合](https://github.com/astramind-ai/Mixture-of-depths)**。用法请参见 `examples/extras/MoD`。

[24/04/19] 我们支持了 **Meta Llama 3** 系列模型。

[24/04/16] 我们支持了 **[BAdam](https://arxiv.org/abs/2404.02827)**。详细用法请参照 `examples/extras/badam`。

[24/04/16] 我们支持了 **[unsloth](https://github.com/unslothai/unsloth)** 的长序列训练(24GB 可训练 Llama-2-7B-56k)。该方法相比 FlashAttention-2 提供了 **117%** 的训练速度和 **50%** 的显存节约。更多数据请见[此页面](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison)。

<details><summary>展开日志</summary>

[24/04/16] 我们支持了 **[unsloth](https://github.com/unslothai/unsloth)** 的长序列训练(24GB 可训练 Llama-2-7B-56k)。该方法相比 FlashAttention-2 提供了 **117%** 的训练速度和 **50%** 的显存节约。更多数据请见[此页面](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison)。

[24/03/31] 我们支持了 **[ORPO](https://arxiv.org/abs/2403.07691)**。详细用法请参照 `examples/lora_single_gpu`。

[24/03/21] 我们的论文 "[LlamaFactory: Unified Efficient Fine-Tuning of 100+ Language Models](https://arxiv.org/abs/2403.13372)" 可在 arXiv 上查看!
Expand Down
3 changes: 3 additions & 0 deletions examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ examples/
├── llama_pro/
│ ├── expand.sh: Expand layers in the model
│ └── sft.sh: Fine-tune the expanded model
├── MoD/
│ ├── freeze_sft.sh: Freeze finetune a model, updating only the MoD router
│ └── sft.sh: Fine-tune the MoD model
└── fsdp_qlora/
└── sft.sh: Fine-tune quantized model with FSDP+QLoRA
```
3 changes: 3 additions & 0 deletions examples/README_zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ examples/
├── llama_pro/
│ ├── expand.sh: 扩展模型中的层
│ └── sft.sh: 训练扩展后的模型
├── MoD/
│ ├── freeze_sft.sh: 冻结微调模型,仅更新 MoD 路由器
│ └── sft.sh: 微调国防部模型
└── fsdp_qlora/
└── sft.sh: 使用 FSDP+QLoRA 微调量化模型
```
33 changes: 33 additions & 0 deletions examples/extras/MoD/freeze_sft.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
#!/bin/bash

CUDA_VISIBLE_DEVICES=0 python ../../../src/train_bash.py \
--stage sft \
--do_train \
--model_name_or_path TinyLlama/TinyLlama-1.1B-Chat-v1.0 \
--dataset alpaca_gpt4_en,glaive_toolcall \
--dataset_dir ../../../data \
--template default \
--finetuning_type freeze \
--name_module_trainable router \
--output_dir ../../../saves/TinyLlama/TinyLlama-1.1B-Chat-v1.0/sft \
--mixture_of_depths convert \
--overwrite_cache \
--overwrite_output_dir \
--cutoff_len 1024 \
--preprocessing_num_workers 16 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 1 \
--lr_scheduler_type cosine \
--logging_steps 10 \
--warmup_steps 20 \
--save_steps 100 \
--eval_steps 100 \
--evaluation_strategy steps \
--load_best_model_at_end \
--learning_rate 5e-5 \
--num_train_epochs 3.0 \
--max_samples 3000 \
--val_size 0.1 \
--plot_loss \
--pure_bf16
32 changes: 32 additions & 0 deletions examples/extras/MoD/sft.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
#!/bin/bash

CUDA_VISIBLE_DEVICES=0 python ../../../src/train_bash.py \
--stage sft \
--do_train \
--model_name_or_path TinyLlama/TinyLlama-1.1B-Chat-v1.0 \
--dataset alpaca_gpt4_en,glaive_toolcall \
--dataset_dir ../../../data \
--template default \
--finetuning_type full \
--output_dir ../../../saves/TinyLlama/TinyLlama-1.1B-Chat-v1.0/sft \
--mixture_of_depths convert \
--overwrite_cache \
--overwrite_output_dir \
--cutoff_len 1024 \
--preprocessing_num_workers 16 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 1 \
--lr_scheduler_type cosine \
--logging_steps 10 \
--warmup_steps 20 \
--save_steps 100 \
--eval_steps 100 \
--evaluation_strategy steps \
--load_best_model_at_end \
--learning_rate 5e-5 \
--num_train_epochs 3.0 \
--max_samples 3000 \
--val_size 0.1 \
--plot_loss \
--pure_bf16
4 changes: 4 additions & 0 deletions src/llmtuner/hparams/model_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,10 @@ class ModelArguments:
default=False,
metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."},
)
mixture_of_depths: Optional[Literal["convert", "continue"]] = field(
default=None,
metadata={"help": "Whether or not to use MoD in the model."},
)
use_unsloth: bool = field(
default=False,
metadata={"help": "Whether or not to use unsloth's optimization for the LoRA training."},
Expand Down
3 changes: 3 additions & 0 deletions src/llmtuner/hparams/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ def _check_extra_dependencies(
if model_args.use_unsloth:
require_version("unsloth", "Please install unsloth: https://github.com/unslothai/unsloth")

if model_args.mixture_of_depths == 'convert' or model_args.mixture_of_depths == 'continue':
require_version("mixture-of-depth", "To fix: pip install mixture-of-depth")

if model_args.infer_backend == "vllm":
require_version("vllm>=0.3.3", "To fix: pip install vllm>=0.3.3")

Expand Down
2 changes: 2 additions & 0 deletions src/llmtuner/model/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ def init_adapter(
for name, _ in model.named_modules():
if ".0." in name:
freeze_modules.add(name.split(".0.")[-1].split(".")[0])
elif ".1." in name: # here since MoD starts from layer 1
freeze_modules.add(name.split(".1.")[-1].split(".")[0])

trainable_layers = []
for module_name in finetuning_args.name_module_trainable:
Expand Down
13 changes: 13 additions & 0 deletions src/llmtuner/model/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,12 @@ def load_model(
patch_config(config, tokenizer, model_args, init_kwargs, is_trainable)

model = None
if model_args.mixture_of_depths == 'continue':
from MoD import AutoMoDModelForCausalLM
model = AutoMoDModelForCausalLM.from_pretrained(model_args.model_name_or_path, config=config)
if model.config.model_type == 'qwen2':
RuntimeError("Qwen models are not supported for MoD training.")

if is_trainable and model_args.use_unsloth:
from unsloth import FastLanguageModel # type: ignore

Expand Down Expand Up @@ -100,6 +106,13 @@ def load_model(
init_kwargs["pretrained_model_name_or_path"] = model_args.model_name_or_path
model: "PreTrainedModel" = AutoModelForCausalLM.from_pretrained(**init_kwargs)

if model_args.mixture_of_depths == 'convert':
from MoD import convert_hf_model
if model.config.model_type == 'qwen2':
RuntimeError("Qwen models are not supported for MoD training.")
model = convert_hf_model(model)


patch_model(model, tokenizer, model_args, is_trainable)
register_autoclass(config, model, tokenizer)

Expand Down