Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[VLM] Add max-count checking in data parser for single image models #11661

Merged
merged 5 commits into from
Jan 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/models/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -566,7 +566,7 @@ See [this page](#generative-models) for more information on how to use generativ
- [V1](gh-issue:8779)
* - `AriaForConditionalGeneration`
- Aria
- T + I
- T + I<sup>+</sup>
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This model actually supports multi-image input, so I've updated the table here.

- `rhymes-ai/Aria`
-
- ✅︎
Expand Down
3 changes: 2 additions & 1 deletion tests/multimodal/test_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,10 +622,11 @@ def _test_processing_cache_correctness(


# yapf: disable
# True if the model supports multiple data items of the modality per request
@pytest.mark.parametrize(("model_id", "modalities"), [
("rhymes-ai/Aria", {"image": True}),
("Salesforce/blip2-opt-2.7b", {"image": False}),
("facebook/chameleon-7b", {"image": True}),
("facebook/chameleon-7b", {"image": False}),
("adept/fuyu-8b", {"image": False}),
("llava-hf/llava-1.5-7b-hf", {"image": True}),
("TIGER-Lab/Mantis-8B-siglip-llama3", {"image": True}),
Expand Down
4 changes: 4 additions & 0 deletions vllm/model_executor/models/blip2.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputsV2, MultiModalKwargs,
NestedTensors, PlaceholderRange)
from vllm.multimodal.parse import MultiModalDataParser
from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataItems, ProcessorInputs,
PromptReplacement)
Expand Down Expand Up @@ -404,6 +405,9 @@ def get_max_blip2_image_tokens(ctx: InputContext):

class Blip2MultiModalProcessor(BaseMultiModalProcessor):

def _get_data_parser(self) -> MultiModalDataParser:
return MultiModalDataParser(max_mm_counts={"image": 1})

def _get_hf_processor(self) -> Blip2Processor:
return self.ctx.get_hf_processor(Blip2Processor)

Expand Down
4 changes: 4 additions & 0 deletions vllm/model_executor/models/chameleon.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputsV2, MultiModalKwargs,
NestedTensors, PlaceholderRange)
from vllm.multimodal.parse import MultiModalDataParser
from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataItems, ProcessorInputs,
PromptReplacement)
Expand Down Expand Up @@ -60,6 +61,9 @@ def get_max_chameleon_image_tokens(ctx: InputContext):

class ChameleonMultiModalProcessor(BaseMultiModalProcessor):

def _get_data_parser(self) -> MultiModalDataParser:
return MultiModalDataParser(max_mm_counts={"image": 1})

def _get_hf_processor(self) -> ChameleonProcessor:
return self.ctx.get_hf_processor(ChameleonProcessor)

Expand Down
18 changes: 11 additions & 7 deletions vllm/model_executor/models/fuyu.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputsV2, MultiModalKwargs,
NestedTensors, PlaceholderRange)
from vllm.multimodal.parse import ImageProcessorItems
from vllm.multimodal.parse import ImageProcessorItems, MultiModalDataParser
from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataItems, ProcessorInputs,
PromptReplacement)
Expand All @@ -54,7 +54,7 @@

class FuyuImagePatchInputs(TypedDict):
type: Literal["image_patches"]
data: torch.Tensor
flat_data: torch.Tensor
"""
Shape:
`(batch_size * num_patches, patch_size_x * patch_size_y * num_channels)`
Expand All @@ -63,7 +63,7 @@ class FuyuImagePatchInputs(TypedDict):
patches_per_image: List[int]
"""
List of number of total patches for each image in the batch.
This is used to restore the first two dimensions of `data`.
This is used to restore the first two dimensions of `flat_data`.
"""


Expand Down Expand Up @@ -102,6 +102,9 @@ def get_max_fuyu_image_tokens(ctx: InputContext):

class FuyuMultiModalProcessor(BaseMultiModalProcessor):

def _get_data_parser(self) -> MultiModalDataParser:
return MultiModalDataParser(max_mm_counts={"image": 1})

def _get_hf_processor(self) -> FuyuProcessor:
return self.ctx.get_hf_processor(FuyuProcessor)

Expand Down Expand Up @@ -304,7 +307,7 @@ def _parse_and_validate_image_input(

return FuyuImagePatchInputs(
type="image_patches",
data=self._validate_pixel_values(
flat_data=self._validate_pixel_values(
flatten_bn(image_patches_flat, concat=True)),
patches_per_image=[x.size(0) for x in image_patches_flat],
)
Expand All @@ -313,12 +316,13 @@ def _parse_and_validate_image_input(

def _process_image_input(
self, image_input: FuyuImagePatchInputs) -> NestedTensors:
image_patches = image_input["data"]
image_patches_flat = image_input["flat_data"]
patches_per_image = image_input["patches_per_image"]

assert self.vision_embed_tokens is not None
vision_embeddings, _ = self.vision_embed_tokens(image_patches)
return vision_embeddings.split(patches_per_image, dim=0)
vision_embeddings_flat, _ = self.vision_embed_tokens(
image_patches_flat)
return vision_embeddings_flat.split(patches_per_image, dim=0)

def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
image_input = self._parse_and_validate_image_input(**kwargs)
Expand Down
28 changes: 26 additions & 2 deletions vllm/multimodal/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,11 +220,24 @@ def get_items(
class MultiModalDataParser:
"""
Parses :class:`MultiModalDataDict` into :class:`MultiModalDataItems`.

Args:
max_mm_counts (Mapping[str, int]): The maximum allowed number of items
belonging to each modality. This effectively sets a hard limit over
`--limit-mm-per-prompt`.
target_sr (float, optional): Enables automatic resampling of audio
items to the model's expected sampling rate.
"""

def __init__(self, *, target_sr: Optional[float] = None) -> None:
def __init__(
self,
*,
max_mm_counts: Mapping[str, int] = {},
target_sr: Optional[float] = None,
) -> None:
super().__init__()

self.max_mm_counts = max_mm_counts
self.target_sr = target_sr

def _is_embeddings(self, data: object) -> TypeGuard[NestedTensors]:
Expand Down Expand Up @@ -332,13 +345,24 @@ def _get_subparsers(self) -> Mapping[str, ModalityDataParser]:

def parse_mm_data(self,
mm_data: MultiModalDataDict) -> MultiModalDataItems:
max_mm_counts = self.max_mm_counts
subparsers = self._get_subparsers()

mm_items = MultiModalDataItems()
for k, v in mm_data.items():
if k not in subparsers:
raise ValueError(f"Unsupported modality: {k}")

mm_items[k] = subparsers[k](v)
modality_items = subparsers[k](v)

if k in max_mm_counts:
max_count = max_mm_counts[k]
if len(modality_items) > max_count:
raise ValueError(
f"This model supports at most {max_count} {k} items "
f"per prompt, but {len(modality_items)} {k} items "
"were given or set as its limit_mm_per_prompt.")

mm_items[k] = modality_items

return mm_items
Loading