diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index d4e6406589e98..7ea3d78c4b8e9 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -31,6 +31,7 @@ ProcessingMixin, PromptReplacement) from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs from vllm.sequence import IntermediateTensors +from vllm.utils import is_list_of from .clip import CLIPVisionModel from .interfaces import SupportsMultiModal, SupportsPP @@ -521,7 +522,7 @@ def sampler(self): return get_sampler() def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: - # The image size may be different for Pixtral-HF + # Only the longest edge is equal to image_size for Pixtral-HF if self.config.vision_config.model_type == "pixtral": return data @@ -550,10 +551,12 @@ def _parse_and_validate_image_input( raise ValueError("Incorrect type of pixel values. " f"Got type: {type(pixel_values)}") + pixel_values = flatten_bn(pixel_values, + concat=is_list_of(pixel_values, list)) + return LlavaImagePixelInputs( type="pixel_values", - data=self._validate_pixel_values( - flatten_bn(pixel_values, concat=True)), + data=self._validate_pixel_values(pixel_values), ) if image_embeds is not None: diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 31017f16d3c97..4ed3b237ae0e2 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -281,6 +281,15 @@ def flatten_bn( ... +@overload +def flatten_bn( + x: Union[List[torch.Tensor], torch.Tensor], + *, + concat: bool = False, +) -> Union[List[torch.Tensor], torch.Tensor]: + ... + + def flatten_bn( x: Union[List[torch.Tensor], torch.Tensor], *,