Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] Initialize support for Deepseek-VL2 models #11578

Merged
merged 52 commits into from
Jan 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
b7f3a3b
init deepseekvl2
Isotr0py Dec 27, 2024
9846268
port config
Isotr0py Dec 27, 2024
0fdc10b
code format
Isotr0py Dec 27, 2024
19cf5e7
process image
Isotr0py Dec 27, 2024
550ed2e
init processor
Isotr0py Dec 28, 2024
f2159c4
clean up
Isotr0py Dec 28, 2024
e20aba5
handle image embedding inputs
Isotr0py Dec 28, 2024
54c92fc
add multimodal processor
Isotr0py Dec 28, 2024
dd19a5d
add max tokens implement
Isotr0py Dec 28, 2024
391ba13
implement embeddings merge
Isotr0py Dec 28, 2024
bb88307
add deepseek-vl2 example
Isotr0py Dec 28, 2024
0ec661c
register model
Isotr0py Dec 28, 2024
847cb03
override example arch
Isotr0py Dec 28, 2024
bec7a43
fix processor
Isotr0py Dec 28, 2024
e417b98
Merge branch 'vllm-project:main' into deepseek-vl2
Isotr0py Dec 28, 2024
acc89f6
fix config name
Isotr0py Dec 28, 2024
b9f2d4b
fix processor dtype
Isotr0py Dec 28, 2024
632c77c
fix a typo
Isotr0py Dec 28, 2024
d97849d
fix vit
Isotr0py Dec 29, 2024
6fb3845
fix a typo
Isotr0py Dec 29, 2024
01a5316
add normal rope rotary
Isotr0py Dec 29, 2024
d5ebfcb
code format
Isotr0py Dec 30, 2024
d787200
fix image token
Isotr0py Dec 30, 2024
d491ff0
update docs
Isotr0py Dec 30, 2024
a3ddf41
Merge branch 'main' into deepseek-vl2
Isotr0py Dec 30, 2024
c5bbeff
update docs
Isotr0py Dec 30, 2024
0231368
add registry test
Isotr0py Dec 30, 2024
12f4553
Merge branch 'main' into deepseek-vl2
Isotr0py Jan 2, 2025
3237715
vision embeddings use nested tensors
Isotr0py Jan 2, 2025
6954a6d
update multimodal processor
Isotr0py Jan 2, 2025
a5dd2b9
Merge branch 'vllm-project:main' into deepseek-vl2
Isotr0py Jan 3, 2025
51645ec
update v1 docs
Isotr0py Jan 4, 2025
66b4126
format
Isotr0py Jan 4, 2025
5882f52
support multi-images input
Isotr0py Jan 4, 2025
79f2c4c
remove dead code
Isotr0py Jan 4, 2025
933afc7
test processor
Isotr0py Jan 4, 2025
c1ce202
add model tests
Isotr0py Jan 5, 2025
8c960fb
fix test
Isotr0py Jan 5, 2025
6557f5a
Merge remote-tracking branch 'upstream/main' into deepseek-vl2
Isotr0py Jan 9, 2025
aa150de
update docs
Isotr0py Jan 9, 2025
86c7fa9
update processor impl
Isotr0py Jan 9, 2025
e92f543
add get_image_size_with_most_features
Isotr0py Jan 9, 2025
32b1e9a
use fork deepseek-vl2 repo
Isotr0py Jan 10, 2025
91c5e2d
Merge branch 'vllm-project:main' into deepseek-vl2
Isotr0py Jan 10, 2025
37e1fd2
add normal rope to deepseek v3
Isotr0py Jan 10, 2025
53ac9e0
add check for num_nextn_predict_layers
Isotr0py Jan 10, 2025
6519a66
fix deepseek-v3 based model
Isotr0py Jan 10, 2025
f5ae01f
Merge remote-tracking branch 'upstream/main' into deepseek-vl2
Isotr0py Jan 11, 2025
92ab7fc
update docs
Isotr0py Jan 11, 2025
80a23ee
Update tests/models/registry.py
Isotr0py Jan 11, 2025
5a6d0d6
update todos
Isotr0py Jan 11, 2025
fc38412
add hf_overrides to initialize test
Isotr0py Jan 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ steps:
- tests/worker
- tests/standalone_tests/lazy_torch_compile.py
commands:
- pip install git+https://github.com/Isotr0py/DeepSeek-VL2.git # Used by multimoda processing test
- python3 standalone_tests/lazy_torch_compile.py
- pytest -v -s mq_llm_engine # MQLLMEngine
- pytest -v -s async_engine # AsyncLLMEngine
Expand Down
20 changes: 19 additions & 1 deletion docs/source/models/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,13 @@ See [this page](#generative-models) for more information on how to use generativ
-
- ✅︎
- ✅︎
* - `DeepseekVLV2ForCausalLM`
- DeepSeek-VL2
- T + I<sup>+</sup>
- `deepseek-ai/deepseek-vl2-tiny`(WIP), `deepseek-ai/deepseek-vl2-small`, `deepseek-ai/deepseek-vl2` etc. (see note)
-
- ✅︎
- ✅︎
* - `FuyuForCausalLM`
- Fuyu
- T + I
Expand Down Expand Up @@ -755,8 +762,19 @@ See [this page](#generative-models) for more information on how to use generativ
<sup>E</sup> Pre-computed embeddings can be inputted for this modality.
<sup>+</sup> Multiple items can be inputted per text prompt for this modality.

````{note}
The `deepseek-ai/deepseek-vl2-tiny` is not supported yet.

To use `DeepSeek-VL2` series models, you need to install a fork version `deepseek_vl2` package:
```shell
pip install git+https://github.com/Isotr0py/DeepSeek-VL2.git
```

Besides, to run `DeepSeek-VL2` series models, you have to pass `--hf_overrides '{"architectures": ["DeepseekVLV2ForCausalLM"]}'` when running vLLM.
````

```{note}
To use `TIGER-Lab/Mantis-8B-siglip-llama3`, you have pass `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'` when running vLLM.
To use `TIGER-Lab/Mantis-8B-siglip-llama3`, you have to pass `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'` when running vLLM.
```

```{note}
Expand Down
18 changes: 18 additions & 0 deletions examples/offline_inference/vision_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,23 @@ def run_chameleon(question: str, modality: str):
return llm, prompt, stop_token_ids


# Deepseek-VL2
def run_deepseek_vl2(question: str, modality: str):
assert modality == "image"

model_name = "deepseek-ai/deepseek-vl2-small"

llm = LLM(model=model_name,
max_model_len=4096,
max_num_seqs=2,
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
hf_overrides={"architectures": ["DeepseekVLV2ForCausalLM"]})

prompt = f"<|User|>: <image>\n{question}\n\n<|Assistant|>:"
stop_token_ids = None
return llm, prompt, stop_token_ids


# Fuyu
def run_fuyu(question: str, modality: str):
assert modality == "image"
Expand Down Expand Up @@ -498,6 +515,7 @@ def run_qwen2_vl(question: str, modality: str):
"aria": run_aria,
"blip-2": run_blip2,
"chameleon": run_chameleon,
"deepseek_vl_v2": run_deepseek_vl2,
"fuyu": run_fuyu,
"glm4v": run_glm4v,
"h2ovl_chat": run_h2ovl,
Expand Down
23 changes: 23 additions & 0 deletions examples/offline_inference/vision_language_multi_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,28 @@ def load_aria(question, image_urls: List[str]) -> ModelRequestData:
)


def load_deepseek_vl2(question: str, image_urls: List[str]):
model_name = "deepseek-ai/deepseek-vl2-small"

llm = LLM(model=model_name,
max_model_len=4096,
max_num_seqs=2,
hf_overrides={"architectures": ["DeepseekVLV2ForCausalLM"]},
limit_mm_per_prompt={"image": len(image_urls)})

placeholder = "".join(f"image_{i}:<image>\n"
for i, _ in enumerate(image_urls, start=1))
prompt = f"<|User|>: {placeholder}{question}\n\n<|Assistant|>:"

return ModelRequestData(
llm=llm,
prompt=prompt,
stop_token_ids=None,
image_data=[fetch_image(url) for url in image_urls],
chat_template=None,
)


def load_h2onvl(question: str, image_urls: List[str]) -> ModelRequestData:
model_name = "h2oai/h2ovl-mississippi-2b"

Expand Down Expand Up @@ -372,6 +394,7 @@ def load_qwen2_vl(question, image_urls: List[str]) -> ModelRequestData:

model_example_map = {
"aria": load_aria,
"deepseek_vl2": load_deepseek_vl2,
"h2ovl_chat": load_h2onvl,
"idefics3": load_idefics3,
"internvl_chat": load_internvl,
Expand Down
27 changes: 27 additions & 0 deletions tests/models/decoder_only/vision_language/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,33 @@
max_tokens=8,
dtype="bfloat16",
),
"deepseek_vl_v2": VLMTestInfo(
models=["deepseek-ai/deepseek-vl2-small"],
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
dtype="bfloat16",
prompt_formatter=lambda img_prompt: f"<|User|>: {img_prompt}\n\n<|Assistant|>: ", # noqa: E501
max_model_len=4096,
max_num_seqs=2,
single_image_prompts=IMAGE_ASSETS.prompts({
"stop_sign": "<image>\nWhat's the color of the stop sign and car?",
"cherry_blossom": "<image>\nWhat's the color of the tower?",
}),
multi_image_prompt="image_1:<image>\nimage_2:<image>\nDescribe the two images shortly.", # noqa: E501
vllm_runner_kwargs={"hf_overrides": {"architectures": ["DeepseekVLV2ForCausalLM"]}}, # noqa: E501
image_size_factors=[(0.10, 0.15)],
patch_hf_runner=model_utils.deepseekvl2_patch_hf_runner,
postprocess_inputs=model_utils.cast_dtype_post_processor("images"),
hf_output_post_proc=model_utils.deepseekvl2_trunc_hf_output,
stop_str=["<|end▁of▁sentence|>", "<|begin▁of▁sentence|>"], # noqa: E501
num_logprobs=5,
marks=[
pytest.mark.skipif(
not is_flash_attn_2_available(),
reason="Model needs flash-attn for numeric convergence.",
),
large_gpu_mark(min_gb=48),
],
),
"fuyu": VLMTestInfo(
models=["adept/fuyu-8b"],
test_type=VLMTestType.IMAGE,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,14 @@ def paligemma_vllm_to_hf_output(vllm_output: RunnerOutput,


####### Post-processors for HF outputs
def deepseekvl2_trunc_hf_output(hf_output: RunnerOutput,
model: str) -> RunnerOutput:
output_ids, output_str, out_logprobs = hf_output
if output_str.endswith("<|end▁of▁sentence|>"):
output_str = output_str.split("<|end▁of▁sentence|>")[0]
return output_ids, output_str, out_logprobs


def minicpmv_trunc_hf_output(hf_output: RunnerOutput,
model: str) -> RunnerOutput:
output_ids, output_str, out_logprobs = hf_output
Expand Down Expand Up @@ -261,6 +269,34 @@ def qwen_prompt_path_encoder(


####### Model-specific HuggingFace runner patchers
def deepseekvl2_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
"""Patches and returns an instance of the HfRunner to use for GLM4."""
hf_processor = hf_model.processor

def processor(*args, text="", images=None, **kwargs):
if isinstance(images, Image):
images = [images]
# inputs is a custom class instead of dict or BatchFeature
inputs = hf_processor(
*args,
prompt=text,
images=images,
**kwargs,
)
inputs = {
k: inputs[k]
for k in inputs.keys() # noqa
if k not in ("seq_lens", "sft_format")
}
inputs = BatchEncoding(data=inputs, tensor_type="pt")
return inputs

hf_model.processor = processor
hf_model.model.get_output_embeddings = lambda: \
hf_model.model.language.model.embed_tokens
return hf_model


def glm_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
"""Patches and returns an instance of the HfRunner to use for GLM4."""
hf_processor = hf_model.processor
Expand Down
2 changes: 2 additions & 0 deletions tests/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,8 @@ class _HfExamplesInfo:
trust_remote_code=True),
"ChatGLMForConditionalGeneration": _HfExamplesInfo("chatglm2-6b",
is_available_online=False),
# TODO(Isotr0py): Use deepseek-vl2-tiny for test after it's supported
"DeepseekVLV2ForCausalLM": _HfExamplesInfo("deepseek-ai/deepseek-vl2-small"), # noqa: E501
"FuyuForCausalLM": _HfExamplesInfo("adept/fuyu-8b"),
"H2OVLChatModel": _HfExamplesInfo("h2oai/h2ovl-mississippi-800m"),
"InternVLChatModel": _HfExamplesInfo("OpenGVLab/InternVL2-1B",
Expand Down
3 changes: 3 additions & 0 deletions tests/models/test_initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ def test_can_initialize(model_arch):

# Avoid OOM
def hf_overrides(hf_config: PretrainedConfig) -> PretrainedConfig:
if hf_config.model_type == "deepseek_vl_v2":
hf_config.update({"architectures": ["DeepseekVLV2ForCausalLM"]})

if hasattr(hf_config, "text_config"):
text_config: PretrainedConfig = hf_config.text_config
else:
Expand Down
4 changes: 2 additions & 2 deletions vllm/entrypoints/chat_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,8 +403,8 @@ def _placeholder_str(self, modality: ModalityStr,
if model_type.startswith("llava"):
return self._cached_token_str(self._tokenizer,
hf_config.image_token_index)
if model_type in ("chameleon", "internvl_chat", "NVLM_D",
"h2ovl_chat"):
if model_type in ("chameleon", "deepseek_vl_v2", "internvl_chat",
"NVLM_D", "h2ovl_chat"):
return "<image>"
if model_type == "mllama":
return "<|image|>"
Expand Down
18 changes: 17 additions & 1 deletion vllm/model_executor/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,11 @@ def __init__(
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj")
rope_scaling["rope_type"] = 'deepseek_yarn'
if rope_scaling:
rope_scaling["rope_type"] = 'deepseek_yarn'
self.use_normal_rope = False
else:
self.use_normal_rope = True
self.rotary_emb = get_rope(qk_rope_head_dim,
rotary_dim=qk_rope_head_dim,
max_position=max_position_embeddings,
Expand Down Expand Up @@ -298,7 +302,18 @@ def forward(
self.qk_nope_head_dim + self.v_head_dim)
k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
k_pe = latent_cache[:, :, self.kv_lora_rank:]

if self.use_normal_rope:
seq_len = positions.size(0)
ori_q_pe_shape, ori_k_pe_shape = q_pe.shape, k_pe.shape
q_pe = q_pe.reshape(seq_len, -1)
k_pe = k_pe.reshape(seq_len, -1)

q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)

if self.use_normal_rope:
q_pe, k_pe = q_pe.view(ori_q_pe_shape), k_pe.view(ori_k_pe_shape)

q[..., self.qk_nope_head_dim:] = q_pe
k = torch.empty_like(q)
k[..., :self.qk_nope_head_dim] = k_nope
Expand Down Expand Up @@ -355,6 +370,7 @@ def __init__(
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
)

if (config.n_routed_experts is not None
and layer_idx >= config.first_k_dense_replace
and layer_idx % config.moe_layer_freq == 0):
Expand Down
20 changes: 18 additions & 2 deletions vllm/model_executor/models/deepseek_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,11 @@ def __init__(
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj")
rope_scaling["rope_type"] = 'deepseek_yarn'
if rope_scaling:
rope_scaling["rope_type"] = 'deepseek_yarn'
self.use_normal_rope = False
else:
self.use_normal_rope = True
self.rotary_emb = get_rope(qk_rope_head_dim,
rotary_dim=qk_rope_head_dim,
max_position=max_position_embeddings,
Expand Down Expand Up @@ -306,7 +310,18 @@ def forward(
self.qk_nope_head_dim + self.v_head_dim)
k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
k_pe = latent_cache[:, :, self.kv_lora_rank:]

if self.use_normal_rope:
seq_len = positions.size(0)
ori_q_pe_shape, ori_k_pe_shape = q_pe.shape, k_pe.shape
q_pe = q_pe.reshape(seq_len, -1)
k_pe = k_pe.reshape(seq_len, -1)

q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)

if self.use_normal_rope:
q_pe, k_pe = q_pe.view(ori_q_pe_shape), k_pe.view(ori_k_pe_shape)

q[..., self.qk_nope_head_dim:] = q_pe
k = torch.empty_like(q)
k[..., :self.qk_nope_head_dim] = k_nope
Expand Down Expand Up @@ -583,7 +598,8 @@ def load_weights(self, weights: Iterable[Tuple[str,
continue

# TODO(simon): support nextn predict layers
if self.config.num_nextn_predict_layers > 0:
if hasattr(self.config, "num_nextn_predict_layers"
) and self.config.num_nextn_predict_layers > 0:
assert self.config.num_nextn_predict_layers == 1
layer_idx = self.config.num_hidden_layers
if name.startswith(f"model.layers.{layer_idx}"):
Expand Down
Loading
Loading