Skip to content

Commit

Permalink
[Misc]Add BNB quantization for Qwen2VL (#11719)
Browse files Browse the repository at this point in the history
Signed-off-by: Jee Jee Li <[email protected]>
Signed-off-by: Isotr0py <[email protected]>
Co-authored-by: Isotr0py <[email protected]>
  • Loading branch information
jeejeelee and Isotr0py authored Jan 3, 2025
1 parent 1543914 commit a655eb3
Showing 1 changed file with 40 additions and 29 deletions.
69 changes: 40 additions & 29 deletions vllm/model_executor/models/qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@

from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig
from vllm.distributed import parallel_state
from vllm.distributed import parallel_state, tensor_model_parallel_all_gather
from vllm.distributed import utils as dist_utils
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
Expand Down Expand Up @@ -239,6 +239,8 @@ def __init__(
super().__init__()
# Per attention head and per partition values.
world_size = parallel_state.get_tensor_model_parallel_world_size()
self.tp_size = world_size
self.tp_rank = parallel_state.get_tensor_model_parallel_rank()
self.hidden_size_per_attention_head = dist_utils.divide(
projection_size, num_heads)
self.num_attention_heads_per_partition = dist_utils.divide(
Expand All @@ -261,24 +263,41 @@ def __init__(
raise RuntimeError(
f"Qwen2-VL does not support {self.attn_backend} backend now.")

def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
# [s, b, 3 * head * head_dim]
seq_len, bs, _ = qkv.shape
if self.tp_size > 1:
qkv = tensor_model_parallel_all_gather(qkv)

# [s, b, 3 * head * head_dim] -> 3 * [s, b, head * head_dim]
q, k, v = qkv.chunk(3, dim=2)

# 3 * [s, b, head * head_dim]
if self.tp_size > 1:
splitter = partial(dist_utils.split_tensor_along_last_dim,
num_partitions=self.tp_size)
q = splitter(q)[self.tp_rank]
k = splitter(k)[self.tp_rank]
v = splitter(v)[self.tp_rank]

# 3 * [s, b, head * head_dim] -> 3 * [s, b, head, head_dim]
new_shape = (seq_len, bs, self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head)
q, k, v = (x.view(*new_shape) for x in (q, k, v))
return q, k, v

def forward(
self,
x: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor,
) -> torch.Tensor:
# [s, b, c] --> [s, b, head * 3 * head_dim]
x, _ = self.qkv(x)

# [s, b, head * 3 * head_dim] --> [s, b, head, 3 * head_dim]
new_x_shape = x.size()[:-1] + (
self.num_attention_heads_per_partition,
3 * self.hidden_size_per_attention_head,
)
x = x.view(*new_x_shape)
# [s, b, c] --> [s, b, 3 * head * head_dim]
x, _ = self.qkv(x)

# [s, b, head, 3 * head_dim] --> 3 [s, b, head, head_dim]
q, k, v = dist_utils.split_tensor_along_last_dim(x, 3)
# [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim]
q, k, v = self.split_qkv(x)
batch_size = q.shape[1]

q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous()
Expand Down Expand Up @@ -614,24 +633,6 @@ def load_weights(self, weights: Iterable[Tuple[str,
weight_loader(param, loaded_weight, shard_id)
break
else:
if name.endswith("qkv.weight"):
visual_num_heads = self.num_heads
visual_embed_dim = self.embed_dim
head_size = visual_embed_dim // visual_num_heads
loaded_weight = loaded_weight.view(3, visual_num_heads,
head_size,
visual_embed_dim)
loaded_weight = loaded_weight.transpose(0, 1)
loaded_weight = loaded_weight.reshape(-1, visual_embed_dim)
elif name.endswith("qkv.bias"):
visual_num_heads = self.num_heads
visual_embed_dim = self.embed_dim
head_size = visual_embed_dim // visual_num_heads
loaded_weight = loaded_weight.view(3, visual_num_heads,
head_size)
loaded_weight = loaded_weight.transpose(0, 1)
loaded_weight = loaded_weight.reshape(-1)

param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
Expand Down Expand Up @@ -935,6 +936,16 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
embedding_modules = {}
embedding_padding_modules = []

# BitandBytes specific attributes
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
"k_proj": ("qkv_proj", 1),
"v_proj": ("qkv_proj", 2),
"gate_proj": ("gate_up_proj", 0),
"up_proj": ("gate_up_proj", 1),
}

# To ensure correct weight loading and mapping.
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={
"lm_head.": "language_model.lm_head.",
Expand Down

0 comments on commit a655eb3

Please sign in to comment.