Skip to content

Commit

Permalink
support HQQ/EETQ hiyouga#4113
Browse files Browse the repository at this point in the history
  • Loading branch information
hiyouga authored and zhangzh committed Jul 1, 2024
1 parent a29bb3b commit 7354c02
Show file tree
Hide file tree
Showing 16 changed files with 138 additions and 61 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ Choose your path:

- **Various models**: LLaMA, LLaVA, Mistral, Mixtral-MoE, Qwen, Yi, Gemma, Baichuan, ChatGLM, Phi, etc.
- **Integrated methods**: (Continuous) pre-training, (multimodal) supervised fine-tuning, reward modeling, PPO, DPO, KTO, ORPO, etc.
- **Scalable resources**: 32-bit full-tuning, 16-bit freeze-tuning, 16-bit LoRA and 2/4/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8.
- **Scalable resources**: 16-bit full-tuning, freeze-tuning, LoRA and 2/3/4/5/6/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8/HQQ/EETQ.
- **Advanced algorithms**: GaLore, BAdam, DoRA, LongLoRA, LLaMA Pro, Mixture-of-Depths, LoRA+, LoftQ, PiSSA and Agent tuning.
- **Practical tricks**: FlashAttention-2, Unsloth, RoPE scaling, NEFTune and rsLoRA.
- **Experiment monitors**: LlamaBoard, TensorBoard, Wandb, MLflow, etc.
Expand Down Expand Up @@ -341,7 +341,7 @@ cd LLaMA-Factory
pip install -e ".[torch,metrics]"
```

Extra dependencies available: torch, torch_npu, metrics, deepspeed, bitsandbytes, vllm, galore, badam, gptq, awq, aqlm, qwen, modelscope, quality
Extra dependencies available: torch, torch-npu, metrics, deepspeed, bitsandbytes, hqq, eetq, gptq, awq, aqlm, vllm, galore, badam, qwen, modelscope, quality

> [!TIP]
> Use `pip install --no-deps -e .` to resolve package conflicts.
Expand Down
4 changes: 2 additions & 2 deletions README_zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd

- **多种模型**:LLaMA、LLaVA、Mistral、Mixtral-MoE、Qwen、Yi、Gemma、Baichuan、ChatGLM、Phi 等等。
- **集成方法**:(增量)预训练、(多模态)指令监督微调、奖励模型训练、PPO 训练、DPO 训练、KTO 训练、ORPO 训练等等。
- **多种精度**32 比特全参数微调、16 比特冻结微调、16 比特 LoRA 微调和基于 AQLM/AWQ/GPTQ/LLM.int8 的 2/4/8 比特 QLoRA 微调。
- **多种精度**16 比特全参数微调、冻结微调、LoRA 微调和基于 AQLM/AWQ/GPTQ/LLM.int8/HQQ/EETQ 的 2/3/4/5/6/8 比特 QLoRA 微调。
- **先进算法**:GaLore、BAdam、DoRA、LongLoRA、LLaMA Pro、Mixture-of-Depths、LoRA+、LoftQ、PiSSA 和 Agent 微调。
- **实用技巧**:FlashAttention-2、Unsloth、RoPE scaling、NEFTune 和 rsLoRA。
- **实验监控**:LlamaBoard、TensorBoard、Wandb、MLflow 等等。
Expand Down Expand Up @@ -341,7 +341,7 @@ cd LLaMA-Factory
pip install -e ".[torch,metrics]"
```

可选的额外依赖项:torch、torch_npu、metrics、deepspeed、bitsandbytes、vllm、galore、badam、gptq、awq、aqlm、qwen、modelscope、quality
可选的额外依赖项:torch、torch-npu、metrics、deepspeed、bitsandbytes、hqq、eetq、gptq、awq、aqlm、vllm、galore、badam、qwen、modelscope、quality

