Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

multiple gpu evaluation/calibration refine #312

Merged
merged 16 commits into from
Nov 8, 2024
Merged
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,12 @@ pip install auto-round

### Basic Usage (Gaudi2/CPU/GPU)

A user guide detailing the full list of supported arguments is provided by calling ```auto-round -h``` on the terminal.
[//]: # (A user guide detailing the full list of supported arguments is provided by calling ```auto-round -h``` on the terminal.)
Alternatively, you can use ```auto_round``` instead of ```auto-round```. Set the format you want in `format` and
multiple formats exporting has been supported.
multiple formats exporting has been supported. Please check out [step-by-step-instruction](./docs/step_by_step.md) for more details about calibration dataset or evaluation.

```bash
CUDA_VISIBLE_DEVICES=0 auto-round \
auto-round \
--model facebook/opt-125m \
--bits 4 \
--group_size 128 \
Expand All @@ -77,7 +77,7 @@ We provide two recipes for best accuracy and fast running speed with low memory.

```bash
## best accuracy, 3X slower, low_gpu_mem_usage could save ~20G but ~30% slower
CUDA_VISIBLE_DEVICES=0 auto-round \
auto-round \
--model facebook/opt-125m \
--bits 4 \
--group_size 128 \
Expand All @@ -89,7 +89,7 @@ CUDA_VISIBLE_DEVICES=0 auto-round \

```bash
## fast and low memory, 2-3X speedup, slight accuracy drop at W4G128
CUDA_VISIBLE_DEVICES=0 auto-round \
auto-round \
--model facebook/opt-125m \
--bits 4 \
--group_size 128 \
Expand Down
12 changes: 8 additions & 4 deletions auto_round/autoround.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ def quantize(self):
Returns:
The quantized model and layer configurations.
"""
# logger.info("cache block input")

if bool(self.quant_block_list):
all_blocks = self.quant_block_list
else:
Expand All @@ -280,10 +280,12 @@ def quantize(self):
layer_names = self.get_quantized_layer_names_outside_blocks()
self.start_time = time.time()
all_first_block_names = [block[0] for block in all_blocks]
logger.info("start calibration")
all_inputs = self.try_cache_inter_data_gpucpu(all_first_block_names, self.nsamples, layer_names=layer_names)
if hasattr(self.model, "hf_device_map") and len(self.model.hf_device_map) > 1:
accelerate.hooks.remove_hook_from_submodules(self.model) ##self.model.hf_device_map has not been changed
self.model = mv_module_from_gpu(self.model, self.low_cpu_mem_usage)
logger.info("calibration done")
for block_names in all_blocks:
inputs = all_inputs[block_names[0]]
all_inputs.pop(block_names[0])
Expand Down Expand Up @@ -611,10 +613,12 @@ def try_cache_inter_data_gpucpu(self, block_names, nsamples, layer_names=None, l
except RuntimeError as e:
if "CUDA out of memory" in str(e):
logger.info("switch to cpu to cache inputs")
if "lm_head" in self.layer_config and self.layer_config["lm_head"]["bits"] < 8:
if (("lm_head" in self.layer_config and self.layer_config["lm_head"]["bits"] <= 16) or
self.__class__.__name__=="AutoRoundMLLM") :
logger.warning(f"we strongly recommend using additional CUDA/HPU devices,e.g. "
f"'CUDA_VISIBLE_DEVICES=0,1 python xxx',"
f" for optimal performance during calibration when enabling lm-head quantization. "
f"set `--device '0,1'` in our cmd line usage or "
f"load the model with `device_mapping=auto`,"
f" for optimal performance during calibration "
f"Otherwise, the process may be significantly slower.")
self.model = mv_module_from_gpu(self.model, self.low_cpu_mem_usage)
clear_memory()
Expand Down
60 changes: 37 additions & 23 deletions auto_round/script/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,24 +25,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import re
import argparse

import torch
import transformers

os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
torch.use_deterministic_algorithms(True, warn_only=True)
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel, AutoConfig, AutoProcessor
from lm_eval.utils import make_table # pylint: disable=E0401

from auto_round import AutoRoundConfig
from auto_round.eval.evaluation import simple_evaluate
from auto_round.utils import detect_device, get_library_version, detect_device_count
from auto_round.utils import logger


class BasicArgumentParser(argparse.ArgumentParser):
def __init__(self, *args, **kwargs):
Expand All @@ -59,7 +43,7 @@ def __init__(self, *args, **kwargs):
self.add_argument("--eval_bs", default=None, type=int,
help="batch size in evaluation")

self.add_argument("--device", default="auto", type=str,
self.add_argument("--device", "--devices", default="auto", type=str,
help="the device to be used for tuning. The default is set to auto,"
"allowing for automatic detection."
"Currently, device settings support CPU, GPU, and HPU.")
Expand Down Expand Up @@ -234,7 +218,7 @@ def tune(args):
supported_formats = ["auto_round", "auto_gptq", "auto_awq", "auto_round:gptq", "auto_round:auto_gptq",
"auto_round:auto_gptq:marlin", "auto_round:gptq:marlin", "auto_round:auto_awq",
"auto_round:awq", "auto_gptq:marlin", "itrex", "iterx_xpu", "fake"]
formats = args.format.replace(' ', '').split(",")
formats = args.format.replace(' ', '').split(",")
for format in formats:
if format not in supported_formats:
raise ValueError(f"{format} is not supported, we only support {supported_formats}")
Expand All @@ -247,12 +231,33 @@ def tune(args):
if "marlin" in args.format and args.asym is True:
assert False, "marlin backend only supports sym quantization, please remove --asym"

##must set this before import torch
import os
devices = args.device.split(',')
wenhuach21 marked this conversation as resolved.
Show resolved Hide resolved
use_auto_mapping = False
if all(s.isdigit() for s in devices):
os.environ["CUDA_VISIBLE_DEVICES"] = args.device
use_auto_mapping = True

import re
import torch
import transformers

os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
torch.use_deterministic_algorithms(True, warn_only=True)
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel, AutoConfig, AutoProcessor
from lm_eval.utils import make_table # pylint: disable=E0401

from auto_round import AutoRoundConfig
from auto_round.eval.evaluation import simple_evaluate
from auto_round.utils import detect_device, get_library_version, detect_device_count
from auto_round.utils import logger

model_name = args.model
if model_name[-1] == "/":
model_name = model_name[:-1]
logger.info(f"start to quantize {model_name}")

device_str = detect_device(args.device)
device_str = detect_device(devices[0])
torch_dtype = "auto"
if "hpu" in device_str:
torch_dtype = torch.bfloat16
Expand Down Expand Up @@ -289,7 +294,7 @@ def tune(args):
trust_remote_code=not args.disable_trust_remote_code
)
else:
if detect_device_count() > 1:
if use_auto_mapping:
model = model_cls.from_pretrained(
model_name, low_cpu_mem_usage=True, torch_dtype=torch_dtype,
trust_remote_code=not args.disable_trust_remote_code, device_map="auto"
Expand Down Expand Up @@ -370,7 +375,6 @@ def tune(args):
raise ValueError(
f"{format} is not supported for lm-head quantization, please change to {auto_round_formats}")


autoround = round(
model, tokenizer, args.bits, args.group_size, sym=not args.asym, batch_size=args.batch_size,
dataset=args.dataset, seqlen=seqlen, nblocks=args.nblocks, iters=args.iters, lr=args.lr,
Expand Down Expand Up @@ -433,8 +437,18 @@ def tune(args):


def eval(args):
device_str = detect_device(args.device)
import os
devices = args.device.split(",")
parallelism = False
if all(s.isdigit() for s in devices):
os.environ["CUDA_VISIBLE_DEVICES"] = args.device
parallelism = True
device_str = None
from auto_round.eval.evaluation import simple_evaluate

model_args = f"pretrained={args.model},trust_remote_code={not args.disable_trust_remote_code}"
if parallelism:
model_args += ",parallelize=True"
if isinstance(args.tasks, str):
tasks = args.tasks.split(',')
res = simple_evaluate(
Expand Down
21 changes: 16 additions & 5 deletions auto_round/script/mllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def __init__(self, *args, **kwargs):
self.add_argument("--eval_bs", default=None, type=int,
help="batch size in evaluation")

self.add_argument("--device", default="auto", type=str,
self.add_argument("--device", "--devices", default="auto", type=str,
help="the device to be used for tuning. The default is set to auto,"
"allowing for automatic detection."
"Currently, device settings support CPU, GPU, and HPU.")
Expand Down Expand Up @@ -206,7 +206,7 @@ def tune(args):
"auto_round:auto_gptq:marlin", "auto_round:gptq:marlin", "auto_round:auto_awq",
"auto_round:awq"]
if not args.quant_nontext_module:
supported_formats.append("auto_gptq","auto_gptq:marlin")
supported_formats.extend(["auto_gptq", "auto_gptq:marlin"])

formats = args.format.replace(' ', '').split(",")
for format in formats:
Expand All @@ -220,7 +220,13 @@ def tune(args):

assert args.dataset is not None, "dataset should not be None."

device_str = detect_device(args.device)
devices = args.device.split(',')
use_auto_mapping = False
if torch.cuda.is_available() and all(s.isdigit() for s in devices):
os.environ["CUDA_VISIBLE_DEVICES"] = args.device
use_auto_mapping = True
device_str = detect_device(devices[0])

torch_dtype = "auto"
if "hpu" in device_str:
torch_dtype = torch.bfloat16
Expand All @@ -239,8 +245,13 @@ def tune(args):
cls = MllamaForConditionalGeneration
else:
cls = AutoModelForCausalLM
model = cls.from_pretrained(
model_name, trust_remote_code=not args.disable_trust_remote_code, torch_dtype=torch_dtype)
if use_auto_mapping:
model = cls.from_pretrained(
model_name, trust_remote_code=not args.disable_trust_remote_code, torch_dtype=torch_dtype)
else:
model = cls.from_pretrained(
model_name, trust_remote_code=not args.disable_trust_remote_code, torch_dtype=torch_dtype,
device_map="auto")
wenhuach21 marked this conversation as resolved.
Show resolved Hide resolved

if "cogvlm2" in model_name:
model.config.model_type = "cogvlm2"
Expand Down
Loading