Skip to content

Commit

Permalink
better support quant_lm_head for larger models (#263)
Browse files Browse the repository at this point in the history
  • Loading branch information
wenhuach21 authored Sep 19, 2024
1 parent 6539d50 commit 82322ac
Show file tree
Hide file tree
Showing 4 changed files with 135 additions and 69 deletions.
53 changes: 28 additions & 25 deletions auto_round/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,16 @@

import torch
import transformers

os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
torch.use_deterministic_algorithms(True, warn_only=True)
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel
from lm_eval.utils import make_table # pylint: disable=E0401

from auto_round import AutoRoundConfig
from auto_round.eval.evaluation import simple_evaluate
from auto_round.utils import detect_device, get_library_version
from auto_round.utils import detect_device, get_library_version, detect_device_count


def setup_parser():
parser = argparse.ArgumentParser()
Expand All @@ -34,7 +36,7 @@ def setup_parser():
"--model", default="facebook/opt-125m"
)

parser.add_argument('--eval', action='store_true',
parser.add_argument('--eval', action='store_true',
help="whether to use eval only mode.")

parser.add_argument("--bits", default=4, type=int,
Expand Down Expand Up @@ -89,8 +91,8 @@ def setup_parser():

parser.add_argument("--format", default=None, type=str,
help="The format in which to save the model. "
"The options are 'auto_round', 'auto_gptq', 'auto_awq', 'itrex', 'itrex_xpu' and 'fake'."
"default to 'auto_round."
"The options are 'auto_round', 'auto_gptq', 'auto_awq', 'itrex', 'itrex_xpu' and 'fake'."
"default to 'auto_round."
)

parser.add_argument("--data_type", default='int',
Expand Down Expand Up @@ -130,29 +132,30 @@ def setup_parser():

parser.add_argument("--low_cpu_mem_mode", default=0, type=int,
help="Choose which low cpu memory mode to use. "
"Can significantly reduce cpu memory footprint but cost more time."
"1 means choose block-wise mode, load the weights of each block"
" from disk when tuning and release the memory of the block after tuning."
"2 means choose layer-wise mode, load the weights of each layer from disk when tuning,"
" minimum memory consumption and also slowest running speed."
"others means not use low cpu memory. Default to 0, not use low cpu memory.")
"Can significantly reduce cpu memory footprint but cost more time."
"1 means choose block-wise mode, load the weights of each block"
" from disk when tuning and release the memory of the block after tuning."
"2 means choose layer-wise mode, load the weights of each layer from disk when tuning,"
" minimum memory consumption and also slowest running speed."
"others means not use low cpu memory. Default to 0, not use low cpu memory.")

parser.add_argument("--low_cpu_mem_tmp_dir", default=None, type=str,
help="temp work space to store the temporary files "
"when using low cpu memory mode. Will remove after tuning.")
"when using low cpu memory mode. Will remove after tuning.")

parser.add_argument("--model_dtype", default=None, type=str,
help="force to convert the dtype, some backends supports fp16 dtype better")

parser.add_argument("--act_bits", default=32, type=int,
help="activation bits")

parser.add_argument("--fp_layers_list", default="", type=str,
help="List of Layers to maintain original data type")

args = parser.parse_args()
return args


def tune(args):
tasks = args.tasks
if args.format is None:
Expand All @@ -163,7 +166,6 @@ def tune(args):
model_name = model_name[:-1]
print(model_name, flush=True)


device_str = detect_device(args.device)
torch_dtype = "auto"
if "hpu" in device_str:
Expand Down Expand Up @@ -197,10 +199,16 @@ def tune(args):
trust_remote_code=not args.disable_trust_remote_code
)
else:
model = model_cls.from_pretrained(
model_name, low_cpu_mem_usage=True, torch_dtype=torch_dtype,
trust_remote_code=not args.disable_trust_remote_code
)
if detect_device_count() > 1:
model = model_cls.from_pretrained(
model_name, low_cpu_mem_usage=True, torch_dtype=torch_dtype,
trust_remote_code=not args.disable_trust_remote_code, device_map="auto"
)
else:
model = model_cls.from_pretrained(
model_name, low_cpu_mem_usage=True, torch_dtype=torch_dtype,
trust_remote_code=not args.disable_trust_remote_code
)

