Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix][V1] Fix molmo text-only inputs #11676

Merged
merged 14 commits into from
Jan 6, 2025
11 changes: 11 additions & 0 deletions tests/models/decoder_only/vision_language/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,17 @@
),
hf_output_post_proc=model_utils.minicpmv_trunc_hf_output,
),
"molmo": VLMTestInfo(
models=["allenai/Molmo-7B-D-0924"],
test_type=(VLMTestType.IMAGE),
prompt_formatter=lambda img_prompt:"User: " + img_prompt + " Assistant:", # noqa: E501
max_model_len=4096,
max_num_seqs=2,
image_size_factors=[(),(1.0, 1.0, 1.0)],
patch_hf_runner=model_utils.mlomo_patch_hf_runner,
postprocess_inputs=model_utils.molmo_post_processor

),
# Tests for phi3v currently live in another file because of a bug in
# transformers. Once this issue is fixed, we can enable them here instead.
# https://github.com/huggingface/transformers/issues/34307
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,20 @@
import re
import types
from pathlib import PosixPath
from typing import Callable, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import torch
from PIL.Image import Image
from transformers import AutoConfig, AutoTokenizer, BatchEncoding
from transformers import (AutoConfig, AutoTokenizer, BatchEncoding,
GenerationConfig)

from vllm.sequence import SampleLogprobs
from vllm.transformers_utils.tokenizer import patch_padding_side
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE

from .....conftest import HfRunner, ImageAsset, _ImageAssets
from .....conftest import (HfRunner, ImageAsset, PromptAudioInput,
PromptImageInput, PromptVideoInput, _ImageAssets)
from ....utils import TokensTextLogprobs
from .types import RunnerOutput


Expand Down Expand Up @@ -222,6 +225,11 @@
return {"model_inputs": hf_inputs}


def molmo_post_processor(hf_inputs: BatchEncoding, dtype: str):
hf_inputs = cast_dtype_post_processor("images")(hf_inputs, dtype)
return {k: v.unsqueeze(0) for k, v in hf_inputs.items()}


