Skip to content

Commit

Permalink
[Model] Initialize support for Deepseek-VL2 models (#11578)
Browse files Browse the repository at this point in the history
Signed-off-by: Isotr0py <[email protected]>
Co-authored-by: Cyrus Leung <[email protected]>
  • Loading branch information
Isotr0py and DarkLight1337 authored Jan 12, 2025
1 parent 43f3d9e commit f967e51
Show file tree
Hide file tree
Showing 17 changed files with 1,050 additions and 9 deletions.
1 change: 1 addition & 0 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ steps:
- tests/worker
- tests/standalone_tests/lazy_torch_compile.py
commands:
- pip install git+https://github.com/Isotr0py/DeepSeek-VL2.git # Used by multimoda processing test
- python3 standalone_tests/lazy_torch_compile.py
- pytest -v -s mq_llm_engine # MQLLMEngine
- pytest -v -s async_engine # AsyncLLMEngine
Expand Down
20 changes: 19 additions & 1 deletion docs/source/models/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,13 @@ See [this page](#generative-models) for more information on how to use generativ
-
- ✅︎
- ✅︎
* - `DeepseekVLV2ForCausalLM`
- DeepSeek-VL2
- T + I<sup>+</sup>
- `deepseek-ai/deepseek-vl2-tiny`(WIP), `deepseek-ai/deepseek-vl2-small`, `deepseek-ai/deepseek-vl2` etc. (see note)
-
- ✅︎
- ✅︎
* - `FuyuForCausalLM`
- Fuyu
- T + I
Expand Down Expand Up @@ -755,8 +762,19 @@ See [this page](#generative-models) for more information on how to use generativ
<sup>E</sup> Pre-computed embeddings can be inputted for this modality.
<sup>+</sup> Multiple items can be inputted per text prompt for this modality.

````{note}
The `deepseek-ai/deepseek-vl2-tiny` is not supported yet.
To use `DeepSeek-VL2` series models, you need to install a fork version `deepseek_vl2` package:
```shell
pip install git+https://github.com/Isotr0py/DeepSeek-VL2.git
```
Besides, to run `DeepSeek-VL2` series models, you have to pass `--hf_overrides '{"architectures": ["DeepseekVLV2ForCausalLM"]}'` when running vLLM.
````

```{note}
To use `TIGER-Lab/Mantis-8B-siglip-llama3`, you have pass `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'` when running vLLM.
To use `TIGER-Lab/Mantis-8B-siglip-llama3`, you have to pass `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'` when running vLLM.
```

```{note}
Expand Down
18 changes: 18 additions & 0 deletions examples/offline_inference/vision_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,23 @@ def run_chameleon(question: str, modality: str):
return llm, prompt, stop_token_ids


# Deepseek-VL2
def run_deepseek_vl2(question: str, modality: str):
assert modality == "image"

model_name = "deepseek-ai/deepseek-vl2-small"

llm = LLM(model=model_name,
max_model_len=4096,
max_num_seqs=2,
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
hf_overrides={"architectures": ["DeepseekVLV2ForCausalLM"]})

prompt = f"<|User|>: <image>\n{question}\n\n<|Assistant|>:"
stop_token_ids = None
return llm, prompt, stop_token_ids


# Fuyu
def run_fuyu(question: str, modality: str):
assert modality == "image"
Expand Down Expand Up @@ -498,6 +515,7 @@ def run_qwen2_vl(question: str, modality: str):
"aria": run_aria,
"blip-2": run_blip2,
"chameleon": run_chameleon,
"deepseek_vl_v2": run_deepseek_vl2,
"fuyu": run_fuyu,
"glm4v": run_glm4v,
"h2ovl_chat": run_h2ovl,
Expand Down
23 changes: 23 additions & 0 deletions examples/offline_inference/vision_language_multi_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,28 @@ def load_aria(question, image_urls: List[str]) -> ModelRequestData:
)


def load_deepseek_vl2(question: str, image_urls: List[str]):
model_name = "deepseek-ai/deepseek-vl2-small"

llm = LLM(model=model_name,
max_model_len=4096,
max_num_seqs=2,
hf_overrides={"architectures": ["DeepseekVLV2ForCausalLM"]},
limit_mm_per_prompt={"image": len(image_urls)})

placeholder = "".join(f"image_{i}:<image>\n"
for i, _ in enumerate(image_urls, start=1))
prompt = f"<|User|>: {placeholder}{question}\n\n<|Assistant|>:"

return ModelRequestData(
llm=llm,
prompt=prompt,
stop_token_ids=None,
image_data=[fetch_image(url) for url in image_urls],
chat_template=None,
)


def load_h2onvl(question: str, image_urls: List[str]) -> ModelRequestData:
model_name = "h2oai/h2ovl-mississippi-2b"

Expand Down Expand Up @@ -372,6 +394,7 @@ def load_qwen2_vl(question, image_urls: List[str]) -> ModelRequestData:

model_example_map = {
"aria": load_aria,
"deepseek_vl2": load_deepseek_vl2,
"h2ovl_chat": load_h2onvl,
"idefics3": load_idefics3,
"internvl_chat": load_internvl,
Expand Down
27 changes: 27 additions & 0 deletions tests/models/decoder_only/vision_language/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,33 @@
max_tokens=8,
dtype="bfloat16",
),
"deepseek_vl_v2": VLMTestInfo(
models=["deepseek-ai/deepseek-vl2-small"],
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
dtype="bfloat16",
prompt_formatter=lambda img_prompt: f"<|User|>: {img_prompt}\n\n<|Assistant|>: ", # noqa: E501
max_model_len=4096,
max_num_seqs=2,
single_image_prompts=IMAGE_ASSETS.prompts({
"stop_sign": "<image>\nWhat's the color of the stop sign and car?",
"cherry_blossom": "<image>\nWhat's the color of the tower?",
}),
multi_image_prompt="image_1:<image>\nimage_2:<image>\nDescribe the two images shortly.", # noqa: E501
vllm_runner_kwargs={"hf_overrides": {"architectures": ["DeepseekVLV2ForCausalLM"]}}, # noqa: E501
image_size_factors=[(0.10, 0.15)],
patch_hf_runner=model_utils.deepseekvl2_patch_hf_runner,
postprocess_inputs=model_utils.cast_dtype_post_processor("images"),
hf_output_post_proc=model_utils.deepseekvl2_trunc_hf_output,
stop_str=["<|end▁of▁sentence|>", "<|begin▁of▁sentence|>"], # noqa: E501
num_logprobs=5,
marks=[
pytest.mark.skipif(
not is_flash_attn_2_available(),
reason="Model needs flash-attn for numeric convergence.",
),
large_gpu_mark(min_gb=48),
],
),
"fuyu": VLMTestInfo(
models=["adept/fuyu-8b"],
test_type=VLMTestType.IMAGE,
Expand Down
36 changes: 36 additions & 0 deletions tests/models/decoder_only/vision_language/vlm_utils/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,14 @@ def paligemma_vllm_to_hf_output(vllm_output: RunnerOutput,


####### Post-processors for HF outputs
def deepseekvl2_trunc_hf_output(hf_output: RunnerOutput,
model: str) -> RunnerOutput:
output_ids, output_str, out_logprobs = hf_output
if output_str.endswith("<|end▁of▁sentence|>"):
output_str = output_str.split("<|end▁of▁sentence|>")[0]
return output_ids, output_str, out_logprobs


def minicpmv_trunc_hf_output(hf_output: RunnerOutput,
model: str) -> RunnerOutput:
output_ids, output_str, out_logprobs = hf_output
Expand Down Expand Up @@ -261,6 +269,34 @@ def qwen_prompt_path_encoder(


####### Model-specific HuggingFace runner patchers
def deepseekvl2_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
"""Patches and returns an instance of the HfRunner to use for GLM4."""
hf_processor = hf_model.processor

def processor(*args, text="", images=None, **kwargs):
if isinstance(images, Image):
images = [images]
# inputs is a custom class instead of dict or BatchFeature
inputs = hf_processor(
*args,
prompt=text,
images=images,
**kwargs,
)
inputs = {
k: inputs[k]
for k in inputs.keys() # noqa
if k not in ("seq_lens", "sft_format")
}
inputs = BatchEncoding(data=inputs, tensor_type="pt")
return inputs

hf_model.processor = processor
hf_model.model.get_output_embeddings = lambda: \
hf_model.model.language.model.embed_tokens
return hf_model


def glm_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
"""Patches and returns an instance of the HfRunner to use for GLM4."""
hf_processor = hf_model.processor
Expand Down
2 changes: 2 additions & 0 deletions tests/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,8 @@ class _HfExamplesInfo:
trust_remote_code=True),
"ChatGLMForConditionalGeneration": _HfExamplesInfo("chatglm2-6b",
is_available_online=False),
# TODO(Isotr0py): Use deepseek-vl2-tiny for test after it's supported
"DeepseekVLV2ForCausalLM": _HfExamplesInfo("deepseek-ai/deepseek-vl2-small"), # noqa: E501
"FuyuForCausalLM": _HfExamplesInfo("adept/fuyu-8b"),
"H2OVLChatModel": _HfExamplesInfo("h2oai/h2ovl-mississippi-800m"),
"InternVLChatModel": _HfExamplesInfo("OpenGVLab/InternVL2-1B",
Expand Down
3 changes: 3 additions & 0 deletions tests/models/test_initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ def test_can_initialize(model_arch):

# Avoid OOM
def hf_overrides(hf_config: PretrainedConfig) -> PretrainedConfig:
if hf_config.model_type == "deepseek_vl_v2":
hf_config.update({"architectures": ["DeepseekVLV2ForCausalLM"]})

if hasattr(hf_config, "text_config"):
text_config: PretrainedConfig = hf_config.text_config
else:
Expand Down
4 changes: 2 additions & 2 deletions vllm/entrypoints/chat_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,8 +403,8 @@ def _placeholder_str(self, modality: ModalityStr,
if model_type.startswith("llava"):
return self._cached_token_str(self._tokenizer,
hf_config.image_token_index)
if model_type in ("chameleon", "internvl_chat", "NVLM_D",
"h2ovl_chat"):
if model_type in ("chameleon", "deepseek_vl_v2", "internvl_chat",
"NVLM_D", "h2ovl_chat"):
return "<image>"
if model_type == "mllama":
return "<|image|>"
Expand Down
18 changes: 17 additions & 1 deletion vllm/model_executor/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,11 @@ def __init__(
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj")
rope_scaling["rope_type"] = 'deepseek_yarn'
if rope_scaling:
rope_scaling["rope_type"] = 'deepseek_yarn'
self.use_normal_rope = False
else:
self.use_normal_rope = True
self.rotary_emb = get_rope(qk_rope_head_dim,
rotary_dim=qk_rope_head_dim,
max_position=max_position_embeddings,
Expand Down Expand Up @@ -298,7 +302,18 @@ def forward(
self.qk_nope_head_dim + self.v_head_dim)
k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
k_pe = latent_cache[:, :, self.kv_lora_rank:]

if self.use_normal_rope:
seq_len = positions.size(0)
ori_q_pe_shape, ori_k_pe_shape = q_pe.shape, k_pe.shape
q_pe = q_pe.reshape(seq_len, -1)
k_pe = k_pe.reshape(seq_len, -1)

q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)

if self.use_normal_rope:
q_pe, k_pe = q_pe.view(ori_q_pe_shape), k_pe.view(ori_k_pe_shape)

q[..., self.qk_nope_head_dim:] = q_pe
k = torch.empty_like(q)
k[..., :self.qk_nope_head_dim] = k_nope
Expand Down Expand Up @@ -355,6 +370,7 @@ def __init__(
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
)

if (config.n_routed_experts is not None
and layer_idx >= config.first_k_dense_replace
and layer_idx % config.moe_layer_freq == 0):
Expand Down
20 changes: 18 additions & 2 deletions vllm/model_executor/models/deepseek_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,11 @@ def __init__(
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj")
rope_scaling["rope_type"] = 'deepseek_yarn'
if rope_scaling:
rope_scaling["rope_type"] = 'deepseek_yarn'
self.use_normal_rope = False
else:
self.use_normal_rope = True
self.rotary_emb = get_rope(qk_rope_head_dim,
rotary_dim=qk_rope_head_dim,
max_position=max_position_embeddings,
Expand Down Expand Up @@ -306,7 +310,18 @@ def forward(
self.qk_nope_head_dim + self.v_head_dim)
k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
k_pe = latent_cache[:, :, self.kv_lora_rank:]

if self.use_normal_rope:
seq_len = positions.size(0)
ori_q_pe_shape, ori_k_pe_shape = q_pe.shape, k_pe.shape
q_pe = q_pe.reshape(seq_len, -1)
k_pe = k_pe.reshape(seq_len, -1)

q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)

if self.use_normal_rope:
q_pe, k_pe = q_pe.view(ori_q_pe_shape), k_pe.view(ori_k_pe_shape)

q[..., self.qk_nope_head_dim:] = q_pe
k = torch.empty_like(q)
k[..., :self.qk_nope_head_dim] = k_nope
Expand Down Expand Up @@ -583,7 +598,8 @@ def load_weights(self, weights: Iterable[Tuple[str,
continue

# TODO(simon): support nextn predict layers
if self.config.num_nextn_predict_layers > 0:
if hasattr(self.config, "num_nextn_predict_layers"
) and self.config.num_nextn_predict_layers > 0:
assert self.config.num_nextn_predict_layers == 1
layer_idx = self.config.num_hidden_layers
if name.startswith(f"model.layers.{layer_idx}"):
Expand Down
Loading

0 comments on commit f967e51

Please sign in to comment.