This repository has been archived by the owner on Oct 11, 2024. It is now read-only.
forked from vllm-project/vllm
-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Model] MLPSpeculator speculative decoding support (vllm-project#4947)
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com> Co-authored-by: Thomas Parnell <tpa@zurich.ibm.com> Co-authored-by: Nick Hill <nickhill@us.ibm.com> Co-authored-by: Davis Wertheimer <Davis.Wertheimer@ibm.com>
1 parent
5ccb86c
commit b05443a
Showing
18 changed files
with
523 additions
and
40 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
import gc | ||
import time | ||
from typing import List | ||
|
||
from vllm import LLM, SamplingParams | ||
|
||
|
||
def time_generation(llm: LLM, prompts: List[str], | ||
sampling_params: SamplingParams): | ||
# Generate texts from the prompts. The output is a list of RequestOutput | ||
# objects that contain the prompt, generated text, and other information. | ||
# Warmup first | ||
llm.generate(prompts, sampling_params) | ||
llm.generate(prompts, sampling_params) | ||
start = time.time() | ||
outputs = llm.generate(prompts, sampling_params) | ||
end = time.time() | ||
print((end - start) / sum([len(o.outputs[0].token_ids) for o in outputs])) | ||
# Print the outputs. | ||
for output in outputs: | ||
generated_text = output.outputs[0].text | ||
print(f"text: {generated_text!r}") | ||
|
||
|
||
if __name__ == "__main__": | ||
|
||
template = ( | ||
"Below is an instruction that describes a task. Write a response " | ||
"that appropriately completes the request.\n\n### Instruction:\n{}" | ||
"\n\n### Response:\n") | ||
|
||
# Sample prompts. | ||
prompts = [ | ||
"Write about the president of the United States.", | ||
] | ||
prompts = [template.format(prompt) for prompt in prompts] | ||
# Create a sampling params object. | ||
sampling_params = SamplingParams(temperature=0.0, max_tokens=200) | ||
|
||
# Create an LLM without spec decoding | ||
llm = LLM(model="meta-llama/Llama-2-13b-chat-hf") | ||
|
||
print("Without speculation") | ||
time_generation(llm, prompts, sampling_params) | ||
|
||
del llm | ||
gc.collect() | ||
|
||
# Create an LLM with spec decoding | ||
llm = LLM( | ||
model="meta-llama/Llama-2-13b-chat-hf", | ||
speculative_model="ibm-fms/llama-13b-accelerator", | ||
# These are currently required for MLPSpeculator decoding | ||
use_v2_block_manager=True, | ||
enforce_eager=True, | ||
) | ||
|
||
print("With speculation") | ||
time_generation(llm, prompts, sampling_params) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,143 @@ | ||
import math | ||
from typing import Iterable, List, Tuple | ||
|
||
import torch | ||
import torch.nn as nn | ||
|
||
from vllm.model_executor import SamplingMetadata | ||
from vllm.model_executor.layers.logits_processor import LogitsProcessor | ||
from vllm.model_executor.layers.sampler import Sampler | ||
from vllm.model_executor.layers.vocab_parallel_embedding import ( | ||
VocabParallelEmbedding) | ||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader | ||
from vllm.sequence import SamplerOutput | ||
|
||
|
||
class MLPSpeculatorLayerNorm(nn.Module): | ||
""" | ||
A L2 normalization implementation | ||
... | ||
Args | ||
---- | ||
normalized_shape : int | ||
Dimensionality of input data (size of final tensor axis) | ||
eps : float | ||
Safety term to prevent division by zero. Make sure the chosen value | ||
fits in the range of your encoding scheme | ||
(i.e. fp16 requires eps >= 6e-8). | ||
""" | ||
|
||
def __init__( | ||
self, | ||
normalized_shape, | ||
eps=1e-06, | ||
): | ||
super(MLPSpeculatorLayerNorm, self).__init__() | ||
self.weight = nn.Parameter(torch.empty(normalized_shape)) | ||
self.bias = nn.Parameter(torch.empty(normalized_shape)) | ||
self.eps = eps | ||
|
||
def forward(self, x): | ||
xf = x | ||
xf = xf * torch.rsqrt(xf.pow(2).mean(-1, keepdim=True) + self.eps) | ||
x = xf.type_as(x) | ||
x = self.weight * x | ||
x = x + self.bias | ||
return x | ||
|
||
|
||
class MLPSpeculator(nn.Module): | ||
|
||
def __init__(self, config, **kwargs) -> None: | ||
super().__init__() | ||
self.n_predict = config.n_predict | ||
self.vocab_size = config.vocab_size | ||
self.emb_dim = config.emb_dim | ||
self.inner_dim = config.inner_dim if config.inner_dim != 0 \ | ||
else config.emb_dim | ||
|
||
self.max_speculative_tokens = getattr(config, "max_speculative_tokens", | ||
self.n_predict) | ||
|
||
self.emb = nn.ModuleList([ | ||
VocabParallelEmbedding(config.vocab_size, | ||
self.inner_dim, | ||
org_num_embeddings=config.vocab_size) | ||
for _ in range(self.max_speculative_tokens) | ||
]) | ||
|
||
self.proj = nn.ModuleList([ | ||
nn.Linear((self.emb_dim if i == 0 else self.inner_dim), | ||
self.inner_dim, | ||
bias=False) for i in range(self.max_speculative_tokens) | ||
]) | ||
|
||
self.head = nn.ModuleList([ | ||
nn.Linear(self.inner_dim, self.vocab_size, bias=False) | ||
for _ in range(self.max_speculative_tokens) | ||
]) | ||
self.ln = nn.ModuleList([ | ||
MLPSpeculatorLayerNorm(self.inner_dim) | ||
for _ in range(self.max_speculative_tokens) | ||
]) | ||
|
||
self.state_weight = 0.5**(0.5 / config.n_predict) | ||
self.emb_weight = math.sqrt( | ||
(1 - self.state_weight**2) * (self.inner_dim / 2)) | ||
self.activation = nn.GELU() | ||
self.config = config | ||
self.logits_processor = LogitsProcessor(config.vocab_size, | ||
config.vocab_size, 1.0) | ||
self.sampler = Sampler() | ||
|
||
def generate_proposals( | ||
self, | ||
input_ids: torch.Tensor, | ||
previous_hidden_states: torch.Tensor, | ||
num_predict_tokens: int, | ||
sampling_metadata: SamplingMetadata, | ||
) -> List[SamplerOutput]: | ||
if num_predict_tokens > self.max_speculative_tokens: | ||
raise ValueError(f"Max speculative tokens for model is " | ||
f"{self.max_speculative_tokens}, but " | ||
f"{num_predict_tokens} were requested") | ||
|
||
# b x 1 x d | ||
previous_hidden_states = previous_hidden_states.unsqueeze(1) | ||
|
||
# b x 1 | ||
last_tokens = input_ids.unsqueeze(1) | ||
|
||
next_tokens = [] | ||
|
||
for head_index in range(num_predict_tokens): | ||
|
||
# Project and predict | ||
z = self.emb[head_index](last_tokens) # b k d | ||
states = self.proj[head_index](previous_hidden_states) | ||
|
||
# Weighted add of state_weight*state and emb_weight*z | ||
# Let subsequent LN take care of denominator | ||
# state_weight is close to 1, so shouldn't be any precision issues | ||
states.add_(z, alpha=self.emb_weight / self.state_weight) | ||
|
||
states = self.activation(self.ln[head_index](states)) # b k d | ||
# TODO: not yet supporting top_k_tokens_per_head | ||
previous_hidden_states = states | ||
|
||
logits = self.logits_processor(self.head[head_index].weight, | ||
states, sampling_metadata) | ||
|
||
output = self.sampler(logits.flatten(0, 1), sampling_metadata) | ||
last_tokens = output.sampled_token_ids | ||
next_tokens.append(output) | ||
|
||
return next_tokens | ||
|
||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): | ||
params_dict = dict(self.named_parameters()) | ||
for name, loaded_weight in weights: | ||
param = params_dict[name.replace("speculator.", "")] | ||
weight_loader = getattr(param, "weight_loader", | ||
default_weight_loader) | ||
weight_loader(param, loaded_weight) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
from typing import List, Optional, Tuple | ||
|
||
import torch | ||
|
||
from vllm.model_executor import SamplingMetadata | ||
from vllm.sequence import (ExecuteModelRequest, SamplerOutput, | ||
SequenceGroupMetadata) | ||
from vllm.spec_decode.multi_step_worker import MultiStepWorker | ||
from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase | ||
from vllm.worker.model_runner import ModelInput | ||
|
||
|
||
class MLPSpeculatorWorker(NonLLMProposerWorkerBase, MultiStepWorker): | ||
"""Worker for MLPSpeculator models. | ||
Not currently compatible with LoRA or chunked prefill. | ||
""" | ||
|
||
@torch.inference_mode() | ||
def sampler_output( | ||
self, | ||
execute_model_req: ExecuteModelRequest, | ||
sample_len: int, | ||
) -> Tuple[List[SamplerOutput], bool]: | ||
"""Run the model forward pass to generate sample_len future tokens. | ||
Returns the list of sampler output, one per layer, along with indicator | ||
of whether torch tensor in sampler output need to be transposed in | ||
latter sampler_output_to_torch logic. | ||
For mlp spec worker, this indicator shall be True. | ||
""" | ||
self._raise_if_unsupported(execute_model_req) | ||
|
||
seq_group_metadata_list = execute_model_req.seq_group_metadata_list | ||
|
||
(input_tokens, seq_lens, | ||
query_lens) = self._prepare_input_tensors(seq_group_metadata_list) | ||
|
||
sampling_metadata = SamplingMetadata.prepare( | ||
seq_group_metadata_list, seq_lens, query_lens, self.device, | ||
self.model_runner.pin_memory) | ||
|
||
model_outputs = self.model_runner.model.generate_proposals( | ||
input_ids=input_tokens, | ||
previous_hidden_states=execute_model_req.previous_hidden_states. | ||
hidden_states, | ||
num_predict_tokens=sample_len, | ||
sampling_metadata=sampling_metadata) | ||
|
||
assert len(model_outputs) == sample_len | ||
|
||
return model_outputs, True | ||
|
||
def _prepare_input_tensors( | ||
self, | ||
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], | ||
) -> Tuple[torch.Tensor, List[int], List[int]]: | ||
if not seq_group_metadata_list: | ||
return ModelInput.empty(self.device) | ||
|
||
input_tokens: List[int] = [] | ||
seq_lens: List[int] = [] | ||
query_lens: List[int] = [] | ||
|
||
for seq_group_metadata in seq_group_metadata_list: | ||
is_prompt = seq_group_metadata.is_prompt | ||
|
||
for seq_data in seq_group_metadata.seq_data.values(): | ||
seq_data_len = seq_data.get_len() | ||
if is_prompt: | ||
context_len = seq_data.get_num_computed_tokens() | ||
seq_len = min( | ||
seq_data_len, | ||
context_len + seq_group_metadata.token_chunk_size) | ||
tokens = seq_data.get_token_ids()[context_len:seq_len] | ||
seq_lens.append(seq_len) | ||
input_tokens.extend(tokens) | ||
query_lens.append(seq_len - context_len) | ||
else: | ||
seq_lens.append(seq_data_len) | ||
input_tokens.append(seq_data.get_last_token_id()) | ||
query_lens.append(1) | ||
|
||
input_tokens_tensor = torch.tensor(input_tokens, | ||
dtype=torch.long, | ||
device=self.device) | ||
return input_tokens_tensor, seq_lens, query_lens |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
from typing import List, Optional | ||
|
||
from transformers import PretrainedConfig | ||
|
||
|
||
class MLPSpeculatorConfig(PretrainedConfig): | ||
model_type = "mlp_speculator" | ||
|
||
attribute_map = { | ||
"hidden_size": "emb_dim", | ||
} | ||
|
||
def __init__(self, | ||
vocab_size: int = 32000, | ||
emb_dim: int = 4096, | ||
inner_dim: int = 0, | ||
n_predict: int = 3, | ||
top_k_tokens_per_head: Optional[List[int]] = None, | ||
n_candidates: int = 5, | ||
**kwargs): | ||
""" | ||
Initialize an MLPSpeculatorConfig | ||
Args: | ||
vocab_size: int | ||
the model vocab size | ||
emb_dim: int | ||
the model embedding dimension | ||
inner_dim: int | ||
the inner dimension of the model. If 0, will be the emb_dim. | ||
n_predict: int | ||
the number of lookaheads for the speculator | ||
top_k_tokens_per_head: List[int] | ||
Number of tokens to consider from each head when forming the | ||
candidate tree. | ||
For each candidate branch in the tree, head n produces topk[n] | ||
additional sub-branches. | ||
n_candidates: int | ||
number of child candidates to create per sequence | ||
""" | ||
if top_k_tokens_per_head is None: | ||
top_k_tokens_per_head = [5, 4, 3] | ||
assert len(top_k_tokens_per_head) == n_predict | ||
self.vocab_size = vocab_size | ||
self.emb_dim = emb_dim | ||
self.inner_dim = inner_dim | ||
self.n_predict = n_predict | ||
self.top_k_tokens_per_head = top_k_tokens_per_head | ||
self.n_candidates = n_candidates | ||
super().__init__(**kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters