Skip to content

Commit

Permalink
[VLM] Enable tokenized inputs for merged multi-modal processor (vllm-…
Browse files Browse the repository at this point in the history
…project#11900)

Signed-off-by: DarkLight1337 <[email protected]>
  • Loading branch information
DarkLight1337 authored and frreiss committed Jan 10, 2025
1 parent 1f524c7 commit d4284f4
Show file tree
Hide file tree
Showing 12 changed files with 207 additions and 77 deletions.
31 changes: 25 additions & 6 deletions tests/multimodal/test_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,7 +649,7 @@ def test_limit_mm_per_prompt_apply(model_id, num_images, limit, is_valid):
)


def _test_processing_cache_correctness(
def _test_processing_correctness(
model_id: str,
modalities: dict[str, bool],
hit_rate: float,
Expand Down Expand Up @@ -691,6 +691,7 @@ def _test_processing_cache_correctness(
baseline_processor = factories.build_processor(ctx, cache=None)
cached_processor = factories.build_processor(ctx, cache=cache)
dummy_inputs = baseline_processor.dummy_inputs
tokenizer = baseline_processor.info.get_tokenizer()

rng = np.random.RandomState(0)

Expand Down Expand Up @@ -747,7 +748,25 @@ def _test_processing_cache_correctness(
)

assert baseline_result == cached_result, (
f"Failed ({batch_idx=}, {mm_data=})")
f"Failed ({batch_idx=}, {prompt=}, {mm_data=})")

baseline_tokenized_result = baseline_processor.apply(
tokenizer.encode(prompt),
mm_data=mm_data,
hf_processor_mm_kwargs={},
)

assert baseline_result == baseline_tokenized_result, (
f"Failed ({batch_idx=}, {prompt=}, {mm_data=})")

cached_tokenized_result = cached_processor.apply(
tokenizer.encode(prompt),
mm_data=mm_data,
hf_processor_mm_kwargs={},
)

assert cached_result == cached_tokenized_result, (
f"Failed ({batch_idx=}, {prompt=}, {mm_data=})")


# yapf: disable
Expand All @@ -771,14 +790,14 @@ def _test_processing_cache_correctness(
@pytest.mark.parametrize("num_batches", [32])
@pytest.mark.parametrize("simplify_rate", [1.0])
# yapf: enable
def test_processing_cache_correctness(
def test_processing_correctness(
model_id: str,
modalities: dict[str, bool],
hit_rate: float,
num_batches: int,
simplify_rate: float,
):
_test_processing_cache_correctness(
_test_processing_correctness(
model_id,
modalities,
hit_rate=hit_rate,
Expand All @@ -795,7 +814,7 @@ def test_processing_cache_correctness(
@pytest.mark.parametrize("num_batches", [32])
@pytest.mark.parametrize("simplify_rate", [1.0])
# yapf: enable
def test_processing_cache_correctness_phi3v(
def test_processing_correctness_phi3v(
model_id: str,
modalities: dict[str, bool],
hit_rate: float,
Expand All @@ -809,7 +828,7 @@ def test_processing_cache_correctness_phi3v(

AutoImageProcessor.from_pretrained(model_id, trust_remote_code=True)

_test_processing_cache_correctness(
_test_processing_correctness(
model_id,
modalities,
hit_rate=hit_rate,
Expand Down
4 changes: 2 additions & 2 deletions vllm/inputs/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,13 @@ class TokensPrompt(TypedDict):

multi_modal_data: NotRequired["MultiModalDataDict"]
"""
DEPRECATED: Optional multi-modal data to pass to the model,
Optional multi-modal data to pass to the model,
if the model supports it.
"""

mm_processor_kwargs: NotRequired[Dict[str, Any]]
"""
DEPRECATED: Optional multi-modal processor kwargs to be forwarded to the
Optional multi-modal processor kwargs to be forwarded to the
multimodal input mapper & processor. Note that if multiple modalities
have registered mappers etc for the model being considered, we attempt
to pass the mm_processor_kwargs to each of them.
Expand Down
4 changes: 0 additions & 4 deletions vllm/inputs/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,10 +279,6 @@ async def _process_multimodal_async(

mm_processor = self.mm_registry.create_processor(
self.model_config, tokenizer)
if isinstance(prompt, list):
logger.warning("Passing `multi_modal_data` in TokensPrompt is"
"deprecated and will be removed in a future update")
prompt = tokenizer.decode(prompt)
if mm_processor_kwargs is None:
mm_processor_kwargs = {}

Expand Down
22 changes: 20 additions & 2 deletions vllm/model_executor/models/blip2.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,24 @@ def get_dummy_processor_inputs(

class Blip2MultiModalProcessor(BaseMultiModalProcessor[Blip2ProcessingInfo]):

def _call_hf_processor(
self,
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
) -> BatchFeature:
if not mm_data:
# HF processor always adds placeholders even when there's no image
tokenizer = self.info.get_tokenizer()
prompt_ids = tokenizer.encode(prompt)
return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")

return super()._call_hf_processor(
prompt=prompt,
mm_data=mm_data,
mm_kwargs=mm_kwargs,
)

def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
Expand Down Expand Up @@ -469,11 +487,11 @@ def _get_prompt_replacements(

def apply(
self,
prompt_text: str,
prompt: Union[str, list[int]],
mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object],
) -> MultiModalInputsV2:
result = super().apply(prompt_text, mm_data, hf_processor_mm_kwargs)
result = super().apply(prompt, mm_data, hf_processor_mm_kwargs)

# Only <image> tokens should be considered as placeholders,
# so we ignore the trailing bos_token
Expand Down
32 changes: 30 additions & 2 deletions vllm/model_executor/models/chameleon.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,34 @@ def get_dummy_processor_inputs(
class ChameleonMultiModalProcessor(
BaseMultiModalProcessor[ChameleonProcessingInfo]):

def _call_hf_processor(
self,
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
) -> BatchFeature:
if not mm_data:
prompt_ids = self.info.get_tokenizer().encode(prompt)
prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids)
return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")

return super()._call_hf_processor(
prompt=prompt,
mm_data=mm_data,
mm_kwargs=mm_kwargs,
)

def _apply_hf_processor_tokens_only(
self,
prompt_tokens: list[int],
) -> list[int]:
# HF processor adds sep token for chat mode
tokenizer = self.info.get_tokenizer()
sep_token_id: int = \
tokenizer.vocab[tokenizer.sep_token] # type: ignore

return prompt_tokens + [sep_token_id]

def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
Expand Down Expand Up @@ -128,11 +156,11 @@ def _get_prompt_replacements(

def apply(
self,
prompt_text: str,
prompt: Union[str, list[int]],
mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object],
) -> MultiModalInputsV2:
result = super().apply(prompt_text, mm_data, hf_processor_mm_kwargs)
result = super().apply(prompt, mm_data, hf_processor_mm_kwargs)

# Only <image> tokens should be considered as placeholders,
# so we ignore the image_start_token and image_end_token
Expand Down
24 changes: 15 additions & 9 deletions vllm/model_executor/models/fuyu.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
""" PyTorch Fuyu model."""
import math
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
TypedDict)
TypedDict, Union)

import torch
import torch.nn as nn
Expand Down Expand Up @@ -149,14 +149,10 @@ def _call_hf_processor(
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
) -> BatchFeature:

if not mm_data:
# Avoid warning from HF logger for text-only input
# Input_ids format: bos_token_id + prompt_token_ids + boa_token_id
# Tokenizer won't add boa_token_id by default, we add it manually.
tokenizer = self.info.get_tokenizer()
boa_token_id: int = tokenizer.vocab["<0x04>"] # type: ignore
prompt_ids = tokenizer.encode(prompt) + [boa_token_id]
prompt_ids = self.info.get_tokenizer().encode(prompt)
prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids)
return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")

processed_outputs = super()._call_hf_processor(
Expand All @@ -181,6 +177,16 @@ def _call_hf_processor(

return processed_outputs

def _apply_hf_processor_tokens_only(
self,
prompt_tokens: list[int],
) -> list[int]:
# HF processor adds boa_token_id
tokenizer = self.info.get_tokenizer()
boa_token_id: int = tokenizer.vocab["<0x04>"] # type: ignore

return prompt_tokens + [boa_token_id]

def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
Expand Down Expand Up @@ -223,11 +229,11 @@ def get_replacement_fuyu(item_idx: int):

def apply(
self,
prompt_text: str,
prompt: Union[str, list[int]],
mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object],
) -> MultiModalInputsV2:
result = super().apply(prompt_text, mm_data, hf_processor_mm_kwargs)
result = super().apply(prompt, mm_data, hf_processor_mm_kwargs)

# Only |SPEAKER| (image) tokens should be considered as placeholders,
# so we ignore the trailing bos_token_id
Expand Down
8 changes: 4 additions & 4 deletions vllm/model_executor/models/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,13 @@ def get_multimodal_embeddings(self, **kwargs) -> Optional[T]:
The output embeddings must be one of the following formats:
- A list or tuple of 2D tensors, where each tensor corresponds to
each input multimodal data item (e.g, image).
- A list or tuple of 2D tensors, where each tensor corresponds to
each input multimodal data item (e.g, image).
- A single 3D tensor, with the batch dimension grouping the 2D tensors.
Note:
The returned multimodal embeddings must be in the same order as
the appearances of their corresponding multimodal data item in the
The returned multimodal embeddings must be in the same order as
the appearances of their corresponding multimodal data item in the
input prompt.
"""
...
Expand Down
8 changes: 4 additions & 4 deletions vllm/model_executor/models/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -724,7 +724,7 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):

def apply(
self,
prompt_text: str,
prompt: Union[str, list[int]],
mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object],
) -> MultiModalInputsV2:
Expand All @@ -737,7 +737,7 @@ def apply(
image_height=-1,
)

result = super().apply(prompt_text, mm_data, hf_processor_mm_kwargs)
result = super().apply(prompt, mm_data, hf_processor_mm_kwargs)

mm_items = self._to_mm_items(mm_data)
mm_item_counts = mm_items.get_all_counts()
Expand All @@ -760,7 +760,7 @@ def get_replacement_mantis(item_idx: int):
)
])

prompt_ids, prompt_text, _ = self._apply_prompt_replacements(
prompt_ids, prompt, _ = self._apply_prompt_replacements(
result["prompt_token_ids"],
mantis_mm_repls,
mm_item_counts,
Expand Down Expand Up @@ -788,7 +788,7 @@ def get_replacement_mantis(item_idx: int):

return MultiModalInputsV2(
type="multimodal",
prompt=prompt_text,
prompt=prompt,
prompt_token_ids=prompt_ids,
mm_kwargs=mm_kwargs,
mm_placeholders=mm_placeholder_ranges,
Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/models/phi3v.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,11 +481,11 @@ def _apply_prompt_replacements(

def apply(
self,
prompt_text: str,
prompt: Union[str, list[int]],
mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object],
) -> MultiModalInputsV2:
result = super().apply(prompt_text, mm_data, hf_processor_mm_kwargs)
result = super().apply(prompt, mm_data, hf_processor_mm_kwargs)

# Only <|image|> tokens should be considered as placeholders,
# so we ignore the trailing bos_token_id
Expand Down
18 changes: 12 additions & 6 deletions vllm/model_executor/models/ultravox.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,12 +138,8 @@ def _call_hf_processor(
) -> BatchFeature:
# Text-only input not supported in composite processor
if not mm_data:
tokenizer = self.info.get_tokenizer()

prompt_ids = tokenizer.encode(
prompt,
add_special_tokens=False, # type: ignore
)
prompt_ids = self.info.get_tokenizer().encode(prompt)
prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids)
return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")

mm_data = dict(mm_data)
Expand Down Expand Up @@ -188,6 +184,16 @@ def _call_hf_processor(
)
return BatchFeature(combined_outputs)

def _apply_hf_processor_tokens_only(
self,
prompt_tokens: list[int],
) -> list[int]:
# HF processor omits bos_token_id by setting add_special_tokens=False
tokenizer = self.info.get_tokenizer()
assert prompt_tokens[0] == tokenizer.bos_token_id

return prompt_tokens[1:]

def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
Expand Down
Loading

0 comments on commit d4284f4

Please sign in to comment.