Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix][V1] Fix molmo text-only inputs #11676

Merged
merged 14 commits into from
Jan 6, 2025
56 changes: 17 additions & 39 deletions vllm/model_executor/models/molmo.py
Original file line number Diff line number Diff line change
Expand Up @@ -1081,45 +1081,25 @@ def input_processor_for_molmo(ctx: InputContext, inputs: DecoderOnlyInputs):
else:
out = processor.process(None, image, tokens=inputs["prompt_token_ids"])

image_processor = processor.image_processor
max_total_crops = 1 + image_processor.max_crops
if image is not None:
images, image_input_idx, image_masks = pad_images(
max_total_crops,
out["images"],
out["image_input_idx"],
out.get("image_masks"),
)
else:
base_image_input_size = image_processor.base_image_input_size
image_patch_size = image_processor.image_patch_size
image_num_patch = (
base_image_input_size[0] // image_patch_size,
base_image_input_size[1] // image_patch_size,
)
n_pixels = image_patch_size * image_patch_size * 3
n_patches = image_num_patch[0] * image_num_patch[1]

image_length_w = image_processor.image_token_length_w
image_length_h = image_processor.image_token_length_h
tokens_per_image = image_length_w * image_length_h
images = torch.full(
(max_total_crops, n_patches, n_pixels),
-1,
dtype=torch.float32,
)
image_input_idx = torch.full(
(max_total_crops, tokens_per_image),
-1,
dtype=torch.int32,
Comment on lines -1093 to -1114
Copy link
Member

@ywang96 ywang96 Jan 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wasn't sure why this was in the code originally when the AI2 team made the PR to support Molmo on vLLM, but I guess it wasn't an issue back then because it didn't matter on V0 since we didn't use the placeholder ranges for these "dummy" image input indices padded to the prompt token ids.

# If there is no image, return directly.
if image is None:
new_prompt_token_ids = out["input_ids"].tolist()
prompt = inputs.get("prompt")
if prompt is None:
prompt = tokenizer.decode(new_prompt_token_ids)
return token_inputs(
prompt_token_ids=new_prompt_token_ids,
prompt=prompt,
)
if image_processor.image_padding_mask:
image_masks = torch.full(
(max_total_crops, n_patches),
-1,
dtype=torch.float32,
)

image_processor = processor.image_processor
max_total_crops = 1 + image_processor.max_crops
images, image_input_idx, image_masks = pad_images(
max_total_crops,
out["images"],
out["image_input_idx"],
out.get("image_masks"),
)
image_data = dict(
images=images,
image_input_idx=image_input_idx,
Expand All @@ -1143,11 +1123,9 @@ def input_processor_for_molmo(ctx: InputContext, inputs: DecoderOnlyInputs):
offset = i
size += 1
image_data["image_start_end"] = (offset, offset + size)

prompt = inputs.get("prompt")
if prompt is None:
prompt = tokenizer.decode(new_prompt_token_ids)

return token_inputs(
prompt_token_ids=new_prompt_token_ids,
prompt=prompt,
Expand Down
Loading