-
Notifications
You must be signed in to change notification settings - Fork 22
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
update eval and fix example #260
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,7 +17,6 @@ | |
import argparse | ||
|
||
import torch | ||
import subprocess | ||
import transformers | ||
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" | ||
torch.use_deterministic_algorithms(True, warn_only=True) | ||
|
@@ -26,7 +25,7 @@ | |
|
||
from auto_round import AutoRoundConfig | ||
from auto_round.eval.evaluation import simple_evaluate | ||
from auto_round.utils import detect_device | ||
from auto_round.utils import detect_device, get_library_version | ||
|
||
def setup_parser(): | ||
parser = argparse.ArgumentParser() | ||
|
@@ -36,7 +35,7 @@ def setup_parser(): | |
) | ||
|
||
parser.add_argument('--eval', action='store_true', | ||
help="whether to use eval mode.") | ||
help="whether to use eval only mode.") | ||
|
||
parser.add_argument("--bits", default=4, type=int, | ||
help="number of bits") | ||
|
@@ -109,7 +108,7 @@ def setup_parser(): | |
help="Where to store the final model.") | ||
|
||
parser.add_argument("--disable_eval", action='store_true', | ||
help="Whether to do lmeval evaluation.") | ||
help="Whether to do lm-eval evaluation after tuning.") | ||
|
||
parser.add_argument("--disable_amp", action='store_true', | ||
help="disable amp") | ||
|
@@ -177,7 +176,6 @@ def tune(args): | |
args.low_cpu_mem_tmp_dir = os.path.join(args.output_dir, "low_cpu_mem_tmp") | ||
if args.low_cpu_mem_mode == 2: | ||
from auto_round.low_cpu_mem.utils import load_model_with_hooks | ||
|
||
model = load_model_with_hooks( | ||
model_name, | ||
model_cls, | ||
|
@@ -189,7 +187,6 @@ def tune(args): | |
) | ||
elif args.low_cpu_mem_mode == 1: | ||
from auto_round.low_cpu_mem.utils import load_empty_model | ||
|
||
low_cpu_mem_usage = True | ||
model = load_empty_model( | ||
model_name, | ||
|
@@ -205,18 +202,16 @@ def tune(args): | |
trust_remote_code=not args.disable_trust_remote_code | ||
) | ||
|
||
from auto_round import (AutoRound, | ||
AutoAdamRound) | ||
from auto_round import AutoRound, AutoAdamRound | ||
|
||
model = model.eval() | ||
# align with GPTQ to eval ppl | ||
seqlen = args.seqlen | ||
if "opt" in model_name: | ||
seqlen = model.config.max_position_embeddings | ||
model.seqlen = model.config.max_position_embeddings | ||
else: | ||
seqlen = 2048 | ||
model.seqlen = seqlen | ||
seqlen = args.seqlen | ||
|
||
if args.model_dtype != None: | ||
if args.model_dtype == "float16" or args.model_dtype == "fp16": | ||
|
@@ -262,7 +257,6 @@ def tune(args): | |
lm_head_layer_name = n | ||
if args.quant_lm_head: | ||
from transformers import AutoConfig | ||
|
||
config = AutoConfig.from_pretrained(model_name, trust_remote_code=not args.disable_trust_remote_code) | ||
if config.tie_word_embeddings and hasattr(model, "_tied_weights_keys"): | ||
tied_keys = model._tied_weights_keys | ||
|
@@ -299,7 +293,6 @@ def tune(args): | |
model_name = args.model.rstrip("/") | ||
if args.low_cpu_mem_mode == 1 or args.low_cpu_mem_mode == 2: | ||
import shutil | ||
|
||
shutil.rmtree(args.low_cpu_mem_tmp_dir, ignore_errors=True) | ||
|
||
model.eval() | ||
|
@@ -308,22 +301,12 @@ def tune(args): | |
|
||
export_dir = args.output_dir + "/" + model_name.split('/')[-1] + f"-w{args.bits}g{args.group_size}" | ||
|
||
|
||
format_list = args.format.replace(' ', '').split(',') | ||
inplace = False if len(format_list) > 1 else True | ||
for format_ in format_list: | ||
eval_folder = f'{export_dir}-{format_}' | ||
autoround.save_quantized(eval_folder, format=format_, inplace=inplace) | ||
|
||
|
||
def get_library_version(library_name): | ||
try: | ||
version = subprocess.check_output(['pip', 'show', library_name]).decode().split('\n')[1].split(': ')[1] | ||
return version | ||
except subprocess.CalledProcessError: | ||
return "Library not found" | ||
|
||
|
||
lm_eval_version = get_library_version("lm-eval") | ||
|
||
if isinstance(tasks, str): | ||
|
@@ -343,28 +326,22 @@ def get_library_version(library_name): | |
tasks=tasks, | ||
batch_size=args.eval_bs, | ||
user_model=user_model) | ||
|
||
print(make_table(res)) | ||
|
||
|
||
def eval(args): | ||
quantization_config = AutoRoundConfig(backend=args.device) | ||
device_str = detect_device(args.device) | ||
user_model = AutoModelForCausalLM.from_pretrained( | ||
args.model, | ||
device_map=device_str, quantization_config=quantization_config) | ||
model_args = f"pretrained={args.model},trust_remote_code={not args.disable_trust_remote_code}" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. reformat this file, there are many extra lines between codes |
||
if isinstance(args.tasks, str): | ||
tasks = args.tasks.split(',') | ||
res = simple_evaluate( | ||
n1ck-guo marked this conversation as resolved.
Show resolved
Hide resolved
|
||
model="hf", | ||
model_args=model_args, | ||
user_model=user_model, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
remove 'model.seqlen = seqlen' this setting, this is only for ppl evaluation without using lm-eval model.seqlen = seqlen |
||
tasks=tasks, | ||
device=device_str, | ||
n1ck-guo marked this conversation as resolved.
Show resolved
Hide resolved
|
||
batch_size=args.eval_bs) | ||
|
||
from lm_eval.utils import make_table # pylint: disable=E0401 | ||
n1ck-guo marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
print(make_table(res)) | ||
|
||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
delete