From 8c38ee7007c50ac5aef9ed43ae91c6f031799c40 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Fri, 3 Jan 2025 00:39:27 +0800 Subject: [PATCH] [VLM] Merged multi-modal processor for LLaVA-NeXT (#11682) Signed-off-by: DarkLight1337 --- .../mm_processor_kwargs/test_llava_next.py | 70 ---- tests/multimodal/test_mapper.py | 118 ------- tests/multimodal/test_processing.py | 97 +++++ .../vllm_add_dummy_model/my_llava.py | 4 +- vllm/model_executor/models/clip.py | 25 ++ vllm/model_executor/models/fuyu.py | 6 +- vllm/model_executor/models/llava.py | 334 +++++++++++------- vllm/model_executor/models/llava_next.py | 321 ++++++----------- vllm/model_executor/models/phi3v.py | 24 +- vllm/model_executor/models/pixtral.py | 66 +++- vllm/model_executor/models/siglip.py | 25 ++ vllm/model_executor/models/utils.py | 2 +- vllm/model_executor/models/vision.py | 52 +++ vllm/multimodal/parse.py | 12 +- 14 files changed, 605 insertions(+), 551 deletions(-) delete mode 100644 tests/models/decoder_only/vision_language/mm_processor_kwargs/test_llava_next.py delete mode 100644 tests/multimodal/test_mapper.py create mode 100644 vllm/model_executor/models/vision.py diff --git a/tests/models/decoder_only/vision_language/mm_processor_kwargs/test_llava_next.py b/tests/models/decoder_only/vision_language/mm_processor_kwargs/test_llava_next.py deleted file mode 100644 index 51c0085101dd0..0000000000000 --- a/tests/models/decoder_only/vision_language/mm_processor_kwargs/test_llava_next.py +++ /dev/null @@ -1,70 +0,0 @@ -import pytest - -from vllm.inputs import InputContext - -from ....utils import build_model_context - - -@pytest.fixture() -def get_max_llava_next_image_tokens(): - from vllm.model_executor.models.llava_next import ( - get_max_llava_next_image_tokens) - return get_max_llava_next_image_tokens - - -@pytest.fixture() -def dummy_data_for_llava_next(): - from vllm.model_executor.models.llava_next import dummy_data_for_llava_next - return dummy_data_for_llava_next - - -@pytest.mark.parametrize("gridpoints,expected_max_tokens", [ - ([[336, 336]], 1176), - ([[336, 672], [672, 336], [672, 672], [1008, 336], [336, 1008]], 2928), -]) -def test_get_max_llava_next_image_tokens(gridpoints, expected_max_tokens, - get_max_llava_next_image_tokens): - ctx = build_model_context(model_name="llava-hf/llava-v1.6-mistral-7b-hf") - - # Update the config image_grid_pinpoints - # and calculate the resulting max tokens - ctx.model_config.hf_config.image_grid_pinpoints = gridpoints - - actual_max_tokens = get_max_llava_next_image_tokens( - InputContext(ctx.model_config)) - - assert expected_max_tokens == actual_max_tokens - - -@pytest.mark.parametrize( - "gridpoints,expected_size", - [ - # One point; it has to be the largest - ([[336, 336]], (336, 336)), - # Default for most llava next models; the 2x2 tile is the largest - ([[336, 672], [672, 336], [672, 672], [1008, 336], [336, 1008]], - (672, 672)), - # If two rectangular gridpoints are the same, the more vertical - # one has the higher feature count due to newline features - ([[336, 672], [672, 336]], (672, 336)) - ]) -def test_dummy_data_for_llava_next_feature_size(dummy_data_for_llava_next, - gridpoints, expected_size): - ctx = build_model_context(model_name="llava-hf/llava-v1.6-mistral-7b-hf") - - # Update the config image_grid_pinpoints - ctx.model_config.hf_config.image_grid_pinpoints = gridpoints - seq_len = 5000 # bigger than the max feature size for any image - - dummy_data = dummy_data_for_llava_next( - ctx, - seq_len=seq_len, - mm_counts={"image": 1}, - ) - seq_data = dummy_data.seq_data - mm_data = dummy_data.multi_modal_data - - # The dummy data dims should match the gridpoint with the biggest feat size - assert mm_data["image"].height == expected_size[0] - assert mm_data["image"].width == expected_size[1] - assert len(seq_data.get_token_ids()) >= seq_len diff --git a/tests/multimodal/test_mapper.py b/tests/multimodal/test_mapper.py deleted file mode 100644 index 81f2a06182bcc..0000000000000 --- a/tests/multimodal/test_mapper.py +++ /dev/null @@ -1,118 +0,0 @@ -from contextlib import nullcontext - -import numpy as np -import pytest -from transformers import LlavaNextImageProcessor - -from vllm.config import ModelConfig -from vllm.multimodal import MultiModalRegistry -from vllm.multimodal.image import rescale_image_size - - -@pytest.fixture -def mm_registry(): - return MultiModalRegistry() - - -@pytest.mark.parametrize("dtype", ["half", "float"]) -@pytest.mark.parametrize("size_factor", [0.25, 0.5, 1.0]) -def test_llava_next_image_processor(image_assets, mm_registry, dtype, - size_factor): - MODEL_NAME = "llava-hf/llava-v1.6-vicuna-7b-hf" - - hf_processor = LlavaNextImageProcessor.from_pretrained(MODEL_NAME) - assert isinstance(hf_processor, LlavaNextImageProcessor) - - model_config = ModelConfig( - model=MODEL_NAME, - task="auto", - tokenizer=MODEL_NAME, - tokenizer_mode="auto", - trust_remote_code=False, - seed=0, - dtype=dtype, - revision=None, - limit_mm_per_prompt={"image": 1}, - ) - - mm_registry.init_mm_limits_per_prompt(model_config) - - for asset in image_assets: - image = rescale_image_size(asset.pil_image, size_factor) - - hf_result = hf_processor.preprocess( - image, - return_tensors="pt", - ) - vllm_result = mm_registry.map_input( - model_config, - {"image": image}, - ) - - assert hf_result.keys() == vllm_result.keys() - for key, hf_tensor in hf_result.items(): - hf_arr: np.ndarray = hf_tensor.numpy() - vllm_arr: np.ndarray = vllm_result[key].numpy() - - assert hf_arr.shape == vllm_arr.shape, f"Failed for key={key}" - assert np.allclose(hf_arr, vllm_arr), f"Failed for key={key}" - - -@pytest.mark.parametrize( - ("num_images", "limit", "is_valid"), - [(0, 0, True), (0, 1, True), (1, 0, False), (1, 1, True), (1, 2, True), - (2, 1, False), (2, 2, True)], -) -def test_mm_limits(image_assets, mm_registry, num_images, limit, is_valid): - MODEL_NAME = "llava-hf/llava-v1.6-mistral-7b-hf" - - model_config = ModelConfig( - model=MODEL_NAME, - task="auto", - tokenizer=MODEL_NAME, - tokenizer_mode="auto", - trust_remote_code=False, - seed=0, - dtype="half", - revision=None, - limit_mm_per_prompt={"image": limit}, - ) - - mm_registry.init_mm_limits_per_prompt(model_config) - - image = image_assets[0].pil_image - if num_images == 0: - mm_inputs = {} - elif num_images == 1: - mm_inputs = {"image": image} - else: - mm_inputs = {"image": [image] * num_images} - - with nullcontext() if is_valid else pytest.raises(ValueError): - mm_registry.map_input(model_config, mm_inputs) - - -# NOTE: We don't test zero images since the HF processor doesn't support it -@pytest.mark.parametrize("num_images", [1, 2]) -def test_image_mapper_multi(image_assets, mm_registry, num_images): - MODEL_NAME = "llava-hf/llava-v1.6-mistral-7b-hf" - - model_config = ModelConfig( - model=MODEL_NAME, - task="auto", - tokenizer=MODEL_NAME, - tokenizer_mode="auto", - trust_remote_code=False, - seed=0, - dtype="half", - revision=None, - limit_mm_per_prompt={"image": num_images}, - ) - - mm_registry.init_mm_limits_per_prompt(model_config) - - image = image_assets[0].pil_image - mm_inputs = {"image": [image] * num_images} - - mapped_inputs = mm_registry.map_input(model_config, mm_inputs) - assert len(mapped_inputs["pixel_values"]) == num_images diff --git a/tests/multimodal/test_processing.py b/tests/multimodal/test_processing.py index 9573351b4dff1..f99d7556b27f9 100644 --- a/tests/multimodal/test_processing.py +++ b/tests/multimodal/test_processing.py @@ -1,5 +1,7 @@ +from contextlib import nullcontext from functools import partial from typing import cast +from unittest.mock import MagicMock import numpy as np import pytest @@ -526,6 +528,100 @@ def _rand_audio( return rng.rand(audio_len), sr +@pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"]) +@pytest.mark.parametrize( + ("limit", "num_supported", "is_valid"), + [(0, 0, True), (0, 1, True), (1, 0, False), (1, 1, True), (1, 2, True), + (2, 1, False), (2, 2, True)], +) +def test_limit_mm_per_prompt_dummy(model_id, limit, num_supported, is_valid): + limit_mm_per_prompt = {"image": limit} + + model_config = ModelConfig( + model=model_id, + task="auto", + tokenizer=model_id, + tokenizer_mode="auto", + trust_remote_code=False, + seed=0, + dtype="half", + revision=None, + limit_mm_per_prompt=limit_mm_per_prompt, + ) + model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config) + + processor_factory = MULTIMODAL_REGISTRY._processor_factories[model_cls] + ctx = InputProcessingContext( + model_config, + tokenizer=cached_get_tokenizer(model_config.tokenizer), + ) + + processor = processor_factory(ctx, cache=None) + + mock_supported_mm_limits = MagicMock(return_value={"image": num_supported}) + processor.get_supported_mm_limits = mock_supported_mm_limits + + if is_valid: + exc_ctx = nullcontext() + else: + exc_ctx = pytest.raises(ValueError, match="this model only supports") + + with exc_ctx: + processor._get_and_validate_dummy_mm_counts() + + +@pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"]) +@pytest.mark.parametrize( + ("num_images", "limit", "is_valid"), + [(0, 0, True), (0, 1, True), (1, 0, False), (1, 1, True), (1, 2, True), + (2, 1, False), (2, 2, True)], +) +def test_limit_mm_per_prompt_apply(model_id, num_images, limit, is_valid): + limit_mm_per_prompt = {"image": limit} + + model_config = ModelConfig( + model=model_id, + task="auto", + tokenizer=model_id, + tokenizer_mode="auto", + trust_remote_code=False, + seed=0, + dtype="half", + revision=None, + limit_mm_per_prompt=limit_mm_per_prompt, + ) + model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config) + + processor_factory = MULTIMODAL_REGISTRY._processor_factories[model_cls] + ctx = InputProcessingContext( + model_config, + tokenizer=cached_get_tokenizer(model_config.tokenizer), + ) + + processor = processor_factory(ctx, cache=None) + + rng = np.random.RandomState(0) + image = _rand_img(rng, min_wh=128, max_wh=256) + if num_images == 0: + mm_data = {} + elif num_images == 1: + mm_data = {"image": image} + else: + mm_data = {"image": [image] * num_images} + + if is_valid: + exc_ctx = nullcontext() + else: + exc_ctx = pytest.raises(ValueError, match=f"passed {num_images} image") + + with exc_ctx: + processor.apply( + "" * num_images, + mm_data=mm_data, + hf_processor_mm_kwargs={}, + ) + + def _test_processing_cache_correctness( model_id: str, modalities: dict[str, bool], @@ -631,6 +727,7 @@ def _test_processing_cache_correctness( ("facebook/chameleon-7b", {"image": False}), ("adept/fuyu-8b", {"image": False}), ("llava-hf/llava-1.5-7b-hf", {"image": True}), + ("llava-hf/llava-v1.6-mistral-7b-hf", {"image": True}), ("TIGER-Lab/Mantis-8B-siglip-llama3", {"image": True}), ("mistral-community/pixtral-12b", {"image": True}), ("Qwen/Qwen2-VL-2B-Instruct", {"image": True, "video": True}), diff --git a/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_llava.py b/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_llava.py index 0d90635093ac7..06dfebbb95527 100644 --- a/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_llava.py +++ b/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_llava.py @@ -3,13 +3,11 @@ import torch from vllm.model_executor.models.llava import (LlavaForConditionalGeneration, - LlavaMultiModalProcessor, - get_max_llava_image_tokens) + LlavaMultiModalProcessor) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens) @MULTIMODAL_REGISTRY.register_processor(LlavaMultiModalProcessor) class MyLlava(LlavaForConditionalGeneration): diff --git a/vllm/model_executor/models/clip.py b/vllm/model_executor/models/clip.py index a5300dfd986f3..0188452054b8c 100644 --- a/vllm/model_executor/models/clip.py +++ b/vllm/model_executor/models/clip.py @@ -24,6 +24,8 @@ resolve_visual_encoder_outputs) from vllm.sequence import SequenceData +from .vision import VisionEncoderInfo + def get_clip_patch_grid_length(*, image_size: int, patch_size: int) -> int: assert image_size % patch_size == 0 @@ -149,6 +151,29 @@ def input_processor_for_clip( multi_modal_placeholders={"image": ranges}) +class CLIPEncoderInfo(VisionEncoderInfo[CLIPVisionConfig]): + + def get_num_image_tokens( + self, + *, + image_width: int, + image_height: int, + ) -> int: + return get_clip_image_feature_size(self.vision_config) + + def get_max_image_tokens(self) -> int: + return get_max_clip_image_tokens(self.vision_config) + + def get_num_patches(self) -> int: + return get_clip_patch_grid_length( + image_size=self.vision_config.image_size, + patch_size=self.vision_config.patch_size, + ) + + def get_image_size(self) -> int: + return self.vision_config.image_size + + # Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/clip/modeling_clip.py#L164 # noqa class CLIPVisionEmbeddings(nn.Module): diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py index 7fb8c5d1ab09c..3680d01725238 100644 --- a/vllm/model_executor/models/fuyu.py +++ b/vllm/model_executor/models/fuyu.py @@ -76,7 +76,7 @@ def _get_image_target_size(self) -> ImageSize: return ImageSize(width=target_size["width"], height=target_size["height"]) - def _get_image_grid_size( + def _get_image_feature_grid_size( self, *, image_width: int, @@ -99,7 +99,7 @@ def _get_image_grid_size( def get_mm_max_tokens_per_item(self) -> Mapping[str, int]: target_width, target_height = self._get_image_target_size() - max_ncols, max_nrows = self._get_image_grid_size( + max_ncols, max_nrows = self._get_image_feature_grid_size( image_width=target_width, image_height=target_height, ) @@ -172,7 +172,7 @@ def get_replacement_fuyu(item_idx: int): images = mm_items.get_items("image", ImageProcessorItems) image_size = images.get_image_size(item_idx) - ncols, nrows = self._get_image_grid_size( + ncols, nrows = self._get_image_feature_grid_size( image_width=image_size.width, image_height=image_size.height, ) diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 808e61edb6fb4..78de27cd821c6 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -1,6 +1,7 @@ +from abc import abstractmethod from functools import cached_property -from typing import (Iterable, List, Literal, Mapping, Optional, Protocol, Set, - Tuple, TypedDict, Union) +from typing import (Final, Iterable, List, Literal, Mapping, Optional, + Protocol, Set, Tuple, TypedDict, Union) import torch import torch.nn as nn @@ -12,7 +13,6 @@ from vllm.attention import AttentionMetadata from vllm.config import VllmConfig -from vllm.inputs import InputContext from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) @@ -23,23 +23,23 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, MultiModalInputsV2, MultiModalKwargs, NestedTensors) -from vllm.multimodal.parse import ImageProcessorItems +from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, + ImageSize) from vllm.multimodal.processing import (BaseMultiModalProcessor, - MultiModalDataItems, ProcessorInputs, - PromptReplacement, + InputProcessingContext, + MultiModalDataItems, ProcessingCache, + ProcessorInputs, PromptReplacement, full_groupby_modality) from vllm.sequence import IntermediateTensors -from .clip import (CLIPVisionModel, dummy_image_for_clip, - get_max_clip_image_tokens) +from .clip import CLIPVisionModel from .interfaces import SupportsMultiModal, SupportsPP -from .pixtral import (PixtralHFVisionModel, dummy_image_for_pixtral_hf, - get_max_pixtral_hf_image_tokens, - get_pixtral_hf_image_feature_size) -from .siglip import (SiglipVisionModel, dummy_image_for_siglip, - get_max_siglip_image_tokens) +from .pixtral import (PixtralHFVisionModel, + get_pixtral_hf_image_feature_grid_size) +from .siglip import SiglipVisionModel from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, maybe_prefix, merge_multimodal_embeddings) +from .vision import vision_encoder_info class LlavaImagePixelInputs(TypedDict): @@ -94,39 +94,167 @@ def forward(self, image_features: torch.Tensor) -> torch.Tensor: return hidden_states -def get_max_llava_image_tokens(ctx: InputContext): - hf_config = ctx.get_hf_config(LlavaConfig) - vision_config = hf_config.vision_config +class LlavaLikeConfig(Protocol): + vision_config: Final[PretrainedConfig] + vision_feature_select_strategy: Final[str] + vision_feature_layer: Final[Union[int, List[int]]] - if isinstance(vision_config, CLIPVisionConfig): - num_image_tokens = get_max_clip_image_tokens(vision_config) - elif isinstance(vision_config, SiglipVisionConfig): - num_image_tokens = get_max_siglip_image_tokens(vision_config) - elif isinstance(vision_config, PixtralVisionConfig): - num_image_tokens = get_max_pixtral_hf_image_tokens(vision_config) - else: - msg = f"Unsupported vision config: {type(vision_config)}" - raise NotImplementedError(msg) - strategy = hf_config.vision_feature_select_strategy - if strategy == "default": - return num_image_tokens - 1 - elif strategy == "full": - return num_image_tokens - else: - raise ValueError(f"Unexpected select feature strategy: {strategy}") +class BaseLlavaMultiModalProcessor(BaseMultiModalProcessor): + def __init__(self, + ctx: InputProcessingContext, + *, + cache: Optional[ProcessingCache] = None, + enable_sanity_checks: bool = True) -> None: + super().__init__(ctx, + cache=cache, + enable_sanity_checks=enable_sanity_checks) + + vision_config = self._get_hf_config().vision_config + self._vision_encoder_info = vision_encoder_info(vision_config) -class LlavaMultiModalProcessor(BaseMultiModalProcessor): + @abstractmethod + def _get_hf_config(self) -> LlavaLikeConfig: + raise NotImplementedError def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"image": None} + def _apply_feature_select_strategy( + self, + strategy: str, + encoder_num_image_tokens: int, + ) -> int: + if strategy == "default": + return encoder_num_image_tokens - 1 + if strategy == "full": + return encoder_num_image_tokens + + msg = f"Unexpected feature select strategy: {strategy!r}" + raise NotImplementedError(msg) + + def _get_max_image_tokens(self) -> int: + hf_config = self._get_hf_config() + + return self._apply_feature_select_strategy( + hf_config.vision_feature_select_strategy, + self._vision_encoder_info.get_max_image_tokens(), + ) + def get_mm_max_tokens_per_item(self) -> Mapping[str, int]: - return {"image": get_max_llava_image_tokens(self.ctx)} + return {"image": self._get_max_image_tokens()} + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict( + pixel_values=MultiModalFieldConfig.batched("image"), + image_embeds=MultiModalFieldConfig.batched("image"), + ) + + def _get_dummy_image_size(self) -> ImageSize: + image_size = self._vision_encoder_info.get_image_size() + return ImageSize(image_size, image_size) + + @abstractmethod + def _get_image_token(self) -> str: + raise NotImplementedError + + def _get_dummy_mm_inputs( + self, + mm_counts: Mapping[str, int], + ) -> ProcessorInputs: + num_images = mm_counts.get("image", 0) + + image_token = self._get_image_token() + target_width, target_height = self._get_dummy_image_size() + + mm_data = { + "image": + self._get_dummy_images(width=target_width, + height=target_height, + num_images=num_images) + } + + return ProcessorInputs( + prompt_text=image_token * num_images, + mm_data=mm_data, + ) + + +class LlavaMultiModalProcessor(BaseLlavaMultiModalProcessor): + + def _get_hf_config(self) -> LlavaConfig: + return self.ctx.get_hf_config(LlavaConfig) + + def _get_hf_processor(self) -> LlavaProcessor: + return self.ctx.get_hf_processor(LlavaProcessor) + + def _get_image_token(self) -> str: + return self._get_hf_processor().image_token + + def _get_num_image_tokens( + self, + *, + image_width: int, + image_height: int, + ) -> int: + hf_config = self._get_hf_config() + + return self._apply_feature_select_strategy( + hf_config.vision_feature_select_strategy, + self._vision_encoder_info.get_num_image_tokens( + image_width=image_width, + image_height=image_height, + ), + ) + + def _get_prompt_replacements( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargs, + ) -> list[PromptReplacement]: + hf_config = self._get_hf_config() + image_token_id = hf_config.image_token_index - def _get_hf_processor(self) -> Union[LlavaProcessor, PixtralProcessor]: - return self.ctx.get_hf_processor((LlavaProcessor, PixtralProcessor)) + def get_replacement(item_idx: int): + images = mm_items.get_items( + "image", (ImageEmbeddingItems, ImageProcessorItems)) + + if isinstance(images, ImageEmbeddingItems): + num_image_tokens = images.get_feature_size(item_idx) + else: + image_size = images.get_image_size(item_idx) + num_image_tokens = self._get_num_image_tokens( + image_width=image_size.width, + image_height=image_size.height, + ) + + return [image_token_id] * num_image_tokens + + return [ + PromptReplacement( + modality="image", + target=[image_token_id], + replacement=get_replacement, + ), + ] + + +class PixtralHFMultiModalProcessor(BaseLlavaMultiModalProcessor): + + def _get_hf_config(self) -> LlavaConfig: + return self.ctx.get_hf_config(LlavaConfig) + + def _get_hf_processor(self) -> PixtralProcessor: + return self.ctx.get_hf_processor(PixtralProcessor) + + def _get_image_token(self) -> str: + return self._get_hf_processor().image_token def _call_hf_processor( self, @@ -140,119 +268,82 @@ def _call_hf_processor( mm_kwargs=mm_kwargs, ) - # NOTE: pixel_values=None for MLlavaProcessor pixel_values = processed_outputs.get("pixel_values") if pixel_values is not None: images = mm_data["images"] assert isinstance(images, list) - if isinstance(self._get_hf_processor(), PixtralProcessor): - # Original output: (1, num_images, C, H, W) - # New output: (num_images, C, H, W) - assert (isinstance(pixel_values, list) - and len(pixel_values) == 1) - assert (isinstance(pixel_values[0], list) - and len(pixel_values[0]) == len(images)) + # Original output: (1, num_images, C, H, W) + # New output: (num_images, C, H, W) + assert (isinstance(pixel_values, list) and len(pixel_values) == 1) + assert (isinstance(pixel_values[0], list) + and len(pixel_values[0]) == len(images)) - processed_outputs["pixel_values"] = pixel_values[0] + processed_outputs["pixel_values"] = pixel_values[0] return processed_outputs - def _get_mm_fields_config( - self, - hf_inputs: BatchFeature, - hf_processor_mm_kwargs: Mapping[str, object], - ) -> Mapping[str, MultiModalFieldConfig]: - return dict( - pixel_values=MultiModalFieldConfig.batched("image"), - image_embeds=MultiModalFieldConfig.batched("image"), - ) - def _get_prompt_replacements( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], out_mm_kwargs: MultiModalKwargs, ) -> list[PromptReplacement]: - hf_config = self.ctx.get_hf_config(LlavaConfig) + hf_config = self._get_hf_config() image_token_id = hf_config.image_token_index processor = self._get_hf_processor() - if isinstance(processor, PixtralProcessor): - image_token = processor.image_token - image_break_token = processor.image_break_token - image_end_token = processor.image_end_token - - vision_config = hf_config.vision_config - assert isinstance(vision_config, PixtralVisionConfig) + image_token = processor.image_token + image_break_token = processor.image_break_token + image_end_token = processor.image_end_token - def get_replacement_pixtral(item_idx: int): - images = mm_items.get_items("image", ImageProcessorItems) - image_size = images.get_image_size(item_idx) - - ( - num_width_tokens, - num_height_tokens, - ) = get_pixtral_hf_image_feature_size( - vision_config, - image_width=image_size.width, - image_height=image_size.height, - ) + vision_config = hf_config.vision_config + assert isinstance(vision_config, PixtralVisionConfig) - tokens = ([image_token] * num_width_tokens + - [image_break_token]) * num_height_tokens - tokens[-1] = image_end_token + def get_replacement(item_idx: int): + images = mm_items.get_items("image", ImageProcessorItems) + image_size = images.get_image_size(item_idx) - return "".join(tokens) + ncols, nrows = get_pixtral_hf_image_feature_grid_size( + vision_config, + image_width=image_size.width, + image_height=image_size.height, + ) - return [ - PromptReplacement( - modality="image", - target=[image_token_id], - replacement=get_replacement_pixtral, - ), - ] + tokens = ([image_token] * ncols + [image_break_token]) * nrows + tokens[-1] = image_end_token - max_image_tokens = get_max_llava_image_tokens(self.ctx) + return "".join(tokens) return [ PromptReplacement( modality="image", target=[image_token_id], - replacement=[image_token_id] * max_image_tokens, - ) + replacement=get_replacement, + ), ] - def _get_dummy_mm_inputs( - self, - mm_counts: Mapping[str, int], - ) -> ProcessorInputs: - hf_config = self.ctx.get_hf_config(LlavaConfig) - vision_config = hf_config.vision_config - num_images = mm_counts.get("image", 0) - - if isinstance(vision_config, CLIPVisionConfig): - data = dummy_image_for_clip(vision_config, num_images) - elif isinstance(vision_config, SiglipVisionConfig): - data = dummy_image_for_siglip(vision_config, num_images) - elif isinstance(vision_config, PixtralVisionConfig): - data = dummy_image_for_pixtral_hf(vision_config, num_images) - else: - msg = f"Unsupported vision config: {type(vision_config)}" - raise NotImplementedError(msg) - hf_processor = self._get_hf_processor() - image_token = hf_processor.image_token +def _build_llava_or_pixtral_hf_processor( + ctx: InputProcessingContext, + *, + cache: Optional[ProcessingCache] = None, + enable_sanity_checks: bool = True, +) -> BaseLlavaMultiModalProcessor: + hf_config = ctx.get_hf_config(LlavaConfig) - return ProcessorInputs( - prompt_text=image_token * num_images, - mm_data=data, + if isinstance(hf_config.vision_config, PixtralVisionConfig): + return PixtralHFMultiModalProcessor( + ctx, + cache=cache, + enable_sanity_checks=enable_sanity_checks, ) - -class LlavaLikeConfig(Protocol): - vision_config: PretrainedConfig - vision_feature_layer: Union[int, List[int]] + return LlavaMultiModalProcessor( + ctx, + cache=cache, + enable_sanity_checks=enable_sanity_checks, + ) def _get_num_hidden_layers(hf_config: LlavaLikeConfig) -> int: @@ -330,7 +421,7 @@ def init_vision_tower_for_llava( raise NotImplementedError(msg) -@MULTIMODAL_REGISTRY.register_processor(LlavaMultiModalProcessor) +@MULTIMODAL_REGISTRY.register_processor(_build_llava_or_pixtral_hf_processor) class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): # BitandBytes specific attributes bitsandbytes_stacked_params_mapping = { @@ -596,7 +687,12 @@ def apply( ) -> MultiModalInputsV2: hf_config = self.ctx.get_hf_config(LlavaConfig) image_token_id = hf_config.image_token_index - max_image_tokens = get_max_llava_image_tokens(self.ctx) + + # Assume that it doesn't depend on the image size + num_image_tokens = self._get_num_image_tokens( + image_width=-1, + image_height=-1, + ) result = super().apply(prompt_text, mm_data, hf_processor_mm_kwargs) @@ -609,14 +705,14 @@ def apply( def get_replacement_mantis(item_idx: int): return "".join([ f"(image {item_idx+1}: ", # 7 tokens - "" * max_image_tokens, + "" * num_image_tokens, ")", # 3 tokens ]) mantis_repls = self._bind_prompt_replacements([ PromptReplacement( modality="image", - target=[image_token_id] * max_image_tokens, + target=[image_token_id] * num_image_tokens, replacement=get_replacement_mantis, ) ]) diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py index 5e70c11363c83..24debd1cbf3fe 100644 --- a/vllm/model_executor/models/llava_next.py +++ b/vllm/model_executor/models/llava_next.py @@ -4,31 +4,25 @@ import torch import torch.nn as nn -from PIL import Image -from transformers import CLIPVisionConfig, LlavaNextConfig, SiglipVisionConfig +from transformers import BatchFeature, LlavaNextConfig, LlavaNextProcessor from transformers.models.llava_next.modeling_llava_next import ( get_anyres_image_grid_shape, unpad_image) from typing_extensions import NotRequired from vllm.attention import AttentionMetadata from vllm.config import VllmConfig -from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, - InputContext) from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import NestedTensors +from vllm.multimodal.inputs import MultiModalFieldConfig, NestedTensors +from vllm.multimodal.parse import ImageSize from vllm.sequence import IntermediateTensors -from vllm.utils import is_list_of -from .clip import (CLIPVisionModel, dummy_image_for_clip, - dummy_seq_data_for_clip, get_clip_image_feature_size, - get_clip_patch_grid_length, input_processor_for_clip) +from .clip import CLIPVisionModel from .interfaces import SupportsMultiModal, SupportsPP -from .llava import LlavaMultiModalProjector, init_vision_tower_for_llava -from .siglip import (SiglipVisionModel, dummy_image_for_siglip, - dummy_seq_data_for_siglip, get_siglip_image_feature_size, - get_siglip_patch_grid_length, input_processor_for_siglip) +from .llava import (LlavaMultiModalProcessor, LlavaMultiModalProjector, + init_vision_tower_for_llava) +from .siglip import SiglipVisionModel from .utils import (AutoWeightsLoader, embed_multimodal, flatten_bn, init_vllm_registered_model, maybe_prefix) @@ -65,218 +59,127 @@ class LlavaNextImageEmbeddingInputs(TypedDict): LlavaNextImageEmbeddingInputs] -# Based on: https://github.com/huggingface/text-generation-inference/blob/v2.2.0/server/text_generation_server/models/vlm_causal_lm.py#L79 -def _get_llava_next_num_unpadded_features( - original_height: int, - original_width: int, - npatches: int, - num_patch_height: int, - num_patch_width: int, -) -> Tuple[int, int]: - current_height = npatches * num_patch_height - current_width = npatches * num_patch_width - - original_aspect_ratio = original_width / original_height - current_aspect_ratio = current_width / current_height - - if original_aspect_ratio > current_aspect_ratio: - scale_factor = current_width / original_width - new_height = int(original_height * scale_factor) - padding = (current_height - new_height) // 2 - current_height -= 2 * padding - else: - scale_factor = current_height / original_height - new_width = int(original_width * scale_factor) - padding = (current_width - new_width) // 2 - current_width -= 2 * padding - - unpadded_features = current_height * current_width - newline_features = current_height - return (unpadded_features, newline_features) - - -# Based on: https://github.com/huggingface/text-generation-inference/blob/v2.2.0/server/text_generation_server/models/vlm_causal_lm.py#L106 -def get_llava_next_image_feature_size( - hf_config: LlavaNextConfig, - *, - input_height: int, - input_width: int, -) -> int: - vision_config = hf_config.vision_config - - if isinstance(vision_config, CLIPVisionConfig): - num_patches = get_clip_patch_grid_length( - image_size=vision_config.image_size, - patch_size=vision_config.patch_size, - ) - base_feature_size = get_clip_image_feature_size(vision_config) - elif isinstance(vision_config, SiglipVisionConfig): - num_patches = get_siglip_patch_grid_length( - image_size=vision_config.image_size, - patch_size=vision_config.patch_size, - ) - base_feature_size = get_siglip_image_feature_size(vision_config) - else: - msg = f"Unsupported vision config: {type(vision_config)}" - raise NotImplementedError(msg) - - strategy = hf_config.vision_feature_select_strategy - if strategy == "default": - base_feature_size -= 1 - elif strategy == "full": - pass - else: - raise ValueError(f"Unexpected select feature strategy: {strategy}") +class LlavaNextMultiModalProcessor(LlavaMultiModalProcessor): - num_patch_height, num_patch_width = get_anyres_image_grid_shape( - image_size=(input_height, input_width), - grid_pinpoints=hf_config.image_grid_pinpoints, - patch_size=vision_config.image_size, - ) - - ( - unpadded_feature_size, - newline_feature_size, - ) = _get_llava_next_num_unpadded_features(input_height, input_width, - num_patches, num_patch_height, - num_patch_width) - - return unpadded_feature_size + newline_feature_size + base_feature_size - - -def get_max_llava_next_image_tokens(ctx: InputContext): - """Compute the max feature size for all possible image grid pinpoints.""" - return _get_pinpoint_with_largest_features(ctx)[0] - - -def _get_pinpoint_with_largest_features( - ctx: InputContext) -> Tuple[int, Tuple[int, int]]: - """Get the grid pinpoint with the largest features & its feature size.""" - hf_config = ctx.get_hf_config(LlavaNextConfig) - largest_feature_size = 0 - largest_feature_pinpoint = None - for (height, width) in hf_config.image_grid_pinpoints: - feat_size = get_llava_next_image_feature_size( - hf_config, - input_height=height, - input_width=width, - ) - if feat_size > largest_feature_size: - largest_feature_size = feat_size - largest_feature_pinpoint = (height, width) - if not largest_feature_size or largest_feature_pinpoint is None: - raise ValueError("Cannot have a largest feature size of 0!") - return largest_feature_size, largest_feature_pinpoint - - -def dummy_data_for_llava_next(ctx: InputContext, seq_len: int, - mm_counts: Mapping[str, int]): - hf_config = ctx.get_hf_config(LlavaNextConfig) - vision_config = hf_config.vision_config - num_images = mm_counts["image"] - - image_feature_size, pinpoint = _get_pinpoint_with_largest_features(ctx) - max_feat_height, max_feat_width = pinpoint - - if isinstance(vision_config, CLIPVisionConfig): - seq_data, ranges = dummy_seq_data_for_clip( - vision_config, - seq_len, - num_images, - image_token_id=hf_config.image_token_index, - image_feature_size_override=image_feature_size, - ) + def _get_hf_config(self) -> LlavaNextConfig: + return self.ctx.get_hf_config(LlavaNextConfig) + + def _get_hf_processor(self) -> LlavaNextProcessor: + return self.ctx.get_hf_processor(LlavaNextProcessor) - mm_data = dummy_image_for_clip( - vision_config, - num_images, - image_width_override=max_feat_width, - image_height_override=max_feat_height, + def _get_image_token(self) -> str: + return self._get_hf_processor().image_token + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict( + pixel_values=MultiModalFieldConfig.batched("image"), + image_sizes=MultiModalFieldConfig.batched("image"), + image_embeds=MultiModalFieldConfig.batched("image"), ) - return DummyData(seq_data, mm_data, ranges) - elif isinstance(vision_config, SiglipVisionConfig): - seq_data, ranges = dummy_seq_data_for_siglip( - vision_config, - seq_len, - num_images, - image_token_id=hf_config.image_token_index, - image_feature_size_override=image_feature_size, + def _get_max_image_tokens(self) -> int: + largest_feature_size, _ = self._get_pinpoint_with_most_features() + return largest_feature_size + + def _get_dummy_image_size(self) -> ImageSize: + _, pinpoint = self._get_pinpoint_with_most_features() + return pinpoint + + # Based on: https://github.com/huggingface/text-generation-inference/blob/v2.2.0/server/text_generation_server/models/vlm_causal_lm.py#L106 + def _get_num_image_tokens( + self, + *, + image_width: int, + image_height: int, + ) -> int: + hf_config = self._get_hf_config() + + base_feature_size = self._apply_feature_select_strategy( + hf_config.vision_feature_select_strategy, + self._vision_encoder_info.get_num_image_tokens( + image_width=image_width, + image_height=image_height, + ), ) + num_patches = self._vision_encoder_info.get_num_patches() - mm_data = dummy_image_for_siglip( - vision_config, - num_images, - image_width_override=max_feat_width, - image_height_override=max_feat_height, + num_patch_height, num_patch_width = get_anyres_image_grid_shape( + image_size=(image_height, image_width), + grid_pinpoints=hf_config.image_grid_pinpoints, + patch_size=self._vision_encoder_info.get_image_size(), ) - return DummyData(seq_data, mm_data, ranges) + ( + unpadded_feature_size, + newline_feature_size, + ) = self._get_num_unpadded_features( + original_height=image_height, + original_width=image_width, + npatches=num_patches, + num_patch_height=num_patch_height, + num_patch_width=num_patch_width, + ) - msg = f"Unsupported vision config: {type(vision_config)}" - raise NotImplementedError(msg) + return unpadded_feature_size + newline_feature_size + base_feature_size + # Based on: https://github.com/huggingface/text-generation-inference/blob/v2.2.0/server/text_generation_server/models/vlm_causal_lm.py#L79 + def _get_num_unpadded_features( + self, + *, + original_height: int, + original_width: int, + npatches: int, + num_patch_height: int, + num_patch_width: int, + ) -> tuple[int, int]: + current_height = npatches * num_patch_height + current_width = npatches * num_patch_width + + original_aspect_ratio = original_width / original_height + current_aspect_ratio = current_width / current_height + + if original_aspect_ratio > current_aspect_ratio: + scale_factor = current_width / original_width + new_height = int(original_height * scale_factor) + padding = (current_height - new_height) // 2 + current_height -= 2 * padding + else: + scale_factor = current_height / original_height + new_width = int(original_width * scale_factor) + padding = (current_width - new_width) // 2 + current_width -= 2 * padding -def input_processor_for_llava_next(ctx: InputContext, - inputs: DecoderOnlyInputs): - multi_modal_data = inputs.get("multi_modal_data") - if multi_modal_data is None or "image" not in multi_modal_data: - return inputs + unpadded_features = current_height * current_width + newline_features = current_height + return (unpadded_features, newline_features) - model_config = ctx.model_config - hf_config = ctx.get_hf_config(LlavaNextConfig) - vision_config = hf_config.vision_config + def _get_pinpoint_with_most_features(self) -> tuple[int, ImageSize]: + """ + Get the grid pinpoint with the most features and + the corresponding feature size. + """ + hf_config = self._get_hf_config() - image_data = multi_modal_data["image"] - if isinstance(image_data, Image.Image): - width, height = image_data.size + largest_feature_size, largest_feature_pinpoint = 0, None + for (height, width) in hf_config.image_grid_pinpoints: + feat_size = self._get_num_image_tokens(image_width=width, + image_height=height) + if feat_size > largest_feature_size: + largest_feature_size = feat_size + largest_feature_pinpoint = ImageSize(width=width, + height=height) - image_feature_size = get_llava_next_image_feature_size( - hf_config, - input_height=height, - input_width=width, - ) - elif is_list_of(image_data, Image.Image): - image_feature_size = [ - get_llava_next_image_feature_size(hf_config, - input_height=img.height, - input_width=img.width) - for img in image_data - ] - elif isinstance(image_data, torch.Tensor): - num_images, image_feature_size, hidden_size = image_data.shape - elif is_list_of(image_data, torch.Tensor): - image_feature_size = [item.shape[1] for item in image_data] - else: - raise TypeError(f"Invalid image type: {type(image_data)}") - - vision_config = hf_config.vision_config - - if isinstance(vision_config, CLIPVisionConfig): - return input_processor_for_clip( - model_config, - vision_config, - inputs, - image_token_id=hf_config.image_token_index, - image_feature_size_override=image_feature_size, - ) - elif isinstance(vision_config, SiglipVisionConfig): - return input_processor_for_siglip( - model_config, - vision_config, - inputs, - image_token_id=hf_config.image_token_index, - image_feature_size_override=image_feature_size, - ) + if largest_feature_size == 0 or largest_feature_pinpoint is None: + raise ValueError("Cannot have a largest feature size of 0!") - msg = f"Unsupported vision config: {type(vision_config)}" - raise NotImplementedError(msg) + return largest_feature_size, largest_feature_pinpoint -@MULTIMODAL_REGISTRY.register_image_input_mapper() -@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_next_image_tokens) -@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_next) -@INPUT_REGISTRY.register_input_processor(input_processor_for_llava_next) +@MULTIMODAL_REGISTRY.register_processor(LlavaNextMultiModalProcessor) class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): @@ -507,7 +410,7 @@ def _merge_image_patch_embeddings(self, image_size: torch.Tensor, def _process_image_pixels( self, inputs: LlavaNextImagePixelInputs, - ) -> Union[torch.Tensor, List[torch.Tensor]]: + ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]: assert self.vision_tower is not None pixel_values = inputs["data"] diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index d855e7d2d36f8..f2e49d8e4848d 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -34,7 +34,7 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, MultiModalInputsV2, MultiModalKwargs, NestedTensors, PlaceholderRange) -from vllm.multimodal.parse import ImageProcessorItems +from vllm.multimodal.parse import ImageEmbeddingItems, ImageProcessorItems from vllm.multimodal.processing import (BaseMultiModalProcessor, MultiModalDataItems, ProcessorInputs, PromptReplacement, @@ -388,15 +388,19 @@ def _get_prompt_replacements( assert isinstance(bos_token_id, int) def get_replacement_phi3v(item_idx: int): - images = mm_items.get_items("image", ImageProcessorItems) - image_size = images.get_image_size(item_idx) - - num_tokens = self._get_num_image_tokens( - image_width=image_size.width, - image_height=image_size.height, - ) - - return [_IMAGE_TOKEN_ID] * num_tokens + [bos_token_id] + images = mm_items.get_items( + "image", (ImageEmbeddingItems, ImageProcessorItems)) + + if isinstance(images, ImageEmbeddingItems): + num_image_tokens = images.get_feature_size(item_idx) + else: + image_size = images.get_image_size(item_idx) + num_image_tokens = self._get_num_image_tokens( + image_width=image_size.width, + image_height=image_size.height, + ) + + return [_IMAGE_TOKEN_ID] * num_image_tokens + [bos_token_id] num_images = mm_items.get_count("image", strict=False) diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index 2bce13792a88d..d7233bd6028ed 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -38,6 +38,7 @@ from .interfaces import SupportsMultiModal, SupportsPP from .utils import (init_vllm_registered_model, maybe_prefix, merge_multimodal_embeddings) +from .vision import VisionEncoderInfo try: from xformers import ops as xops @@ -697,10 +698,18 @@ def get_pixtral_hf_patch_grid_length(*, image_size: int, return image_size // patch_size -def get_pixtral_hf_num_patches(*, image_size: int, patch_size: int) -> int: - grid_length = get_pixtral_hf_patch_grid_length(image_size=image_size, - patch_size=patch_size) - return grid_length * grid_length +def get_pixtral_hf_image_feature_size( + *, + image_size: int, + patch_size: int, +) -> int: + grid_length = get_pixtral_hf_patch_grid_length( + image_size=image_size, + patch_size=patch_size, + ) + + # Consider the image_break_token + return (grid_length + 1) * grid_length def get_max_pixtral_hf_image_tokens(hf_config: PixtralVisionConfig) -> int: @@ -730,13 +739,16 @@ def dummy_image_for_pixtral_hf( return {"image": image if num_images == 1 else [image] * num_images} -def get_pixtral_hf_image_feature_size(hf_config: PixtralVisionConfig, - image_width: int, - image_height: int) -> Tuple[int, int]: - # Adapted from transformers.models.pixtral.image_processing_pixtral.get_resize_output_image_size # noqa: E501 - # https://github.com/huggingface/transformers/blob/2bd4d5897dc73e8b172832070a6f9e567a0df017/src/transformers/models/pixtral/image_processing_pixtral.py#L180 # noqa: E501 - max_width, max_height = hf_config.image_size, hf_config.image_size - patch_width, patch_height = hf_config.patch_size, hf_config.patch_size +# Adapted from transformers.models.pixtral.image_processing_pixtral.get_resize_output_image_size # noqa: E501 +# https://github.com/huggingface/transformers/blob/2bd4d5897dc73e8b172832070a6f9e567a0df017/src/transformers/models/pixtral/image_processing_pixtral.py#L180 +def get_pixtral_hf_image_feature_grid_size( + hf_config: PixtralVisionConfig, + *, + image_width: int, + image_height: int, +) -> tuple[int, int]: + max_width = max_height = hf_config.image_size + patch_width = patch_height = hf_config.patch_size ratio = max(image_width / max_width, image_height / max_height) @@ -744,12 +756,38 @@ def get_pixtral_hf_image_feature_size(hf_config: PixtralVisionConfig, image_width = int(math.ceil(image_width / ratio)) image_height = int(math.ceil(image_height / ratio)) - num_height_tokens, num_width_tokens = _get_pixtral_hf_num_image_tokens( + nrows, ncols = _get_pixtral_hf_num_image_tokens( (image_height, image_width), (patch_height, patch_width), - ) + ) # type: ignore + + return ncols, nrows + + +class PixtralHFEncoderInfo(VisionEncoderInfo[PixtralVisionConfig]): + + def get_num_image_tokens( + self, + *, + image_width: int, + image_height: int, + ) -> int: + return get_pixtral_hf_image_feature_size( + image_size=self.vision_config.image_size, + patch_size=self.get_image_size(), + ) + + def get_max_image_tokens(self) -> int: + return get_max_pixtral_hf_image_tokens(self.vision_config) + + def get_num_patches(self) -> int: + return get_pixtral_hf_patch_grid_length( + image_size=self.vision_config.image_size, + patch_size=self.vision_config.patch_size, + ) - return num_width_tokens, num_height_tokens + def get_image_size(self) -> int: + return self.vision_config.image_size class PixtralHFMLP(nn.Module): diff --git a/vllm/model_executor/models/siglip.py b/vllm/model_executor/models/siglip.py index 6fb9e2cc4584f..115eaaac900e0 100644 --- a/vllm/model_executor/models/siglip.py +++ b/vllm/model_executor/models/siglip.py @@ -28,6 +28,8 @@ resolve_visual_encoder_outputs) from vllm.sequence import SequenceData +from .vision import VisionEncoderInfo + def get_siglip_patch_grid_length(*, image_size: int, patch_size: int) -> int: # Since interpolation is applied, the image size need not be divisible @@ -156,6 +158,29 @@ def input_processor_for_siglip( multi_modal_placeholders={"image": ranges}) +class SiglipEncoderInfo(VisionEncoderInfo[SiglipVisionConfig]): + + def get_num_image_tokens( + self, + *, + image_width: int, + image_height: int, + ) -> int: + return get_siglip_image_feature_size(self.vision_config) + + def get_max_image_tokens(self) -> int: + return get_max_siglip_image_tokens(self.vision_config) + + def get_num_patches(self) -> int: + return get_siglip_patch_grid_length( + image_size=self.vision_config.image_size, + patch_size=self.vision_config.patch_size, + ) + + def get_image_size(self) -> int: + return self.vision_config.image_size + + # Adapted from https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/models/siglip/modeling_siglip.py#L249 # noqa class SiglipVisionEmbeddings(nn.Module): diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 269b66806adf4..31017f16d3c97 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -373,7 +373,7 @@ def embed_multimodal( input_ids: torch.Tensor, multimodal_token_id: int, get_text_embeds: Callable[[torch.Tensor], torch.Tensor], - multimodal_embeds: Union[torch.Tensor, List[torch.Tensor]], + multimodal_embeds: NestedTensors, ) -> torch.Tensor: """ Embed token IDs and multimodal inputs and combine their embeddings. diff --git a/vllm/model_executor/models/vision.py b/vllm/model_executor/models/vision.py new file mode 100644 index 0000000000000..65a773480d2a1 --- /dev/null +++ b/vllm/model_executor/models/vision.py @@ -0,0 +1,52 @@ +from abc import ABC, abstractmethod +from typing import Generic, TypeVar + +from transformers import PretrainedConfig + +_C = TypeVar("_C", bound=PretrainedConfig) + + +class VisionEncoderInfo(ABC, Generic[_C]): + + def __init__(self, vision_config: _C) -> None: + super().__init__() + + self.vision_config = vision_config + + @abstractmethod + def get_num_image_tokens( + self, + *, + image_width: int, + image_height: int, + ) -> int: + raise NotImplementedError + + @abstractmethod + def get_max_image_tokens(self) -> int: + raise NotImplementedError + + @abstractmethod + def get_num_patches(self) -> int: + raise NotImplementedError + + @abstractmethod + def get_image_size(self) -> int: + raise NotImplementedError + + +def vision_encoder_info(vision_config: PretrainedConfig) -> VisionEncoderInfo: + # Avoid circular imports + from .clip import CLIPEncoderInfo, CLIPVisionConfig + from .pixtral import PixtralHFEncoderInfo, PixtralVisionConfig + from .siglip import SiglipEncoderInfo, SiglipVisionConfig + + if isinstance(vision_config, CLIPVisionConfig): + return CLIPEncoderInfo(vision_config) + if isinstance(vision_config, PixtralVisionConfig): + return PixtralHFEncoderInfo(vision_config) + if isinstance(vision_config, SiglipVisionConfig): + return SiglipEncoderInfo(vision_config) + + msg = f"Unsupported vision config: {type(vision_config)}" + raise NotImplementedError(msg) diff --git a/vllm/multimodal/parse.py b/vllm/multimodal/parse.py index 4e1b78ab2c59d..00acb77435163 100644 --- a/vllm/multimodal/parse.py +++ b/vllm/multimodal/parse.py @@ -1,7 +1,8 @@ from abc import ABC, abstractmethod from collections import UserDict from collections.abc import Callable, Iterator, Mapping, Sequence -from typing import TYPE_CHECKING, Any, Generic, NamedTuple, Optional, TypeVar +from typing import (TYPE_CHECKING, Any, Generic, NamedTuple, Optional, TypeVar, + Union) import numpy as np import torch @@ -87,7 +88,7 @@ class EmbeddingItems(ModalityDataItems[NestedTensors, torch.Tensor]): def get_count(self) -> int: return len(self.data) - def get(self, index: int) -> object: + def get(self, index: int) -> torch.Tensor: return self.data[index] def get_processor_data(self) -> Mapping[str, object]: @@ -96,6 +97,9 @@ def get_processor_data(self) -> Mapping[str, object]: def get_passthrough_data(self) -> Mapping[str, object]: return {f"{self.modality}_embeds": self.data} + def get_feature_size(self, item_idx: int) -> int: + return len(self.get(item_idx)) + class AudioProcessorItems(ProcessorBatchItems[HfAudioItem]): @@ -182,7 +186,7 @@ def get_all_counts(self) -> Mapping[str, int]: def get_items( self, modality: str, - typ: type[_D], + typ: Union[type[_D], tuple[type[_D], ...]], ) -> _D: """ Get the data items belonging to a modality, @@ -199,7 +203,7 @@ def get_items( f"Expected type: {typ}, but " f"found type: {type(items)}") - return items + return items # type: ignore[return-value] ModalityDataParser: TypeAlias = Callable[[ModalityData[Any]],