Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Spec Decode] Add Script for converting HF Eagle checkpoint to vLLM compatible checkpoint #11866

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docs/source/features/spec_decode.md
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ llm = LLM(
tensor_parallel_size=4,
speculative_model="path/to/modified/eagle/model",
speculative_draft_tensor_parallel_size=1,
num_speculative_tokens=2,
)

outputs = llm.generate(prompts, sampling_params)
Expand All @@ -192,7 +193,7 @@ A few important things to consider when using the EAGLE based draft models:

1. The EAGLE draft models available in the [HF repository for EAGLE models](https://huggingface.co/yuhuili) cannot be
used directly with vLLM due to differences in the expected layer names and model definition.
To use these models with vLLM, use the [following script](https://gist.github.com/abhigoyal1997/1e7a4109ccb7704fbc67f625e86b2d6d)
To use these models with vLLM, use the [following script](../../../vllm/spec_decode/scripts/vLLM_compatible_EAGLE_ckpt.py)
to convert them. Note that this script does not modify the model's weights.

In the above example, use the script to first convert
Expand Down
Empty file.
78 changes: 78 additions & 0 deletions vllm/spec_decode/scripts/vLLM_compatible_EAGLE_ckpt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import argparse
import json
import os

import torch
from safetensors.torch import load_file, save_file

# Script to convert an EAGLE checkpoint available at
# https://huggingface.co/yuhuili into vLLM compatible checkpoint.
# Borrowed from
# https://gist.github.com/abhigoyal1997/1e7a4109ccb7704fbc67f625e86b2d6d

# Example Usage

# python3 vllm/spec_decode/scripts/vLLM_compatible_EAGLE_ckpt.py
# --eagle_dir </path/to/eagle/checkpoint/dir>
# --baseline_ckpt_file_lm_head \
# </path/to/base_model/ckpt_file_with_lm_head_weight>


def update_model(eagle_dir, baseline_ckpt_path):
# Load the Eagle model checkpoint
eagle_ckpt_path = os.path.join(eagle_dir, "pytorch_model.bin")
ckpt = torch.load(eagle_ckpt_path)

# Load the baseline model checkpoint
ref_ckpt = load_file(baseline_ckpt_path)

# Update the EAGLE model with the lm_head.weight from the reference model
ckpt['lm_head.weight'] = ref_ckpt['lm_head.weight']

# Save the modified checkpoint as safetensors
save_file(ckpt, os.path.join(eagle_dir, "model.safetensors"))

# Load and modify the configuration
config_path = os.path.join(eagle_dir, "config.json")
with open(config_path) as rf:
cfg = json.load(rf)

if "fc.bias" in ckpt:
cfg["eagle_fc_bias"] = True
sroy745 marked this conversation as resolved.
Show resolved Hide resolved

cfg = {"model_type": "eagle", "model": cfg}

# Save the new configuration
with open(config_path, "w") as wf:
json.dump(cfg, wf)

# Delete the original pytorch model checkpoint file
os.remove(eagle_ckpt_path)

print(f"Model updated and saved to {eagle_dir}/model.safetensors. "
"Use this new checkpoint with vLLM")


def main():
# Parse command-line arguments
parser = argparse.ArgumentParser(
description="Make EAGLE checkpoints available at "
"https://huggingface.co/yuhuili vLLM compatible.")
parser.add_argument('--eagle_dir',
type=str,
help="Directory for the EAGLE checkpoint")
parser.add_argument(
'--baseline_ckpt_file_lm_head',
type=str,
help="Path to the baseline model checkpoint file containing the "
"weights for lm_head. The checkpoint needs to be in safetensor "
"format.")

args = parser.parse_args()

# Update the model
update_model(args.eagle_dir, args.baseline_ckpt_file_lm_head)


if __name__ == "__main__":
main()
Loading