Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[VLM] Move supported limits and max tokens to merged multi-modal processor #11669

Merged
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest
from transformers import AutoTokenizer

from vllm.inputs import InputContext, InputProcessingContext
from vllm.inputs import InputProcessingContext
from vllm.model_executor.models.phi3v import _IMAGE_TOKEN_ID

from .....conftest import _ImageAssets
Expand All @@ -20,42 +20,6 @@ def processor_for_phi3v():
return Phi3VMultiModalProcessor


@pytest.fixture()
def get_max_phi3v_image_tokens():
from vllm.model_executor.models.phi3v import get_max_phi3v_image_tokens
return get_max_phi3v_image_tokens


@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize("num_crops,expected_max_tokens", [
(4, 781),
(16, 2653),
])
def test_max_tokens_override(get_max_phi3v_image_tokens, model: str,
num_crops: int, expected_max_tokens: int):
"""Ensure get_max_phi3v_image_tokens handles num_crops properly."""
# NOTE: mm_processor_kwargs on the context in this test is unused, since
# this is testing the mapper directly. In practice, the processor kwargs
# are wrapped in a closure when calling the max tokens func. We explicitly
# do NOT use the mm_processor_kwargs in the model context here to ensure
# that the max image tokens implementation is referencing a mix of the
# kwargs to the function and the original mm_processor_kwargs in case
# values are somehow updated and end up in a bad state.
ctx = build_model_context(
model_name=model,
tokenizer_name=model,
trust_remote_code=True,
mm_processor_kwargs=None,
)

actual_max_tokens = get_max_phi3v_image_tokens(
InputContext(ctx.model_config),
num_crops=num_crops,
)

assert expected_max_tokens == actual_max_tokens
Comment on lines -34 to -56
Copy link
Member Author

@DarkLight1337 DarkLight1337 Jan 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test is unnecessary because we already check its consistency against the dummy data in multi-modal processor at profiling stage.



