Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] Remove hardcoded image tokens ids from Pixtral #11582

Merged
merged 1 commit into from
Dec 28, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 13 additions & 14 deletions vllm/model_executor/models/pixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,6 @@
except ImportError:
USE_XFORMERS_OPS = False

# These token ids cannot be retrieved from model config
# so we hardcode them here.
PIXTRAL_12B_IMAGE_BREAK_ID = 12
PIXTRAL_12B_IMAGE_END_ID = 13
PIXTRAL_LARGE_IMAGE_BREAK_ID = 14
PIXTRAL_LARGE_IMAGE_END_ID = 15


def get_max_pixtral_image_tokens(ctx: InputContext):
tokenizer = cached_get_tokenizer(
Expand Down Expand Up @@ -201,6 +194,13 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
if key in dataclass_fields
}

if not ("image_break_token_id" in vision_args
and "image_end_token_id" in vision_args):
raise ValueError(
"'image_break_token_id' and 'image_end_token_id' not found "
"in the vision_encoder arguments. Please download the latest "
"version of 'params.json' from the model repository.")

self.vision_args = VisionEncoderArgs(**vision_args)

# init MistralForCausalLM
Expand Down Expand Up @@ -240,9 +240,8 @@ def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:

# NOTE: Image embeddings are split into separate tensors for each image
# by the indices of `[IMG_END]` token.
image_end_condition = (image_tokens == PIXTRAL_12B_IMAGE_END_ID) | (
image_tokens == PIXTRAL_LARGE_IMAGE_END_ID)
split_indices = torch.where(image_end_condition)[0] + 1
image_end_mask = image_tokens == self.vision_args.image_end_token_id
split_indices = torch.where(image_end_mask)[0] + 1
if len(split_indices) <= 1:
# Do not split, return as tensor of shape [1, fs, hs]
return image_embeds.unsqueeze(0)
Expand All @@ -265,10 +264,8 @@ def get_input_embeddings(
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings, [
self.vision_args.image_token_id,
PIXTRAL_12B_IMAGE_END_ID,
PIXTRAL_12B_IMAGE_BREAK_ID,
PIXTRAL_LARGE_IMAGE_BREAK_ID,
PIXTRAL_LARGE_IMAGE_END_ID,
self.vision_args.image_break_token_id,
self.vision_args.image_end_token_id,
])
return inputs_embeds

Expand Down Expand Up @@ -409,6 +406,8 @@ class VisionEncoderArgs:
num_attention_heads: int
rope_theta: float # for rope-2D
image_token_id: int
image_break_token_id: int
image_end_token_id: int
adapter_bias: bool = True


Expand Down
Loading