Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] Add cogagent model support vLLM #11742

Merged
merged 4 commits into from
Jan 11, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions vllm/model_executor/models/chatglm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Adapted from
# https://github.com/THUDM/GLM-4
"""Inference-only ChatGLM model compatible with THUDM weights."""
# https://github.com/THUDM/CogAgent
"""Inference-only CogAgent model compatible with THUDM weights."""
from argparse import Namespace
from array import array
from typing import (Dict, Iterable, List, Mapping, Optional, Set, Tuple,
Expand Down Expand Up @@ -201,7 +201,6 @@ def input_processor_for_glmv(ctx: InputContext, inputs: DecoderOnlyInputs):

new_input_ids = []
final_processed_position = 0
final_processed_position = 0

for boi_position, eoi_position in zip(boi_positions, eoi_positions):
assert boi_position < eoi_position
Expand Down Expand Up @@ -275,12 +274,15 @@ def __init__(
# https://huggingface.co/THUDM/chatglm3-6b-32k/blob/e210410255278dd9d74463cf396ba559c0ef801c/modeling_chatglm.py#L141
rope_ratio = getattr(config, "rope_ratio", 1.0)
max_positions = getattr(config, "seq_length", 8192)
# NOTE: THUDM/cogagent-9b-20241220 uses original_rope=False,
# which is equivalent to is_neox_style=True
is_neox_style = not config.original_rope
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim // 2,
max_position=max_positions,
base=10000 * rope_ratio,
is_neox_style=False,
is_neox_style=is_neox_style,
)
self.attn = Attention(self.num_heads,
self.head_dim,
Expand Down Expand Up @@ -779,4 +781,4 @@ def __new__(
return ChatGLMV(vllm_config=vllm_config, prefix=prefix)
# Initialize LLM
else:
return ChatGLM(vllm_config=vllm_config, prefix=prefix)
return ChatGLM(vllm_config=vllm_config, prefix=prefix)
Loading