Skip to content

Commit

Permalink
[V1][VLM] V1 support for selected single-image models. (#11632)
Browse files Browse the repository at this point in the history
Signed-off-by: Roger Wang <[email protected]>
Signed-off-by: DarkLight1337 <[email protected]>
Signed-off-by: Isotr0py <[email protected]>
Co-authored-by: DarkLight1337 <[email protected]>
Co-authored-by: Isotr0py <[email protected]>
  • Loading branch information
3 people authored Dec 31, 2024
1 parent 8c3230d commit e7c7c5e
Show file tree
Hide file tree
Showing 19 changed files with 590 additions and 636 deletions.
10 changes: 5 additions & 5 deletions docs/source/models/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -570,28 +570,28 @@ See [this page](#generative-models) for more information on how to use generativ
- `rhymes-ai/Aria`
-
- ✅︎
-
- ✅︎
* - `Blip2ForConditionalGeneration`
- BLIP-2
- T + I<sup>E</sup>
- `Salesforce/blip2-opt-2.7b`, `Salesforce/blip2-opt-6.7b`, etc.
-
- ✅︎
-
- ✅︎
* - `ChameleonForConditionalGeneration`
- Chameleon
- T + I
- `facebook/chameleon-7b` etc.
-
- ✅︎
-
- ✅︎
* - `FuyuForCausalLM`
- Fuyu
- T + I
- `adept/fuyu-8b` etc.
-
- ✅︎
-
- ✅︎
* - `ChatGLMModel`
- GLM-4V
- T + I
Expand Down Expand Up @@ -633,7 +633,7 @@ See [this page](#generative-models) for more information on how to use generativ
- `llava-hf/llava-v1.6-mistral-7b-hf`, `llava-hf/llava-v1.6-vicuna-7b-hf`, etc.
-
- ✅︎
-
- ✅︎
* - `LlavaNextVideoForConditionalGeneration`
- LLaVA-NeXT-Video
- T + V
Expand Down
10 changes: 8 additions & 2 deletions examples/offline_inference_vision_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,13 @@ def run_aria(question: str, modality: str):
assert modality == "image"
model_name = "rhymes-ai/Aria"

# NOTE: Need L40 (or equivalent) to avoid OOM
llm = LLM(model=model_name,
tokenizer_mode="slow",
trust_remote_code=True,
dtype="bfloat16",
max_model_len=4096,
max_num_seqs=2,
trust_remote_code=True,
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)

prompt = (f"<|im_start|>user\n<fim_prefix><|img|><fim_suffix>\n{question}"
Expand Down Expand Up @@ -57,6 +60,7 @@ def run_chameleon(question: str, modality: str):
prompt = f"{question}<image>"
llm = LLM(model="facebook/chameleon-7b",
max_model_len=4096,
max_num_seqs=2,
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
stop_token_ids = None
return llm, prompt, stop_token_ids
Expand Down Expand Up @@ -257,7 +261,7 @@ def run_minicpmv(question: str, modality: str):
# 2.5
# model_name = "openbmb/MiniCPM-Llama3-V-2_5"

#2.6
# 2.6
model_name = "openbmb/MiniCPM-V-2_6"
tokenizer = AutoTokenizer.from_pretrained(model_name,
trust_remote_code=True)
Expand Down Expand Up @@ -430,9 +434,11 @@ def run_pixtral_hf(question: str, modality: str):

model_name = "mistral-community/pixtral-12b"

# NOTE: Need L40 (or equivalent) to avoid OOM
llm = LLM(
model=model_name,
max_model_len=8192,
max_num_seqs=2,
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
)

Expand Down
7 changes: 2 additions & 5 deletions tests/models/decoder_only/vision_language/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,10 +140,7 @@
"aria": VLMTestInfo(
models=["rhymes-ai/Aria"],
tokenizer_mode="slow",
test_type=(
VLMTestType.IMAGE,
VLMTestType.MULTI_IMAGE,
),
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
dtype="bfloat16",
prompt_formatter=lambda img_prompt: f"<|im_start|>user\n{img_prompt}<|im_end|>\n<|im_start|>assistant\n ", # noqa: E501
img_idx_to_prompt=lambda idx: "<fim_prefix><|img|><fim_suffix>\n",
Expand Down Expand Up @@ -179,6 +176,7 @@
test_type=VLMTestType.IMAGE,
prompt_formatter=lambda img_prompt: f"USER: {img_prompt}\nASSISTANT:",
max_model_len=4096,
max_num_seqs=2,
auto_cls=AutoModelForVision2Seq,
postprocess_inputs=model_utils.cast_dtype_post_processor(
"pixel_values"
Expand All @@ -201,7 +199,6 @@
vllm_output_post_proc=model_utils.fuyu_vllm_to_hf_output,
num_logprobs=10,
image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)],
marks=[large_gpu_mark(min_gb=48)],
),
"glm4": VLMTestInfo(
models=["THUDM/glm-4v-9b"],
Expand Down
29 changes: 16 additions & 13 deletions tests/multimodal/test_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,7 @@ def _rand_audio(

def _test_processing_cache_correctness(
model_id: str,
modalities: set[str],
modalities: dict[str, bool],
hit_rate: float,
num_batches: int,
simplify_rate: float,
Expand Down Expand Up @@ -583,9 +583,8 @@ def _test_processing_cache_correctness(
partial(_rand_audio, rng, min_len=256, max_len=512, sr=16000),
}
input_max_count = {
"image": 3,
"video": 3,
"audio": 3,
modality: 3 if supports_multi else 1
for modality, supports_multi in modalities.items()
}

for batch_idx in range(num_batches):
Expand Down Expand Up @@ -624,20 +623,24 @@ def _test_processing_cache_correctness(

# yapf: disable
@pytest.mark.parametrize(("model_id", "modalities"), [
("llava-hf/llava-1.5-7b-hf", {"image"}),
("TIGER-Lab/Mantis-8B-siglip-llama3", {"image"}),
("mistral-community/pixtral-12b", {"image"}),
("Qwen/Qwen2-VL-2B-Instruct", {"image", "video"}),
("Qwen/Qwen2-Audio-7B-Instruct", {"audio"}),
("fixie-ai/ultravox-v0_3", {"audio"}),
("rhymes-ai/Aria", {"image": True}),
("Salesforce/blip2-opt-2.7b", {"image": False}),
("facebook/chameleon-7b", {"image": True}),
("adept/fuyu-8b", {"image": False}),
("llava-hf/llava-1.5-7b-hf", {"image": True}),
("TIGER-Lab/Mantis-8B-siglip-llama3", {"image": True}),
("mistral-community/pixtral-12b", {"image": True}),
("Qwen/Qwen2-VL-2B-Instruct", {"image": True, "video": True}),
("Qwen/Qwen2-Audio-7B-Instruct", {"audio": True}),
("fixie-ai/ultravox-v0_3", {"audio": True}),
])
@pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0])
@pytest.mark.parametrize("num_batches", [32])
@pytest.mark.parametrize("simplify_rate", [1.0])
# yapf: enable
def test_processing_cache_correctness(
model_id: str,
modalities: set[str],
modalities: dict[str, bool],
hit_rate: float,
num_batches: int,
simplify_rate: float,
Expand All @@ -653,15 +656,15 @@ def test_processing_cache_correctness(

# yapf: disable
@pytest.mark.parametrize(("model_id", "modalities"), [
("microsoft/Phi-3-vision-128k-instruct", {"image"}),
("microsoft/Phi-3-vision-128k-instruct", {"image": True}),
])
@pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0])
@pytest.mark.parametrize("num_batches", [32])
@pytest.mark.parametrize("simplify_rate", [1.0])
# yapf: enable
def test_processing_cache_correctness_phi3v(
model_id: str,
modalities: set[str],
modalities: dict[str, bool],
hit_rate: float,
num_batches: int,
simplify_rate: float,
Expand Down
Loading

0 comments on commit e7c7c5e

Please sign in to comment.