> [!TIP]
> 遇到包冲突时,可使用 `pip install --no-deps -e .` 解决。
Expand Down
8 changes: 5 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,14 @@ def get_requires():
"metrics": ["nltk", "jieba", "rouge-chinese"],
"deepspeed": ["deepspeed>=0.10.0"],
"bitsandbytes": ["bitsandbytes>=0.39.0"],
"vllm": ["vllm>=0.4.3"],
"galore": ["galore-torch"],
"badam": ["badam>=1.2.1"],
"hqq": ["hqq"],
"eetq": ["eetq"],
"gptq": ["optimum>=1.17.0", "auto-gptq>=0.5.0"],
"awq": ["autoawq"],
"aqlm": ["aqlm[gpu]>=1.1.0"],
"vllm": ["vllm>=0.4.3"],
"galore": ["galore-torch"],
"badam": ["badam>=1.2.1"],
"qwen": ["transformers_stream_generator"],
"modelscope": ["modelscope"],
"dev": ["ruff", "pytest"],
Expand Down
5 changes: 4 additions & 1 deletion src/llamafactory/extras/env.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
# Copyright 2024 the LlamaFactory team.
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/commands/env.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
7 changes: 4 additions & 3 deletions src/llamafactory/hparams/model_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,10 @@ class ModelArguments:
default=True,
metadata={"help": "Whether or not to use memory-efficient model loading."},
)
quantization_method: Literal["bitsandbytes", "hqq", "eetq"] = field(
default="bitsandbytes",
metadata={"help": "Quantization method to use for on-the-fly quantization."},
)
quantization_bit: Optional[int] = field(
default=None,
metadata={"help": "The number of bits to quantize the model using bitsandbytes."},
Expand Down Expand Up @@ -235,9 +239,6 @@ def __post_init__(self):
if self.new_special_tokens is not None: # support multiple special tokens
self.new_special_tokens = [token.strip() for token in self.new_special_tokens.split(",")]

assert self.quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
assert self.export_quantization_bit in [None, 8, 4, 3, 2], "We only accept 2/3/4/8-bit quantization."

if self.export_quantization_bit is not None and self.export_quantization_dataset is None:
raise ValueError("Quantization dataset is necessary for exporting.")

Expand Down
2 changes: 2 additions & 0 deletions src/llamafactory/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@

from .loader import load_config, load_model, load_tokenizer
from .model_utils.misc import find_all_linear_modules
from .model_utils.quantization import QuantizationMethod
from .model_utils.valuehead import load_valuehead_params


__all__ = [
"QuantizationMethod",
"load_config",
"load_model",
"load_tokenizer",
Expand Down
4 changes: 2 additions & 2 deletions src/llamafactory/model/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,11 +186,11 @@ def load_model(

trainable_params, all_param = count_parameters(model)
if is_trainable:
param_stats = "trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format(
param_stats = "trainable params: {:,} || all params: {:,} || trainable%: {:.4f}".format(
trainable_params, all_param, 100 * trainable_params / all_param
)
else:
param_stats = "all params: {:d}".format(all_param)
param_stats = "all params: {:,}".format(all_param)

logger.info(param_stats)

Expand Down
83 changes: 52 additions & 31 deletions src/llamafactory/model/model_utils/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

import torch
from datasets import load_dataset
from transformers import BitsAndBytesConfig, GPTQConfig
from transformers import BitsAndBytesConfig, EetqConfig, GPTQConfig, HqqConfig
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.modeling_utils import is_fsdp_enabled
from transformers.utils.versions import require_version
Expand Down Expand Up @@ -59,7 +59,7 @@ class QuantizationMethod(str, Enum):

def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> List[Dict[str, Any]]:
r"""
Prepares the dataset to perform AutoGPTQ.
Prepares the tokenized dataset to perform AutoGPTQ. Do not use tensor output for JSON serialization.
"""
if os.path.isfile(model_args.export_quantization_dataset):
data_path = FILEEXT2TYPE.get(model_args.export_quantization_dataset.split(".")[-1], None)
Expand Down Expand Up @@ -93,7 +93,7 @@ def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "Mod
word_idx = random.randint(0, sample["input_ids"].size(1) - maxlen - 1)
input_ids = sample["input_ids"][:, word_idx : word_idx + maxlen]
attention_mask = sample["attention_mask"][:, word_idx : word_idx + maxlen]
samples.append({"input_ids": input_ids, "attention_mask": attention_mask})
samples.append({"input_ids": input_ids.tolist(), "attention_mask": attention_mask.tolist()})

return samples

Expand All @@ -105,7 +105,7 @@ def configure_quantization(
init_kwargs: Dict[str, Any],
) -> None:
r"""
Priority: PTQ-quantized (training) > AutoGPTQ (export) > Bitsandbytes (training)
Priority: PTQ-quantized (train/infer) > AutoGPTQ (export) > On-the-fly quantization (train/infer)
"""
if getattr(config, "quantization_config", None): # ptq
if is_deepspeed_zero3_enabled():
Expand All @@ -131,6 +131,9 @@ def configure_quantization(
logger.info("Loading {}-bit {}-quantized model.".format(quant_bits, quant_method.upper()))

elif model_args.export_quantization_bit is not None: # auto-gptq
if model_args.export_quantization_bit not in [8, 4, 3, 2]:
raise ValueError("AutoGPTQ only accepts 2/3/4/8-bit quantization.")

require_version("optimum>=1.17.0", "To fix: pip install optimum>=1.17.0")
require_version("auto_gptq>=0.5.0", "To fix: pip install auto_gptq>=0.5.0")
from accelerate.utils import get_max_memory
Expand All @@ -146,30 +149,48 @@ def configure_quantization(
init_kwargs["max_memory"] = get_max_memory()
logger.info("Quantizing model to {} bit with AutoGPTQ.".format(model_args.export_quantization_bit))

elif model_args.quantization_bit is not None: # bnb
if model_args.quantization_bit == 8:
require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
init_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)

elif model_args.quantization_bit == 4:
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
init_kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=model_args.compute_dtype,
bnb_4bit_use_double_quant=model_args.double_quantization,
bnb_4bit_quant_type=model_args.quantization_type,
bnb_4bit_quant_storage=model_args.compute_dtype, # crucial for fsdp+qlora
)

# Do not assign device map if:
# 1. deepspeed zero3 or fsdp (train)
# 2. auto quantization device map (inference)
if is_deepspeed_zero3_enabled() or is_fsdp_enabled() or model_args.quantization_device_map == "auto":
if model_args.quantization_bit != 4:
raise ValueError("Only 4-bit quantized model can use fsdp+qlora or auto device map.")

require_version("bitsandbytes>=0.43.0", "To fix: pip install bitsandbytes>=0.43.0")
else:
init_kwargs["device_map"] = {"": get_current_device()} # change auto device map for inference

logger.info("Quantizing model to {} bit with bitsandbytes.".format(model_args.quantization_bit))
elif model_args.quantization_bit is not None: # on-the-fly
if model_args.quantization_method == QuantizationMethod.BITS_AND_BYTES.value:
if model_args.quantization_bit == 8:
require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
init_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
elif model_args.quantization_bit == 4:
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
init_kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=model_args.compute_dtype,
bnb_4bit_use_double_quant=model_args.double_quantization,
bnb_4bit_quant_type=model_args.quantization_type,
bnb_4bit_quant_storage=model_args.compute_dtype, # crucial for fsdp+qlora
)
else:
raise ValueError("Bitsandbytes only accepts 4-bit or 8-bit quantization.")

# Do not assign device map if:
# 1. deepspeed zero3 or fsdp (train)
# 2. auto quantization device map (inference)
if is_deepspeed_zero3_enabled() or is_fsdp_enabled() or model_args.quantization_device_map == "auto":
if model_args.quantization_bit != 4:
raise ValueError("Only 4-bit quantized model can use fsdp+qlora or auto device map.")

require_version("bitsandbytes>=0.43.0", "To fix: pip install bitsandbytes>=0.43.0")
else:
init_kwargs["device_map"] = {"": get_current_device()} # change auto device map for inference

logger.info("Quantizing model to {} bit with bitsandbytes.".format(model_args.quantization_bit))
elif model_args.quantization_method == QuantizationMethod.HQQ.value:
if model_args.quantization_bit not in [8, 6, 5, 4, 3, 2, 1]:
raise ValueError("HQQ only accepts 1/2/3/4/5/6/8-bit quantization.")

require_version("hqq", "To fix: pip install hqq")
init_kwargs["quantization_config"] = HqqConfig(
nbits=model_args.quantization_bit, quant_zero=False, quant_scale=False, axis=0
) # use ATEN kernel (axis=0) for performance
logger.info("Quantizing model to {} bit with HQQ.".format(model_args.quantization_bit))
elif model_args.quantization_method == QuantizationMethod.EETQ.value:
if model_args.quantization_bit != 8:
raise ValueError("EETQ only accepts 8-bit quantization.")

require_version("eetq", "To fix: pip install eetq")
init_kwargs["quantization_config"] = EetqConfig()
logger.info("Quantizing model to {} bit with EETQ.".format(model_args.quantization_bit))
10 changes: 8 additions & 2 deletions src/llamafactory/webui/chatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from ..extras.constants import PEFT_METHODS
from ..extras.misc import torch_gc
from ..extras.packages import is_gradio_available
from .common import get_save_dir
from .common import QUANTIZATION_BITS, get_save_dir
from .locales import ALERTS


Expand Down Expand Up @@ -76,11 +76,17 @@ def load_model(self, data) -> Generator[str, None, None]:
yield error
return

if get("top.quantization_bit") in QUANTIZATION_BITS:
quantization_bit = int(get("top.quantization_bit"))
else:
quantization_bit = None

yield ALERTS["info_loading"][lang]
args = dict(
model_name_or_path=model_path,
finetuning_type=finetuning_type,
quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,
quantization_bit=quantization_bit,
quantization_method=get("top.quantization_method"),
template=get("top.template"),
flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto",
use_unsloth=(get("top.booster") == "unsloth"),
Expand Down
2 changes: 2 additions & 0 deletions src/llamafactory/webui/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@
DEFAULT_DATA_DIR = "data"
DEFAULT_SAVE_DIR = "saves"
USER_CONFIG = "user_config.yaml"
QUANTIZATION_BITS = ["8", "6", "5", "4", "3", "2", "1"]
GPTQ_BITS = ["8", "4", "3", "2"]


def get_save_dir(*paths: str) -> os.PathLike:
Expand Down
5 changes: 1 addition & 4 deletions src/llamafactory/webui/components/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from ...extras.misc import torch_gc
from ...extras.packages import is_gradio_available
from ...train.tuner import export_model
from ..common import get_save_dir
from ..common import GPTQ_BITS, get_save_dir
from ..locales import ALERTS


Expand All @@ -32,9 +32,6 @@
from ..engine import Engine


GPTQ_BITS = ["8", "4", "3", "2"]


def can_quantize(checkpoint_path: Union[str, List[str]]) -> "gr.Dropdown":
if isinstance(checkpoint_path, list) and len(checkpoint_path) != 0:
return gr.Dropdown(value="none", interactive=False)
Expand Down
13 changes: 8 additions & 5 deletions src/llamafactory/webui/components/top.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from ...extras.constants import METHODS, SUPPORTED_MODELS
from ...extras.packages import is_gradio_available
from ..common import get_model_info, list_checkpoints, save_config
from ..utils import can_quantize
from ..utils import can_quantize, can_quantize_to


if is_gradio_available():
Expand All @@ -43,10 +43,11 @@ def create_top() -> Dict[str, "Component"]:

with gr.Accordion(open=False) as advanced_tab:
with gr.Row():
quantization_bit = gr.Dropdown(choices=["none", "8", "4"], value="none", scale=2)
template = gr.Dropdown(choices=list(TEMPLATES.keys()), value="default", scale=2)
rope_scaling = gr.Radio(choices=["none", "linear", "dynamic"], value="none", scale=3)
booster = gr.Radio(choices=["none", "flashattn2", "unsloth"], value="none", scale=3)
quantization_bit = gr.Dropdown(choices=["none", "8", "4"], value="none", scale=1)
quantization_method = gr.Dropdown(choices=["bitsandbytes", "hqq", "eetq"], value="bitsandbytes", scale=1)
template = gr.Dropdown(choices=list(TEMPLATES.keys()), value="default", scale=1)
rope_scaling = gr.Radio(choices=["none", "linear", "dynamic"], value="none", scale=2)
booster = gr.Radio(choices=["auto", "flashattn2", "unsloth"], value="auto", scale=2)
visual_inputs = gr.Checkbox(scale=1)

model_name.change(get_model_info, [model_name], [model_path, template, visual_inputs], queue=False).then(
Expand All @@ -58,6 +59,7 @@ def create_top() -> Dict[str, "Component"]:
list_checkpoints, [model_name, finetuning_type], [checkpoint_path], queue=False
)
checkpoint_path.focus(list_checkpoints, [model_name, finetuning_type], [checkpoint_path], queue=False)
quantization_method.change(can_quantize_to, [quantization_method], [quantization_bit], queue=False)

return dict(
lang=lang,
Expand All @@ -67,6 +69,7 @@ def create_top() -> Dict[str, "Component"]:
checkpoint_path=checkpoint_path,
advanced_tab=advanced_tab,
quantization_bit=quantization_bit,
quantization_method=quantization_method,
template=template,
rope_scaling=rope_scaling,
booster=booster,
Expand Down
20 changes: 17 additions & 3 deletions src/llamafactory/webui/locales.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,15 +85,29 @@
"quantization_bit": {
"en": {
"label": "Quantization bit",
"info": "Enable 4/8-bit model quantization (QLoRA).",
"info": "Enable quantization (QLoRA).",
},
"ru": {
"label": "Уровень квантования",
"info": "Включить 4/8-битное квантование модели (QLoRA).",
"info": "Включить квантование (QLoRA).",
},
"zh": {
"label": "量化等级",
"info": "启用 4/8 比特模型量化(QLoRA)。",
"info": "启用量化(QLoRA)。",
},
},
"quantization_method": {
"en": {
"label": "Quantization method",
"info": "Quantization algorithm to use.",
},
"ru": {
"label": "Метод квантования",
"info": "Алгоритм квантования, который следует использовать.",
},
"zh": {
"label": "量化方法",
"info": "使用的量化算法。",
},
},
"template": {
Expand Down
1 change: 1 addition & 0 deletions src/llamafactory/webui/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def get_base_elems(self) -> Set["Component"]:
self._id_to_elem["top.finetuning_type"],
self._id_to_elem["top.checkpoint_path"],
self._id_to_elem["top.quantization_bit"],
self._id_to_elem["top.quantization_method"],
self._id_to_elem["top.template"],
self._id_to_elem["top.rope_scaling"],
self._id_to_elem["top.booster"],
Expand Down
Loading

0 comments on commit 7354c02

Please sign in to comment.