from auto_round import AutoRound, AutoAdamRound

Expand Down Expand Up @@ -274,18 +282,13 @@ def tune(args):
error_message = "Please upgrade transformers>=4.38.0 to support lm-head quantization."
raise EnvironmentError(error_message)

if args.quant_lm_head and args.low_gpu_mem_usage:
print(
f"warning, low_gpu_mem_usage=False is strongly recommended"
" if the whole model could be loaded to gpu")

autoround = round(
model, tokenizer, args.bits, args.group_size, sym=args.sym, batch_size=args.batch_size,
dataset=args.dataset, seqlen=seqlen, nblocks=args.nblocks, iters=args.iters, lr=args.lr,
minmax_lr=args.minmax_lr, enable_quanted_input=not args.disable_quanted_input,
minmax_lr=args.minmax_lr, enable_quanted_input=not args.disable_quanted_input,
device=device_str, amp=not args.disable_amp, nsamples=args.nsamples, seed=args.seed,
low_gpu_mem_usage=args.low_gpu_mem_usage, scale_dtype=args.scale_dtype,
gradient_accumulate_steps=args.gradient_accumulate_steps, layer_config=layer_config,
low_gpu_mem_usage=args.low_gpu_mem_usage, scale_dtype=args.scale_dtype,
gradient_accumulate_steps=args.gradient_accumulate_steps, layer_config=layer_config,
enable_minmax_tuning=not args.disable_minmax_tuning, act_bits=args.act_bits,
low_cpu_mem_usage=low_cpu_mem_usage, data_type=args.data_type,
enable_norm_bias_tuning=args.enable_norm_bias_tuning)
Expand Down
42 changes: 30 additions & 12 deletions auto_round/autoround.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,11 @@
to_dtype,
get_layer_names_in_block,
mv_module_from_gpu,
unsupport_meta_device,
unsupport_meta_device, detect_device_count,
)

from .low_cpu_mem.utils import get_layers_before_block
import accelerate


class AutoRound(object):
Expand Down Expand Up @@ -164,7 +165,6 @@ def __init__(
"please do not using device_map='auto' in model loading, "
"or follow examples/language-modeling/main.py to enable low_cpu_mem_usage")
self.model = model.eval()
self.model = mv_module_from_gpu(self.model, self.low_cpu_mem_usage)
self.amp = amp
self.enable_quanted_input = enable_quanted_input
self.enable_minmax_tuning = enable_minmax_tuning
Expand Down Expand Up @@ -283,6 +283,8 @@ def quantize(self):
self.start_time = time.time()
all_first_block_names = [block[0] for block in all_blocks]
all_inputs = self.try_cache_inter_data_gpucpu(all_first_block_names, self.nsamples, layer_names=layer_names)
if hasattr(self.model, "hf_device_map") and len(self.model.hf_device_map) > 1:
accelerate.hooks.remove_hook_from_submodules(self.model) ##self.model.hf_device_map has not been changed
self.model = mv_module_from_gpu(self.model, self.low_cpu_mem_usage)
for block_names in all_blocks:
inputs = all_inputs[block_names[0]]
Expand Down Expand Up @@ -374,15 +376,24 @@ def quant_layers(self, layer_names, layer_inputs):
if len(layer_names) == 0:
return
q_layer_inputs = None
if self.enable_quanted_input:
enable_quanted_input = self.enable_quanted_input
if hasattr(self.model, "hf_device_map") and len(self.model.hf_device_map) > 1 and enable_quanted_input:
from accelerate.big_modeling import dispatch_model

