Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Speculative Decoding] Medusa Implementation with Top-1 proposer #4978

Merged
merged 18 commits into from
Jul 10, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -927,6 +927,10 @@ def maybe_create_spec_config(
"Speculative decoding with mlp_speculator models does not "
"yet support distributed inferencing (TP > 1).")

if draft_model_config.hf_config.model_type == "medusa":
abhigoyal1997 marked this conversation as resolved.
Show resolved Hide resolved
draft_model_config.hf_config.set_num_lookahead_tokens(
num_speculative_tokens)
abhigoyal1997 marked this conversation as resolved.
Show resolved Hide resolved

n_predict = getattr(draft_model_config.hf_config, "n_predict",
None)
if n_predict is not None:
Expand Down
1 change: 1 addition & 0 deletions vllm/model_executor/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
"ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
"XverseForCausalLM": ("xverse", "XverseForCausalLM"),
"Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"),
"MedusaModel": ("medusa", "Medusa"),
"MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
}

Expand Down
148 changes: 148 additions & 0 deletions vllm/model_executor/models/medusa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
from typing import Iterable, List, Optional, Tuple

import torch
import torch.nn as nn

from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput
from vllm.transformers_utils.configs.medusa import MedusaConfig


class ResidualBlock(nn.Module):

def __init__(self, hidden_size: int, num_layers: int) -> None:
super().__init__()

self.layers = nn.ModuleList([
nn.Linear(hidden_size, hidden_size, bias=False)
for _ in range(num_layers)
])
self.act = nn.SiLU()

def forward(self, x: torch.Tensor) -> torch.Tensor:
for layer in self.layers:
x = x + self.act(layer(x))
return x


class Medusa(nn.Module):

def __init__(self, config: MedusaConfig, **_) -> None:
super().__init__()
self.config = config
self.blocks = nn.ModuleList([
ResidualBlock(hidden_size=self.config.hidden_size,
num_layers=self.config.num_hidden_layers)
for _ in range(self.config.num_heads)
])
self.orig_vocab_size = config.vocab_size
self.truncated_vocab_size = config.truncated_vocab_size
self.unpadded_vocab_size = self.truncated_vocab_size

self.lm_heads = nn.ModuleList([
ParallelLMHead(
self.unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=self.truncated_vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
) for _ in range(self.config.num_heads)
])

logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
self.truncated_vocab_size,
logit_scale)

self.token_map = None

def forward(self, hidden_states: torch.Tensor) -> List[torch.Tensor]:
return [block(hidden_states) for block in self.blocks]

def compute_logits(
self, hidden_states: List[torch.Tensor],
sampling_metadata: SamplingMetadata) -> List[torch.Tensor]:
logits = []

for hs, lm_head in zip(hidden_states, self.lm_heads):
_logits = self.logits_processor(lm_head.weight, hs,
sampling_metadata)

if self.token_map is None:
logits.append(_logits)
else:
logits.append(-torch.inf * torch.ones(
abhigoyal1997 marked this conversation as resolved.
Show resolved Hide resolved
size=(*_logits.shape[:-1], self.orig_vocab_size),
device=_logits.device,
dtype=_logits.dtype))

logits[-1][..., self.token_map] = _logits

return logits

def sample(
self,
logits: List[torch.Tensor],
sampling_metadata: SamplingMetadata,
) -> List[SamplerOutput]:
logits = torch.stack(logits, dim=0).float()
logprobs = torch.log_softmax(logits, dim=-1)
token_ids = logits.argmax(-1) # support only top-1 for now
probs = torch.softmax(logits, dim=-1)
Comment on lines +90 to +93
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we use the lossless rejection sampler, we will have to run vLLM's standard sampling routine here -- the probability distribution must be modified in the same way as the scoring probability distributions, else you will get distributional drift in the output.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please elaborate on the distribution shift? The tokens from the draft model are either accepted or rejected based on target model distribution, right? So even if the tokens from the draft are from a slightly different distribution, the final output should still match the target model distribution due to rejection. Is this understanding wrong or am I missing something?

The issue with using the standard sampling is that it was causing too much overhead. So if we do need to use it, we might need some optimizations there to get some speed-up out of Medusa.

