Skip to content

Commit

Permalink
Fix shape error that occurred when loading lora weight of gemma2 mode…
Browse files Browse the repository at this point in the history
…l. (#2330)
  • Loading branch information
upskyy authored Dec 8, 2024
1 parent ef995da commit 63dfab1
Showing 1 changed file with 34 additions and 0 deletions.
34 changes: 34 additions & 0 deletions python/sglang/srt/models/gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,40 @@ def forward(
input_ids, hidden_states, self.model.embed_tokens, forward_batch
)

def get_hidden_dim(self, module_name):
# return input_dim, output_dim
if module_name in ["q_proj", "qkv_proj"]:
return (
self.config.hidden_size,
self.config.head_dim * self.config.num_attention_heads,
)
elif module_name in ["o_proj"]:
return (
self.config.head_dim * self.config.num_attention_heads,
self.config.hidden_size,
)
elif module_name in ["kv_proj"]:
return (
self.config.hidden_size,
self.config.head_dim * self.config.num_key_value_heads,
)
elif module_name == "gate_up_proj":
return self.config.hidden_size, self.config.intermediate_size
elif module_name == "down_proj":
return self.config.intermediate_size, self.config.hidden_size
else:
raise NotImplementedError()

def get_module_name(self, name):
params_mapping = {
"q_proj": "qkv_proj",
"k_proj": "qkv_proj",
"v_proj": "qkv_proj",
"gate_proj": "gate_up_proj",
"up_proj": "gate_up_proj",
}
return params_mapping.get(name, name)

def get_attention_sliding_window_size(self):
return get_attention_sliding_window_size(self.config)

Expand Down

0 comments on commit 63dfab1

Please sign in to comment.