Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] Implement merged input processor for Phi-3-Vision models #10977

Merged
merged 14 commits into from
Dec 9, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,10 @@
from typing import Optional

import pytest
import torch
from transformers import AutoImageProcessor, AutoTokenizer
from transformers import AutoTokenizer

from vllm.inputs import InputContext, token_inputs
from vllm.inputs import InputContext, InputProcessingContext
from vllm.model_executor.models.phi3v import _IMAGE_TOKEN_ID
from vllm.multimodal import MultiModalRegistry

from .....conftest import _ImageAssets
from ....utils import build_model_context
Expand All @@ -17,15 +15,9 @@

# Wrap lazy imports to avoid initializing CUDA during test collection
@pytest.fixture()
def input_processor_for_phi3v():
from vllm.model_executor.models.phi3v import input_processor_for_phi3v
return input_processor_for_phi3v


@pytest.fixture()
def dummy_data_for_phi3v():
from vllm.model_executor.models.phi3v import dummy_data_for_phi3v
return dummy_data_for_phi3v
def processor_for_phi3v():
from vllm.model_executor.models.phi3v import Phi3VProcessor
return Phi3VProcessor


@pytest.fixture()
Expand All @@ -34,53 +26,6 @@ def get_max_phi3v_image_tokens():
return get_max_phi3v_image_tokens


@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize("num_crops", [4, 16, None])
def test_input_mapper_override(model: str, image_assets: _ImageAssets,
num_crops: Optional[int]):
"""Ensure that the [default] input mapper handles num_crops properly."""
# We pass the processor kwargs here since for this model, we fall back to
# the default mapper; this will fall back to the HF mapper and forward
# mm_processor_kwargs to it.
mm_processor_kwargs = {
"num_crops": num_crops
} if num_crops is not None else {}
ctx = build_model_context(
model_name=model,
tokenizer_name=model,
trust_remote_code=True,
mm_processor_kwargs=mm_processor_kwargs,
)

hf_processor = AutoImageProcessor.from_pretrained(model,
trust_remote_code=True,
**mm_processor_kwargs)

mm_registry = MultiModalRegistry()
mm_registry.init_mm_limits_per_prompt(ctx.model_config)

image = image_assets[0].pil_image
hf_result = hf_processor.preprocess(
image,
return_tensors="pt",
)

vllm_result = mm_registry.map_input(
ctx.model_config,
{"image": image},
)

assert torch.all(hf_result["image_sizes"] == vllm_result["image_sizes"])
assert torch.all(
hf_result["num_img_tokens"] == vllm_result["num_img_tokens"])

# For pixel values, the second axis should be the num_crops + 1
# for the rescaled original image. The default value in VLLM falls
# back to the HF config, which is why we compare to the processor num_crops
assert torch.all(hf_result["pixel_values"] == vllm_result["pixel_values"])
assert vllm_result["pixel_values"].shape[1] == hf_processor.num_crops + 1


