diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index feb33bb373c3e..89992de7e238d 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -187,31 +187,6 @@ **_SPECULATIVE_DECODING_MODELS, } -# Models not supported by ROCm. -_ROCM_UNSUPPORTED_MODELS: List[str] = [] - -# Models partially supported by ROCm. -# Architecture -> Reason. -_ROCM_SWA_REASON = ("Sliding window attention (SWA) is not yet supported in " - "Triton flash attention. For half-precision SWA support, " - "please use CK flash attention by setting " - "`VLLM_USE_TRITON_FLASH_ATTN=0`") -_ROCM_PARTIALLY_SUPPORTED_MODELS: Dict[str, str] = { - "Qwen2ForCausalLM": - _ROCM_SWA_REASON, - "MistralForCausalLM": - _ROCM_SWA_REASON, - "MixtralForCausalLM": - _ROCM_SWA_REASON, - "PaliGemmaForConditionalGeneration": - ("ROCm flash attention does not yet " - "fully support 32-bit precision on PaliGemma"), - "Phi3VForCausalLM": - ("ROCm Triton flash attention may run into compilation errors due to " - "excessive use of shared memory. If this happens, disable Triton FA " - "by setting `VLLM_USE_TRITON_FLASH_ATTN=0`") -} - @dataclass(frozen=True) class _ModelInfo: @@ -297,17 +272,7 @@ def _try_load_model_cls( model_arch: str, model: _BaseRegisteredModel, ) -> Optional[Type[nn.Module]]: - if current_platform.is_rocm(): - if model_arch in _ROCM_UNSUPPORTED_MODELS: - raise ValueError(f"Model architecture '{model_arch}' is not " - "supported by ROCm for now.") - - if model_arch in _ROCM_PARTIALLY_SUPPORTED_MODELS: - msg = _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch] - logger.warning( - "Model architecture '%s' is partially " - "supported by ROCm: %s", model_arch, msg) - + current_platform.verify_model_arch(model_arch) try: return model.load_model_cls() except Exception: diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 4150b0cdf836a..ddccaa2ce0148 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -199,6 +199,18 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: """ pass + @classmethod + def verify_model_arch(cls, model_arch: str) -> None: + """ + Verify whether the current platform supports the specified model + architecture. + + - This will raise an Error or Warning based on the model support on + the current platform. + - By default all models are considered supported. + """ + pass + @classmethod def verify_quantization(cls, quant: str) -> None: """ diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 7778b565372cb..aa779f265135f 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -1,6 +1,6 @@ import os from functools import lru_cache -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Dict, List, Optional import torch @@ -33,6 +33,31 @@ " `spawn` instead.") os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" +# Models not supported by ROCm. +_ROCM_UNSUPPORTED_MODELS: List[str] = [] + +# Models partially supported by ROCm. +# Architecture -> Reason. +_ROCM_SWA_REASON = ("Sliding window attention (SWA) is not yet supported in " + "Triton flash attention. For half-precision SWA support, " + "please use CK flash attention by setting " + "`VLLM_USE_TRITON_FLASH_ATTN=0`") +_ROCM_PARTIALLY_SUPPORTED_MODELS: Dict[str, str] = { + "Qwen2ForCausalLM": + _ROCM_SWA_REASON, + "MistralForCausalLM": + _ROCM_SWA_REASON, + "MixtralForCausalLM": + _ROCM_SWA_REASON, + "PaliGemmaForConditionalGeneration": + ("ROCm flash attention does not yet " + "fully support 32-bit precision on PaliGemma"), + "Phi3VForCausalLM": + ("ROCm Triton flash attention may run into compilation errors due to " + "excessive use of shared memory. If this happens, disable Triton FA " + "by setting `VLLM_USE_TRITON_FLASH_ATTN=0`") +} + class RocmPlatform(Platform): _enum = PlatformEnum.ROCM @@ -102,6 +127,18 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: else: parallel_config.worker_cls = "vllm.worker.worker.Worker" + @classmethod + def verify_model_arch(cls, model_arch: str) -> None: + if model_arch in _ROCM_UNSUPPORTED_MODELS: + raise ValueError(f"Model architecture '{model_arch}' is not " + "supported by ROCm for now.") + + if model_arch in _ROCM_PARTIALLY_SUPPORTED_MODELS: + msg = _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch] + logger.warning( + "Model architecture '%s' is partially " + "supported by ROCm: %s", model_arch, msg) + @classmethod def verify_quantization(cls, quant: str) -> None: super().verify_quantization(quant)