Skip to content

Commit

Permalink
[Bugfix] Comprehensively test and fix LLaVA-NeXT feature size calcula…
Browse files Browse the repository at this point in the history
…tion (#11800)

Signed-off-by: DarkLight1337 <[email protected]>
  • Loading branch information
DarkLight1337 authored Jan 7, 2025
1 parent 8082ad7 commit 8f37be3
Show file tree
Hide file tree
Showing 6 changed files with 253 additions and 89 deletions.
1 change: 1 addition & 0 deletions requirements-test.in
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ einops # required for MPT, qwen-vl and Mamba
httpx
librosa # required for audio tests
peft
pqdm
ray[adag]==2.40.0
sentence-transformers # required for embedding tests
soundfile # required for audio tests
Expand Down
4 changes: 4 additions & 0 deletions requirements-test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ botocore==1.35.57
# awscli
# boto3
# s3transfer
bounded-pool-executor==0.0.3
# via pqdm
buildkite-test-collector==0.1.9
# via -r requirements-test.in
certifi==2024.8.30
Expand Down Expand Up @@ -342,6 +344,8 @@ pooch==1.8.2
# via librosa
portalocker==2.10.1
# via sacrebleu
pqdm==0.2.0
# via -r requirements-test.in
propcache==0.2.0
# via yarl
protobuf==5.28.3
Expand Down
129 changes: 108 additions & 21 deletions tests/models/decoder_only/vision_language/processing/test_llava_next.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
import itertools
from functools import partial

import pytest
from PIL import Image
from pqdm.threads import pqdm
from transformers import AutoTokenizer

from vllm.inputs import InputProcessingContext
from vllm.multimodal.parse import ImageSize

from ....utils import build_model_context

Expand All @@ -15,20 +20,69 @@ def processor_for_llava_next():
return LlavaNextMultiModalProcessor


def _validate_image_prompt_replacements_one(
processor,
num_imgs: int,
failed_size_excs: list[tuple[ImageSize, Exception]],
image_size: ImageSize,
) -> None:
prompt = "<image>" * num_imgs
image = Image.new("RGB", size=image_size)
mm_data = {"image": [image] * num_imgs}

try:
# The processor will throw an error if there is a mismatch
# in the prompt replacements
processed_inputs = processor.apply(prompt, mm_data, {})

image_placeholders = processed_inputs["mm_placeholders"]["image"]
assert len(image_placeholders) == num_imgs

first_placeholder = image_placeholders[0]

# NOTE: There is a BOS token
assert first_placeholder["offset"] == 1
assert first_placeholder["length"] == (
len(processed_inputs["prompt_token_ids"]) - 1) // num_imgs

except Exception as exc:
failed_size_excs.append((image_size, exc))


def _test_image_prompt_replacements(
processor,
*,
num_imgs: int,
image_sizes: list[ImageSize],
) -> None:
"""
Ensure LlavaNextMultiModalProcessor
handles prompt replacement properly for input images.
"""
failed_size_excs = list[tuple[ImageSize, Exception]]()

validate_one = partial(
_validate_image_prompt_replacements_one,
processor,
num_imgs,
failed_size_excs,
)
pqdm(image_sizes, validate_one, n_jobs=8, desc="Validating image sizes")

if failed_size_excs:
msg = "Found failing image sizes:" \
+ "\n========\n".join(f"[{size}]\n{exc}"
for size, exc in failed_size_excs)
raise AssertionError(msg)


@pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"])
@pytest.mark.parametrize("image_size", [(1669, 2560), (2560, 1669), (183, 488),
(488, 183), (198, 176), (176, 198),
(161, 184), (184, 161)])
@pytest.mark.parametrize("num_imgs", [1, 2])
def test_processor_prompt_replacements(
def test_processor_prompt_replacements_regression(
processor_for_llava_next,
model_id: str,
image_size: tuple[int, int],
num_imgs: int,
):
"""
Ensure LlavaNextMultiModalProcessor handles prompt replacement properly.
"""
ctx = build_model_context(
model_name=model_id,
tokenizer_name=model_id,
Expand All @@ -37,22 +91,55 @@ def test_processor_prompt_replacements(
)
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
ctx = InputProcessingContext(ctx.model_config, tokenizer)
processor = processor_for_llava_next(ctx)

image_ratios = [(171, 152), (184, 161), (198, 176), (333, 296), (369, 328),
(488, 183), (2560, 1669)]
image_sizes = [
size for w, h in image_ratios
for size in [ImageSize(w, h), ImageSize(h, w)]
]

_test_image_prompt_replacements(
processor,
num_imgs=num_imgs,
image_sizes=image_sizes,
)

# Build the image str / prompt based on the number of images we pass
prompt = "<image>" * num_imgs
mm_data = {"image": [Image.new("RGB", size=image_size)] * num_imgs}

# The processor will throw an error if there is a mismatch
# in the prompt replacements
@pytest.mark.skip("This test takes around 2 hours to run. "
"Comment this out to run it manually.")
@pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"])
@pytest.mark.parametrize("num_imgs", [1])
def test_processor_prompt_replacements_all(
processor_for_llava_next,
model_id: str,
num_imgs: int,
):
ctx = build_model_context(
model_name=model_id,
tokenizer_name=model_id,
mm_processor_kwargs=None,
limit_mm_per_prompt={"image": num_imgs},
)
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
ctx = InputProcessingContext(ctx.model_config, tokenizer)
processor = processor_for_llava_next(ctx)
processed_inputs = processor.apply(prompt, mm_data, {})

image_placeholders = processed_inputs["mm_placeholders"]["image"]
assert len(image_placeholders) == num_imgs
seen_aspect_ratios = set[float]()
image_sizes = list[ImageSize]()

first_placeholder = image_placeholders[0]
# The aspect ratio of the grid layout is between 1 and 2
# NOTE: Assumes that feature size calculation is the same if we
# swap the width and height of the image
for w, h in itertools.product(range(64, 1024), repeat=2):
aspect_ratio = w / h
if 1 <= aspect_ratio <= 2 and aspect_ratio not in seen_aspect_ratios:
image_sizes.append(ImageSize(w, h))
seen_aspect_ratios.add(aspect_ratio)

# NOTE: There is a BOS token
assert first_placeholder["offset"] == 1
assert first_placeholder["length"] == (
len(processed_inputs["prompt_token_ids"]) - 1) // num_imgs
_test_image_prompt_replacements(
processor,
num_imgs=num_imgs,
image_sizes=image_sizes,
)
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
import itertools
from functools import partial

import pytest
from PIL import Image
from pqdm.threads import pqdm
from transformers import AutoTokenizer

from vllm.inputs import InputProcessingContext
from vllm.multimodal.parse import ImageSize

from ....utils import build_model_context

Expand All @@ -15,22 +20,68 @@ def processor_for_llava_onevision():
return LlavaOnevisionMultiModalProcessor


def _validate_image_prompt_replacements_one(
processor,
num_imgs: int,
failed_size_excs: list[tuple[ImageSize, Exception]],
image_size: ImageSize,
) -> None:
prompt = "<image>" * num_imgs
image = Image.new("RGB", size=image_size)
mm_data = {"image": [image] * num_imgs}

try:
# The processor will throw an error if there is a mismatch
# in the prompt replacements
processed_inputs = processor.apply(prompt, mm_data, {})

image_placeholders = processed_inputs["mm_placeholders"]["image"]
assert len(image_placeholders) == num_imgs

first_placeholder = image_placeholders[0]

assert first_placeholder["offset"] == 0
assert first_placeholder["length"] == len(
processed_inputs["prompt_token_ids"]) // num_imgs
except Exception as exc:
failed_size_excs.append((image_size, exc))


def _test_image_prompt_replacements(
processor,
*,
num_imgs: int,
image_sizes: list[ImageSize],
) -> None:
"""
Ensure LlavaOnevisionMultiModalProcessor
handles prompt replacement properly for input images.
"""
failed_size_excs = list[tuple[ImageSize, Exception]]()

validate_one = partial(
_validate_image_prompt_replacements_one,
processor,
num_imgs,
failed_size_excs,
)
pqdm(image_sizes, validate_one, n_jobs=8, desc="Validating image sizes")

if failed_size_excs:
msg = "Found failing image sizes:" \
+ "\n========\n".join(f"[{size}]\n{exc}"
for size, exc in failed_size_excs)
raise AssertionError(msg)


@pytest.mark.parametrize("model_id",
["llava-hf/llava-onevision-qwen2-0.5b-ov-hf"])
@pytest.mark.parametrize("image_size", [(1669, 2560), (2560, 1669), (183, 488),
(488, 183), (198, 176), (176, 198),
(161, 184), (184, 161)])
@pytest.mark.parametrize("num_imgs", [1, 2])
def test_processor_prompt_replacements(
def test_processor_prompt_replacements_regression(
processor_for_llava_onevision,
model_id: str,
image_size: tuple[int, int],
num_imgs: int,
):
"""
Ensure LlavaOnevisionMultiModalProcessor handles prompt replacement
properly.
"""
ctx = build_model_context(
model_name=model_id,
tokenizer_name=model_id,
Expand All @@ -39,22 +90,56 @@ def test_processor_prompt_replacements(
)
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
ctx = InputProcessingContext(ctx.model_config, tokenizer)
processor = processor_for_llava_onevision(ctx)

# Build the image str / prompt based on the number of images we pass
prompt = "<image>" * num_imgs
mm_data = {"image": [Image.new("RGB", size=image_size)] * num_imgs}
image_ratios = [(171, 152), (184, 161), (198, 176), (333, 296), (369, 328),
(488, 183), (2560, 1669)]
image_sizes = [
size for w, h in image_ratios
for size in [ImageSize(w, h), ImageSize(h, w)]
]

_test_image_prompt_replacements(
processor,
num_imgs=num_imgs,
image_sizes=image_sizes,
)

# The processor will throw an error if there is a mismatch
# in the prompt replacements

@pytest.mark.skip("This test takes around 2 hours to run. "
"Comment this out to run it manually.")
@pytest.mark.parametrize("model_id",
["llava-hf/llava-onevision-qwen2-0.5b-ov-hf"])
@pytest.mark.parametrize("num_imgs", [1])
def test_processor_prompt_replacements_all(
processor_for_llava_onevision,
model_id: str,
num_imgs: int,
):
ctx = build_model_context(
model_name=model_id,
tokenizer_name=model_id,
mm_processor_kwargs=None,
limit_mm_per_prompt={"image": num_imgs},
)
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
ctx = InputProcessingContext(ctx.model_config, tokenizer)
processor = processor_for_llava_onevision(ctx)
processed_inputs = processor.apply(prompt, mm_data, {})

image_placeholders = processed_inputs["mm_placeholders"]["image"]
assert len(image_placeholders) == num_imgs
seen_aspect_ratios = set[float]()
image_sizes = list[ImageSize]()

first_placeholder = image_placeholders[0]
# The aspect ratio of the grid layout is between 1 and 6
# NOTE: Assumes that feature size calculation is the same if we
# swap the width and height of the image
for w, h in itertools.product(range(64, 1024), repeat=2):
aspect_ratio = w / h
if 1 <= aspect_ratio <= 6 and aspect_ratio not in seen_aspect_ratios:
image_sizes.append(ImageSize(w, h))
seen_aspect_ratios.add(aspect_ratio)

# NOTE: There is a BOS token
assert first_placeholder["offset"] == 0
assert first_placeholder["length"] == len(
processed_inputs["prompt_token_ids"]) // num_imgs
_test_image_prompt_replacements(
processor,
num_imgs=num_imgs,
image_sizes=image_sizes,
)
Loading

0 comments on commit 8f37be3

Please sign in to comment.