Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Platform] Move model arch check to platform #11503

Merged
merged 1 commit into from
Dec 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 1 addition & 36 deletions vllm/model_executor/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,31 +186,6 @@
**_SPECULATIVE_DECODING_MODELS,
}

# Models not supported by ROCm.
_ROCM_UNSUPPORTED_MODELS: List[str] = []

# Models partially supported by ROCm.
# Architecture -> Reason.
_ROCM_SWA_REASON = ("Sliding window attention (SWA) is not yet supported in "
"Triton flash attention. For half-precision SWA support, "
"please use CK flash attention by setting "
"`VLLM_USE_TRITON_FLASH_ATTN=0`")
_ROCM_PARTIALLY_SUPPORTED_MODELS: Dict[str, str] = {
"Qwen2ForCausalLM":
_ROCM_SWA_REASON,
"MistralForCausalLM":
_ROCM_SWA_REASON,
"MixtralForCausalLM":
_ROCM_SWA_REASON,
"PaliGemmaForConditionalGeneration":
("ROCm flash attention does not yet "
"fully support 32-bit precision on PaliGemma"),
"Phi3VForCausalLM":
("ROCm Triton flash attention may run into compilation errors due to "
"excessive use of shared memory. If this happens, disable Triton FA "
"by setting `VLLM_USE_TRITON_FLASH_ATTN=0`")
}


@dataclass(frozen=True)
class _ModelInfo:
Expand Down Expand Up @@ -296,17 +271,7 @@ def _try_load_model_cls(
model_arch: str,
model: _BaseRegisteredModel,
) -> Optional[Type[nn.Module]]:
if current_platform.is_rocm():
if model_arch in _ROCM_UNSUPPORTED_MODELS:
raise ValueError(f"Model architecture '{model_arch}' is not "
"supported by ROCm for now.")

if model_arch in _ROCM_PARTIALLY_SUPPORTED_MODELS:
msg = _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch]
logger.warning(
"Model architecture '%s' is partially "
"supported by ROCm: %s", model_arch, msg)

current_platform.verify_model_arch(model_arch)
try:
return model.load_model_cls()
except Exception:
Expand Down
12 changes: 12 additions & 0 deletions vllm/platforms/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,18 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
"""
pass

@classmethod
def verify_model_arch(cls, model_arch: str) -> None:
"""
Verify whether the current platform supports the specified model
architecture.

- This will raise an Error or Warning based on the model support on
the current platform.
- By default all models are considered supported.
"""
pass

@classmethod
def verify_quantization(cls, quant: str) -> None:
"""
Expand Down
39 changes: 38 additions & 1 deletion vllm/platforms/rocm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
from functools import lru_cache
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Dict, List, Optional

import torch

Expand Down Expand Up @@ -33,6 +33,31 @@
" `spawn` instead.")
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"

# Models not supported by ROCm.
_ROCM_UNSUPPORTED_MODELS: List[str] = []

# Models partially supported by ROCm.
# Architecture -> Reason.
_ROCM_SWA_REASON = ("Sliding window attention (SWA) is not yet supported in "
"Triton flash attention. For half-precision SWA support, "
"please use CK flash attention by setting "
"`VLLM_USE_TRITON_FLASH_ATTN=0`")
_ROCM_PARTIALLY_SUPPORTED_MODELS: Dict[str, str] = {
"Qwen2ForCausalLM":
_ROCM_SWA_REASON,
"MistralForCausalLM":
_ROCM_SWA_REASON,
"MixtralForCausalLM":
_ROCM_SWA_REASON,
"PaliGemmaForConditionalGeneration":
("ROCm flash attention does not yet "
"fully support 32-bit precision on PaliGemma"),
"Phi3VForCausalLM":
("ROCm Triton flash attention may run into compilation errors due to "
"excessive use of shared memory. If this happens, disable Triton FA "
"by setting `VLLM_USE_TRITON_FLASH_ATTN=0`")
}


class RocmPlatform(Platform):
_enum = PlatformEnum.ROCM
Expand Down Expand Up @@ -102,6 +127,18 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
else:
parallel_config.worker_cls = "vllm.worker.worker.Worker"

@classmethod
def verify_model_arch(cls, model_arch: str) -> None:
MengqingCao marked this conversation as resolved.
Show resolved Hide resolved
if model_arch in _ROCM_UNSUPPORTED_MODELS:
raise ValueError(f"Model architecture '{model_arch}' is not "
"supported by ROCm for now.")

if model_arch in _ROCM_PARTIALLY_SUPPORTED_MODELS:
msg = _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch]
logger.warning(
"Model architecture '%s' is partially "
"supported by ROCm: %s", model_arch, msg)

@classmethod
def verify_quantization(cls, quant: str) -> None:
super().verify_quantization(quant)
Expand Down
Loading