Skip to content

Commit

Permalink
[Bugfix][V1] Fix molmo text-only inputs (#11676)
Browse files Browse the repository at this point in the history
Signed-off-by: Jee Jee Li <[email protected]>
  • Loading branch information
jeejeelee authored Jan 6, 2025
1 parent 4ca5d40 commit 32c9eff
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 42 deletions.
10 changes: 10 additions & 0 deletions tests/models/decoder_only/vision_language/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,16 @@
),
hf_output_post_proc=model_utils.minicpmv_trunc_hf_output,
),
"molmo": VLMTestInfo(
models=["allenai/Molmo-7B-D-0924"],
test_type=(VLMTestType.IMAGE),
prompt_formatter=lambda img_prompt:"User: " + img_prompt + " Assistant:", # noqa: E501
max_model_len=4096,
max_num_seqs=2,
image_size_factors=[(),(1.0, 1.0, 1.0)],
patch_hf_runner=model_utils.mlomo_patch_hf_runner,
postprocess_inputs=model_utils.molmo_post_processor,
),
# Tests for phi3v currently live in another file because of a bug in
# transformers. Once this issue is fixed, we can enable them here instead.
# https://github.com/huggingface/transformers/issues/34307
Expand Down
99 changes: 96 additions & 3 deletions tests/models/decoder_only/vision_language/vlm_utils/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,20 @@
import re
import types
from pathlib import PosixPath
from typing import Callable, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import torch
from PIL.Image import Image
from transformers import AutoConfig, AutoTokenizer, BatchEncoding
from transformers import (AutoConfig, AutoTokenizer, BatchEncoding,
GenerationConfig)

from vllm.sequence import SampleLogprobs
from vllm.transformers_utils.tokenizer import patch_padding_side
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE

from .....conftest import HfRunner, ImageAsset, _ImageAssets
from .....conftest import (HfRunner, ImageAsset, PromptAudioInput,
PromptImageInput, PromptVideoInput, _ImageAssets)
from ....utils import TokensTextLogprobs
from .types import RunnerOutput


Expand Down Expand Up @@ -222,6 +225,11 @@ def wrap_inputs_post_processor(hf_inputs: BatchEncoding, dtype: str):
return {"model_inputs": hf_inputs}


def molmo_post_processor(hf_inputs: BatchEncoding, dtype: str):
hf_inputs = cast_dtype_post_processor("images")(hf_inputs, dtype)
return {k: v.unsqueeze(0) for k, v in hf_inputs.items()}