Copy link
Contributor Author

@abhigoyal1997 abhigoyal1997 Jun 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's one case that I have noticed generates different tokens sometimes (not sure if this is what you are referring to though).
If without Medusa the logits of top-2 tokens have very close values (or same), then with Medusa those values sometimes change a little bit (I don't know why this is happening since Medusa shouldn't affect the output logits of the target model). This causes different tokens to be preferred by the target model, even for greedy sampling, depending on how those values change.

These images show this:
Screenshot 2024-06-05 at 6 14 04 PM
Screenshot 2024-06-05 at 6 17 23 PM

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I realised this was happening because of bf16 precision, not seeing any such shift when using fp32.


token_id_list = []
token_prob_list = []
token_logprob_list = []

for idx, seq_group in enumerate(sampling_metadata.seq_groups):
token_id_list.append(token_ids[:, seq_group.sample_indices])
token_prob_list.append(probs[:, seq_group.sample_indices])
token_logprob_list.append(logprobs[:, seq_group.sample_indices])

outputs: List[Optional[SamplerOutput]] = []
for idx in range(len(sampling_metadata.seq_groups)):
outputs.append(
SamplerOutput(
outputs=None,
sampled_token_probs=token_prob_list[idx].squeeze(),
logprobs=token_logprob_list[idx].squeeze(),
sampled_token_ids=token_id_list[idx].squeeze(),
))

return outputs

