Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update eval and fix example #260

Merged
merged 5 commits into from
Sep 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 6 additions & 29 deletions auto_round/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import argparse

import torch
import subprocess
import transformers
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
torch.use_deterministic_algorithms(True, warn_only=True)
Expand All @@ -26,7 +25,7 @@

from auto_round import AutoRoundConfig
from auto_round.eval.evaluation import simple_evaluate
from auto_round.utils import detect_device
from auto_round.utils import detect_device, get_library_version

def setup_parser():
parser = argparse.ArgumentParser()
Expand All @@ -36,7 +35,7 @@ def setup_parser():
)

parser.add_argument('--eval', action='store_true',
help="whether to use eval mode.")
help="whether to use eval only mode.")

parser.add_argument("--bits", default=4, type=int,
help="number of bits")
Expand Down Expand Up @@ -109,7 +108,7 @@ def setup_parser():
help="Where to store the final model.")

parser.add_argument("--disable_eval", action='store_true',
help="Whether to do lmeval evaluation.")
help="Whether to do lm-eval evaluation after tuning.")

parser.add_argument("--disable_amp", action='store_true',
help="disable amp")
Expand Down Expand Up @@ -177,7 +176,6 @@ def tune(args):
args.low_cpu_mem_tmp_dir = os.path.join(args.output_dir, "low_cpu_mem_tmp")
if args.low_cpu_mem_mode == 2:
from auto_round.low_cpu_mem.utils import load_model_with_hooks

model = load_model_with_hooks(
model_name,
model_cls,
Expand All @@ -189,7 +187,6 @@ def tune(args):
)
elif args.low_cpu_mem_mode == 1:
from auto_round.low_cpu_mem.utils import load_empty_model

low_cpu_mem_usage = True
model = load_empty_model(
model_name,
Expand All @@ -205,18 +202,16 @@ def tune(args):
trust_remote_code=not args.disable_trust_remote_code
)

from auto_round import (AutoRound,
AutoAdamRound)
from auto_round import AutoRound, AutoAdamRound

model = model.eval()
# align with GPTQ to eval ppl
seqlen = args.seqlen
if "opt" in model_name:
seqlen = model.config.max_position_embeddings
model.seqlen = model.config.max_position_embeddings
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

delete

else:
seqlen = 2048
model.seqlen = seqlen
seqlen = args.seqlen

if args.model_dtype != None:
if args.model_dtype == "float16" or args.model_dtype == "fp16":
Expand Down Expand Up @@ -262,7 +257,6 @@ def tune(args):
lm_head_layer_name = n
if args.quant_lm_head:
from transformers import AutoConfig

config = AutoConfig.from_pretrained(model_name, trust_remote_code=not args.disable_trust_remote_code)
if config.tie_word_embeddings and hasattr(model, "_tied_weights_keys"):
tied_keys = model._tied_weights_keys
Expand Down Expand Up @@ -299,7 +293,6 @@ def tune(args):
model_name = args.model.rstrip("/")
if args.low_cpu_mem_mode == 1 or args.low_cpu_mem_mode == 2:
import shutil

shutil.rmtree(args.low_cpu_mem_tmp_dir, ignore_errors=True)

model.eval()
Expand All @@ -308,22 +301,12 @@ def tune(args):

export_dir = args.output_dir + "/" + model_name.split('/')[-1] + f"-w{args.bits}g{args.group_size}"


format_list = args.format.replace(' ', '').split(',')
inplace = False if len(format_list) > 1 else True
for format_ in format_list:
eval_folder = f'{export_dir}-{format_}'
autoround.save_quantized(eval_folder, format=format_, inplace=inplace)


def get_library_version(library_name):
try:
version = subprocess.check_output(['pip', 'show', library_name]).decode().split('\n')[1].split(': ')[1]
return version
except subprocess.CalledProcessError:
return "Library not found"


lm_eval_version = get_library_version("lm-eval")

if isinstance(tasks, str):
Expand All @@ -343,28 +326,22 @@ def get_library_version(library_name):
tasks=tasks,
batch_size=args.eval_bs,
user_model=user_model)

print(make_table(res))


def eval(args):
quantization_config = AutoRoundConfig(backend=args.device)
device_str = detect_device(args.device)
user_model = AutoModelForCausalLM.from_pretrained(
args.model,
device_map=device_str, quantization_config=quantization_config)
model_args = f"pretrained={args.model},trust_remote_code={not args.disable_trust_remote_code}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reformat this file, there are many extra lines between codes

if isinstance(args.tasks, str):
tasks = args.tasks.split(',')
res = simple_evaluate(
n1ck-guo marked this conversation as resolved.
Show resolved Hide resolved
model="hf",
model_args=model_args,
user_model=user_model,
Copy link
Contributor

@wenhuach21 wenhuach21 Sep 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if "opt" in model_name:
    seqlen = model.config.max_position_embeddings
    model.seqlen = model.config.max_position_embeddings
else:
    seqlen = 2048
    model.seqlen = seqlen
seqlen = args.seqlen

remove 'model.seqlen = seqlen' this setting, this is only for ppl evaluation without using lm-eval model.seqlen = seqlen

tasks=tasks,
device=device_str,
n1ck-guo marked this conversation as resolved.
Show resolved Hide resolved
batch_size=args.eval_bs)

from lm_eval.utils import make_table # pylint: disable=E0401
n1ck-guo marked this conversation as resolved.
Show resolved Hide resolved

print(make_table(res))


Expand Down
Loading