Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[VLM] Enable tokenized inputs for merged multi-modal processor #11900

Merged
merged 2 commits into from
Jan 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 25 additions & 6 deletions tests/multimodal/test_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,7 +649,7 @@ def test_limit_mm_per_prompt_apply(model_id, num_images, limit, is_valid):
)


def _test_processing_cache_correctness(
def _test_processing_correctness(
model_id: str,
modalities: dict[str, bool],
hit_rate: float,
Expand Down Expand Up @@ -691,6 +691,7 @@ def _test_processing_cache_correctness(
baseline_processor = factories.build_processor(ctx, cache=None)
cached_processor = factories.build_processor(ctx, cache=cache)
dummy_inputs = baseline_processor.dummy_inputs
tokenizer = baseline_processor.info.get_tokenizer()

rng = np.random.RandomState(0)

Expand Down Expand Up @@ -747,7 +748,25 @@ def _test_processing_cache_correctness(
)

assert baseline_result == cached_result, (
f"Failed ({batch_idx=}, {mm_data=})")
f"Failed ({batch_idx=}, {prompt=}, {mm_data=})")

baseline_tokenized_result = baseline_processor.apply(
tokenizer.encode(prompt),
mm_data=mm_data,
hf_processor_mm_kwargs={},
)

assert baseline_result == baseline_tokenized_result, (
f"Failed ({batch_idx=}, {prompt=}, {mm_data=})")

cached_tokenized_result = cached_processor.apply(
tokenizer.encode(prompt),
mm_data=mm_data,
hf_processor_mm_kwargs={},
)

assert cached_result == cached_tokenized_result, (
f"Failed ({batch_idx=}, {prompt=}, {mm_data=})")


# yapf: disable
Expand All @@ -771,14 +790,14 @@ def _test_processing_cache_correctness(
@pytest.mark.parametrize("num_batches", [32])
@pytest.mark.parametrize("simplify_rate", [1.0])
# yapf: enable
def test_processing_cache_correctness(
def test_processing_correctness(
model_id: str,
modalities: dict[str, bool],
hit_rate: float,
num_batches: int,
simplify_rate: float,
):
_test_processing_cache_correctness(
_test_processing_correctness(
model_id,
modalities,
hit_rate=hit_rate,
Expand All @@ -795,7 +814,7 @@ def test_processing_cache_correctness(
@pytest.mark.parametrize("num_batches", [32])
@pytest.mark.parametrize("simplify_rate", [1.0])
# yapf: enable
def test_processing_cache_correctness_phi3v(
def test_processing_correctness_phi3v(
model_id: str,
modalities: dict[str, bool],
hit_rate: float,
Expand All @@ -809,7 +828,7 @@ def test_processing_cache_correctness_phi3v(

AutoImageProcessor.from_pretrained(model_id, trust_remote_code=True)

_test_processing_cache_correctness(
_test_processing_correctness(
model_id,
modalities,
hit_rate=hit_rate,
Expand Down
4 changes: 2 additions & 2 deletions vllm/inputs/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,13 @@ class TokensPrompt(TypedDict):

multi_modal_data: NotRequired["MultiModalDataDict"]
"""
DEPRECATED: Optional multi-modal data to pass to the model,
Optional multi-modal data to pass to the model,
if the model supports it.
"""

mm_processor_kwargs: NotRequired[Dict[str, Any]]
"""
DEPRECATED: Optional multi-modal processor kwargs to be forwarded to the
Optional multi-modal processor kwargs to be forwarded to the
multimodal input mapper & processor. Note that if multiple modalities
have registered mappers etc for the model being considered, we attempt
to pass the mm_processor_kwargs to each of them.
Expand Down
4 changes: 0 additions & 4 deletions vllm/inputs/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,10 +279,6 @@ async def _process_multimodal_async(

mm_processor = self.mm_registry.create_processor(
self.model_config, tokenizer)
if isinstance(prompt, list):
logger.warning("Passing `multi_modal_data` in TokensPrompt is"
"deprecated and will be removed in a future update")
prompt = tokenizer.decode(prompt)
if mm_processor_kwargs is None:
mm_processor_kwargs = {}

Expand Down
22 changes: 20 additions & 2 deletions vllm/model_executor/models/blip2.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,24 @@ def get_dummy_processor_inputs(

class Blip2MultiModalProcessor(BaseMultiModalProcessor[Blip2ProcessingInfo]):

def _call_hf_processor(
self,
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
) -> BatchFeature:
if not mm_data:
# HF processor always adds placeholders even when there's no image
tokenizer = self.info.get_tokenizer()
prompt_ids = tokenizer.encode(prompt)
return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")
Isotr0py marked this conversation as resolved.
Show resolved Hide resolved

return super()._call_hf_processor(
prompt=prompt,
mm_data=mm_data,
mm_kwargs=mm_kwargs,
)

def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
Expand Down Expand Up @@ -469,11 +487,11 @@ def _get_prompt_replacements(

def apply(
self,
prompt_text: str,
prompt: Union[str, list[int]],
mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object],
) -> MultiModalInputsV2:
result = super().apply(prompt_text, mm_data, hf_processor_mm_kwargs)
result = super().apply(prompt, mm_data, hf_processor_mm_kwargs)

# Only <image> tokens should be considered as placeholders,
# so we ignore the trailing bos_token
Expand Down
32 changes: 30 additions & 2 deletions vllm/model_executor/models/chameleon.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,34 @@ def get_dummy_processor_inputs(
class ChameleonMultiModalProcessor(
BaseMultiModalProcessor[ChameleonProcessingInfo]):

def _call_hf_processor(
self,
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
) -> BatchFeature:
if not mm_data:
prompt_ids = self.info.get_tokenizer().encode(prompt)
prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids)
return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")

return super()._call_hf_processor(
prompt=prompt,
mm_data=mm_data,
mm_kwargs=mm_kwargs,
)

def _apply_hf_processor_tokens_only(
self,
prompt_tokens: list[int],
) -> list[int]:
# HF processor adds sep token for chat mode
tokenizer = self.info.get_tokenizer()
sep_token_id: int = \
tokenizer.vocab[tokenizer.sep_token] # type: ignore

return prompt_tokens + [sep_token_id]

def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
Expand Down Expand Up @@ -128,11 +156,11 @@ def _get_prompt_replacements(

def apply(
self,
prompt_text: str,
prompt: Union[str, list[int]],
mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object],
) -> MultiModalInputsV2:
result = super().apply(prompt_text, mm_data, hf_processor_mm_kwargs)
result = super().apply(prompt, mm_data, hf_processor_mm_kwargs)

# Only <image> tokens should be considered as placeholders,
# so we ignore the image_start_token and image_end_token
Expand Down
24 changes: 15 additions & 9 deletions vllm/model_executor/models/fuyu.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
""" PyTorch Fuyu model."""
import math
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
TypedDict)
TypedDict, Union)

import torch
import torch.nn as nn
Expand Down Expand Up @@ -149,14 +149,10 @@ def _call_hf_processor(
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
) -> BatchFeature:

if not mm_data:
# Avoid warning from HF logger for text-only input
# Input_ids format: bos_token_id + prompt_token_ids + boa_token_id
# Tokenizer won't add boa_token_id by default, we add it manually.
tokenizer = self.info.get_tokenizer()
boa_token_id: int = tokenizer.vocab["<0x04>"] # type: ignore
prompt_ids = tokenizer.encode(prompt) + [boa_token_id]
prompt_ids = self.info.get_tokenizer().encode(prompt)
prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids)
return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")

processed_outputs = super()._call_hf_processor(
Expand All @@ -181,6 +177,16 @@ def _call_hf_processor(

return processed_outputs

def _apply_hf_processor_tokens_only(
self,
prompt_tokens: list[int],
) -> list[int]:
# HF processor adds boa_token_id
tokenizer = self.info.get_tokenizer()
boa_token_id: int = tokenizer.vocab["<0x04>"] # type: ignore

return prompt_tokens + [boa_token_id]

def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
Expand Down Expand Up @@ -223,11 +229,11 @@ def get_replacement_fuyu(item_idx: int):

def apply(
self,
prompt_text: str,
prompt: Union[str, list[int]],
mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object],
) -> MultiModalInputsV2:
result = super().apply(prompt_text, mm_data, hf_processor_mm_kwargs)
result = super().apply(prompt, mm_data, hf_processor_mm_kwargs)

# Only |SPEAKER| (image) tokens should be considered as placeholders,
# so we ignore the trailing bos_token_id
Expand Down
8 changes: 4 additions & 4 deletions vllm/model_executor/models/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,13 @@ def get_multimodal_embeddings(self, **kwargs) -> Optional[T]:

The output embeddings must be one of the following formats:

- A list or tuple of 2D tensors, where each tensor corresponds to
each input multimodal data item (e.g, image).
- A list or tuple of 2D tensors, where each tensor corresponds to
each input multimodal data item (e.g, image).
- A single 3D tensor, with the batch dimension grouping the 2D tensors.

Note:
The returned multimodal embeddings must be in the same order as
the appearances of their corresponding multimodal data item in the
The returned multimodal embeddings must be in the same order as
the appearances of their corresponding multimodal data item in the
input prompt.
"""
...
Expand Down
8 changes: 4 additions & 4 deletions vllm/model_executor/models/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -724,7 +724,7 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):

def apply(
self,
prompt_text: str,
prompt: Union[str, list[int]],
mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object],
) -> MultiModalInputsV2:
Expand All @@ -737,7 +737,7 @@ def apply(
image_height=-1,
)

result = super().apply(prompt_text, mm_data, hf_processor_mm_kwargs)
result = super().apply(prompt, mm_data, hf_processor_mm_kwargs)

mm_items = self._to_mm_items(mm_data)
mm_item_counts = mm_items.get_all_counts()
Expand All @@ -760,7 +760,7 @@ def get_replacement_mantis(item_idx: int):
)
])

prompt_ids, prompt_text, _ = self._apply_prompt_replacements(
prompt_ids, prompt, _ = self._apply_prompt_replacements(
result["prompt_token_ids"],
mantis_mm_repls,
mm_item_counts,
Expand Down Expand Up @@ -788,7 +788,7 @@ def get_replacement_mantis(item_idx: int):

return MultiModalInputsV2(
type="multimodal",
prompt=prompt_text,
prompt=prompt,
prompt_token_ids=prompt_ids,
mm_kwargs=mm_kwargs,
mm_placeholders=mm_placeholder_ranges,
Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/models/phi3v.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,11 +481,11 @@ def _apply_prompt_replacements(

def apply(
self,
prompt_text: str,
prompt: Union[str, list[int]],
mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object],
) -> MultiModalInputsV2:
result = super().apply(prompt_text, mm_data, hf_processor_mm_kwargs)
result = super().apply(prompt, mm_data, hf_processor_mm_kwargs)

# Only <|image|> tokens should be considered as placeholders,
# so we ignore the trailing bos_token_id
Expand Down
18 changes: 12 additions & 6 deletions vllm/model_executor/models/ultravox.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,12 +138,8 @@ def _call_hf_processor(
) -> BatchFeature:
# Text-only input not supported in composite processor
if not mm_data:
tokenizer = self.info.get_tokenizer()

prompt_ids = tokenizer.encode(
prompt,
add_special_tokens=False, # type: ignore
)
prompt_ids = self.info.get_tokenizer().encode(prompt)
prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids)
return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")

mm_data = dict(mm_data)
Expand Down Expand Up @@ -188,6 +184,16 @@ def _call_hf_processor(
)
return BatchFeature(combined_outputs)

def _apply_hf_processor_tokens_only(
self,
prompt_tokens: list[int],
) -> list[int]:
# HF processor omits bos_token_id by setting add_special_tokens=False
tokenizer = self.info.get_tokenizer()
assert prompt_tokens[0] == tokenizer.bos_token_id

return prompt_tokens[1:]

def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
Expand Down
Loading
Loading