def generate_proposals(
self,
previous_hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> List[SamplerOutput]:
return self.sample(
logits=self.compute_logits(
hidden_states=self.forward(previous_hidden_states),
sampling_metadata=sampling_metadata,
),
sampling_metadata=sampling_metadata,
)

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
name = name.replace("medusa_heads.", "")

if name == "token_map":
self.token_map = nn.Parameter(
loaded_weight, requires_grad=False).to(
device=self.lm_heads[0].weight.device)
continue

# Skip loading extra heads
if name not in params_dict:
continue

param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
127 changes: 127 additions & 0 deletions vllm/spec_decode/medusa_worker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
import weakref
from typing import List, Optional, Tuple

import torch

from vllm.model_executor import SamplingMetadata
from vllm.sequence import (ExecuteModelRequest, SamplerOutput,
SequenceGroupMetadata)
from vllm.spec_decode.interfaces import SpeculativeProposals
from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase
from vllm.spec_decode.top1_proposer import Top1Proposer
from vllm.worker.worker import Worker


class MedusaWorker(NonLLMProposerWorkerBase, Worker):
"""Worker for Medusa.
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

# Lazy initialization list.
self._proposer: Top1Proposer

def init_device(self):
super().init_device()

self._proposer = Top1Proposer(
weakref.proxy(self), # type: ignore[arg-type]
self.device,
self.vocab_size,
max_proposal_len=self.max_model_len,
)

def set_include_gpu_probs_tensor(self):
pass

@torch.inference_mode()
def sampler_output(
self,
execute_model_req: ExecuteModelRequest,
sample_len: int,
) -> Tuple[List[SamplerOutput], bool]:
"""Run the model forward pass to generate sample_len future tokens.
Returns the list of sampler output, one per layer, along with indicator
of whether torch tensor in sampler output need to be transposed in
latter sampler_output_to_torch logic.

For medusa worker, this indicator shall be False.
"""
self._raise_if_unsupported(execute_model_req)

seq_group_metadata_list = execute_model_req.seq_group_metadata_list

seq_lens, query_lens = self._prepare_input_tensors(
seq_group_metadata_list)

sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list, seq_lens, query_lens, self.device,
self.model_runner.pin_memory)

model_outputs = self.model_runner.model.generate_proposals(
previous_hidden_states=execute_model_req.previous_hidden_states.
hidden_states,
sampling_metadata=sampling_metadata)

return model_outputs, False

def _prepare_input_tensors(
self,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
) -> Tuple[List[int], List[int]]:
if not seq_group_metadata_list:
return [], []

seq_lens: List[int] = []
query_lens: List[int] = []

for seq_group_metadata in seq_group_metadata_list:
is_prompt = seq_group_metadata.is_prompt

for seq_data in seq_group_metadata.seq_data.values():
seq_data_len = seq_data.get_len()
if is_prompt:
context_len = seq_data.get_num_computed_tokens()
seq_len = min(
seq_data_len,
context_len + seq_group_metadata.token_chunk_size)
seq_lens.append(seq_len)
query_lens.append(seq_len - context_len)
else:
seq_lens.append(seq_data_len)
query_lens.append(1)

return seq_lens, query_lens

def get_spec_proposals(
self,
execute_model_req: ExecuteModelRequest,
) -> SpeculativeProposals:
"""Produce speculations given an input batch of sequences. The number of
speculative tokens per sequence is determined by max_proposal_len.
"""

return self._proposer.get_spec_proposals(execute_model_req)

def _raise_if_unsupported(
self,
execute_model_req: ExecuteModelRequest,
) -> None:
"""MedusaWorker does not yet implement support for cache swap
operations or beam search.
"""
if any([
execute_model_req.blocks_to_swap_in,
execute_model_req.blocks_to_swap_out,
execute_model_req.blocks_to_copy
]):
raise NotImplementedError(
"MedusaWorker does not support cache operations")

if any(
len(seq_group_metadata.seq_data.keys()) != 1
for seq_group_metadata in
execute_model_req.seq_group_metadata_list):
raise NotImplementedError(
"MedusaWorker does not support beam search.")
5 changes: 5 additions & 0 deletions vllm/spec_decode/spec_decode_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeScorer, SpeculativeScores)
from vllm.spec_decode.medusa_worker import MedusaWorker
from vllm.spec_decode.metrics import AsyncMetricsCollector
from vllm.spec_decode.mlp_speculator_worker import MLPSpeculatorWorker
from vllm.spec_decode.multi_step_worker import MultiStepWorker
Expand Down Expand Up @@ -107,6 +108,10 @@ def create_worker(
proposer_worker = NGramWorker(**draft_worker_kwargs)
proposer_worker.set_ngram_window_size(ngram_prompt_lookup_min,
ngram_prompt_lookup_max)
elif draft_worker_kwargs[
"model_config"].hf_config.model_type == "medusa":
proposer_worker = MedusaWorker(**draft_worker_kwargs)
disable_bonus_tokens = False
elif draft_worker_kwargs[
"model_config"].hf_config.model_type == "mlp_speculator":
proposer_worker = MLPSpeculatorWorker(**draft_worker_kwargs)
Expand Down
6 changes: 4 additions & 2 deletions vllm/transformers_utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
from vllm.envs import VLLM_USE_MODELSCOPE
from vllm.logger import init_logger
from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig,
JAISConfig, MLPSpeculatorConfig,
MPTConfig, RWConfig)
JAISConfig, MedusaConfig,
MLPSpeculatorConfig, MPTConfig,
RWConfig)

if VLLM_USE_MODELSCOPE:
from modelscope import AutoConfig
Expand All @@ -24,6 +25,7 @@
"RefinedWebModel": RWConfig, # For tiiuae/falcon-7b(-instruct)
"jais": JAISConfig,
"mlp_speculator": MLPSpeculatorConfig,
"medusa": MedusaConfig,
}

for name, cls in _CONFIG_REGISTRY.items():
abhigoyal1997 marked this conversation as resolved.
Show resolved Hide resolved
Expand Down
2 changes: 2 additions & 0 deletions vllm/transformers_utils/configs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# `FalconConfig` class from the official HuggingFace transformers library.
from vllm.transformers_utils.configs.falcon import RWConfig
from vllm.transformers_utils.configs.jais import JAISConfig
from vllm.transformers_utils.configs.medusa import MedusaConfig
from vllm.transformers_utils.configs.mlp_speculator import MLPSpeculatorConfig
from vllm.transformers_utils.configs.mpt import MPTConfig

Expand All @@ -14,5 +15,6 @@
"MPTConfig",
"RWConfig",
"JAISConfig",
"MedusaConfig",
"MLPSpeculatorConfig",
]
Loading
Loading