Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[VLM] Separate out profiling-related logic #11746

Merged
merged 11 commits into from
Jan 6, 2025
7 changes: 4 additions & 3 deletions tests/multimodal/test_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,17 +586,18 @@ def test_limit_mm_per_prompt_dummy(model_id, limit, num_supported, is_valid):
)

processor = processor_factory(ctx, cache=None)
profiler = processor.profiling_info

mock_supported_mm_limits = MagicMock(return_value={"image": num_supported})
processor.get_supported_mm_limits = mock_supported_mm_limits
profiler.get_supported_mm_limits = mock_supported_mm_limits

if is_valid:
exc_ctx = nullcontext()
else:
exc_ctx = pytest.raises(ValueError, match="this model only supports")

with exc_ctx:
processor._get_and_validate_dummy_mm_counts()
profiler.get_mm_limits()


@pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"])
Expand Down Expand Up @@ -723,7 +724,7 @@ def _test_processing_cache_correctness(
}

mm_counts = {k: len(vs) for k, vs in mm_data.items()}
prompt = baseline_processor._get_dummy_processor_inputs(
prompt = baseline_processor.profiling_info.get_dummy_processor_inputs(
model_config.max_model_len,
mm_counts,
).prompt_text
Expand Down
79 changes: 47 additions & 32 deletions vllm/model_executor/models/aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
NestedTensors)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataItems, ProcessorInputs,
MultiModalDataItems, ProcessingMixin,
PromptReplacement)
from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.aria import (AriaMoELMConfig,
AriaVisionConfig)
Expand Down Expand Up @@ -444,18 +445,58 @@ def build_mm_projector(config: PretrainedConfig):
)


class AriaMultiModalProcessor(BaseMultiModalProcessor):
class AriaProcessingMixin(ProcessingMixin):

def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None}
def _get_hf_config(self):
return self.ctx.get_hf_config()

def _get_vision_config(self) -> AriaVisionConfig:
return self._get_hf_config().vision_config

def _get_num_image_tokens(self) -> int:
hf_config = self.ctx.get_hf_config()
hf_config = self._get_hf_config()
return max(hf_config.projector_patch_to_query_dict.values())


class AriaProfilingInfo(AriaProcessingMixin, BaseProfilingInfo):

def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None}

def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
return {"image": self._get_num_image_tokens()}

def get_dummy_processor_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
vision_config = self._get_vision_config()

max_image_size = vision_config.image_size
num_images = mm_counts.get("image", 0)

mm_data = {
"image":
self._get_dummy_images(width=max_image_size,
height=max_image_size,
num_images=num_images)
}

hf_processor = self._get_hf_processor()
image_token: str = hf_processor.image_token # type: ignore

return ProcessorInputs(
prompt_text=image_token * num_images,
mm_data=mm_data,
)


class AriaMultiModalProcessor(AriaProcessingMixin, BaseMultiModalProcessor):

def _get_profiling_info(self) -> BaseProfilingInfo:
return AriaProfilingInfo(self.ctx)

def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
Expand All @@ -472,7 +513,7 @@ def _get_prompt_replacements(
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
hf_config = self.ctx.get_hf_config()
hf_config = self._get_hf_config()
image_token_id = hf_config.image_token_index

num_image_tokens = self._get_num_image_tokens()
Expand All @@ -485,32 +526,6 @@ def _get_prompt_replacements(
)
]

def _get_dummy_processor_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
hf_config = self.ctx.get_hf_config()
vision_config: AriaVisionConfig = hf_config.vision_config

max_image_size = vision_config.image_size
num_images = mm_counts.get("image", 0)

mm_data = {
"image":
self._get_dummy_images(width=max_image_size,
height=max_image_size,
num_images=num_images)
}

hf_processor = self._get_hf_processor()
image_token: str = hf_processor.image_token # type: ignore

return ProcessorInputs(
prompt_text=image_token * num_images,
mm_data=mm_data,
)


@MULTIMODAL_REGISTRY.register_processor(AriaMultiModalProcessor)
class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
Expand Down
78 changes: 44 additions & 34 deletions vllm/model_executor/models/blip2.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

import torch
import torch.nn as nn
from transformers import (BatchFeature, Blip2Config, Blip2Processor,
Blip2QFormerConfig, apply_chunking_to_forward)
from transformers import (BatchFeature, Blip2Config, Blip2QFormerConfig,
apply_chunking_to_forward)

from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, VllmConfig
Expand All @@ -18,8 +18,9 @@
MultiModalInputsV2, MultiModalKwargs,
NestedTensors, PlaceholderRange)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataItems, ProcessorInputs,
MultiModalDataItems, ProcessingMixin,
PromptReplacement)
from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs
from vllm.sequence import IntermediateTensors

