Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] Support GGUF models newly added in transformers 4.46.0 #9685

Merged
merged 19 commits into from
Jan 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 8 additions & 14 deletions examples/offline_inference/gguf_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,27 +3,20 @@
from vllm import LLM, SamplingParams


def run_gguf_inference(model_path):
PROMPT_TEMPLATE = "<|system|>\n{system_message}</s>\n<|user|>\n{prompt}</s>\n<|assistant|>\n" # noqa: E501
system_message = "You are a friendly chatbot who always responds in the style of a pirate." # noqa: E501
def run_gguf_inference(model_path, tokenizer):
# Sample prompts.
prompts = [
"How many helicopters can a human eat in one sitting?",
"What's the future of AI?",
]
prompts = [
PROMPT_TEMPLATE.format(system_message=system_message, prompt=prompt)
for prompt in prompts
]
prompts = [[{"role": "user", "content": prompt}] for prompt in prompts]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0, max_tokens=128)

# Create an LLM.
llm = LLM(model=model_path,
tokenizer="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
gpu_memory_utilization=0.95)
llm = LLM(model=model_path, tokenizer=tokenizer)

outputs = llm.generate(prompts, sampling_params)
outputs = llm.chat(prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt = output.prompt
Expand All @@ -32,7 +25,8 @@ def run_gguf_inference(model_path):


if __name__ == "__main__":
repo_id = "TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF"
filename = "tinyllama-1.1b-chat-v1.0.Q4_0.gguf"
repo_id = "bartowski/Phi-3-medium-4k-instruct-GGUF"
filename = "Phi-3-medium-4k-instruct-IQ2_M.gguf"
tokenizer = "microsoft/Phi-3-medium-4k-instruct"
model = hf_hub_download(repo_id, filename=filename)
run_gguf_inference(model)
run_gguf_inference(model, tokenizer)
105 changes: 74 additions & 31 deletions tests/models/decoder_only/language/test_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,45 +4,90 @@
"""

import os
from typing import List, NamedTuple, Type

import pytest
from huggingface_hub import hf_hub_download
from transformers import AutoTokenizer

from tests.quantization.utils import is_quant_method_supported

from ....conftest import VllmRunner
from ...utils import check_logprobs_close

os.environ["TOKENIZERS_PARALLELISM"] = "true"

MAX_MODEL_LEN = 1024


class GGUFTestConfig(NamedTuple):
original_model: str
gguf_repo: str
gguf_filename: str

@property
def gguf_model(self):
return hf_hub_download(self.gguf_repo, filename=self.gguf_filename)


LLAMA_CONFIG = GGUFTestConfig(
original_model="meta-llama/Llama-3.2-1B-Instruct",
gguf_repo="bartowski/Llama-3.2-1B-Instruct-GGUF",
gguf_filename="Llama-3.2-1B-Instruct-IQ4_XS.gguf",
)

QWEN2_CONFIG = GGUFTestConfig(
original_model="Qwen/Qwen2.5-1.5B-Instruct",
gguf_repo="Qwen/Qwen2.5-1.5B-Instruct-GGUF",
gguf_filename="qwen2.5-1.5b-instruct-q6_k.gguf",
)

PHI3_CONFIG = GGUFTestConfig(
original_model="microsoft/Phi-3.5-mini-instruct",
gguf_repo="bartowski/Phi-3.5-mini-instruct-GGUF",
gguf_filename="Phi-3.5-mini-instruct-IQ4_XS.gguf",
)

GPT2_CONFIG = GGUFTestConfig(
original_model="openai-community/gpt2-large",
gguf_repo="QuantFactory/gpt2-large-GGUF",
gguf_filename="gpt2-large.Q4_K_M.gguf",
)

STABLELM_CONFIG = GGUFTestConfig(
original_model="stabilityai/stablelm-3b-4e1t",
gguf_repo="afrideva/stablelm-3b-4e1t-GGUF",
gguf_filename="stablelm-3b-4e1t.q4_k_m.gguf",
)

STARCODER_CONFIG = GGUFTestConfig(
original_model="bigcode/starcoder2-3b",
gguf_repo="QuantFactory/starcoder2-3b-GGUF",
gguf_filename="starcoder2-3b.Q6_K.gguf",
)

MODELS = [
LLAMA_CONFIG,
QWEN2_CONFIG,
PHI3_CONFIG,
GPT2_CONFIG,
STABLELM_CONFIG,
# STARCODER_CONFIG, # broken
]


@pytest.mark.skipif(not is_quant_method_supported("gguf"),
reason="gguf is not supported on this GPU type.")
@pytest.mark.parametrize(("original_model", "gguf_id", "gguf_path"), [
("meta-llama/Llama-3.2-1B-Instruct",
"bartowski/Llama-3.2-1B-Instruct-GGUF",
"Llama-3.2-1B-Instruct-Q4_K_M.gguf"),
("meta-llama/Llama-3.2-1B-Instruct",
"bartowski/Llama-3.2-1B-Instruct-GGUF",
"Llama-3.2-1B-Instruct-IQ4_XS.gguf"),
("Qwen/Qwen2-1.5B-Instruct", "Qwen/Qwen2-1.5B-Instruct-GGUF",
"qwen2-1_5b-instruct-q4_k_m.gguf"),
("Qwen/Qwen2-1.5B-Instruct", "legraphista/Qwen2-1.5B-Instruct-IMat-GGUF",
"Qwen2-1.5B-Instruct.IQ4_XS.gguf"),
])
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [32])
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("tp_size", [1, 2])
def test_models(
num_gpus_available,
vllm_runner,
example_prompts,
original_model,
gguf_id,
gguf_path,
num_gpus_available: int,
vllm_runner: Type[VllmRunner],
example_prompts: List[str],
model: GGUFTestConfig,
dtype: str,
max_tokens: int,
num_logprobs: int,
Expand All @@ -51,28 +96,26 @@ def test_models(
if num_gpus_available < tp_size:
pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}")

gguf_model = hf_hub_download(gguf_id, filename=gguf_path)

tokenizer = AutoTokenizer.from_pretrained(original_model)
messages = [[{
'role': 'user',
'content': prompt
}] for prompt in example_prompts]
example_prompts = tokenizer.apply_chat_template(messages,
tokenize=False,
add_generation_prompt=True)
tokenizer = AutoTokenizer.from_pretrained(model.original_model)
if tokenizer.chat_template is not None:
messages = [[{
'role': 'user',
'content': prompt
}] for prompt in example_prompts]
example_prompts = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True)

# Run unquantized model.
with vllm_runner(model_name=original_model,
with vllm_runner(model_name=model.original_model,
dtype=dtype,
max_model_len=MAX_MODEL_LEN,
tensor_parallel_size=tp_size) as original_model:

original_outputs = original_model.generate_greedy_logprobs(
example_prompts[:-1], max_tokens, num_logprobs)

# Run gguf model.
with vllm_runner(model_name=gguf_model,
with vllm_runner(model_name=model.gguf_model,
tokenizer_name=model.original_model,
dtype=dtype,
max_model_len=MAX_MODEL_LEN,
tensor_parallel_size=tp_size) as gguf_model:
Expand Down
58 changes: 35 additions & 23 deletions vllm/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,8 +447,14 @@ def weight_loader(self,
is_gguf_weight = getattr(param, "is_gguf_weight", False)
is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
if is_gguf_weight_type:
param.data[loaded_shard_id].copy_(loaded_weight)
param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
if loaded_shard_id is not None:
param.data[loaded_shard_id].copy_(loaded_weight)
param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
else:
param.shard_weight_type = {
i: loaded_weight.item()
for i, _ in enumerate(self.output_sizes)
}
return

if is_gguf_weight:
Expand All @@ -459,15 +465,15 @@ def weight_loader(self,
shard_size = loaded_weight.size(output_dim) // tp_size
start_idx = tp_rank * shard_size

loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size)

param.shard_id.append(loaded_shard_id)
param.shard_id_map[loaded_shard_id] = len(param.data_container)
param.data_container.append(loaded_weight)
if len(param.data_container) == 2:
self.qweight = param.materialize_nested()
return
if loaded_shard_id is not None:
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size)
param.shard_id.append(loaded_shard_id)
param.shard_id_map[loaded_shard_id] = len(param.data_container)
param.data_container.append(loaded_weight)
if len(param.data_container) == 2:
self.qweight = param.materialize_nested()
return

param_data = param.data
output_dim = getattr(param, "output_dim", None)
Expand Down Expand Up @@ -811,10 +817,16 @@ def weight_loader(self,
# initialize GGUF param after we know the quantize type
is_gguf_weight = getattr(param, "is_gguf_weight", False)
is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
if is_gguf_weight_type and loaded_shard_id is not None:
if is_gguf_weight_type:
idx_map = {"q": 0, "k": 1, "v": 2}
param.data[idx_map[loaded_shard_id]].copy_(loaded_weight)
param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
if loaded_shard_id is not None:
param.data[idx_map[loaded_shard_id]].copy_(loaded_weight)
param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
else:
param.shard_weight_type = {
k: loaded_weight.item()
for k in idx_map
}
return

if is_gguf_weight:
Expand All @@ -825,15 +837,15 @@ def weight_loader(self,
shard_size = loaded_weight.size(output_dim) // tp_size
start_idx = tp_rank * shard_size

loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size)

param.shard_id.append(loaded_shard_id)
param.shard_id_map[loaded_shard_id] = len(param.data_container)
param.data_container.append(loaded_weight)
if len(param.data_container) == 3:
self.qweight = param.materialize_nested()
return
if loaded_shard_id is not None:
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size)
param.shard_id.append(loaded_shard_id)
param.shard_id_map[loaded_shard_id] = len(param.data_container)
param.data_container.append(loaded_weight)
if len(param.data_container) == 3:
self.qweight = param.materialize_nested()
return

param_data = param.data
output_dim = getattr(param, "output_dim", None)
Expand Down
11 changes: 8 additions & 3 deletions vllm/model_executor/models/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
assert not config.scale_attn_by_inverse_layer_idx
assert not config.reorder_and_upcast_attn
self.embed_dim = config.hidden_size
self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim)
self.wte = VocabParallelEmbedding(config.vocab_size,
self.embed_dim,
quant_config=quant_config,
prefix=f"{prefix}.wte")
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
self.start_layer, self.end_layer, self.h = make_layers(
config.num_hidden_layers,
Expand Down Expand Up @@ -259,7 +262,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.lm_head = self.transformer.wte
else:
self.lm_head = ParallelLMHead(self.config.vocab_size,
self.config.hidden_size)
self.config.hidden_size,
quant_config=quant_config,
prefix=f"{prefix}.lm_head")
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
Expand Down Expand Up @@ -304,7 +309,7 @@ def load_weights(self, weights: Iterable[Tuple[str,
params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if "lm_head.weight" in name:
if name.startswith("lm_head"):
# GPT-2 ties the weights of the embedding layer and the final
# linear layer.
continue
Expand Down
3 changes: 2 additions & 1 deletion vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,8 @@ def __init__(
)

is_neox_style = True
if quant_config is not None and quant_config.get_name() == "gguf":
is_gguf = quant_config and quant_config.get_name() == "gguf"
if is_gguf and config.model_type == "llama":
is_neox_style = False

self.rotary_emb = get_rope(
Expand Down
Loading
Loading