Skip to content

Commit

Permalink
[Misc] Move some model utils into vision file (vllm-project#11848)
Browse files Browse the repository at this point in the history
Signed-off-by: DarkLight1337 <[email protected]>
  • Loading branch information
DarkLight1337 authored and frreiss committed Jan 10, 2025
1 parent 5064a0f commit 43e04ad
Show file tree
Hide file tree
Showing 8 changed files with 94 additions and 92 deletions.
5 changes: 2 additions & 3 deletions vllm/model_executor/models/clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,10 @@
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal.utils import (cached_get_tokenizer,
consecutive_placeholder_ranges,
repeat_and_pad_placeholder_tokens,
resolve_visual_encoder_outputs)
repeat_and_pad_placeholder_tokens)
from vllm.sequence import SequenceData

from .vision import VisionEncoderInfo
from .vision import VisionEncoderInfo, resolve_visual_encoder_outputs


def get_clip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
Expand Down
5 changes: 2 additions & 3 deletions vllm/model_executor/models/pixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,13 @@
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.multimodal.inputs import NestedTensors, PlaceholderRange
from vllm.multimodal.utils import (cached_get_tokenizer,
consecutive_placeholder_ranges,
resolve_visual_encoder_outputs)
consecutive_placeholder_ranges)
from vllm.sequence import IntermediateTensors, SequenceData

from .interfaces import SupportsMultiModal, SupportsPP
from .utils import (init_vllm_registered_model, maybe_prefix,
merge_multimodal_embeddings)
from .vision import VisionEncoderInfo
from .vision import VisionEncoderInfo, resolve_visual_encoder_outputs

try:
from xformers import ops as xops
Expand Down
3 changes: 2 additions & 1 deletion vllm/model_executor/models/qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,9 @@
from vllm.transformers_utils.config import uses_mrope

from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, WeightsMapper, get_vit_attn_backend,
from .utils import (AutoWeightsLoader, WeightsMapper,
init_vllm_registered_model, maybe_prefix)
from .vision import get_vit_attn_backend

logger = init_logger(__name__)

Expand Down
5 changes: 2 additions & 3 deletions vllm/model_executor/models/siglip.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,10 @@
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal.utils import (cached_get_tokenizer,
consecutive_placeholder_ranges,
repeat_and_pad_placeholder_tokens,
resolve_visual_encoder_outputs)
repeat_and_pad_placeholder_tokens)
from vllm.sequence import SequenceData

from .vision import VisionEncoderInfo
from .vision import VisionEncoderInfo, resolve_visual_encoder_outputs


def get_siglip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
Expand Down
37 changes: 1 addition & 36 deletions vllm/model_executor/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,12 @@
from torch.func import functional_call
from transformers import PretrainedConfig

import vllm.envs as envs
from vllm.attention.selector import (backend_name_to_enum,
get_global_forced_attn_backend)
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal import MultiModalPlaceholderMap, NestedTensors
from vllm.platforms import _Backend, current_platform
from vllm.sequence import IntermediateTensors
from vllm.utils import is_pin_memory_available, print_warning_once
from vllm.utils import is_pin_memory_available

logger = init_logger(__name__)

