Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix] Fix image input for Pixtral-HF #11741

Merged
merged 11 commits into from
Jan 8, 2025
41 changes: 36 additions & 5 deletions examples/offline_inference_vision_language_multi_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
class ModelRequestData(NamedTuple):
llm: LLM
prompt: str
stop_token_ids: Optional[List[str]]
stop_token_ids: Optional[List[int]]
image_data: List[Image]
chat_template: Optional[str]

Expand All @@ -44,12 +44,14 @@ def load_aria(question, image_urls: List[str]) -> ModelRequestData:
prompt = (f"<|im_start|>user\n{placeholders}{question}<|im_end|>\n"
"<|im_start|>assistant\n")
stop_token_ids = [93532, 93653, 944, 93421, 1019, 93653, 93519]

return ModelRequestData(
llm=llm,
prompt=prompt,
stop_token_ids=stop_token_ids,
image_data=[fetch_image(url) for url in image_urls],
chat_template=None)
chat_template=None,
)


def load_h2onvl(question: str, image_urls: List[str]) -> ModelRequestData:
Expand Down Expand Up @@ -166,7 +168,8 @@ def load_mllama(question, image_urls: List[str]) -> ModelRequestData:
limit_mm_per_prompt={"image": len(image_urls)},
)

prompt = f"<|image|><|image|><|begin_of_text|>{question}"
placeholders = "<|image|>" * len(image_urls)
prompt = f"{placeholders}<|begin_of_text|>{question}"
return ModelRequestData(
llm=llm,
prompt=prompt,
Expand Down Expand Up @@ -209,6 +212,31 @@ def load_nvlm_d(question: str, image_urls: List[str]):
)


def load_pixtral_hf(question: str, image_urls: List[str]) -> ModelRequestData:
model_name = "mistral-community/pixtral-12b"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you want to use chat template, you could use this model where I added it https://huggingface.co/mgoin/pixtral-12b

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the prompt format is straightforward enough that we don't need a chat template for this.


# Adjust this as necessary to fit in GPU
llm = LLM(
model=model_name,
max_model_len=8192,
max_num_seqs=2,
tensor_parallel_size=2,
limit_mm_per_prompt={"image": len(image_urls)},
)

placeholders = "[IMG]" * len(image_urls)
prompt = f"<s>[INST]{question}\n{placeholders}[/INST]"
stop_token_ids = None

return ModelRequestData(
llm=llm,
prompt=prompt,
stop_token_ids=stop_token_ids,
image_data=[fetch_image(url) for url in image_urls],
chat_template=None,
)


def load_phi3v(question: str, image_urls: List[str]) -> ModelRequestData:
# num_crops is an override kwarg to the multimodal image processor;
# For some models, e.g., Phi-3.5-vision-instruct, it is recommended
Expand Down Expand Up @@ -244,7 +272,8 @@ def load_phi3v(question: str, image_urls: List[str]) -> ModelRequestData:
)


def load_qwenvl_chat(question: str, image_urls: List[str]) -> ModelRequestData:
def load_qwen_vl_chat(question: str,
image_urls: List[str]) -> ModelRequestData:
model_name = "Qwen/Qwen-VL-Chat"
llm = LLM(
model=model_name,
Expand Down Expand Up @@ -274,6 +303,7 @@ def load_qwenvl_chat(question: str, image_urls: List[str]) -> ModelRequestData:

stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>"]
stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]

return ModelRequestData(
llm=llm,
prompt=prompt,
Expand Down Expand Up @@ -348,7 +378,8 @@ def load_qwen2_vl(question, image_urls: List[str]) -> ModelRequestData:
"mllama": load_mllama,
"NVLM_D": load_nvlm_d,
"phi3_v": load_phi3v,
"qwen_vl_chat": load_qwenvl_chat,
"pixtral_hf": load_pixtral_hf,
"qwen_vl_chat": load_qwen_vl_chat,
"qwen2_vl": load_qwen2_vl,
}

Expand Down
6 changes: 6 additions & 0 deletions vllm/model_executor/models/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,6 +546,12 @@ def _parse_and_validate_image_input(
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")

if self.config.vision_config.model_type == "pixtral":
return LlavaImagePixelInputs(
type="pixel_values",
data=flatten_bn(pixel_values),
)

return LlavaImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/models/pixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,7 +774,7 @@ def get_num_image_tokens(
) -> int:
return get_pixtral_hf_image_feature_size(
image_size=self.vision_config.image_size,
patch_size=self.get_image_size(),
patch_size=self.vision_config.patch_size,
)

def get_max_image_tokens(self) -> int:
Expand Down
9 changes: 9 additions & 0 deletions vllm/model_executor/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,15 @@ def flatten_bn(
...


@overload
def flatten_bn(
x: Union[List[torch.Tensor], torch.Tensor],
*,
concat: bool = False,
) -> Union[List[torch.Tensor], torch.Tensor]:
...
comaniac marked this conversation as resolved.
Show resolved Hide resolved


def flatten_bn(
x: Union[List[torch.Tensor], torch.Tensor],
*,
Expand Down
Loading