from .blip import BlipVisionModel
Expand Down Expand Up @@ -396,20 +397,52 @@ def forward(
return sequence_output


class Blip2MultiModalProcessor(BaseMultiModalProcessor):
class Blip2ProcessingMixin(ProcessingMixin):

def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": 1}
def _get_hf_config(self):
return self.ctx.get_hf_config(Blip2Config)

def _get_num_image_tokens(self) -> int:
hf_config = self.ctx.get_hf_config(Blip2Config)
hf_config = self._get_hf_config()
return hf_config.num_query_tokens


class Blip2ProfilingInfo(Blip2ProcessingMixin, BaseProfilingInfo):

def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": 1}

def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
return {"image": self._get_num_image_tokens()}

def _get_hf_processor(self) -> Blip2Processor:
return self.ctx.get_hf_processor(Blip2Processor)
def get_dummy_processor_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
hf_config = self._get_hf_config()
vision_config = hf_config.vision_config

max_image_size = vision_config.image_size
num_images = mm_counts.get("image", 0)

mm_data = {
"image":
self._get_dummy_images(width=max_image_size,
height=max_image_size,
num_images=num_images)
}

return ProcessorInputs(
prompt_text="",
mm_data=mm_data,
)


class Blip2MultiModalProcessor(Blip2ProcessingMixin, BaseMultiModalProcessor):

def _get_profiling_info(self) -> BaseProfilingInfo:
return Blip2ProfilingInfo(self.ctx)

def _get_mm_fields_config(
self,
Expand All @@ -427,13 +460,13 @@ def _get_prompt_replacements(
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
max_image_tokens = self._get_num_image_tokens()
num_image_tokens = self._get_num_image_tokens()

return [
PromptReplacement(
modality="image",
target="</s>",
replacement="<image>" * max_image_tokens + "</s>",
replacement="<image>" * num_image_tokens + "</s>",
)
]

Expand All @@ -457,29 +490,6 @@ def apply(

return result

def _get_dummy_processor_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
hf_config = self.ctx.get_hf_config(Blip2Config)
vision_config = hf_config.vision_config

max_image_size = vision_config.image_size
num_images = mm_counts.get("image", 0)

mm_data = {
"image":
self._get_dummy_images(width=max_image_size,
height=max_image_size,
num_images=num_images)
}

return ProcessorInputs(
prompt_text="",
mm_data=mm_data,
)


@MULTIMODAL_REGISTRY.register_processor(Blip2MultiModalProcessor)
class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
Expand Down
72 changes: 43 additions & 29 deletions vllm/model_executor/models/chameleon.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,9 @@
MultiModalInputsV2, MultiModalKwargs,
NestedTensors, PlaceholderRange)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataItems, ProcessorInputs,
MultiModalDataItems, ProcessingMixin,
PromptReplacement)
from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs
from vllm.sequence import IntermediateTensors
from vllm.utils import print_warning_once

Expand All @@ -48,20 +49,55 @@ class ChameleonImagePixelInputs(TypedDict):
"""Shape: `(batch_size * num_images, num_channels, height, width)`"""


class ChameleonMultiModalProcessor(BaseMultiModalProcessor):
class ChameleonProcessingMixin(ProcessingMixin):

def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": 1}
def _get_hf_config(self):
return self.ctx.get_hf_config(ChameleonConfig)

def _get_hf_processor(self):
return self.ctx.get_hf_processor(ChameleonProcessor)

def _get_num_image_tokens(self) -> int:
processor = self._get_hf_processor()
return processor.image_seq_length


class ChameleonProfilingInfo(ChameleonProcessingMixin, BaseProfilingInfo):

def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": 1}

def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
return {"image": self._get_num_image_tokens()}

def _get_hf_processor(self) -> ChameleonProcessor:
return self.ctx.get_hf_processor(ChameleonProcessor)
def get_dummy_processor_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
config = self._get_hf_config()

width = height = config.vq_config.resolution
num_images = mm_counts.get("image", 0)

mm_data = {
"image":
self._get_dummy_images(width=width,
height=height,
num_images=num_images)
}

return ProcessorInputs(
prompt_text="<image>" * num_images,
mm_data=mm_data,
)


class ChameleonMultiModalProcessor(ChameleonProcessingMixin,
BaseMultiModalProcessor):

def _get_profiling_info(self) -> BaseProfilingInfo:
return ChameleonProfilingInfo(self.ctx)

def _get_mm_fields_config(
self,
Expand All @@ -76,7 +112,7 @@ def _get_prompt_replacements(
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
processor = self._get_hf_processor()
processor = self._get_hf_processor(**hf_processor_mm_kwargs)

return [
PromptReplacement(
Expand All @@ -90,28 +126,6 @@ def _get_prompt_replacements(
)
]

def _get_dummy_processor_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
config = self.ctx.get_hf_config(ChameleonConfig)

width = height = config.vq_config.resolution
num_images = mm_counts.get("image", 0)

mm_data = {
"image":
self._get_dummy_images(width=width,
height=height,
num_images=num_images)
}

return ProcessorInputs(
prompt_text="<image>" * num_images,
mm_data=mm_data,
)

def apply(
self,
prompt_text: str,
Expand Down
Loading
Loading