From cdd84fcc0c7e278417b505eee58e63222a15179d Mon Sep 17 00:00:00 2001 From: ispobock Date: Thu, 5 Dec 2024 01:26:10 +0000 Subject: [PATCH] fix awq --- python/sglang/srt/models/deepseek_v2.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index e83774ff55e..80db9a35c71 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -21,6 +21,7 @@ import torch from torch import nn from transformers import PretrainedConfig +from vllm import _custom_ops as ops from vllm.distributed import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, @@ -894,7 +895,19 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): if not global_server_args_dict["disable_mla"]: for layer_id in range(self.config.num_hidden_layers): self_attn = self.model.layers[layer_id].self_attn - w_kc, w_vc = self_attn.kv_b_proj.weight.unflatten( + if hasattr(self_attn.kv_b_proj, "qweight"): + # AWQ compatible + w = ops.awq_dequantize( + self_attn.kv_b_proj.qweight, + self_attn.kv_b_proj.scales, + self_attn.kv_b_proj.qzeros, + 0, + 0, + 0, + ).T + else: + w = self_attn.kv_b_proj.weight + w_kc, w_vc = w.unflatten( 0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim) ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1) self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2)