Skip to content

Commit

Permalink
Fix Phi3V and Qwen2-VL tests
Browse files Browse the repository at this point in the history
Signed-off-by: DarkLight1337 <[email protected]>
  • Loading branch information
DarkLight1337 committed Jan 8, 2025
1 parent 1d0fab0 commit 90c2547
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 31 deletions.
24 changes: 10 additions & 14 deletions tests/models/decoder_only/vision_language/processing/test_phi3v.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,13 @@
"""Tests for phi3v's multimodal preprocessing kwargs."""
import pytest
from transformers import AutoTokenizer

from vllm.inputs import InputProcessingContext
from vllm.model_executor.models.phi3v import _IMAGE_TOKEN_ID
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.utils import cached_get_tokenizer

from .....conftest import _ImageAssets
from ....utils import build_model_context


# Wrap lazy imports to avoid initializing CUDA during test collection
@pytest.fixture()
def processor_for_phi3v():
from vllm.model_executor.models.phi3v import Phi3VMultiModalProcessor
return Phi3VMultiModalProcessor


@pytest.mark.parametrize("model_id", ["microsoft/Phi-3.5-vision-instruct"])
# yapf: disable
@pytest.mark.parametrize(
Expand All @@ -29,29 +21,33 @@ def processor_for_phi3v():
# yapf: enable
@pytest.mark.parametrize("num_imgs", [1, 2])
def test_processor_override(
processor_for_phi3v,
image_assets: _ImageAssets,
model_id: str,
mm_processor_kwargs: dict[str, int],
expected_toks_per_img: int,
num_imgs: int,
):
"""Ensure input_processor_for_phi3v handles num_crops properly."""
# Avoid initializing CUDA early
from vllm.model_executor.models.phi3v import _IMAGE_TOKEN_ID

ctx = build_model_context(
model_name=model_id,
tokenizer_name=model_id,
trust_remote_code=True,
limit_mm_per_prompt={"image": num_imgs},
)
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
ctx = InputProcessingContext(ctx.model_config, tokenizer)
tokenizer = cached_get_tokenizer(ctx.model_config.tokenizer)
processor = MULTIMODAL_REGISTRY.create_processor(
ctx.model_config,
tokenizer=tokenizer,
)

# Build the image str / prompt based on the number of images we pass
img_str = "".join([f"<|image_{idx}|>\n" for idx in range(1, num_imgs + 1)])
prompt = f"<|user|>\n{img_str}<|end|>\n<|assistant|>\n"
mm_data = {"image": [image_assets[0].pil_image] * num_imgs}

processor = processor_for_phi3v(ctx)
processed_inputs = processor.apply(prompt, mm_data, mm_processor_kwargs)

# Ensure we have the right number of placeholders per num_crops size
Expand Down
Original file line number Diff line number Diff line change
@@ -1,19 +1,12 @@
import pytest
from transformers import AutoTokenizer

from vllm.inputs import InputProcessingContext
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.utils import cached_get_tokenizer

from .....conftest import _ImageAssets
from ....utils import build_model_context


# Fixtures lazy import to avoid initializing CUDA during test collection
@pytest.fixture()
def processor_for_qwen2_vl():
from vllm.model_executor.models.qwen2_vl import Qwen2VLMultiModalProcessor
return Qwen2VLMultiModalProcessor


@pytest.mark.parametrize("model_id", ["Qwen/Qwen2-VL-2B-Instruct"])
# yapf: disable
@pytest.mark.parametrize(
Expand All @@ -24,7 +17,6 @@ def processor_for_qwen2_vl():
# yapf: enable
@pytest.mark.parametrize("num_imgs", [1, 2])
def test_processor_override(
processor_for_qwen2_vl,
image_assets: _ImageAssets,
model_id: str,
mm_processor_kwargs: dict[str, object],
Expand All @@ -39,18 +31,20 @@ def test_processor_override(
mm_processor_kwargs=None,
limit_mm_per_prompt={"image": num_imgs},
)
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
ctx = InputProcessingContext(ctx.model_config, tokenizer)
tokenizer = cached_get_tokenizer(ctx.model_config.tokenizer)
processor = MULTIMODAL_REGISTRY.create_processor(
ctx.model_config,
tokenizer=tokenizer,
)

# Build the image str / prompt based on the number of images we pass
prompt = "<|vision_start|><|image_pad|><|vision_end|>" * num_imgs
mm_data = {"image": [image_assets[0].pil_image] * num_imgs}

processor = processor_for_qwen2_vl(ctx)
processed_inputs = processor.apply(prompt, mm_data, mm_processor_kwargs)

# Ensure we have the right number of placeholders per num_crops size
hf_processor = processor._get_hf_processor(**mm_processor_kwargs)
hf_processor = processor.info.get_hf_processor(**mm_processor_kwargs)
image_token_id = tokenizer.convert_tokens_to_ids(hf_processor.image_token)
img_tok_count = processed_inputs["prompt_token_ids"].count(image_token_id)
pixel_shape = processed_inputs["mm_kwargs"]["pixel_values"].shape
Expand Down
6 changes: 5 additions & 1 deletion vllm/model_executor/models/phi3v.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,7 @@ def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
max_image_tokens = self.get_num_image_tokens(
image_width=target_width,
image_height=target_height,
processor=None,
)

return {"image": max_image_tokens}
Expand All @@ -331,8 +332,10 @@ def get_num_image_tokens(
*,
image_width: int,
image_height: int,
processor: Optional[ProcessorMixin],
) -> int:
processor = self.get_hf_processor()
if processor is None:
processor = self.get_hf_processor()

return processor.calc_num_image_tokens_from_image_size( # type: ignore
width=image_width,
Expand Down Expand Up @@ -431,6 +434,7 @@ def get_replacement_phi3v(item_idx: int):
num_image_tokens = self.info.get_num_image_tokens(
image_width=image_size.width,
image_height=image_size.height,
processor=hf_processor,
)

return [_IMAGE_TOKEN_ID] * num_image_tokens + [bos_token_id]
Expand Down
14 changes: 12 additions & 2 deletions vllm/model_executor/models/qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -763,15 +763,17 @@ def _get_vision_info(
image_height: int,
num_frames: int = 1,
do_resize: bool = True,
image_processor: Optional[Qwen2VLImageProcessor],
) -> tuple[ImageSize, int]:
if image_processor is None:
image_processor = self.get_image_processor()

hf_config = self.get_hf_config()
vision_config = hf_config.vision_config
patch_size = vision_config.patch_size
merge_size = vision_config.spatial_merge_size
temporal_patch_size = vision_config.temporal_patch_size

image_processor = self.get_image_processor()

if do_resize:
resized_height, resized_width = smart_resize(
height=image_height,
Expand Down Expand Up @@ -800,10 +802,12 @@ def get_num_image_tokens(
*,
image_width: int,
image_height: int,
image_processor: Optional[Qwen2VLImageProcessor],
) -> int:
_, num_image_tokens = self._get_vision_info(
image_width=image_width,
image_height=image_height,
image_processor=image_processor,
)
return num_image_tokens

Expand All @@ -813,18 +817,21 @@ def get_num_video_tokens(
image_width: int,
image_height: int,
num_frames: int,
image_processor: Optional[Qwen2VLImageProcessor],
) -> int:
_, num_video_tokens = self._get_vision_info(
image_width=image_width,
image_height=image_height,
num_frames=num_frames,
image_processor=image_processor,
)
return num_video_tokens

def get_image_size_with_most_features(self) -> ImageSize:
max_image_size, _ = self._get_vision_info(
image_width=9999999,
image_height=9999999,
image_processor=None,
)
return max_image_size

Expand All @@ -834,6 +841,7 @@ def get_max_image_tokens(self) -> int:
return self.get_num_image_tokens(
image_width=target_width,
image_height=target_height,
image_processor=None,
)

def _get_max_video_frames(self, max_tokens: int) -> int:
Expand All @@ -847,6 +855,7 @@ def _get_max_video_frames(self, max_tokens: int) -> int:
image_width=target_width,
image_height=target_height,
num_frames=next_num_frames,
image_processor=None,
)

if next_max_tokens > max_tokens:
Expand Down Expand Up @@ -880,6 +889,7 @@ def get_max_video_tokens(self, seq_len: int) -> int:
image_width=target_width,
image_height=target_height,
num_frames=self.get_num_frames_with_most_features(seq_len),
image_processor=None,
)


Expand Down

0 comments on commit 90c2547

Please sign in to comment.