-
-
Notifications
You must be signed in to change notification settings - Fork 5.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[V1][VLM] V1 support for selected single-image models. #11632
Changes from 12 commits
425d3c4
1ca9369
8edcc83
5f76291
814f3bd
efeb999
5e568e8
135fd5c
0a8dbe0
03f741d
8bce949
bbde414
ea928c6
55eada7
bbd5752
0452b99
938c0bf
6cc54a7
ba713ba
b0efc4f
cdbd969
48c6946
ea76759
bc976a7
f79f79a
45ec10c
0926717
0fe561d
3512ed6
5e0f66c
1c243ab
09d64f4
6d6d71c
ea93a2c
9aeb7b2
768c1d9
0c82c51
d0d1fdc
afcf7b1
cb9522d
df832df
868e8e9
cc9c5f1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,24 +1,27 @@ | ||
import math | ||
from typing import Iterable, List, Optional, Set, Tuple, TypedDict, Union | ||
from typing import (Iterable, List, Mapping, Optional, Set, Tuple, TypedDict, | ||
Union) | ||
|
||
import torch | ||
import torch.nn as nn | ||
from PIL import Image | ||
from torch.nn.init import trunc_normal_ | ||
from transformers import LlamaConfig | ||
|
||
from vllm.attention import AttentionMetadata | ||
from vllm.config import CacheConfig, QuantizationConfig, VllmConfig | ||
from vllm.distributed import get_tensor_model_parallel_rank | ||
from vllm.inputs import INPUT_REGISTRY, token_inputs | ||
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, | ||
InputContext, token_inputs) | ||
from vllm.model_executor.layers.activation import get_act_fn | ||
from vllm.model_executor.layers.fused_moe import FusedMoE | ||
from vllm.model_executor.layers.linear import (ColumnParallelLinear, | ||
RowParallelLinear) | ||
from vllm.model_executor.layers.logits_processor import LogitsProcessor | ||
from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( | ||
get_compressed_tensors_cache_scale) | ||
from vllm.model_executor.layers.sampler import (Sampler, SamplerOutput, | ||
SamplingMetadata) | ||
from vllm.model_executor.layers.sampler import (SamplerOutput, | ||
SamplingMetadata, get_sampler) | ||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead | ||
from vllm.model_executor.model_loader.weight_utils import ( | ||
default_weight_loader, maybe_remap_kv_scale_name) | ||
|
@@ -35,10 +38,12 @@ | |
from vllm.multimodal.image import cached_get_image_processor | ||
from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors | ||
from vllm.multimodal.utils import (cached_get_tokenizer, | ||
consecutive_placeholder_ranges, | ||
repeat_and_pad_placeholder_tokens) | ||
from vllm.sequence import IntermediateTensors | ||
from vllm.sequence import IntermediateTensors, SequenceData | ||
from vllm.transformers_utils.configs.aria import (AriaMoELMConfig, | ||
AriaVisionConfig) | ||
from vllm.utils import is_list_of | ||
|
||
from .utils import flatten_bn | ||
|
||
|
@@ -445,15 +450,74 @@ def build_mm_projector(config): | |
) | ||
|
||
|
||
def get_max_multimodal_tokens(ctx): | ||
return max(ctx.model_config.hf_config.image_size2tokens.values()) | ||
|
||
|
||
def input_mapper_for_aria(ctx, data): | ||
return MultiModalKwargs(data) | ||
|
||
|
||
def input_processor(ctx, llm_inputs): | ||
def get_aria_max_multimodal_tokens(ctx: InputContext): | ||
hf_config = ctx.get_hf_config() | ||
image_size2tokens = { | ||
int(math.sqrt(k) * hf_config.vision_config.patch_size): v | ||
for k, v in hf_config.projector_patch_to_query_dict.items() | ||
} | ||
return max(image_size2tokens.values()) | ||
|
||
|
||
def dummy_seq_data_for_aria(ctx: InputContext, seq_len: int, num_images: int): | ||
image_feature_size = get_aria_max_multimodal_tokens(ctx) | ||
hf_config = ctx.get_hf_config() | ||
return SequenceData.from_prompt_token_counts( | ||
(hf_config.image_token_index, image_feature_size * num_images), | ||
(0, seq_len - image_feature_size * num_images), | ||
), { | ||
"image": | ||
consecutive_placeholder_ranges(num_items=num_images, | ||
item_size=image_feature_size) | ||
} | ||
|
||
|
||
def dummy_image_for_aria( | ||
ctx: InputContext, | ||
num_images: int, | ||
): | ||
hf_config = ctx.get_hf_config() | ||
max_image_size = hf_config.vision_config.image_size | ||
image = Image.new("RGB", (max_image_size, max_image_size), color=0) | ||
images = [image] * num_images | ||
|
||
return {"image": images} | ||
|
||
|
||
def dummy_data_for_aria(ctx: InputContext, seq_len: int, | ||
mm_counts: Mapping[str, int]): | ||
num_images = mm_counts["image"] | ||
seq_data, ranges = dummy_seq_data_for_aria(ctx, seq_len, num_images) | ||
mm_data = dummy_image_for_aria(ctx, num_images) | ||
return DummyData(seq_data, mm_data, ranges) | ||
|
||
|
||
def input_mapper_for_aria(ctx: InputContext, data: object): | ||
data_list = data if isinstance(data, list) else [data] | ||
|
||
# For profiling with dummy image data | ||
if is_list_of(data_list, Image.Image): | ||
hf_config = ctx.get_hf_config() | ||
max_image_size = hf_config.vision_config.image_size | ||
model_config = ctx.model_config | ||
image_processor = cached_get_image_processor( | ||
model_config.model, | ||
trust_remote_code=model_config.trust_remote_code) | ||
image_inputs = image_processor.preprocess( | ||
data_list, | ||
max_image_size=max_image_size, | ||
split_image=False, | ||
return_tensors="pt").data | ||
image_inputs['pixel_values'] = image_inputs['pixel_values'].to( | ||
ctx.model_config.dtype) | ||
return MultiModalKwargs(image_inputs) | ||
|
||
# For actual inference when image has been processed with | ||
# prompt in input processor | ||
return MultiModalKwargs(data_list[0]) | ||
|
||
|
||
def input_processor_for_aria(ctx: InputContext, llm_inputs: DecoderOnlyInputs): | ||
multi_modal_data = llm_inputs.get("multi_modal_data") | ||
# if it is pure text input, use it as is | ||
if multi_modal_data is None or "image" not in multi_modal_data: | ||
|
@@ -494,9 +558,12 @@ def input_processor(ctx, llm_inputs): | |
repeat_count=num_crops, | ||
) | ||
|
||
repeat_count = [hf_config.image_size2tokens[max_image_size] | ||
] * sum(num_crops).item() | ||
new_prompt, new_token_ids, _ = repeat_and_pad_placeholder_tokens( | ||
image_size2tokens = { | ||
int(math.sqrt(k) * hf_config.vision_config.patch_size): v | ||
for k, v in hf_config.projector_patch_to_query_dict.items() | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Seems that this is a fixed value, perhaps we can move it to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yep I can do that There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I realized we actually don't need this calculation |
||
repeat_count = [image_size2tokens[max_image_size]] * sum(num_crops).item() | ||
new_prompt, new_token_ids, ranges = repeat_and_pad_placeholder_tokens( | ||
tokenizer, | ||
None, | ||
prompt_token_ids, | ||
|
@@ -508,12 +575,14 @@ def input_processor(ctx, llm_inputs): | |
prompt_token_ids=new_token_ids, | ||
prompt=new_prompt, | ||
multi_modal_data={"image": image_inputs}, | ||
multi_modal_placeholders={"image": ranges}, | ||
) | ||
|
||
|
||
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_multimodal_tokens) | ||
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_aria_max_multimodal_tokens) | ||
@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_aria) | ||
@INPUT_REGISTRY.register_input_processor(input_processor) | ||
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_aria) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The code for dummy data generation was entirely missing and I'm not sure why, so I added in this PR since it's required for V1. cc @xffxff who originally added this model |
||
@INPUT_REGISTRY.register_input_processor(input_processor_for_aria) | ||
class AriaForConditionalGeneration(nn.Module, SupportsMultiModal): | ||
""" | ||
Aria model for conditional generation tasks. | ||
|
@@ -540,12 +609,6 @@ def __init__( | |
config = vllm_config.model_config.hf_config | ||
quant_config = vllm_config.quant_config | ||
|
||
# prepare the image_size to tokens mapping for the image preprocess, see | ||
# input_processor | ||
config.image_size2tokens = { | ||
int(math.sqrt(k) * config.vision_config.patch_size): v | ||
for k, v in config.projector_patch_to_query_dict.items() | ||
} | ||
self.config = config | ||
self.vision_tower = AriaVisionModel(config.vision_config) | ||
self.multi_modal_projector = build_mm_projector(config) | ||
|
@@ -566,7 +629,7 @@ def __init__( | |
logit_scale = getattr(config, "logit_scale", 1.0) | ||
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, | ||
self.vocab_size, logit_scale) | ||
self.sampler = Sampler() | ||
self.sampler = get_sampler() | ||
|
||
def _validate_image_sizes( | ||
self, images: List[torch.Tensor]) -> List[torch.Tensor]: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Llava-next was already supported on V1 so this is just a doc update.