Skip to content

Commit

Permalink
[V1] Extend beyond image modality and support mixed-modality inferenc…
Browse files Browse the repository at this point in the history
…e with Llava-OneVision (#11685)

Signed-off-by: Roger Wang <[email protected]>
Signed-off-by: DarkLight1337 <[email protected]>
Co-authored-by: DarkLight1337 <[email protected]>
  • Loading branch information
ywang96 and DarkLight1337 authored Jan 6, 2025
1 parent e20c92b commit 91b361a
Show file tree
Hide file tree
Showing 17 changed files with 636 additions and 282 deletions.
2 changes: 1 addition & 1 deletion docs/source/models/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -647,7 +647,7 @@ See [this page](#generative-models) for more information on how to use generativ
- `llava-hf/llava-onevision-qwen2-7b-ov-hf`, `llava-hf/llava-onevision-qwen2-0.5b-ov-hf`, etc.
-
- ✅︎
-
- ✅︎
* - `MiniCPMV`
- MiniCPM-V
- T + I<sup>E+</sup>
Expand Down
209 changes: 208 additions & 1 deletion tests/multimodal/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,22 @@
import mimetypes
import os
from tempfile import NamedTemporaryFile, TemporaryDirectory
from typing import Dict, Tuple
from typing import TYPE_CHECKING, Dict, NamedTuple, Optional, Tuple

import numpy as np
import pytest
from PIL import Image, ImageChops
from transformers import AutoConfig, AutoTokenizer

from vllm.multimodal.inputs import PlaceholderRange
from vllm.multimodal.utils import (MediaConnector,
merge_and_sort_multimodal_metadata,
repeat_and_pad_placeholder_tokens)

if TYPE_CHECKING:
from vllm.multimodal.hasher import MultiModalHashDict
from vllm.multimodal.inputs import MultiModalPlaceholderDict

# Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA)
TEST_IMAGE_URLS = [
"https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg",
Expand Down Expand Up @@ -191,3 +197,204 @@ def test_repeat_and_pad_placeholder_tokens(model):
assert new_prompt == expected_prompt
assert new_token_ids == expected_token_ids
assert ranges == expected_ranges


# Used for the next two tests related to `merge_and_sort_multimodal_metadata`.
class TestCase(NamedTuple):
mm_positions: "MultiModalPlaceholderDict"
mm_hashes: Optional["MultiModalHashDict"]
expected_modalities: list[str]
expected_ranges: list[PlaceholderRange]
expected_hashes: Optional[list[str]]


def test_merge_and_sort_multimodal_metadata():

test_cases = [
# Single modality should return result as is but flattened
TestCase(
mm_positions={
"image": [
PlaceholderRange(offset=0, length=2),
PlaceholderRange(offset=3, length=2),
]
},
mm_hashes={"image": ["hash1", "hash2"]},
expected_modalities=["image"],
expected_ranges=[
PlaceholderRange(offset=0, length=2),
PlaceholderRange(offset=3, length=2),
],
expected_hashes=["hash1", "hash2"],
),

# Single modality without hashes return None for mm hash.
TestCase(
mm_positions={
"image": [
PlaceholderRange(offset=0, length=2),
PlaceholderRange(offset=2, length=2),
]
},
mm_hashes=None,
expected_modalities=["image"],
expected_ranges=[
PlaceholderRange(offset=0, length=2),
PlaceholderRange(offset=2, length=2),
],
expected_hashes=None,
),

# Multiple modalities with hashes should return sorted modalities
# and flattened ranges and hashes.
TestCase(
mm_positions={
"image": [
PlaceholderRange(offset=7, length=4),
PlaceholderRange(offset=11, length=5),
],
"audio": [
PlaceholderRange(offset=0, length=2),
PlaceholderRange(offset=2, length=3),
]
},
mm_hashes={
"image": ["image_hash1", "image_hash2"],
"audio": ["audio_hash1", "audio_hash2"],
},
expected_modalities=["audio", "image"],
expected_ranges=[
PlaceholderRange(offset=0, length=2),
PlaceholderRange(offset=2, length=3),
PlaceholderRange(offset=7, length=4),
PlaceholderRange(offset=11, length=5),
],
expected_hashes=[
"audio_hash1", "audio_hash2", "image_hash1", "image_hash2"
],
),

# Multiple modalities without hashes should return sorted modalities
# and flattened ranges and None.
TestCase(
mm_positions={
"image": [
PlaceholderRange(offset=7, length=4),
PlaceholderRange(offset=11, length=5),
],
"audio": [
PlaceholderRange(offset=0, length=2),
PlaceholderRange(offset=2, length=3),
]
},
mm_hashes=None,
expected_modalities=["audio", "image"],
expected_ranges=[
PlaceholderRange(offset=0, length=2),
PlaceholderRange(offset=2, length=3),
PlaceholderRange(offset=7, length=4),
PlaceholderRange(offset=11, length=5),
],
expected_hashes=None,
),

# Three modalities
TestCase(
mm_positions={
"image": [
PlaceholderRange(offset=15, length=7),
PlaceholderRange(offset=22, length=8),
],
"audio": [
PlaceholderRange(offset=0, length=2),
],
"video": [
PlaceholderRange(offset=3, length=4),
PlaceholderRange(offset=7, length=5),
PlaceholderRange(offset=12, length=6),
]
},
mm_hashes={
"image": ["image_hash1", "image_hash2"],
"audio": ["audio_hash1"],
"video": ["video_hash1", "video_hash2", "video_hash3"]
},
expected_modalities=["audio", "video", "image"],
expected_ranges=[
PlaceholderRange(offset=0, length=2),
PlaceholderRange(offset=3, length=4),
PlaceholderRange(offset=7, length=5),
PlaceholderRange(offset=12, length=6),
PlaceholderRange(offset=15, length=7),
PlaceholderRange(offset=22, length=8),
],
expected_hashes=[
"audio_hash1", "video_hash1", "video_hash2", "video_hash3",
"image_hash1", "image_hash2"
],
),
]

for (mm_positions, mm_hashes, expected_modalities, expected_ranges,
expected_hashes) in test_cases:
modalities, ranges, hashes = merge_and_sort_multimodal_metadata(
mm_positions, mm_hashes)

assert modalities == expected_modalities
assert ranges == expected_ranges
assert hashes == expected_hashes


def test_merge_and_sort_multimodal_metadata_with_interleaving():

test_cases = [

# <image> <audio> <image> <audio>
TestCase(
mm_positions={
"image": [
PlaceholderRange(offset=0, length=4),
PlaceholderRange(offset=8, length=2),
],
"audio": [
PlaceholderRange(offset=5, length=2),
PlaceholderRange(offset=11, length=4),
]
},
mm_hashes={
"image": ["image_hash1", "image_hash2"],
"audio": ["audio_hash1", "audio_hash2"],
},
expected_modalities=[],
expected_ranges=[],
expected_hashes=None,
),

# <image> <image> <video> <audio> <image>
TestCase(
mm_positions={
"image": [
PlaceholderRange(offset=0, length=2),
PlaceholderRange(offset=2, length=3),
PlaceholderRange(offset=20, length=4),
],
"audio": [
PlaceholderRange(offset=5, length=2),
],
"video": [
PlaceholderRange(offset=8, length=5),
]
},
mm_hashes=None,
expected_modalities=[],
expected_ranges=[],
expected_hashes=None,
),
]

for case in test_cases:
with pytest.raises(ValueError) as ex_info:
merge_and_sort_multimodal_metadata(case.mm_positions,
case.mm_hashes)

assert "Interleaved mixed-modality" in str(ex_info.value)
18 changes: 11 additions & 7 deletions tests/v1/core/test_kv_cache_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest

from vllm.inputs import token_inputs
from vllm.multimodal.inputs import MultiModalKwargs
from vllm.sampling_params import SamplingParams
from vllm.v1.core.kv_cache_utils import (BlockHashType, FreeKVCacheBlockQueue,
KVCacheBlock,
Expand All @@ -14,14 +14,18 @@ def make_request(request_id,
prompt_token_ids,
mm_positions=None,
mm_hashes=None):
if mm_positions is None:
multi_modal_inputs = None
else:
multi_modal_inputs = [MultiModalKwargs({})] * len(mm_positions)

return Request(
request_id=request_id,
inputs=token_inputs(
prompt_token_ids=prompt_token_ids,
multi_modal_placeholders={"image": mm_positions}
if mm_positions else None,
multi_modal_hashes=mm_hashes,
),
prompt=None,
prompt_token_ids=prompt_token_ids,
multi_modal_inputs=multi_modal_inputs,
multi_modal_hashes=mm_hashes,
multi_modal_placeholders=mm_positions,
sampling_params=SamplingParams(max_tokens=17),
eos_token_id=100,
arrival_time=0,
Expand Down
17 changes: 11 additions & 6 deletions tests/v1/core/test_prefix_caching.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
"""Compare the with and without prefix caching."""
import pytest

from vllm.inputs import token_inputs
from vllm.multimodal.inputs import PlaceholderRange
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.sampling_params import SamplingParams
from vllm.utils import cdiv
from vllm.v1.core.kv_cache_manager import KVCacheManager, Request
Expand All @@ -13,12 +12,18 @@ def make_request(request_id,
prompt_token_ids,
mm_positions=None,
mm_hashes=None):
if mm_positions is None:
multi_modal_inputs = None
else:
multi_modal_inputs = [MultiModalKwargs({})] * len(mm_positions)

return Request(
request_id=request_id,
inputs=token_inputs(prompt_token_ids=prompt_token_ids,
multi_modal_placeholders={"image": mm_positions}
if mm_positions else None,
multi_modal_hashes=mm_hashes),
prompt=None,
prompt_token_ids=prompt_token_ids,
multi_modal_inputs=multi_modal_inputs,
multi_modal_hashes=mm_hashes,
multi_modal_placeholders=mm_positions,
sampling_params=SamplingParams(max_tokens=17),
eos_token_id=100,
arrival_time=0,
Expand Down
6 changes: 5 additions & 1 deletion vllm/model_executor/models/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,12 @@ def get_multimodal_embeddings(self, **kwargs) -> Optional[T]:
The output embeddings must be one of the following formats:
- A list or tuple of 2D tensors, where each tensor corresponds to
each input image.
each input multimodal data item (e.g, image).
- A single 3D tensor, with the batch dimension grouping the 2D tensors.
NOTE: The returned multimodal embeddings must be in the same order as
the appearances of their corresponding multimodal data item in the
input prompt.
"""
...

Expand Down
Loading

0 comments on commit 91b361a

Please sign in to comment.