From 5776bc040ea0e85384c39ac0412b8a1cff0e62ae Mon Sep 17 00:00:00 2001 From: jainapurva Date: Tue, 24 Sep 2024 11:52:38 -0700 Subject: [PATCH] Add new evaluation metrics --- test/integration/test_integration.py | 11 ++++++++++- torchao/_models/llama/evals.sh | 18 ++++++++++++------ 2 files changed, 22 insertions(+), 7 deletions(-) diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index 6a5ea8ef9d..8e047985c5 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -72,7 +72,7 @@ AQInt8WeightOnlyQuantizedLinearWeight2, AQInt8WeightOnlyQuantizedLinearWeight3, AutoQuantizableLinearWeight, - + AQFloat8WeightOnlyQuantizedLinearWeight, ) from torch.ao.quantization.quantize_fx import convert_to_reference_fx, prepare_fx import os @@ -98,6 +98,7 @@ COMMON_DTYPES = [torch.float32, torch.float16, torch.bfloat16] COMMON_DEVICE_DTYPE = list(itertools.product(COMMON_DEVICES, COMMON_DTYPES)).copy() +is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9) def _int8wo_api(mod): if TORCH_VERSION_AT_LEAST_2_4: @@ -744,6 +745,14 @@ def test_aq_int8_weight_only_quant_3_subclass(self, device, dtype): AQInt8WeightOnlyQuantizedLinearWeight3.from_float, device, 35, test_dtype=dtype ) + @parameterized.expand(COMMON_DEVICE_DTYPE) + @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch") + @unittest.skipIf(not is_H100, "Need H100 to run") + def test_aq_float8_weight_only_quant_subclass(self, device, dtype): + self._test_lin_weight_subclass_impl( + AQFloat8WeightOnlyQuantizedLinearWeight.from_float, device, 30, test_dtype=dtype + ) + @parameterized.expand(COMMON_DEVICE_DTYPE) @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.") # @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now") diff --git a/torchao/_models/llama/evals.sh b/torchao/_models/llama/evals.sh index 253d1dfeec..7863720181 100644 --- a/torchao/_models/llama/evals.sh +++ b/torchao/_models/llama/evals.sh @@ -1,9 +1,15 @@ export CHECKPOINT_PATH=../../../checkpoints # path to checkpoints folder -export MODEL_REPO=meta-llama/Llama-2-7b-chat-hf -python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization autoround # auto-round w/o quant_lm_head -python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization autoround-cuda-1 # auto-round w/ quant_lm_head +# export MODEL_REPO=meta-llama/Llama-2-7b-chat-hf +# python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization autoround # auto-round w/o quant_lm_head +# python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization autoround-cuda-1 # auto-round w/ quant_lm_head -export MODEL_REPO=meta-llama/Meta-Llama-3-8B -python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantizatio autoround-cpu # auto-round w/o quant_lm_head -python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization autoround-cpu-1 # auto-round w/ quant_lm_head +# export MODEL_REPO=meta-llama/Meta-Llama-3-8B +# python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization autoround-cpu # auto-round w/o quant_lm_head +# python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization autoround-cuda-1 # auto-round w/ quant_lm_head + +export MODEL_REPO=meta-llama/Meta-Llama-3.1-8B +# python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization autoround-cpu # auto-round w/o quant_lm_head +# python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization autoround-cuda-1 # auto-round w/ quant_lm_head +python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization autoquant --tasks 'mmlu' 'truthfulqa_mc2' +python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization autoquant --tasks 'winogrande' 'arc_challenge'