From ccec961a77d4fd42505671ab0d9b3f16f0f08182 Mon Sep 17 00:00:00 2001 From: "upskyy (Patrick Ha)" Date: Tue, 3 Dec 2024 15:54:40 +0900 Subject: [PATCH] Fix gemma2 lora inference --- python/sglang/srt/models/gemma2.py | 34 ++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/python/sglang/srt/models/gemma2.py b/python/sglang/srt/models/gemma2.py index dbca7268803..0c0e6155d35 100644 --- a/python/sglang/srt/models/gemma2.py +++ b/python/sglang/srt/models/gemma2.py @@ -355,6 +355,40 @@ def forward( input_ids, hidden_states, self.model.embed_tokens, forward_batch ) + def get_hidden_dim(self, module_name): + # return input_dim, output_dim + if module_name in ["q_proj", "qkv_proj"]: + return ( + self.config.hidden_size, + self.config.head_dim * self.config.num_attention_heads, + ) + elif module_name in ["o_proj"]: + return ( + self.config.head_dim * self.config.num_attention_heads, + self.config.hidden_size, + ) + elif module_name in ["kv_proj"]: + return ( + self.config.hidden_size, + self.config.head_dim * self.config.num_key_value_heads, + ) + elif module_name == "gate_up_proj": + return self.config.hidden_size, self.config.intermediate_size + elif module_name == "down_proj": + return self.config.intermediate_size, self.config.hidden_size + else: + raise NotImplementedError() + + def get_module_name(self, name): + params_mapping = { + "q_proj": "qkv_proj", + "k_proj": "qkv_proj", + "v_proj": "qkv_proj", + "gate_proj": "gate_up_proj", + "up_proj": "gate_up_proj", + } + return params_mapping.get(name, name) + def get_attention_sliding_window_size(self): return get_attention_sliding_window_size(self.config)