@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize("num_crops,expected_max_tokens", [
(4, 781),
Expand Down Expand Up @@ -112,48 +57,20 @@ def test_max_tokens_override(get_max_phi3v_image_tokens, model: str,


@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize("num_crops,toks_per_img,num_imgs", [
(4, 781, 1),
(4, 781, 2),
(16, 2653, 1),
(16, 2653, 2),
])
def test_dummy_data_override(dummy_data_for_phi3v, model: str, num_crops: int,
toks_per_img: int, num_imgs: int):
"""Ensure dummy_data_for_phi3v handles num_crops properly."""
# Same as the previous test - don't initialize mm_processor_kwargs
# in this test and assume that the kwargs will be correctly expanded by
# the partial when calling the dummy data func.
ctx = build_model_context(
model_name=model,
tokenizer_name=model,
trust_remote_code=True,
mm_processor_kwargs=None,
)

dummy_data = dummy_data_for_phi3v(
ctx=ctx,
seq_len=8192, # Should be bigger than num_imgs * toks_per_img
mm_counts={"image": num_imgs},
num_crops=num_crops,
)
sequence_data = dummy_data.seq_data
# Ensure we have the right number of placeholders per num_crops size
img_tok_count = sequence_data.get_token_ids().count(_IMAGE_TOKEN_ID)
assert img_tok_count == toks_per_img * num_imgs


@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize("num_crops,expected_toks_per_img,num_imgs", [
(4, 757, 1),
(4, 757, 2),
(16, 1921, 1),
(16, 1921, 2),
])
def test_input_processor_override(input_processor_for_phi3v,
image_assets: _ImageAssets, model: str,
num_crops: int, expected_toks_per_img: int,
num_imgs: int):
@pytest.mark.parametrize(
"num_crops,expected_toks_per_img,num_imgs",
[
(4, 757, 1),
(4, 757, 2),
(16, 1921, 1),
(16, 1921, 2),
# the default num_crops of phi-3.5-vision is 4
(None, 757, 2),
(None, 757, 2),
])
def test_processor_override(processor_for_phi3v, image_assets: _ImageAssets,
model: str, num_crops: Optional[int],
expected_toks_per_img: int, num_imgs: int):
"""Ensure input_processor_for_phi3v handles num_crops properly."""
# Same as the previous test - don't initialize mm_processor_kwargs
# in this test and assume that the kwargs will be correctly expanded by
Expand All @@ -163,19 +80,20 @@ def test_input_processor_override(input_processor_for_phi3v,
tokenizer_name=model,
trust_remote_code=True,
)
tokenizer = AutoTokenizer.from_pretrained(model)
tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True)
ctx = InputProcessingContext(ctx.model_config, tokenizer)
# Build the image str / prompt based on the number of images we pass
img_str = "".join([f"<|image_{idx}|>\n" for idx in range(1, num_imgs + 1)])
prompt = f"<|user|>\n{img_str}<|end|>\n<|assistant|>\n"
images = [image_assets[0].pil_image] * num_imgs

inputs = token_inputs(prompt_token_ids=tokenizer.encode(prompt),
prompt=prompt,
multi_modal_data={"image": images})
mm_data = {"image": images}
mm_processor_kwargs = {}
if num_crops is not None:
mm_processor_kwargs = {"num_crops": num_crops}

processed_inputs = input_processor_for_phi3v(ctx,
inputs,
num_crops=num_crops)
processor = processor_for_phi3v(ctx)
processed_inputs = processor.apply(prompt, mm_data, mm_processor_kwargs)

# Ensure we have the right number of placeholders per num_crops size
img_tok_count = processed_inputs["prompt_token_ids"].count(_IMAGE_TOKEN_ID)
Expand Down
4 changes: 2 additions & 2 deletions vllm/inputs/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,12 @@ class InputProcessingContext(InputContext):
tokenizer: AnyTokenizer
"""The tokenizer used to tokenize the inputs."""

def get_hf_processor(self) -> ProcessorMixin:
def get_hf_processor(self, **kwargs) -> ProcessorMixin:
return cached_get_processor(
self.model_config.tokenizer,
tokenizer=self.tokenizer, # Override the tokenizer with ours
trust_remote_code=self.model_config.trust_remote_code,
)
**kwargs)


N = TypeVar("N", bound=Type[nn.Module])
Expand Down
6 changes: 4 additions & 2 deletions vllm/model_executor/models/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,8 @@ def preprocess(__self, *args, **kwargs):

hf_processor.__is_patched__ = True # type: ignore

def _get_hf_processor(self) -> ProcessorMixin:
def _get_hf_processor(
self, mm_processor_kwargs: Mapping[str, object]) -> ProcessorMixin:
DarkLight1337 marked this conversation as resolved.
Show resolved Hide resolved
hf_processor = self.ctx.get_hf_processor()

if isinstance(hf_processor, PixtralProcessor):
Expand Down Expand Up @@ -590,7 +591,8 @@ def load_weights(self, weights: Iterable[Tuple[str,

class MantisProcessor(LlavaProcessor):

def _get_hf_processor(self) -> ProcessorMixin:
def _get_hf_processor(
self, mm_processor_kwargs: Mapping[str, object]) -> ProcessorMixin:
try:
from mantis.models.mllava import MLlavaProcessor
except ModuleNotFoundError as exc:
Expand Down
Loading
Loading