####### Prompt path encoders for models that need models on disk
def qwen_prompt_path_encoder(
tmp_path: PosixPath, prompt: str, assets: Union[List[ImageAsset],
Expand Down Expand Up @@ -451,3 +459,88 @@ def _generate(self, *args, **kwargs):
hf_model.model.generate = types.MethodType(_generate, hf_model.model)

return hf_model


def _generate_greedy_logprobs_limit(
self,
prompts: List[str],
max_tokens: int,
num_logprobs: int,
images: Optional[PromptImageInput] = None,
audios: Optional[PromptAudioInput] = None,
videos: Optional[PromptVideoInput] = None,
**kwargs: Any,
) -> List[TokensTextLogprobs]:
all_inputs = self.get_inputs(prompts,
images=images,
videos=videos,
audios=audios)

# Process in batches for inference.
if len(all_inputs):
input_ids_lst = []
images_lst = []
images_input_idx_lst = []
imges_masks_lst = []
for inputs in all_inputs:
input_ids_lst.append(inputs["input_ids"])
images_lst.append(inputs["images"])
images_input_idx_lst.append(inputs["image_input_idx"])
imges_masks_lst.append(inputs["image_masks"])
batch_inputs = {}
batch_inputs['input_ids'] = torch.cat(input_ids_lst, dim=0)
batch_inputs['images'] = torch.cat(images_lst, dim=0)
batch_inputs['image_input_idx'] = torch.cat(images_input_idx_lst,
dim=0)
batch_inputs['image_masks'] = torch.cat(imges_masks_lst, dim=0)

outputs = self.model.generate_from_batch(
batch=self.wrap_device(batch_inputs,
device=self.model.device.type),
generation_config=GenerationConfig(
max_new_tokens=max_tokens,
stop_strings="<|endoftext|>",
do_sample=False,
),
tokenizer=self.tokenizer,
output_hidden_states=True,
return_dict_in_generate=True,
)

all_logprobs: List[List[Dict[int, float]]] = []
all_output_ids: List[List[int]] = []
all_output_strs: List[str] = []

for index in range(len(all_inputs)):
(
seq_logprobs_lst,
output_len,
) = self._hidden_states_to_logprobs(outputs.hidden_states,
num_logprobs)
all_logprobs.append(seq_logprobs_lst)
seq_ids = outputs.sequences[index]
output_ids = seq_ids[-output_len:]
all_output_ids.append(output_ids.tolist())
all_output_strs.append(self.tokenizer.decode(output_ids))
outputs = zip(all_output_ids, all_output_strs, all_logprobs)
return [(output_ids, output_str, output_logprobs)
for output_ids, output_str, output_logprobs in outputs]


####### Molmo-specific HuggingFace runner patchers
def mlomo_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
"""Patches and returns an instance of the HfRunner to use for Molmo."""
hf_processor = hf_model.processor

def _processor(*args, **kwargs):
return hf_processor.process(*args, **kwargs)

hf_model.processor = _processor

setattr( # noqa: B010
hf_model,
"generate_greedy_logprobs_limit",
types.MethodType(_generate_greedy_logprobs_limit, hf_model),
)

return hf_model
56 changes: 17 additions & 39 deletions vllm/model_executor/models/molmo.py
Original file line number Diff line number Diff line change
Expand Up @@ -1081,45 +1081,25 @@ def input_processor_for_molmo(ctx: InputContext, inputs: DecoderOnlyInputs):
else:
out = processor.process(None, image, tokens=inputs["prompt_token_ids"])

image_processor = processor.image_processor
max_total_crops = 1 + image_processor.max_crops
if image is not None:
images, image_input_idx, image_masks = pad_images(
max_total_crops,
out["images"],
out["image_input_idx"],
out.get("image_masks"),
)
else:
base_image_input_size = image_processor.base_image_input_size
image_patch_size = image_processor.image_patch_size
image_num_patch = (
base_image_input_size[0] // image_patch_size,
base_image_input_size[1] // image_patch_size,
)
n_pixels = image_patch_size * image_patch_size * 3
n_patches = image_num_patch[0] * image_num_patch[1]

image_length_w = image_processor.image_token_length_w
image_length_h = image_processor.image_token_length_h
tokens_per_image = image_length_w * image_length_h
images = torch.full(
(max_total_crops, n_patches, n_pixels),
-1,
dtype=torch.float32,
)
image_input_idx = torch.full(
(max_total_crops, tokens_per_image),
-1,
dtype=torch.int32,
# If there is no image, return directly.
if image is None:
new_prompt_token_ids = out["input_ids"].tolist()
prompt = inputs.get("prompt")
if prompt is None:
prompt = tokenizer.decode(new_prompt_token_ids)
return token_inputs(
prompt_token_ids=new_prompt_token_ids,
prompt=prompt,
)
if image_processor.image_padding_mask:
image_masks = torch.full(
(max_total_crops, n_patches),
-1,
dtype=torch.float32,
)

image_processor = processor.image_processor
max_total_crops = 1 + image_processor.max_crops
images, image_input_idx, image_masks = pad_images(
max_total_crops,
out["images"],
out["image_input_idx"],
out.get("image_masks"),
)
image_data = dict(
images=images,
image_input_idx=image_input_idx,
Expand All @@ -1143,11 +1123,9 @@ def input_processor_for_molmo(ctx: InputContext, inputs: DecoderOnlyInputs):
offset = i
size += 1
image_data["image_start_end"] = (offset, offset + size)

prompt = inputs.get("prompt")
if prompt is None:
prompt = tokenizer.decode(new_prompt_token_ids)

return token_inputs(
prompt_token_ids=new_prompt_token_ids,
prompt=prompt,
Expand Down

0 comments on commit 32c9eff

Please sign in to comment.