####### Prompt path encoders for models that need models on disk
def qwen_prompt_path_encoder(
tmp_path: PosixPath, prompt: str, assets: Union[List[ImageAsset],
Expand Down Expand Up @@ -451,3 +459,82 @@
hf_model.model.generate = types.MethodType(_generate, hf_model.model)

return hf_model


def _generate_greedy_logprobs_limit(
self,
prompts: List[str],
max_tokens: int,
num_logprobs: int,
images: Optional[PromptImageInput] = None,
audios: Optional[PromptAudioInput] = None,
videos: Optional[PromptVideoInput] = None,
**kwargs: Any,
) -> List[TokensTextLogprobs]:
all_inputs = self.get_inputs(prompts,
images=images,
videos=videos,
audios=audios)

# Process in batches for inference.
if len(all_inputs):
input_ids_lst = []
images_lst = []
images_input_idx_lst = []
imges_masks_lst = []
for inputs in all_inputs:
input_ids_lst.append(inputs["input_ids"])
images_lst.append(inputs["images"])
images_input_idx_lst.append(inputs["image_input_idx"])
imges_masks_lst.append(inputs["image_masks"])
batch_inputs = {}
batch_inputs['input_ids'] = torch.cat(input_ids_lst, dim=0)
batch_inputs['images'] = torch.cat(images_lst, dim=0)
batch_inputs['image_input_idx'] = torch.cat(images_input_idx_lst,
dim=0)
batch_inputs['image_masks'] = torch.cat(imges_masks_lst, dim=0)

outputs = self.model.generate_from_batch(
batch=self.wrap_device(batch_inputs,
device=self.model.device.type),
generation_config=GenerationConfig(
max_new_tokens=max_tokens,
stop_strings="<|endoftext|>",
do_sample=False,
),
tokenizer=self.tokenizer,
output_hidden_states=True,
return_dict_in_generate=True,
)

all_logprobs: List[List[Dict[int, float]]] = []
all_output_ids: List[List[int]] = []
all_output_strs: List[str] = []

for index in range(len(all_inputs)):
(
seq_logprobs_lst,
output_len,
) = self._hidden_states_to_logprobs(outputs.hidden_states,
num_logprobs)
all_logprobs.append(seq_logprobs_lst)
seq_ids = outputs.sequences[index]
output_ids = seq_ids[-output_len:]
all_output_ids.append(output_ids.tolist())
all_output_strs.append(self.tokenizer.decode(output_ids))
outputs = zip(all_output_ids, all_output_strs, all_logprobs)
return [(output_ids, output_str, output_logprobs)
for output_ids, output_str, output_logprobs in outputs]


####### Molmo-specific HuggingFace runner patchers
def mlomo_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
"""Patches and returns an instance of the HfRunner to use for Molmo."""
hf_processor = hf_model.processor

def _processor(*args, **kwargs):
return hf_processor.process(*args, **kwargs)

hf_model.processor = _processor
HfRunner.generate_greedy_logprobs_limit = _generate_greedy_logprobs_limit

Check failure on line 539 in tests/models/decoder_only/vision_language/vlm_utils/model_utils.py

View workflow job for this annotation

GitHub Actions / mypy (3.9)

Cannot assign to a method [method-assign]

Check failure on line 539 in tests/models/decoder_only/vision_language/vlm_utils/model_utils.py

View workflow job for this annotation

GitHub Actions / mypy (3.10)

Cannot assign to a method [method-assign]

Check failure on line 539 in tests/models/decoder_only/vision_language/vlm_utils/model_utils.py

View workflow job for this annotation

GitHub Actions / mypy (3.11)

Cannot assign to a method [method-assign]

Check failure on line 539 in tests/models/decoder_only/vision_language/vlm_utils/model_utils.py

View workflow job for this annotation

GitHub Actions / mypy (3.12)

Cannot assign to a method [method-assign]
DarkLight1337 marked this conversation as resolved.
Show resolved Hide resolved
return hf_model
56 changes: 17 additions & 39 deletions vllm/model_executor/models/molmo.py
Original file line number Diff line number Diff line change
Expand Up @@ -1081,45 +1081,25 @@ def input_processor_for_molmo(ctx: InputContext, inputs: DecoderOnlyInputs):
else:
out = processor.process(None, image, tokens=inputs["prompt_token_ids"])

image_processor = processor.image_processor
max_total_crops = 1 + image_processor.max_crops
if image is not None:
images, image_input_idx, image_masks = pad_images(
max_total_crops,
out["images"],
out["image_input_idx"],
out.get("image_masks"),
)
else:
base_image_input_size = image_processor.base_image_input_size
image_patch_size = image_processor.image_patch_size
image_num_patch = (
base_image_input_size[0] // image_patch_size,
base_image_input_size[1] // image_patch_size,
)
n_pixels = image_patch_size * image_patch_size * 3
n_patches = image_num_patch[0] * image_num_patch[1]

image_length_w = image_processor.image_token_length_w
image_length_h = image_processor.image_token_length_h
tokens_per_image = image_length_w * image_length_h
images = torch.full(
(max_total_crops, n_patches, n_pixels),
-1,
dtype=torch.float32,
)
image_input_idx = torch.full(
(max_total_crops, tokens_per_image),
-1,
dtype=torch.int32,
Comment on lines -1093 to -1114
Copy link
Member

@ywang96 ywang96 Jan 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wasn't sure why this was in the code originally when the AI2 team made the PR to support Molmo on vLLM, but I guess it wasn't an issue back then because it didn't matter on V0 since we didn't use the placeholder ranges for these "dummy" image input indices padded to the prompt token ids.

# If there is no image, return directly.
if image is None:
new_prompt_token_ids = out["input_ids"].tolist()
prompt = inputs.get("prompt")
if prompt is None:
prompt = tokenizer.decode(new_prompt_token_ids)
return token_inputs(
prompt_token_ids=new_prompt_token_ids,
prompt=prompt,
)
if image_processor.image_padding_mask:
image_masks = torch.full(
(max_total_crops, n_patches),
-1,
dtype=torch.float32,
)

image_processor = processor.image_processor
max_total_crops = 1 + image_processor.max_crops
images, image_input_idx, image_masks = pad_images(
max_total_crops,
out["images"],
out["image_input_idx"],
out.get("image_masks"),
)
image_data = dict(
images=images,
image_input_idx=image_input_idx,
Expand All @@ -1143,11 +1123,9 @@ def input_processor_for_molmo(ctx: InputContext, inputs: DecoderOnlyInputs):
offset = i
size += 1
image_data["image_start_end"] = (offset, offset + size)

prompt = inputs.get("prompt")
if prompt is None:
prompt = tokenizer.decode(new_prompt_token_ids)

return token_inputs(
prompt_token_ids=new_prompt_token_ids,
prompt=prompt,
Expand Down
Loading