Skip to content

Commit

Permalink
Fix AWQ with enable MLA (#2364)
Browse files Browse the repository at this point in the history
  • Loading branch information
ispobock authored Dec 5, 2024
1 parent 2b0fc59 commit 4a63c18
Showing 1 changed file with 14 additions and 1 deletion.
15 changes: 14 additions & 1 deletion python/sglang/srt/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import torch
from torch import nn
from transformers import PretrainedConfig
from vllm import _custom_ops as ops
from vllm.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
Expand Down Expand Up @@ -894,7 +895,19 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
if not global_server_args_dict["disable_mla"]:
for layer_id in range(self.config.num_hidden_layers):
self_attn = self.model.layers[layer_id].self_attn
w_kc, w_vc = self_attn.kv_b_proj.weight.unflatten(
if hasattr(self_attn.kv_b_proj, "qweight"):
# AWQ compatible
w = ops.awq_dequantize(
self_attn.kv_b_proj.qweight,
self_attn.kv_b_proj.scales,
self_attn.kv_b_proj.qzeros,
0,
0,
0,
).T
else:
w = self_attn.kv_b_proj.weight
w_kc, w_vc = w.unflatten(
0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2)
Expand Down

0 comments on commit 4a63c18

Please sign in to comment.