Skip to content

Commit

Permalink
Add new evaluation metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
jainapurva committed Sep 27, 2024
1 parent 26e790d commit 5776bc0
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 7 deletions.
11 changes: 10 additions & 1 deletion test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@
AQInt8WeightOnlyQuantizedLinearWeight2,
AQInt8WeightOnlyQuantizedLinearWeight3,
AutoQuantizableLinearWeight,

AQFloat8WeightOnlyQuantizedLinearWeight,
)
from torch.ao.quantization.quantize_fx import convert_to_reference_fx, prepare_fx
import os
Expand All @@ -98,6 +98,7 @@
COMMON_DTYPES = [torch.float32, torch.float16, torch.bfloat16]

COMMON_DEVICE_DTYPE = list(itertools.product(COMMON_DEVICES, COMMON_DTYPES)).copy()
is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)

def _int8wo_api(mod):
if TORCH_VERSION_AT_LEAST_2_4:
Expand Down Expand Up @@ -744,6 +745,14 @@ def test_aq_int8_weight_only_quant_3_subclass(self, device, dtype):
AQInt8WeightOnlyQuantizedLinearWeight3.from_float, device, 35, test_dtype=dtype
)

@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch")
@unittest.skipIf(not is_H100, "Need H100 to run")
def test_aq_float8_weight_only_quant_subclass(self, device, dtype):
self._test_lin_weight_subclass_impl(
AQFloat8WeightOnlyQuantizedLinearWeight.from_float, device, 30, test_dtype=dtype
)

@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.")
# @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now")
Expand Down
18 changes: 12 additions & 6 deletions torchao/_models/llama/evals.sh
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
export CHECKPOINT_PATH=../../../checkpoints # path to checkpoints folder

export MODEL_REPO=meta-llama/Llama-2-7b-chat-hf
python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization autoround # auto-round w/o quant_lm_head
python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization autoround-cuda-1 # auto-round w/ quant_lm_head
# export MODEL_REPO=meta-llama/Llama-2-7b-chat-hf
# python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization autoround # auto-round w/o quant_lm_head
# python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization autoround-cuda-1 # auto-round w/ quant_lm_head

export MODEL_REPO=meta-llama/Meta-Llama-3-8B
python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantizatio autoround-cpu # auto-round w/o quant_lm_head
python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization autoround-cpu-1 # auto-round w/ quant_lm_head
# export MODEL_REPO=meta-llama/Meta-Llama-3-8B
# python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization autoround-cpu # auto-round w/o quant_lm_head
# python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization autoround-cuda-1 # auto-round w/ quant_lm_head

export MODEL_REPO=meta-llama/Meta-Llama-3.1-8B
# python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization autoround-cpu # auto-round w/o quant_lm_head
# python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization autoround-cuda-1 # auto-round w/ quant_lm_head
python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization autoquant --tasks 'mmlu' 'truthfulqa_mc2'
python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization autoquant --tasks 'winogrande' 'arc_challenge'

0 comments on commit 5776bc0

Please sign in to comment.