Expand Down Expand Up @@ -612,37 +608,6 @@ def make_empty_intermediate_tensors(
return make_empty_intermediate_tensors


def get_vit_attn_backend(support_fa: bool = False) -> _Backend:
"""
Get the available attention backend for Vision Transformer.
"""
# TODO(Isotr0py): Remove `support_fa` after support FA for all ViTs attn.
selected_backend: Optional[_Backend] = get_global_forced_attn_backend()
if selected_backend is None:
backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND
if backend_by_env_var is not None:
selected_backend = backend_name_to_enum(backend_by_env_var)
if selected_backend is None:
# For Volta and Turing GPUs, use xformers instead.
device_available = current_platform.has_device_capability(80)
if device_available and support_fa:
from transformers.utils import is_flash_attn_2_available
if is_flash_attn_2_available():
selected_backend = _Backend.FLASH_ATTN
else:
print_warning_once(
"Current `vllm-flash-attn` has a bug inside vision module, "
"so we use xformers backend instead. You can run "
"`pip install flash-attn` to use flash-attention backend.")
selected_backend = _Backend.XFORMERS
elif current_platform.is_cpu() or current_platform.is_rocm():
# ROCM doesn't support xformers
selected_backend = _Backend.TORCH_SDPA
else:
selected_backend = _Backend.XFORMERS
return selected_backend


def maybe_prefix(prefix: str, name: str) -> str:
"""Add a prefix to a name if the prefix is non-empty.
Expand Down
83 changes: 82 additions & 1 deletion vllm/model_executor/models/vision.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
from abc import ABC, abstractmethod
from typing import Final, Generic, Protocol, TypeVar
from typing import Final, Generic, Optional, Protocol, TypeVar, Union

import torch
from transformers import PretrainedConfig

import vllm.envs as envs
from vllm.attention.selector import (backend_name_to_enum,
get_global_forced_attn_backend)
from vllm.platforms import _Backend, current_platform
from vllm.utils import print_warning_once

_C = TypeVar("_C", bound=PretrainedConfig)


Expand Down Expand Up @@ -60,3 +67,77 @@ def get_vision_encoder_info(

msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)


def get_vit_attn_backend(support_fa: bool = False) -> _Backend:
"""
Get the available attention backend for Vision Transformer.
"""
# TODO(Isotr0py): Remove `support_fa` after support FA for all ViTs attn.
selected_backend: Optional[_Backend] = get_global_forced_attn_backend()
if selected_backend is None:
backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND
if backend_by_env_var is not None:
selected_backend = backend_name_to_enum(backend_by_env_var)
if selected_backend is None:
# For Volta and Turing GPUs, use xformers instead.
device_available = current_platform.has_device_capability(80)
if device_available and support_fa:
from transformers.utils import is_flash_attn_2_available
if is_flash_attn_2_available():
selected_backend = _Backend.FLASH_ATTN
else:
print_warning_once(
"Current `vllm-flash-attn` has a bug inside vision module, "
"so we use xformers backend instead. You can run "
"`pip install flash-attn` to use flash-attention backend.")
selected_backend = _Backend.XFORMERS
elif current_platform.is_cpu() or current_platform.is_rocm():
# ROCM doesn't support xformers
selected_backend = _Backend.TORCH_SDPA
else:
selected_backend = _Backend.XFORMERS
return selected_backend


def resolve_visual_encoder_outputs(
encoder_outputs: Union[torch.Tensor, list[torch.Tensor]],
feature_sample_layers: Optional[list[int]],
post_layer_norm: Optional[torch.nn.LayerNorm],
max_possible_layers: int,
) -> torch.Tensor:
"""Given the outputs a visual encoder module that may correspond to the
output of the last layer, or a list of hidden states to be stacked,
handle post normalization and resolve it into a single output tensor.
Args:
encoder_outputs: Output of encoder's last layer or all hidden states.
feature_sample_layers: Optional layer indices to grab from the encoder
outputs; if provided, encoder outputs must be a list.
post_layer_norm: Post norm to apply to the output of the encoder.
max_possible_layers: Total layers in the fully loaded visual encoder.
"""
if feature_sample_layers is None:
if post_layer_norm is not None:
return post_layer_norm(encoder_outputs)
return encoder_outputs

# Get the hidden states corresponding to the layer indices.
# Negative values are relative to the full visual encoder,
# so offset them depending on how many layers were loaded.
# NOTE: this assumes that encoder_outputs contains a list
# of hidden states in the same order as the encoder layers
# that produced them.
offset = max_possible_layers - len(encoder_outputs)
hs_pool = [
encoder_outputs[layer_idx]
if layer_idx >= 0 else encoder_outputs[layer_idx + offset]
for layer_idx in feature_sample_layers
]

# Apply post-norm on the final hidden state if we are using it
uses_last_layer = feature_sample_layers[-1] in (len(hs_pool) - 1, -1)
if post_layer_norm is not None and uses_last_layer:
hs_pool[-1] = post_layer_norm(encoder_outputs)
return torch.cat(hs_pool, dim=-1)
4 changes: 3 additions & 1 deletion vllm/multimodal/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ class MultiModalDataBuiltins(TypedDict, total=False):
MultiModalDataDict: TypeAlias = Mapping[str, ModalityData[Any]]
"""
A dictionary containing an entry for each modality type to input.
The built-in modalities are defined by :class:`MultiModalDataBuiltins`.
"""


Expand Down Expand Up @@ -485,7 +487,7 @@ def get_items(self, modality: str) -> Sequence[MultiModalKwargsItem]:

MultiModalPlaceholderDict = Mapping[str, Sequence[PlaceholderRange]]
"""
A dictionary containing placeholder ranges.
A dictionary containing placeholder ranges for each modality.
"""


Expand Down
44 changes: 0 additions & 44 deletions vllm/multimodal/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import numpy as np
import numpy.typing as npt
import torch
from PIL import Image

import vllm.envs as envs
Expand Down Expand Up @@ -285,49 +284,6 @@ def encode_video_base64(frames: npt.NDArray) -> str:
return video_io.encode_base64(frames)


def resolve_visual_encoder_outputs(
encoder_outputs: Union[torch.Tensor, list[torch.Tensor]],
feature_sample_layers: Optional[list[int]],
post_layer_norm: Optional[torch.nn.LayerNorm],
max_possible_layers: int,
) -> torch.Tensor:
"""Given the outputs a visual encoder module that may correspond to the
output of the last layer, or a list of hidden states to be stacked,
handle post normalization and resolve it into a single output tensor.
Args:
encoder_outputs: Output of encoder's last layer or all hidden states.
feature_sample_layers: Optional layer indices to grab from the encoder
outputs; if provided, encoder outputs must be a list.
post_layer_norm: Post norm to apply to the output of the encoder.
max_possible_layers: Total layers in the fully loaded visual encoder.
"""
if feature_sample_layers is None:
if post_layer_norm is not None:
return post_layer_norm(encoder_outputs)
return encoder_outputs

# Get the hidden states corresponding to the layer indices.
# Negative values are relative to the full visual encoder,
# so offset them depending on how many layers were loaded.
# NOTE: this assumes that encoder_outputs contains a list
# of hidden states in the same order as the encoder layers
# that produced them.
offset = max_possible_layers - len(encoder_outputs)
hs_pool = [
encoder_outputs[layer_idx]
if layer_idx >= 0 else encoder_outputs[layer_idx + offset]
for layer_idx in feature_sample_layers
]

# Apply post-norm on the final hidden state if we are using it
uses_last_layer = feature_sample_layers[-1] in (len(hs_pool) - 1, -1)
if post_layer_norm is not None and uses_last_layer:
hs_pool[-1] = post_layer_norm(encoder_outputs)
return torch.cat(hs_pool, dim=-1)


# Utilities for input processors
_T = TypeVar("_T", str, int)

Expand Down

0 comments on commit 43e04ad

Please sign in to comment.