diff --git a/tests/multimodal/test_processing.py b/tests/multimodal/test_processing.py index d98bd9736b65f..d18909a4197b6 100644 --- a/tests/multimodal/test_processing.py +++ b/tests/multimodal/test_processing.py @@ -649,7 +649,7 @@ def test_limit_mm_per_prompt_apply(model_id, num_images, limit, is_valid): ) -def _test_processing_cache_correctness( +def _test_processing_correctness( model_id: str, modalities: dict[str, bool], hit_rate: float, @@ -691,6 +691,7 @@ def _test_processing_cache_correctness( baseline_processor = factories.build_processor(ctx, cache=None) cached_processor = factories.build_processor(ctx, cache=cache) dummy_inputs = baseline_processor.dummy_inputs + tokenizer = baseline_processor.info.get_tokenizer() rng = np.random.RandomState(0) @@ -747,7 +748,25 @@ def _test_processing_cache_correctness( ) assert baseline_result == cached_result, ( - f"Failed ({batch_idx=}, {mm_data=})") + f"Failed ({batch_idx=}, {prompt=}, {mm_data=})") + + baseline_tokenized_result = baseline_processor.apply( + tokenizer.encode(prompt), + mm_data=mm_data, + hf_processor_mm_kwargs={}, + ) + + assert baseline_result == baseline_tokenized_result, ( + f"Failed ({batch_idx=}, {prompt=}, {mm_data=})") + + cached_tokenized_result = cached_processor.apply( + tokenizer.encode(prompt), + mm_data=mm_data, + hf_processor_mm_kwargs={}, + ) + + assert cached_result == cached_tokenized_result, ( + f"Failed ({batch_idx=}, {prompt=}, {mm_data=})") # yapf: disable @@ -771,14 +790,14 @@ def _test_processing_cache_correctness( @pytest.mark.parametrize("num_batches", [32]) @pytest.mark.parametrize("simplify_rate", [1.0]) # yapf: enable -def test_processing_cache_correctness( +def test_processing_correctness( model_id: str, modalities: dict[str, bool], hit_rate: float, num_batches: int, simplify_rate: float, ): - _test_processing_cache_correctness( + _test_processing_correctness( model_id, modalities, hit_rate=hit_rate, @@ -795,7 +814,7 @@ def test_processing_cache_correctness( @pytest.mark.parametrize("num_batches", [32]) @pytest.mark.parametrize("simplify_rate", [1.0]) # yapf: enable -def test_processing_cache_correctness_phi3v( +def test_processing_correctness_phi3v( model_id: str, modalities: dict[str, bool], hit_rate: float, @@ -809,7 +828,7 @@ def test_processing_cache_correctness_phi3v( AutoImageProcessor.from_pretrained(model_id, trust_remote_code=True) - _test_processing_cache_correctness( + _test_processing_correctness( model_id, modalities, hit_rate=hit_rate, diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index cdaf6dd76eaa1..b8163a7acde1d 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -44,13 +44,13 @@ class TokensPrompt(TypedDict): multi_modal_data: NotRequired["MultiModalDataDict"] """ - DEPRECATED: Optional multi-modal data to pass to the model, + Optional multi-modal data to pass to the model, if the model supports it. """ mm_processor_kwargs: NotRequired[Dict[str, Any]] """ - DEPRECATED: Optional multi-modal processor kwargs to be forwarded to the + Optional multi-modal processor kwargs to be forwarded to the multimodal input mapper & processor. Note that if multiple modalities have registered mappers etc for the model being considered, we attempt to pass the mm_processor_kwargs to each of them. diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index a738ffe18e3ae..0890883cc984f 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -279,10 +279,6 @@ async def _process_multimodal_async( mm_processor = self.mm_registry.create_processor( self.model_config, tokenizer) - if isinstance(prompt, list): - logger.warning("Passing `multi_modal_data` in TokensPrompt is" - "deprecated and will be removed in a future update") - prompt = tokenizer.decode(prompt) if mm_processor_kwargs is None: mm_processor_kwargs = {} diff --git a/vllm/model_executor/models/blip2.py b/vllm/model_executor/models/blip2.py index 7dfc0b687c6e3..917b88e802071 100644 --- a/vllm/model_executor/models/blip2.py +++ b/vllm/model_executor/models/blip2.py @@ -441,6 +441,24 @@ def get_dummy_processor_inputs( class Blip2MultiModalProcessor(BaseMultiModalProcessor[Blip2ProcessingInfo]): + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + ) -> BatchFeature: + if not mm_data: + # HF processor always adds placeholders even when there's no image + tokenizer = self.info.get_tokenizer() + prompt_ids = tokenizer.encode(prompt) + return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt") + + return super()._call_hf_processor( + prompt=prompt, + mm_data=mm_data, + mm_kwargs=mm_kwargs, + ) + def _get_mm_fields_config( self, hf_inputs: BatchFeature, @@ -469,11 +487,11 @@ def _get_prompt_replacements( def apply( self, - prompt_text: str, + prompt: Union[str, list[int]], mm_data: MultiModalDataDict, hf_processor_mm_kwargs: Mapping[str, object], ) -> MultiModalInputsV2: - result = super().apply(prompt_text, mm_data, hf_processor_mm_kwargs) + result = super().apply(prompt, mm_data, hf_processor_mm_kwargs) # Only tokens should be considered as placeholders, # so we ignore the trailing bos_token diff --git a/vllm/model_executor/models/chameleon.py b/vllm/model_executor/models/chameleon.py index 452fe727875fe..a6634204699c9 100644 --- a/vllm/model_executor/models/chameleon.py +++ b/vllm/model_executor/models/chameleon.py @@ -99,6 +99,34 @@ def get_dummy_processor_inputs( class ChameleonMultiModalProcessor( BaseMultiModalProcessor[ChameleonProcessingInfo]): + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + ) -> BatchFeature: + if not mm_data: + prompt_ids = self.info.get_tokenizer().encode(prompt) + prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids) + return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt") + + return super()._call_hf_processor( + prompt=prompt, + mm_data=mm_data, + mm_kwargs=mm_kwargs, + ) + + def _apply_hf_processor_tokens_only( + self, + prompt_tokens: list[int], + ) -> list[int]: + # HF processor adds sep token for chat mode + tokenizer = self.info.get_tokenizer() + sep_token_id: int = \ + tokenizer.vocab[tokenizer.sep_token] # type: ignore + + return prompt_tokens + [sep_token_id] + def _get_mm_fields_config( self, hf_inputs: BatchFeature, @@ -128,11 +156,11 @@ def _get_prompt_replacements( def apply( self, - prompt_text: str, + prompt: Union[str, list[int]], mm_data: MultiModalDataDict, hf_processor_mm_kwargs: Mapping[str, object], ) -> MultiModalInputsV2: - result = super().apply(prompt_text, mm_data, hf_processor_mm_kwargs) + result = super().apply(prompt, mm_data, hf_processor_mm_kwargs) # Only tokens should be considered as placeholders, # so we ignore the image_start_token and image_end_token diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py index 59af5f0b3ae98..63e7147f84e03 100644 --- a/vllm/model_executor/models/fuyu.py +++ b/vllm/model_executor/models/fuyu.py @@ -16,7 +16,7 @@ """ PyTorch Fuyu model.""" import math from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple, - TypedDict) + TypedDict, Union) import torch import torch.nn as nn @@ -149,14 +149,10 @@ def _call_hf_processor( mm_data: Mapping[str, object], mm_kwargs: Mapping[str, object], ) -> BatchFeature: - if not mm_data: # Avoid warning from HF logger for text-only input - # Input_ids format: bos_token_id + prompt_token_ids + boa_token_id - # Tokenizer won't add boa_token_id by default, we add it manually. - tokenizer = self.info.get_tokenizer() - boa_token_id: int = tokenizer.vocab["<0x04>"] # type: ignore - prompt_ids = tokenizer.encode(prompt) + [boa_token_id] + prompt_ids = self.info.get_tokenizer().encode(prompt) + prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids) return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt") processed_outputs = super()._call_hf_processor( @@ -181,6 +177,16 @@ def _call_hf_processor( return processed_outputs + def _apply_hf_processor_tokens_only( + self, + prompt_tokens: list[int], + ) -> list[int]: + # HF processor adds boa_token_id + tokenizer = self.info.get_tokenizer() + boa_token_id: int = tokenizer.vocab["<0x04>"] # type: ignore + + return prompt_tokens + [boa_token_id] + def _get_mm_fields_config( self, hf_inputs: BatchFeature, @@ -223,11 +229,11 @@ def get_replacement_fuyu(item_idx: int): def apply( self, - prompt_text: str, + prompt: Union[str, list[int]], mm_data: MultiModalDataDict, hf_processor_mm_kwargs: Mapping[str, object], ) -> MultiModalInputsV2: - result = super().apply(prompt_text, mm_data, hf_processor_mm_kwargs) + result = super().apply(prompt, mm_data, hf_processor_mm_kwargs) # Only |SPEAKER| (image) tokens should be considered as placeholders, # so we ignore the trailing bos_token_id diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index b51cba86ec1a4..c5fd0d9332379 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -39,13 +39,13 @@ def get_multimodal_embeddings(self, **kwargs) -> Optional[T]: The output embeddings must be one of the following formats: - - A list or tuple of 2D tensors, where each tensor corresponds to - each input multimodal data item (e.g, image). + - A list or tuple of 2D tensors, where each tensor corresponds to + each input multimodal data item (e.g, image). - A single 3D tensor, with the batch dimension grouping the 2D tensors. Note: - The returned multimodal embeddings must be in the same order as - the appearances of their corresponding multimodal data item in the + The returned multimodal embeddings must be in the same order as + the appearances of their corresponding multimodal data item in the input prompt. """ ... diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 8d94acf3b21d5..bb3db60c7d8ed 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -724,7 +724,7 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor): def apply( self, - prompt_text: str, + prompt: Union[str, list[int]], mm_data: MultiModalDataDict, hf_processor_mm_kwargs: Mapping[str, object], ) -> MultiModalInputsV2: @@ -737,7 +737,7 @@ def apply( image_height=-1, ) - result = super().apply(prompt_text, mm_data, hf_processor_mm_kwargs) + result = super().apply(prompt, mm_data, hf_processor_mm_kwargs) mm_items = self._to_mm_items(mm_data) mm_item_counts = mm_items.get_all_counts() @@ -760,7 +760,7 @@ def get_replacement_mantis(item_idx: int): ) ]) - prompt_ids, prompt_text, _ = self._apply_prompt_replacements( + prompt_ids, prompt, _ = self._apply_prompt_replacements( result["prompt_token_ids"], mantis_mm_repls, mm_item_counts, @@ -788,7 +788,7 @@ def get_replacement_mantis(item_idx: int): return MultiModalInputsV2( type="multimodal", - prompt=prompt_text, + prompt=prompt, prompt_token_ids=prompt_ids, mm_kwargs=mm_kwargs, mm_placeholders=mm_placeholder_ranges, diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index a1b1af35604db..7a230e5beb367 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -481,11 +481,11 @@ def _apply_prompt_replacements( def apply( self, - prompt_text: str, + prompt: Union[str, list[int]], mm_data: MultiModalDataDict, hf_processor_mm_kwargs: Mapping[str, object], ) -> MultiModalInputsV2: - result = super().apply(prompt_text, mm_data, hf_processor_mm_kwargs) + result = super().apply(prompt, mm_data, hf_processor_mm_kwargs) # Only <|image|> tokens should be considered as placeholders, # so we ignore the trailing bos_token_id diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index fada22d685dd6..3edfb5107683a 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -138,12 +138,8 @@ def _call_hf_processor( ) -> BatchFeature: # Text-only input not supported in composite processor if not mm_data: - tokenizer = self.info.get_tokenizer() - - prompt_ids = tokenizer.encode( - prompt, - add_special_tokens=False, # type: ignore - ) + prompt_ids = self.info.get_tokenizer().encode(prompt) + prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids) return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt") mm_data = dict(mm_data) @@ -188,6 +184,16 @@ def _call_hf_processor( ) return BatchFeature(combined_outputs) + def _apply_hf_processor_tokens_only( + self, + prompt_tokens: list[int], + ) -> list[int]: + # HF processor omits bos_token_id by setting add_special_tokens=False + tokenizer = self.info.get_tokenizer() + assert prompt_tokens[0] == tokenizer.bos_token_id + + return prompt_tokens[1:] + def _get_mm_fields_config( self, hf_inputs: BatchFeature, diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index 07d883d5d7295..8b47dfb07387f 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -725,15 +725,15 @@ def _call_hf_processor( mm_kwargs, ) - def _apply_hf_processor( + def _apply_hf_processor_text_mm( self, prompt_text: str, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], ) -> tuple[list[int], MultiModalKwargs]: """ - Wrapper of :meth:`_call_hf_processor` that applies - additional pre-processing and post-processing. + Apply the HF processor on the prompt text and multi-modal data + together. """ processor_data, passthrough_data = self._get_hf_mm_data(mm_items) @@ -753,40 +753,93 @@ def _apply_hf_processor( return prompt_ids, mm_kwargs - def _apply_hf_processor_missing( - self, - prompt_text: str, - mm_missing_data_items: MultiModalDataItems, - hf_processor_mm_kwargs: Mapping[str, object], - ): + def _apply_hf_processor_text_only(self, prompt_text: str) -> list[int]: """ - Apply the HF processor on the full prompt text, but only on the - multi-modal data that are missing from the cache. + Apply the HF processor on the prompt text only. - Note: - We pass prompt text and multi-modal data into the HF processor - in separate calls to avoid HF prompt replacement being done for - cached items; instead, we rely on our own prompt replacement logic - (:meth:`_get_prompt_replacements`) for the full text. + Since HF processor requires that text and multi-modal items + correspond to each other, we create dummy multi-modal items + to go along with the text. """ - mm_missing_counts = mm_missing_data_items.get_all_counts() - - prompt_ids, _ = self._apply_hf_processor( + prompt_ids, _ = self._apply_hf_processor_text_mm( prompt_text=prompt_text, mm_items=MultiModalDataItems({}), hf_processor_mm_kwargs={}, ) - # Some HF processors (e.g. Qwen2-VL) expect corresponding - # multi-modal tokens to be in the prompt text + return prompt_ids + + def _apply_hf_processor_tokens_only( + self, + prompt_tokens: list[int], + ) -> list[int]: + """ + Apply the HF processor on the prompt tokens only. + + Most HF processors accept prompt text but not prompt tokens. + If the HF processor adds or removes tokens that are not related to + multi-modal data, you should override this method so it is consistent + with the output of :meth:`_apply_hf_processor_text_only` on the + corresponding text. + """ + return prompt_tokens + + def _apply_hf_processor_mm_only( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> MultiModalKwargs: + """ + Apply the HF processor on the multi-modal data only. + + Since HF processor requires that text and multi-modal items + correspond to each other, we generate dummy text using + :class:`DummyInputsBuilder` to go along with the multi-modal data. + """ + mm_counts = mm_items.get_all_counts() + dummy_inputs = self.dummy_inputs.get_dummy_processor_inputs( self.info.ctx.model_config.max_model_len, - mm_missing_counts, + mm_counts, ) - _, mm_missing_kwargs = self._apply_hf_processor( + _, mm_kwargs = self._apply_hf_processor_text_mm( prompt_text=dummy_inputs.prompt_text, - mm_items=mm_missing_data_items, + mm_items=mm_items, + hf_processor_mm_kwargs=hf_processor_mm_kwargs, + ) + + return mm_kwargs + + def _apply_hf_processor_main( + self, + prompt: Union[str, list[int]], + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + *, + enable_hf_prompt_replacement: bool, + ) -> tuple[list[int], MultiModalKwargs]: + """ + Apply the HF processor on the prompt text and multi-modal data. + + Note: + If :code:`enable_hf_prompt_replacement=False`, the prompt should + correspond to the multi-modal items. + """ + if isinstance(prompt, str): + if enable_hf_prompt_replacement: + return self._apply_hf_processor_text_mm( + prompt_text=prompt, + mm_items=mm_items, + hf_processor_mm_kwargs=hf_processor_mm_kwargs, + ) + + prompt_ids = self._apply_hf_processor_text_only(prompt) + else: + prompt_ids = self._apply_hf_processor_tokens_only(prompt) + + mm_missing_kwargs = self._apply_hf_processor_mm_only( + mm_items=mm_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, ) @@ -794,7 +847,7 @@ def _apply_hf_processor_missing( def _cached_apply_hf_processor( self, - prompt_text: str, + prompt: Union[str, list[int]], mm_data_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], ) -> tuple[list[int], MultiModalKwargs]: @@ -807,10 +860,11 @@ def _cached_apply_hf_processor( _, passthrough_data = self._get_hf_mm_data(mm_data_items) if cache is None or passthrough_data: - return self._apply_hf_processor( - prompt_text=prompt_text, + return self._apply_hf_processor_main( + prompt=prompt, mm_items=mm_data_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, + enable_hf_prompt_replacement=True, ) mm_maybe_cached_kw_items = { @@ -832,10 +886,13 @@ def _cached_apply_hf_processor( } mm_missing_data_items = self._to_mm_items(mm_missing_data) - prompt_ids, mm_missing_kwargs = self._apply_hf_processor_missing( - prompt_text=prompt_text, - mm_missing_data_items=mm_missing_data_items, + # NOTE: `prompt` does not correspond to `mm_missing_data_items`, + # so we need to pass `enable_hf_prompt_replacement=False` + prompt_ids, mm_missing_kwargs = self._apply_hf_processor_main( + prompt=prompt, + mm_items=mm_missing_data_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, + enable_hf_prompt_replacement=False, ) mm_missing_next_idx = { @@ -1018,7 +1075,7 @@ def _validate_mm_placeholders( def apply( self, - prompt_text: str, + prompt: Union[str, list[int]], mm_data: MultiModalDataDict, hf_processor_mm_kwargs: Mapping[str, object], ) -> MultiModalInputsV2: @@ -1056,7 +1113,7 @@ def apply( mm_hashes = None prompt_ids, mm_kwargs = self._cached_apply_hf_processor( - prompt_text, + prompt, mm_items, hf_processor_mm_kwargs, ) @@ -1101,12 +1158,12 @@ def apply( # there is no need for us to insert them if all(len(repls) == 0 for repls in mm_missing_repls.items()): tokenizer = self.info.get_tokenizer() - prompt_text = decode_tokens(tokenizer, prompt_ids) + prompt = decode_tokens(tokenizer, prompt_ids) mm_placeholders = hf_mm_placeholders else: ( prompt_ids, - prompt_text, + prompt, missing_mm_placeholders, ) = self._apply_prompt_replacements( prompt_ids, @@ -1125,7 +1182,7 @@ def apply( return MultiModalInputsV2( type="multimodal", - prompt=prompt_text, + prompt=prompt, prompt_token_ids=prompt_ids, mm_kwargs=mm_kwargs, mm_hashes=mm_hashes, diff --git a/vllm/multimodal/profiling.py b/vllm/multimodal/profiling.py index 6f7da1509990f..ec580cd6ecddd 100644 --- a/vllm/multimodal/profiling.py +++ b/vllm/multimodal/profiling.py @@ -137,7 +137,7 @@ def _get_dummy_mm_inputs( seq_len, mm_counts) return self.processor.apply( - prompt_text=processor_inputs.prompt_text, + prompt=processor_inputs.prompt_text, mm_data=processor_inputs.mm_data, hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs, )