Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Misc]Add BNB quantization for Qwen2VL #11719

Merged
merged 2 commits into from
Jan 3, 2025
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 28 additions & 21 deletions vllm/model_executor/models/qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,15 +267,30 @@ def forward(
cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor,
) -> torch.Tensor:
# [s, b, c] --> [s, b, head * 3 * head_dim]

s, b = x.size()[:-1]
# [s, b, c] --> [s, b,3 * head * head_dim]
x, _ = self.qkv(x)

# [s, b, head * 3 * head_dim] --> [s, b, head, 3 * head_dim]
new_x_shape = x.size()[:-1] + (
3,
self.num_attention_heads_per_partition,
3 * self.hidden_size_per_attention_head,
self.hidden_size_per_attention_head,
)
x = x.view(*new_x_shape)
# NOTE: A way to ensure the loaded weights and bias remains unchanged.
# Although this introduces reshape operations that impact performance,
# reshaping the weights would not be conducive to supporting like BNB.
# [s, b, 3, head, head_dim] --> [s, b, head, 3, head_dim]
x = x.transpose(2, 3)
# [s, b, head, 3, head_dim] --> [s, b, head, 3 * head_dim]
new_x_shape = (
s,
b,
self.num_attention_heads_per_partition,
3 * self.hidden_size_per_attention_head,
)
x = x.reshape(*new_x_shape)

# [s, b, head, 3 * head_dim] --> 3 [s, b, head, head_dim]
q, k, v = dist_utils.split_tensor_along_last_dim(x, 3)
Expand Down Expand Up @@ -614,24 +629,6 @@ def load_weights(self, weights: Iterable[Tuple[str,
weight_loader(param, loaded_weight, shard_id)
break
else:
if name.endswith("qkv.weight"):
visual_num_heads = self.num_heads
visual_embed_dim = self.embed_dim
head_size = visual_embed_dim // visual_num_heads
loaded_weight = loaded_weight.view(3, visual_num_heads,
head_size,
Isotr0py marked this conversation as resolved.
Show resolved Hide resolved
visual_embed_dim)
loaded_weight = loaded_weight.transpose(0, 1)
loaded_weight = loaded_weight.reshape(-1, visual_embed_dim)
elif name.endswith("qkv.bias"):
visual_num_heads = self.num_heads
visual_embed_dim = self.embed_dim
head_size = visual_embed_dim // visual_num_heads
loaded_weight = loaded_weight.view(3, visual_num_heads,
head_size)
loaded_weight = loaded_weight.transpose(0, 1)
loaded_weight = loaded_weight.reshape(-1)

param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
Expand Down Expand Up @@ -935,6 +932,16 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
embedding_modules = {}
embedding_padding_modules = []

# BitandBytes specific attributes
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
"k_proj": ("qkv_proj", 1),
"v_proj": ("qkv_proj", 2),
"gate_proj": ("gate_up_proj", 0),
"up_proj": ("gate_up_proj", 1),
}

# To ensure correct weight loading and mapping.
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={
"lm_head.": "language_model.lm_head.",
Expand Down
Loading