Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Gemma2 support. #562

Merged
merged 1 commit into from
Aug 4, 2024
Merged

Add Gemma2 support. #562

merged 1 commit into from
Aug 4, 2024

Conversation

radi-cho
Copy link
Contributor

@radi-cho radi-cho commented Jul 31, 2024

This PR aims to add Gemma2 support, as initially discussed in #529.

  • Gemma2 with scaling of the pre_feedforward_layernorm op.
  • Gemma2Block with proper residuals, pre_feedforward_layernorm, and post_feedforward_layernorm.
  • Hidden state normalizer (keeping lm_head weights unnormalized since they are tied with the embeddings).
  • Setting weight + 1 when scaling Gemma2RMSNorm.
  • Attention softcapping. (final softcapping should be covered by the Gemma2ForCausalLM module).
  • Sliding window mask (?)

@StoyanGanchev helped for this contribution.

@casper-hansen
Copy link
Owner

Thanks @radi-cho for this contribution. As for the sliding window, it's not supported in the attention kernels and it would require to switch to flash attention v2 by using flash_attn_with_kvcache. This has been on my todo-list and is documented in #208. You are free to help with it in this PR or in future PRs but it's not required for Gemma2 support.

@radi-cho
Copy link
Contributor Author

radi-cho commented Aug 1, 2024

Great. I am not as familiar with flashattention so let's leave that for a future PR and keep this focused on Gemma2. I just released gemma-2-2b and gemma-2-9b quantized with the current branch to huggingface along with evaluation of perplexity on WikiText-2 and C4.

I just noticed that when I save the model it saves the tied weights lm_head.weight and model.embed_tokens.weight twice in the safetensors file. @casper-hansen do you know how to fix that?

@haichuan1221
Copy link

Great. I am not as familiar with flashattention so let's leave that for a future PR and keep this focused on Gemma2. I just released gemma-2-2b and gemma-2-9b quantized with the current branch to huggingface along with evaluation of perplexity on WikiText-2 and C4.

I tried your model, it only supports torch.bfloat16. Can you make it support both bfloat16 and float16?

@casper-hansen
Copy link
Owner

casper-hansen commented Aug 2, 2024

I just noticed that when I save the model it saves the tied weights lm_head.weight and model.embed_tokens.weight twice in the safetensors file. @casper-hansen do you know how to fix that?

I see thatlm_head and embed_tokens are separated in the weights file. I am not sure what is causing it, but I would assume it either happens at load time or at save time. I'm not sure if it's calling model.state_dict() that causes it but an area to investigate is here:

AutoAWQ/awq/models/base.py

Lines 302 to 319 in 2e2635b

shards, index = shard_checkpoint(
self.model.state_dict(), max_shard_size=shard_size, weights_name=model_name
)
for shard_file, shard in shards.items():
if safetensors:
# safetensors must be in the same memory, so we duplicate and use contiguous memory
shard = {k: v.clone().contiguous() for k, v in shard.items()}
save_file(
shard, os.path.join(save_dir, shard_file), metadata={"format": "pt"}
)
else:
torch.save(shard, os.path.join(save_dir, shard_file))
# save shard index
if index is not None:
with open(f"{save_dir}/{model_name}.index.json", "w+") as file:
file.write(json.dumps(index, indent=4))

@radi-cho
Copy link
Contributor Author

radi-cho commented Aug 2, 2024

@casper-hansen thanks for pointing this out. I guess the save_quantized method should be extended to make sure tied (duplicated) weights are deleted from the state dict as it is in modeling_utils.py#L2657-L2689. If this is not Gemma-specific I would suggest to fix it in a separate commit (even before gemma is merged). Can you review the current codebase?

@casper-hansen casper-hansen merged commit b9e9b73 into casper-hansen:main Aug 4, 2024
@casper-hansen
Copy link
Owner

Thanks for the PR @radi-cho! This is great work and much welcomed as the Gemma2 model series requires a bit more effort than others.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants