From f7001d8deb29cf7798eaddd248bac58155c1ff6a Mon Sep 17 00:00:00 2001
From: Travis Johnson <tsjohnso@us.ibm.com>
Date: Fri, 10 Jan 2025 16:26:00 -0700
Subject: [PATCH] [Bugfix] Check that number of images matches number of
 <|image|> tokens with mllama (#11939)

Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com>
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
---
 vllm/model_executor/models/mllama.py | 9 +++++++++
 1 file changed, 9 insertions(+)

diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py
index c5046e06edecb..593a4d3fb6940 100644
--- a/vllm/model_executor/models/mllama.py
+++ b/vllm/model_executor/models/mllama.py
@@ -123,6 +123,13 @@ def input_processor_for_mllama(
 
     assert is_list_of(image_data, Image.Image)
 
+    num_image_tokens = dec_inputs['prompt_token_ids'].count(
+        MLLAMA_IMAGE_TOKEN_ID)
+    if num_image_tokens != len(image_data):
+        raise ValueError(
+            f"The number of image tokens ({num_image_tokens}) must be"
+            f" the same as the number of images ({len(image_data)})")
+
     # Since only the last group of consecutive images
     # are attended by the decoded tokens, we only need to
     # get the number of tiles for those images.
@@ -1493,6 +1500,8 @@ def convert_sparse_cross_attention_mask_to_dense(
             dense_mask[seq_start + start:seq_start + end,
                        tile_start:tile_start + tile] = 1
             tile_start += tile
+        assert ts != -1
+        assert td != 0
         tile_range_for_decode.append((ts, ts + td))
         seq_start += length