Skip to content

Commit

Permalink
[Bugfix] Fix image input for Pixtral-HF (vllm-project#11741)
Browse files Browse the repository at this point in the history
Signed-off-by: DarkLight1337 <[email protected]>
  • Loading branch information
DarkLight1337 authored and frreiss committed Jan 10, 2025
1 parent 43a3e68 commit 550b91c
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 6 deletions.
41 changes: 36 additions & 5 deletions examples/offline_inference_vision_language_multi_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
class ModelRequestData(NamedTuple):
llm: LLM
prompt: str
stop_token_ids: Optional[List[str]]
stop_token_ids: Optional[List[int]]
image_data: List[Image]
chat_template: Optional[str]

Expand All @@ -44,12 +44,14 @@ def load_aria(question, image_urls: List[str]) -> ModelRequestData:
prompt = (f"<|im_start|>user\n{placeholders}{question}<|im_end|>\n"
"<|im_start|>assistant\n")
stop_token_ids = [93532, 93653, 944, 93421, 1019, 93653, 93519]

return ModelRequestData(
llm=llm,
prompt=prompt,
stop_token_ids=stop_token_ids,
image_data=[fetch_image(url) for url in image_urls],
chat_template=None)
chat_template=None,
)


def load_h2onvl(question: str, image_urls: List[str]) -> ModelRequestData:
Expand Down Expand Up @@ -166,7 +168,8 @@ def load_mllama(question, image_urls: List[str]) -> ModelRequestData:
limit_mm_per_prompt={"image": len(image_urls)},
)

prompt = f"<|image|><|image|><|begin_of_text|>{question}"
placeholders = "<|image|>" * len(image_urls)
prompt = f"{placeholders}<|begin_of_text|>{question}"
return ModelRequestData(
llm=llm,
prompt=prompt,
Expand Down Expand Up @@ -209,6 +212,31 @@ def load_nvlm_d(question: str, image_urls: List[str]):
)


def load_pixtral_hf(question: str, image_urls: List[str]) -> ModelRequestData:
model_name = "mistral-community/pixtral-12b"

# Adjust this as necessary to fit in GPU
llm = LLM(
model=model_name,
max_model_len=8192,
max_num_seqs=2,
tensor_parallel_size=2,
limit_mm_per_prompt={"image": len(image_urls)},
)

placeholders = "[IMG]" * len(image_urls)
prompt = f"<s>[INST]{question}\n{placeholders}[/INST]"
stop_token_ids = None

return ModelRequestData(
llm=llm,
prompt=prompt,
stop_token_ids=stop_token_ids,
image_data=[fetch_image(url) for url in image_urls],
chat_template=None,
)


def load_phi3v(question: str, image_urls: List[str]) -> ModelRequestData:
# num_crops is an override kwarg to the multimodal image processor;
# For some models, e.g., Phi-3.5-vision-instruct, it is recommended
Expand Down Expand Up @@ -244,7 +272,8 @@ def load_phi3v(question: str, image_urls: List[str]) -> ModelRequestData:
)


def load_qwenvl_chat(question: str, image_urls: List[str]) -> ModelRequestData:
def load_qwen_vl_chat(question: str,
image_urls: List[str]) -> ModelRequestData:
model_name = "Qwen/Qwen-VL-Chat"
llm = LLM(
model=model_name,
Expand Down Expand Up @@ -274,6 +303,7 @@ def load_qwenvl_chat(question: str, image_urls: List[str]) -> ModelRequestData:

stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>"]
stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]

return ModelRequestData(
llm=llm,
prompt=prompt,
Expand Down Expand Up @@ -348,7 +378,8 @@ def load_qwen2_vl(question, image_urls: List[str]) -> ModelRequestData:
"mllama": load_mllama,
"NVLM_D": load_nvlm_d,
"phi3_v": load_phi3v,
"qwen_vl_chat": load_qwenvl_chat,
"pixtral_hf": load_pixtral_hf,
"qwen_vl_chat": load_qwen_vl_chat,
"qwen2_vl": load_qwen2_vl,
}

Expand Down
6 changes: 6 additions & 0 deletions vllm/model_executor/models/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,6 +546,12 @@ def _parse_and_validate_image_input(
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")

if self.config.vision_config.model_type == "pixtral":
return LlavaImagePixelInputs(
type="pixel_values",
data=flatten_bn(pixel_values),
)

return LlavaImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/models/pixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,7 +774,7 @@ def get_num_image_tokens(
) -> int:
return get_pixtral_hf_image_feature_size(
image_size=self.vision_config.image_size,
patch_size=self.get_image_size(),
patch_size=self.vision_config.patch_size,
)

def get_max_image_tokens(self) -> int:
Expand Down
9 changes: 9 additions & 0 deletions vllm/model_executor/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,15 @@ def flatten_bn(
...


@overload
def flatten_bn(
x: Union[List[torch.Tensor], torch.Tensor],
*,
concat: bool = False,
) -> Union[List[torch.Tensor], torch.Tensor]:
...


def flatten_bn(
x: Union[List[torch.Tensor], torch.Tensor],
*,
Expand Down

0 comments on commit 550b91c

Please sign in to comment.