-
Notifications
You must be signed in to change notification settings - Fork 217
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Gemma2 support. #562
Add Gemma2 support. #562
Conversation
Thanks @radi-cho for this contribution. As for the sliding window, it's not supported in the attention kernels and it would require to switch to flash attention v2 by using |
Great. I am not as familiar with flashattention so let's leave that for a future PR and keep this focused on Gemma2. I just released I just noticed that when I save the model it saves the tied weights |
I tried your model, it only supports torch.bfloat16. Can you make it support both bfloat16 and float16? |
I see that Lines 302 to 319 in 2e2635b
|
@casper-hansen thanks for pointing this out. I guess the |
Thanks for the PR @radi-cho! This is great work and much welcomed as the Gemma2 model series requires a bit more effort than others. |
This PR aims to add Gemma2 support, as initially discussed in #529.
pre_feedforward_layernorm
op.pre_feedforward_layernorm
, andpost_feedforward_layernorm
.weight + 1
when scalingGemma2RMSNorm
.@StoyanGanchev helped for this contribution.