dispatch_model(self.model, self.model.hf_device_map)

if enable_quanted_input:
q_layer_inputs = self.try_cache_inter_data_gpucpu([], self.nsamples, layer_names=layer_names)
if hasattr(self.model, "hf_device_map") and len(self.model.hf_device_map) > 1:
accelerate.hooks.remove_hook_from_submodules(
self.model) ##self.model.hf_device_map has not been changed

self.model = mv_module_from_gpu(self.model, self.low_cpu_mem_usage)
torch.cuda.empty_cache()
for layer_name in layer_names:
layer_input = layer_inputs[layer_name]
layer_input = to_device(layer_input, self.cache_device)
q_layer_input = q_layer_inputs[layer_name] if self.enable_quanted_input else None
q_layer_input = q_layer_inputs[layer_name] if enable_quanted_input else None
q_layer_input = to_device(q_layer_input, self.cache_device)
self.quant_layer(layer_name, layer_input, q_layer_input, device=self.device)
for i in range(len(layer_input)):
Expand Down Expand Up @@ -561,7 +572,7 @@ def calib(self, nsamples, bs):
if self.low_cpu_mem_usage:
for n, m in embed_layers:
m = m.to("meta")
torch.cuda.empty_cache()
# torch.cuda.empty_cache()

@torch.no_grad()
def try_cache_inter_data_gpucpu(self, block_names, nsamples, layer_names=[], last_cache_name=None):
Expand All @@ -581,14 +592,22 @@ def try_cache_inter_data_gpucpu(self, block_names, nsamples, layer_names=[], las
"""
try:
if not self.model.device.type == "meta":
self.model = self.model.to(self.device)
if hasattr(self.model, "hf_device_map") and len(self.model.hf_device_map) > 1:
pass
else:
self.model = self.model.to(self.device)
all_inputs = self.cache_inter_data(
block_names, nsamples, layer_names=layer_names, last_cache_name=last_cache_name
)
self.model = mv_module_from_gpu(self.model, self.low_cpu_mem_usage)
torch.cuda.empty_cache()
except:
logger.info("switch to cpu to cache inputs")
if "lm_head" in self.layer_config and self.layer_config["lm_head"]["bits"]<8:
logger.warning(f"we strongly recommend using additional CUDA/HPU devices,e.g. "
f"'CUDA_VISIBLE_DEVICES=0,1 python xxx',"
f" for optimal performance during calibration when enabling lm-head quantization. "
f"Otherwise, the process may be significantly slower.")
self.model = mv_module_from_gpu(self.model, self.low_cpu_mem_usage)
torch.cuda.empty_cache()
all_inputs = self.cache_inter_data(
Expand Down Expand Up @@ -670,7 +689,7 @@ def forward(m, hidden_states, *positional_args, **kwargs):
self.inputs[name]["input_ids"].append(hidden_states.to("cpu"))
else:
self.inputs[name]["input_ids"].extend(
list(torch.split(hidden_states.to("cpu"), 1, dim=self.input_dim)))
list(torch.split(hidden_states.to("cpu"), 1, dim=self.input_dim)))
else:
self.inputs[name] = {}
if self.train_bs == 1 and self.not_share_rotary_pos_emb_flag:
Expand Down Expand Up @@ -708,13 +727,13 @@ def forward(m, hidden_states, *positional_args, **kwargs):
elif "position_ids" in key or 'cache_position' in key:
if self.train_bs == 1 and self.not_share_rotary_pos_emb_flag:
if key not in self.inputs[name].keys():
self.inputs[name][key] = [to_device(kwargs[key], device=torch.device("cpu"))]
self.inputs[name][key] = [to_device(kwargs[key], device=torch.device("cpu"))]
else:
self.inputs[name][key].append(to_device(kwargs[key], device=torch.device("cpu")))
elif key not in self.inputs[name].keys():
self.inputs[name][key] = list(torch.split(kwargs[key].to("cpu"), 1, dim=0)) \
if self.not_share_position_ids_flag \
else to_device(kwargs[key], device=torch.device("cpu"))
self.inputs[name][key] = list(torch.split(kwargs[key].to("cpu"), 1, dim=0)) \
if self.not_share_position_ids_flag \
else to_device(kwargs[key], device=torch.device("cpu"))
elif kwargs[key] is not None and self.not_share_position_ids_flag:
self.inputs[name][key].extend(list(torch.split(kwargs[key].to("cpu"), 1, dim=0)))
elif 'rotary_pos_emb' in key or 'cu_seqlens' in key:
Expand Down Expand Up @@ -1588,4 +1607,3 @@ def __init__(
optimizer,
**kwargs,
)

53 changes: 50 additions & 3 deletions auto_round/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,16 @@ def block_forward(block, input_ids, input_others, amp=False, amp_dtype=torch.flo


def check_to_quantized(config):
"""Checks if the configuration is valid for quantization.
Args:
config (dict or object): The configuration to check. It can be either a
dictionary with a 'bits' key or an object with a 'bits' attribute.
Returns:
bool: True if the configuration is valid for quantization (bits <= 8),
False otherwise.
"""
if isinstance(config, dict):
if config["bits"] > 8:
return False
Expand All @@ -430,7 +440,42 @@ def check_to_quantized(config):
return True


def detect_device_count():
"""Detects the number of available computation devices.
This function checks if CUDA is available. If it is, it returns the count
of available CUDA devices. If not, it attempts to import the Habana
device framework to return the count of Habana devices. If the import
fails or no devices are found, it returns 0.
Returns:
int: The number of available devices (CUDA or Habana).
"""
if torch.cuda.is_available():
return torch.cuda.device_count()
else:
try:
import habana_frameworks.torch.hpu as hthpu # pylint: disable=E0401
return hthpu.device_count()
except ImportError:
return 0


def detect_device(device=None):
"""Detects the appropriate computation device.
This function determines the device to use for computations. It can take
a specific device index or default to 'auto'. The function checks for
available devices in the following order: CUDA, Habana, and finally CPU.
Args:
device (str, int, or torch.device, optional): The desired device.
If 'auto' or None, the function will determine the best device
automatically.
Returns:
str: The device to use for computations, formatted as a string.
"""
def is_valid_digit(s):
try:
num = int(s)
Expand Down Expand Up @@ -830,26 +875,28 @@ def dynamic_import_inference_linear(backend, bits, group_size, sym):
from auto_round_extension.cuda.qlinear_tritonv2 import QuantLinear
return QuantLinear


def get_library_version(library_name):
from packaging.version import Version
python_vesion = Version(sys.version.split()[0])
if python_vesion < Version("3.8"):
import warnings
warnings.filterwarnings('ignore', category=DeprecationWarning)
import pkg_resources # pylint: disable=E0401
import pkg_resources # pylint: disable=E0401
try:
version = pkg_resources.get_distribution(library_name).version
return version
except pkg_resources.DistributionNotFound:
return f"{library_name} is not installed"
else:
import importlib_metadata # pylint: disable=E0401
import importlib_metadata # pylint: disable=E0401
try:
version = importlib_metadata.version(library_name)
return version
except importlib_metadata.PackageNotFoundError:
return f"{library_name} is not installed"


def get_autogptq_packing_qlinear(backend, bits=4, group_size=128, sym=False):
"""
Configures and returns a QuantLinear class based on the specified backend and parameters.
Expand Down Expand Up @@ -919,4 +966,4 @@ def get_autogptq_packing_qlinear(backend, bits=4, group_size=128, sym=False):
use_qigen=use_qigen,
use_marlin=not disable_marlin,
)
return QuantLinear
return QuantLinear
Loading

0 comments on commit 82322ac

Please sign in to comment.