From dc7d1ed9bef667726e4f5ca873a0b2a4bce2dcee Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Wed, 8 Jan 2025 10:17:16 +0800 Subject: [PATCH] [Bugfix] Fix image input for Pixtral-HF (#11741) Signed-off-by: DarkLight1337 Signed-off-by: Fred Reiss --- ...e_inference_vision_language_multi_image.py | 41 ++++++++++++++++--- vllm/model_executor/models/llava.py | 6 +++ vllm/model_executor/models/pixtral.py | 2 +- vllm/model_executor/models/utils.py | 9 ++++ 4 files changed, 52 insertions(+), 6 deletions(-) diff --git a/examples/offline_inference_vision_language_multi_image.py b/examples/offline_inference_vision_language_multi_image.py index 6af8d7768e75d..cf2e90a325c6a 100644 --- a/examples/offline_inference_vision_language_multi_image.py +++ b/examples/offline_inference_vision_language_multi_image.py @@ -23,7 +23,7 @@ class ModelRequestData(NamedTuple): llm: LLM prompt: str - stop_token_ids: Optional[List[str]] + stop_token_ids: Optional[List[int]] image_data: List[Image] chat_template: Optional[str] @@ -44,12 +44,14 @@ def load_aria(question, image_urls: List[str]) -> ModelRequestData: prompt = (f"<|im_start|>user\n{placeholders}{question}<|im_end|>\n" "<|im_start|>assistant\n") stop_token_ids = [93532, 93653, 944, 93421, 1019, 93653, 93519] + return ModelRequestData( llm=llm, prompt=prompt, stop_token_ids=stop_token_ids, image_data=[fetch_image(url) for url in image_urls], - chat_template=None) + chat_template=None, + ) def load_h2onvl(question: str, image_urls: List[str]) -> ModelRequestData: @@ -166,7 +168,8 @@ def load_mllama(question, image_urls: List[str]) -> ModelRequestData: limit_mm_per_prompt={"image": len(image_urls)}, ) - prompt = f"<|image|><|image|><|begin_of_text|>{question}" + placeholders = "<|image|>" * len(image_urls) + prompt = f"{placeholders}<|begin_of_text|>{question}" return ModelRequestData( llm=llm, prompt=prompt, @@ -209,6 +212,31 @@ def load_nvlm_d(question: str, image_urls: List[str]): ) +def load_pixtral_hf(question: str, image_urls: List[str]) -> ModelRequestData: + model_name = "mistral-community/pixtral-12b" + + # Adjust this as necessary to fit in GPU + llm = LLM( + model=model_name, + max_model_len=8192, + max_num_seqs=2, + tensor_parallel_size=2, + limit_mm_per_prompt={"image": len(image_urls)}, + ) + + placeholders = "[IMG]" * len(image_urls) + prompt = f"[INST]{question}\n{placeholders}[/INST]" + stop_token_ids = None + + return ModelRequestData( + llm=llm, + prompt=prompt, + stop_token_ids=stop_token_ids, + image_data=[fetch_image(url) for url in image_urls], + chat_template=None, + ) + + def load_phi3v(question: str, image_urls: List[str]) -> ModelRequestData: # num_crops is an override kwarg to the multimodal image processor; # For some models, e.g., Phi-3.5-vision-instruct, it is recommended @@ -244,7 +272,8 @@ def load_phi3v(question: str, image_urls: List[str]) -> ModelRequestData: ) -def load_qwenvl_chat(question: str, image_urls: List[str]) -> ModelRequestData: +def load_qwen_vl_chat(question: str, + image_urls: List[str]) -> ModelRequestData: model_name = "Qwen/Qwen-VL-Chat" llm = LLM( model=model_name, @@ -274,6 +303,7 @@ def load_qwenvl_chat(question: str, image_urls: List[str]) -> ModelRequestData: stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>"] stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens] + return ModelRequestData( llm=llm, prompt=prompt, @@ -348,7 +378,8 @@ def load_qwen2_vl(question, image_urls: List[str]) -> ModelRequestData: "mllama": load_mllama, "NVLM_D": load_nvlm_d, "phi3_v": load_phi3v, - "qwen_vl_chat": load_qwenvl_chat, + "pixtral_hf": load_pixtral_hf, + "qwen_vl_chat": load_qwen_vl_chat, "qwen2_vl": load_qwen2_vl, } diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 4299af8cd03a2..305f1364dba23 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -546,6 +546,12 @@ def _parse_and_validate_image_input( raise ValueError("Incorrect type of pixel values. " f"Got type: {type(pixel_values)}") + if self.config.vision_config.model_type == "pixtral": + return LlavaImagePixelInputs( + type="pixel_values", + data=flatten_bn(pixel_values), + ) + return LlavaImagePixelInputs( type="pixel_values", data=self._validate_pixel_values( diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index 9e1d38512c0b4..b74bb3c8a3f88 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -774,7 +774,7 @@ def get_num_image_tokens( ) -> int: return get_pixtral_hf_image_feature_size( image_size=self.vision_config.image_size, - patch_size=self.get_image_size(), + patch_size=self.vision_config.patch_size, ) def get_max_image_tokens(self) -> int: diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 31017f16d3c97..4ed3b237ae0e2 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -281,6 +281,15 @@ def flatten_bn( ... +@overload +def flatten_bn( + x: Union[List[torch.Tensor], torch.Tensor], + *, + concat: bool = False, +) -> Union[List[torch.Tensor], torch.Tensor]: + ... + + def flatten_bn( x: Union[List[torch.Tensor], torch.Tensor], *,