-
-
Notifications
You must be signed in to change notification settings - Fork 5.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[TPU][Quantization] TPU
W8A8
(#11785)
Co-authored-by: Woosuk Kwon <[email protected]>
- Loading branch information
1 parent
47de882
commit 56fe4c2
Showing
18 changed files
with
565 additions
and
190 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
from dataclasses import dataclass | ||
|
||
import lm_eval | ||
import pytest | ||
|
||
TASK = "gsm8k" | ||
FILTER = "exact_match,strict-match" | ||
RTOL = 0.03 | ||
|
||
|
||
@dataclass | ||
class GSM8KAccuracyTestConfig: | ||
model_name: str | ||
excepted_value: float | ||
|
||
def get_model_args(self) -> str: | ||
return (f"pretrained={self.model_name}," | ||
"max_model_len=4096,max_num_seqs=32") | ||
|
||
|
||
# NOTE: Accuracy scores measured on GPUs. | ||
ACCURACY_CONFIGS = [ | ||
GSM8KAccuracyTestConfig( | ||
model_name="neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", | ||
excepted_value=0.76), # no bias | ||
# NOTE(rob): We cannot re-initialize VLLM in the same process for TPU, | ||
# so only one of these tests can run in a single call to pytest. As | ||
# a follow up, move this into the LM-EVAL section of the CI. | ||
# GSM8KAccuracyTestConfig( | ||
# model_name="neuralmagic/Qwen2-7B-Instruct-quantized.w8a8", | ||
# excepted_value=0.66), # bias in QKV layers | ||
] | ||
|
||
|
||
@pytest.mark.parametrize("config", ACCURACY_CONFIGS) | ||
def test_gsm8k_correctness(config: GSM8KAccuracyTestConfig): | ||
|
||
results = lm_eval.simple_evaluate( | ||
model="vllm", | ||
model_args=config.get_model_args(), | ||
tasks="gsm8k", | ||
batch_size="auto", | ||
) | ||
|
||
EXPECTED_VALUE = config.excepted_value | ||
measured_value = results["results"][TASK][FILTER] | ||
assert (measured_value - RTOL < EXPECTED_VALUE | ||
and measured_value + RTOL > EXPECTED_VALUE | ||
), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,74 +0,0 @@ | ||
from typing import List, Optional, Type | ||
|
||
import vllm.envs as envs | ||
from vllm.model_executor.layers.quantization.kernels.exllama import ( | ||
ExllamaLinearKernel) | ||
from vllm.model_executor.layers.quantization.kernels.machete import ( | ||
MacheteLinearKernel) | ||
from vllm.model_executor.layers.quantization.kernels.marlin import ( | ||
MarlinLinearKernel) | ||
from vllm.model_executor.layers.quantization.kernels.MPLinearKernel import ( | ||
MPLinearKernel, MPLinearLayerConfig) | ||
from vllm.platforms import current_platform | ||
|
||
# in priority/performance order (when available) | ||
_POSSIBLE_KERNELS: List[Type[MPLinearKernel]] = [ | ||
MacheteLinearKernel, | ||
MarlinLinearKernel, | ||
ExllamaLinearKernel, | ||
] | ||
|
||
|
||
def choose_mp_linear_kernel( | ||
config: MPLinearLayerConfig, | ||
compute_capability: Optional[int] = None) -> Type[MPLinearKernel]: | ||
""" | ||
Choose an MPLinearKernel that can implement the given config for the given | ||
compute capability. Attempts to choose the best kernel in terms of | ||
performance. | ||
Args: | ||
config (MPLinearLayerConfig): Description of the linear layer to be | ||
implemented. | ||
compute_capability (Optional[int], optional): The compute capability of | ||
the target device, if None uses `current_platform` to get the compute | ||
capability. Defaults to None. | ||
Raises: | ||
ValueError: If no kernel can implement the given config. | ||
Returns: | ||
Type[MPLinearKernel]: Chosen kernel. | ||
""" | ||
if compute_capability is None: | ||
if current_platform is None: | ||
raise ValueError("Cannot determine compute capability") | ||
_cc = current_platform.get_device_capability() | ||
compute_capability = _cc[0] * 10 + _cc[1] | ||
|
||
failure_reasons = [] | ||
for kernel in _POSSIBLE_KERNELS: | ||
if kernel.__name__ in envs.VLLM_DISABLED_KERNELS: | ||
failure_reasons.append( | ||
f' {kernel.__name__} disabled by environment variable') | ||
continue | ||
|
||
if kernel.get_min_capability() > compute_capability: | ||
failure_reasons.append( | ||
f"{kernel.__name__} requires capability " | ||
f"{kernel.get_min_capability()}, current compute capability " | ||
f"is {compute_capability}") | ||
continue | ||
|
||
can_implement, failure_reason = kernel.can_implement(config) | ||
if can_implement: | ||
return kernel | ||
else: | ||
failure_reasons.append( | ||
f' {kernel.__name__} cannot implement due to: {failure_reason}' | ||
) | ||
|
||
raise ValueError( | ||
"Failed to find a kernel that can implement the "\ | ||
"WNA16 linear layer. Reasons: \n" | ||
+ '\n'.join(failure_reasons)) | ||
File renamed without changes.
74 changes: 74 additions & 0 deletions
74
vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
from typing import List, Optional, Type | ||
|
||
import vllm.envs as envs | ||
from vllm.model_executor.layers.quantization.kernels.mixed_precision.exllama import ( # noqa: E501 | ||
ExllamaLinearKernel) | ||
from vllm.model_executor.layers.quantization.kernels.mixed_precision.machete import ( # noqa: E501 | ||
MacheteLinearKernel) | ||
from vllm.model_executor.layers.quantization.kernels.mixed_precision.marlin import ( # noqa: E501 | ||
MarlinLinearKernel) | ||
from vllm.model_executor.layers.quantization.kernels.mixed_precision.MPLinearKernel import ( # noqa: E501 | ||
MPLinearKernel, MPLinearLayerConfig) | ||
from vllm.platforms import current_platform | ||
|
||
# in priority/performance order (when available) | ||
_POSSIBLE_KERNELS: List[Type[MPLinearKernel]] = [ | ||
MacheteLinearKernel, | ||
MarlinLinearKernel, | ||
ExllamaLinearKernel, | ||
] | ||
|
||
|
||
def choose_mp_linear_kernel( | ||
config: MPLinearLayerConfig, | ||
compute_capability: Optional[int] = None) -> Type[MPLinearKernel]: | ||
""" | ||
Choose an MPLinearKernel that can implement the given config for the given | ||
compute capability. Attempts to choose the best kernel in terms of | ||
performance. | ||
Args: | ||
config (MPLinearLayerConfig): Description of the linear layer to be | ||
implemented. | ||
compute_capability (Optional[int], optional): The compute capability of | ||
the target device, if None uses `current_platform` to get the compute | ||
capability. Defaults to None. | ||
Raises: | ||
ValueError: If no kernel can implement the given config. | ||
Returns: | ||
Type[MPLinearKernel]: Chosen kernel. | ||
""" | ||
if compute_capability is None: | ||
if current_platform is None: | ||
raise ValueError("Cannot determine compute capability") | ||
_cc = current_platform.get_device_capability() | ||
compute_capability = _cc[0] * 10 + _cc[1] | ||
|
||
failure_reasons = [] | ||
for kernel in _POSSIBLE_KERNELS: | ||
if kernel.__name__ in envs.VLLM_DISABLED_KERNELS: | ||
failure_reasons.append( | ||
f' {kernel.__name__} disabled by environment variable') | ||
continue | ||
|
||
if kernel.get_min_capability() > compute_capability: | ||
failure_reasons.append( | ||
f"{kernel.__name__} requires capability " | ||
f"{kernel.get_min_capability()}, current compute capability " | ||
f"is {compute_capability}") | ||
continue | ||
|
||
can_implement, failure_reason = kernel.can_implement(config) | ||
if can_implement: | ||
return kernel | ||
else: | ||
failure_reasons.append( | ||
f' {kernel.__name__} cannot implement due to: {failure_reason}' | ||
) | ||
|
||
raise ValueError( | ||
"Failed to find a kernel that can implement the "\ | ||
"WNA16 linear layer. Reasons: \n" | ||
+ '\n'.join(failure_reasons)) |
File renamed without changes.
File renamed without changes.
File renamed without changes.
Oops, something went wrong.