Skip to content

Commit

Permalink
fixing errors after merging main
Browse files Browse the repository at this point in the history
  • Loading branch information
abhigoyal1997 committed May 23, 2024
1 parent 4028273 commit b16b541
Show file tree
Hide file tree
Showing 8 changed files with 41 additions and 44 deletions.
8 changes: 0 additions & 8 deletions vllm/executor/gpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,6 @@ def execute_model(
self, execute_model_req: ExecuteModelRequest
) -> List[Union[SamplerOutput, PoolerOutput]]:
output = self.driver_worker.execute_model(execute_model_req)

if not isinstance(output[0], (SamplerOutput, PoolerOutput)):
output = [sampler_output for sampler_output, _ in output]

return output

def add_lora(self, lora_request: LoRARequest) -> bool:
Expand Down Expand Up @@ -120,8 +116,4 @@ async def execute_model_async(
) -> List[Union[SamplerOutput, PoolerOutput]]:
output = await make_async(self.driver_worker.execute_model
)(execute_model_req=execute_model_req, )

if not isinstance(output[0], (SamplerOutput, PoolerOutput)):
output = [sampler_output for sampler_output, _ in output]

return output
7 changes: 4 additions & 3 deletions vllm/spec_decode/batch_expansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def score_proposals(
self,
execute_model_req: ExecuteModelRequest,
proposals: SpeculativeProposals,
) -> Tuple[SpeculativeScores, Optional[ExtraTensorData]]:
) -> SpeculativeScores:
"""Score the proposed tokens via the scorer model.
This converts each input sequence to a set of k+1 target sequences. The
Expand Down Expand Up @@ -82,7 +82,7 @@ def score_proposals(
execute_model_req=execute_model_req.clone(
seq_group_metadata_list=target_seq_group_metadata_list, ))
assert len(target_sampler_output) == 1, "expected single-step output"
target_sampler_output, _ = target_sampler_output[0]
target_sampler_output = target_sampler_output[0]

(all_tokens, all_probs, spec_logprobs,
all_extra_output_data) = self._contract_batch(
Expand All @@ -99,7 +99,8 @@ def score_proposals(
probs=all_probs,
token_ids=all_tokens,
logprobs=spec_logprobs,
), all_extra_output_data
extra_tensor_data=all_extra_output_data,
)

def _expand_batch(
self,
Expand Down
10 changes: 7 additions & 3 deletions vllm/spec_decode/interfaces.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Optional, Tuple
from typing import Optional

import torch

Expand Down Expand Up @@ -47,10 +47,14 @@ class SpeculativeScores:
# tokens and also non-speculative normal decoding.
token_ids: torch.Tensor

# Extra data output by the model
extra_tensor_data: Optional[ExtraTensorData]

def __repr__(self):
return (f"SpeculativeScores("
f"probs={self.probs.shape}, "
f"token_ids={self.token_ids.shape})")
f"token_ids={self.token_ids.shape}, "
f"extra_tensor_data={self.extra_tensor_data})")


class SpeculativeProposer(ABC):
Expand All @@ -70,5 +74,5 @@ def score_proposals(
self,
execute_model_req: ExecuteModelRequest,
proposals: SpeculativeProposals,
) -> Tuple[SpeculativeScores, Optional[ExtraTensorData]]:
) -> SpeculativeScores:
raise NotImplementedError
2 changes: 1 addition & 1 deletion vllm/spec_decode/multi_head_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def sampler_output(
execute_model_req: ExecuteModelRequest,
sample_len: int,
) -> Tuple[List[SamplerOutput], bool]:
model_outputs, _ = super().execute_model(
model_outputs = super().execute_model(
execute_model_req=execute_model_req)[0]
return model_outputs, False

Expand Down
20 changes: 9 additions & 11 deletions vllm/spec_decode/spec_decode_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,9 @@ def execute_model(
self._maybe_disable_speculative_tokens(
disable_all_speculation, execute_model_req.seq_group_metadata_list)

execute_model_req.extra_outputs = self.proposer_worker.extra_inputs\
.copy()

# If no spec tokens, call the proposer and scorer workers normally.
# Used for prefill.
if num_lookahead_slots == 0 or len(
Expand Down Expand Up @@ -327,15 +330,12 @@ def _run_no_spec(self, execute_model_req: ExecuteModelRequest,
not called, meaning that the kv-cache in proposer for requests is not
updated, so they cannot enable spec decode in the rest decoding.
"""
execute_model_req.extra_outputs = self.proposer_worker.extra_inputs\
.copy()
model_outputs = self.scorer_worker.execute_model(execute_model_req)
assert len(model_outputs) == 1

