Fix shape error that occurred when loading lora weight of gemma2 model. #2330
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Motivation & Modifications
The
get_hidden_dim
functions of the llama model and the gemma2 model differ, which causes a shape error when loading gemma2 lora weights.For example, in the llama model,
head_dim
*num_attention_heads
equalshidden_size
, so usingself.config.hidden_size, self.config.hidden_size
works fine. (3072 = 128 * 24)However, in the gemma2 model,
head_dim
*num_attention_heads
andhidden_size
are unrelated and need to be implemented differently. Specifically, in the gemma2 model,hidden_size
is 2304, whilehead_dim
*num_attention_heads
is 2048. (2304 != 2048)As a result, the
q_proj
in the gemma2 model should be defined asself.config.hidden_size, self.config.head_dim * self.config.num_attention_heads
rather thanself.config.hidden_size, self.config.hidden_size
.Since the
input_dim
andoutput_dim
are different, theo_proj
also needs to be adjusted.I modified the code to ensure that the gemma2 model works correctly with multi-LoRA inference, and I have verified its functionality.
llama3 config : https://huggingface.co/meta-llama/Llama-3.2-3B-Instruct/blob/main/config.json
gemma2 config : https://huggingface.co/google/gemma-2-2b-it/blob/main/config.json
current
get_hidden_dim
functions : https://github.com/sgl-project/sglang/blob/v0.3.6.post2/python/sglang/srt/models/llama.py#L326-L339The gemma2 model does not have a
get_hidden_dim
function implemented, which allows it to bypass the following code(https://github.com/sgl-project/sglang/blob/v0.3.6.post2/python/sglang/srt/lora/lora_manager.py#L34-L46)
Checklist