@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize(
"num_crops,expected_toks_per_img",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytest
from transformers import AutoTokenizer

from vllm.inputs import InputContext, InputProcessingContext
from vllm.inputs import InputProcessingContext

from .....conftest import _ImageAssets
from ....utils import build_model_context
Expand All @@ -22,39 +22,6 @@ def processor_for_qwen2_vl():
return Qwen2VLMultiModalProcessor


@pytest.fixture()
def get_max_qwen2_vl_image_tokens():
from vllm.model_executor.models.qwen2_vl import (
get_max_qwen2_vl_image_tokens)
return get_max_qwen2_vl_image_tokens


@pytest.mark.parametrize("mm_processor_kwargs,expected_max_tokens", [
({}, 16384),
({
MIN_PIXELS: 64**2,
MAX_PIXELS: 512**2
}, 324),
])
@pytest.mark.parametrize("model", [MODEL])
def test_qwen2_vl_max_image_tokens(
get_max_qwen2_vl_image_tokens,
model: str,
mm_processor_kwargs: Dict[str, Any],
expected_max_tokens: int,
):
"""Ensure that the max token calc handles min/max pixels properly."""
ctx = build_model_context(
model_name=model,
tokenizer_name=model,
mm_processor_kwargs=None,
)

actual_max_tokens = get_max_qwen2_vl_image_tokens(
InputContext(ctx.model_config), **mm_processor_kwargs)
assert actual_max_tokens == expected_max_tokens


Comment on lines -25 to -57
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto

@pytest.mark.parametrize(
"mm_processor_kwargs, expected_toks_per_img, expected_pixels_shape", [
({}, 1426, (5704, 1176)),
Expand Down
8 changes: 1 addition & 7 deletions vllm/inputs/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,13 +331,7 @@ def dummy_data_for_profiling(
trust_remote_code=model_config.trust_remote_code,
)
processor = mm_registry.create_processor(model_config, tokenizer)

mm_counts = mm_registry.get_mm_limits_per_prompt(model_config)
mm_max_tokens = mm_registry.get_max_tokens_by_modality(
model_config)

dummy_data = processor.get_dummy_data(seq_len, mm_counts,
mm_max_tokens)
dummy_data = processor.get_dummy_data(seq_len)
else:
model_cls, _ = get_model_architecture(model_config)
if is_encoder_data:
Expand Down
75 changes: 42 additions & 33 deletions vllm/model_executor/models/aria.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import (Iterable, List, Mapping, Optional, Set, Tuple, TypedDict,
Union)
from typing import (Callable, Iterable, List, Mapping, Optional, Set, Tuple,
TypedDict, Union)

import torch
import torch.nn as nn
Expand All @@ -9,7 +9,6 @@
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, QuantizationConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_rank
from vllm.inputs import InputContext
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
Expand Down Expand Up @@ -87,8 +86,8 @@ def __init__(
def forward(
self,
pixel_values: torch.Tensor,
pixel_mask: Optional[torch.BoolTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.BoolTensor]]:
pixel_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
patch_attention_mask = self._create_patch_attention_mask(pixel_mask)

vit_oup = self.vision_model(
Expand All @@ -100,7 +99,8 @@ def forward(

return vit_oup, image_atts

def _create_patch_attention_mask(self, pixel_mask):
def _create_patch_attention_mask(
self, pixel_mask: Optional[torch.Tensor]) -> torch.Tensor:
if pixel_mask is None:
return None

Expand All @@ -115,7 +115,8 @@ def _create_patch_attention_mask(self, pixel_mask):
)
return (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()

def _create_image_attention_mask(self, patch_attention_mask):
def _create_image_attention_mask(
self, patch_attention_mask: torch.Tensor) -> torch.Tensor:
if patch_attention_mask is None:
return None

Expand All @@ -125,13 +126,13 @@ def _create_image_attention_mask(self, patch_attention_mask):

class FFN(nn.Module):

def __init__(self, embed_dim, ff_dim, output_dim):
def __init__(self, embed_dim: int, ff_dim: int, output_dim: int) -> None:
super().__init__()
self.linear_in = ColumnParallelLinear(embed_dim, ff_dim, bias=False)
self.linear_out = RowParallelLinear(ff_dim, output_dim, bias=False)
self.act = get_act_fn("gelu_new")

def forward(self, hidden_states):
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states, _ = self.linear_in(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states, _ = self.linear_out(hidden_states)
Expand All @@ -140,7 +141,7 @@ def forward(self, hidden_states):

class CrossAttention(nn.Module):

def __init__(self, kv_dim, embed_dim, num_heads, drop_out_rate=0):
def __init__(self, kv_dim: int, embed_dim: int, num_heads: int) -> None:
super().__init__()
self.num_heads = num_heads
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=False)
Expand All @@ -149,12 +150,16 @@ def __init__(self, kv_dim, embed_dim, num_heads, drop_out_rate=0):

self.multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
self.linear = nn.Linear(embed_dim, embed_dim)
self.dropout = nn.Dropout(drop_out_rate)

self.layer_norm = nn.LayerNorm(embed_dim)
self.ln_kv = nn.LayerNorm(kv_dim)

def forward(self, x, hidden_states, attn_mask=None, add_residual=False):
def forward(
self,
x: torch.Tensor,
hidden_states: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
normed_hidden_states = self.layer_norm(hidden_states)
query = self.q_proj(normed_hidden_states).permute(1, 0, 2)

Expand All @@ -169,11 +174,7 @@ def forward(self, x, hidden_states, attn_mask=None, add_residual=False):

attn_output = attn_output.permute(1, 0, 2)

if add_residual:
attn_output = hidden_states + self.dropout(
self.linear(attn_output))
else:
attn_output = self.dropout(self.linear(attn_output))
attn_output = self.linear(attn_output)

return attn_output

Expand Down Expand Up @@ -201,14 +202,14 @@ class AriaProjector(nn.Module):

def __init__(
self,
patch_to_query_dict,
embed_dim,
num_heads,
kv_dim,
ff_dim,
output_dim,
norm_layer=nn.LayerNorm,
):
patch_to_query_dict: dict[int, int],
embed_dim: int,
num_heads: int,
kv_dim: int,
ff_dim: int,
output_dim: int,
norm_layer: Callable[[int], nn.Module] = nn.LayerNorm,
) -> None:
super().__init__()
self.patch_to_query_dict = patch_to_query_dict
self.embed_dim = embed_dim
Expand All @@ -224,7 +225,11 @@ def __init__(
self.ln_ffn = norm_layer(embed_dim)
self.ffn = FFN(embed_dim, ff_dim, output_dim)

def forward(self, x, attn_mask=None):
def forward(
self,
x: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
bs = x.shape[0]
queries = self.query.unsqueeze(0).repeat(bs, 1, 1)

Expand Down Expand Up @@ -442,12 +447,17 @@ def build_mm_projector(config: PretrainedConfig):
)


def get_max_aria_image_tokens(ctx: InputContext):
hf_config = ctx.get_hf_config()
return max(hf_config.projector_patch_to_query_dict.values())
class AriaMultiModalProcessor(BaseMultiModalProcessor):

def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None}

def _get_num_image_tokens(self) -> int:
hf_config = self.ctx.get_hf_config()
return max(hf_config.projector_patch_to_query_dict.values())

class AriaMultiModalProcessor(BaseMultiModalProcessor):
def get_mm_max_tokens_per_item(self) -> Mapping[str, int]:
return {"image": self._get_num_image_tokens()}

def _get_mm_fields_config(
self,
Expand All @@ -468,13 +478,13 @@ def _get_prompt_replacements(
hf_config = self.ctx.get_hf_config()
image_token_id = hf_config.image_token_index

max_image_tokens = get_max_aria_image_tokens(self.ctx)
num_image_tokens = self._get_num_image_tokens()

return [
PromptReplacement(
modality="image",
target=[image_token_id],
replacement=[image_token_id] * max_image_tokens,
replacement=[image_token_id] * num_image_tokens,
)
]

Expand Down Expand Up @@ -504,7 +514,6 @@ def _get_dummy_mm_inputs(
)


@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_aria_image_tokens)
@MULTIMODAL_REGISTRY.register_processor(AriaMultiModalProcessor)
class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
"""
Expand Down
19 changes: 9 additions & 10 deletions vllm/model_executor/models/blip2.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, VllmConfig
from vllm.inputs import InputContext
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
Expand All @@ -18,7 +17,6 @@
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputsV2, MultiModalKwargs,
NestedTensors, PlaceholderRange)
from vllm.multimodal.parse import MultiModalDataParser
from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataItems, ProcessorInputs,
PromptReplacement)
Expand Down Expand Up @@ -398,15 +396,17 @@ def forward(
return sequence_output


def get_max_blip2_image_tokens(ctx: InputContext):
hf_config = ctx.get_hf_config(Blip2Config)
return hf_config.num_query_tokens
class Blip2MultiModalProcessor(BaseMultiModalProcessor):

def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": 1}

class Blip2MultiModalProcessor(BaseMultiModalProcessor):
def _get_num_image_tokens(self) -> int:
hf_config = self.ctx.get_hf_config(Blip2Config)
return hf_config.num_query_tokens

def _get_data_parser(self) -> MultiModalDataParser:
return MultiModalDataParser(max_mm_counts={"image": 1})
def get_mm_max_tokens_per_item(self) -> Mapping[str, int]:
return {"image": self._get_num_image_tokens()}

def _get_hf_processor(self) -> Blip2Processor:
return self.ctx.get_hf_processor(Blip2Processor)
Expand All @@ -427,7 +427,7 @@ def _get_prompt_replacements(
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
max_image_tokens = get_max_blip2_image_tokens(self.ctx)
max_image_tokens = self._get_num_image_tokens()

return [
PromptReplacement(
Expand Down Expand Up @@ -480,7 +480,6 @@ def _get_dummy_mm_inputs(
)


@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_blip2_image_tokens)
@MULTIMODAL_REGISTRY.register_processor(Blip2MultiModalProcessor)
class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):

Expand Down
Loading
Loading