sampler_output, prefill_extra_tensor_data = model_outputs[0]
sampler_output = self.scorer_worker.execute_model(execute_model_req)
assert len(sampler_output) == 1
sampler_output = sampler_output[0]

execute_model_req.extra_outputs.clear()
execute_model_req.extra_inputs = prefill_extra_tensor_data
execute_model_req.extra_inputs = sampler_output.extra_tensor_data

if not skip_proposer:
self.proposer_worker.execute_model(execute_model_req)
Expand Down Expand Up @@ -389,9 +389,7 @@ def _run_speculative_decoding_step(
# Generate proposals using draft worker.
proposals = self.proposer_worker.get_spec_proposals(execute_model_req)

execute_model_req.extra_outputs = self.proposer_worker.extra_inputs\
.copy()
proposal_scores, extra_tensor_data = self.scorer.score_proposals(
proposal_scores = self.scorer.score_proposals(
execute_model_req,
proposals,
)
Expand All @@ -405,7 +403,7 @@ def _run_speculative_decoding_step(
accepted_token_ids,
target_logprobs=target_logprobs,
k=execute_model_req.num_lookahead_slots,
extra_tensor_data=extra_tensor_data)
extra_tensor_data=proposal_scores.extra_tensor_data)

@nvtx_range("spec_decode_worker._verify_tokens")
def _verify_tokens(
Expand Down
4 changes: 2 additions & 2 deletions vllm/transformers_utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
"medusa": MedusaConfig,
}

with contextlib.suppress(ValueError):
for name, cls in _CONFIG_REGISTRY.items():
for name, cls in _CONFIG_REGISTRY.items():
with contextlib.suppress(ValueError):
AutoConfig.register(name, cls)


Expand Down
25 changes: 13 additions & 12 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -722,7 +722,7 @@ def execute_model(
kv_caches: List[torch.Tensor],
extra_inputs: ExtraTensorData = None,
extra_outputs: Optional[Set[str]] = None,
) -> Tuple[Optional[SamplerOutput], ExtraTensorData]:
) -> Optional[SamplerOutput]:
(input_tokens, input_positions, attn_metadata, sampling_metadata,
lora_requests, lora_mapping, multi_modal_input, prepared_extra_inputs
) = self.prepare_input_tensors(seq_group_metadata_list)
Expand Down Expand Up @@ -771,19 +771,19 @@ def execute_model(
logits = self.model.compute_logits(hidden_states, sampling_metadata)

# Only perform sampling in the driver worker.
if self.is_driver_worker:
# Sample the next token.
output = self.model.sample(
logits=logits,
sampling_metadata=sampling_metadata,
)
if not self.is_driver_worker:
return None

# Sample the next token.
output = self.model.sample(
logits=logits,
sampling_metadata=sampling_metadata,
)

if extra_outputs:
sampled_extra_tensor_data = extra_tensor_data.index_select(
0, sampling_metadata.selected_token_indices)
else:
output = None

if extra_outputs:
if prefill_meta is not None:
for k in extra_tensor_data:
extra_tensor_data[k] = extra_tensor_data[k].roll(shifts=1,
Expand All @@ -794,12 +794,13 @@ def execute_model(
if output is not None:
_move_extra_tensor_data_to_seq_outputs(
output, sampled_extra_tensor_data, sampling_metadata)

output.extra_tensor_data = extra_tensor_data
else:
extra_tensor_data.clear()
if output is not None:
output.extra_tensor_data = sampled_extra_tensor_data

return output, extra_tensor_data
return output

@torch.inference_mode()
def profile_run(self) -> None:
Expand Down
9 changes: 5 additions & 4 deletions vllm/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@
set_custom_all_reduce)
from vllm.lora.request import LoRARequest
from vllm.model_executor import set_random_seed
from vllm.sequence import (ExecuteModelRequest, ExtraTensorData, PoolerOutput,
SamplerOutput)
from vllm.sequence import ExecuteModelRequest, PoolerOutput, SamplerOutput
from vllm.worker.cache_engine import CacheEngine
from vllm.worker.embedding_model_runner import EmbeddingModelRunner
from vllm.worker.model_runner import ModelRunner
Expand Down Expand Up @@ -273,8 +272,10 @@ def execute_model(
output = self.model_runner.execute_model(
seq_group_metadata_list,
self.gpu_cache,
extra_inputs=execute_model_req.extra_inputs,
extra_outputs=execute_model_req.extra_outputs)
extra_inputs=None
if execute_model_req is None else execute_model_req.extra_inputs,
extra_outputs=None
if execute_model_req is None else execute_model_req.extra_outputs)

# Worker only supports single-step execution. Wrap the output in a list
# to conform to interface.
Expand Down

0 comments on commit b16b541

Please sign in to comment.