Skip to content

Commit

Permalink
fixes for passing format.sh
Browse files Browse the repository at this point in the history
  • Loading branch information
abhigoyal1997 committed May 22, 2024
1 parent 6dd8d26 commit d794d13
Show file tree
Hide file tree
Showing 11 changed files with 48 additions and 43 deletions.
7 changes: 4 additions & 3 deletions vllm/config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import enum
import json
from dataclasses import dataclass, field, fields
from typing import TYPE_CHECKING, ClassVar, Dict, List, Optional, Set, Tuple, Union
from typing import (TYPE_CHECKING, ClassVar, Dict, List, Optional, Set, Tuple,
Union)

import torch
from transformers import PretrainedConfig
Expand Down Expand Up @@ -98,7 +99,7 @@ def __init__(
max_logprobs: int = 5,
skip_tokenizer_init: bool = False,
served_model_name: Optional[Union[str, List[str]]] = None,
extra_inputs: Set[str] = set(),
extra_inputs: Optional[Set[str]] = None,
) -> None:
self.model = model
self.tokenizer = tokenizer
Expand Down Expand Up @@ -132,7 +133,7 @@ def __init__(

self.extra_inputs: Dict[str, Tuple[Tuple[int],
Optional[torch.dtype]]] = {}
if "hidden_states" in extra_inputs:
if extra_inputs and "hidden_states" in extra_inputs:
self.extra_inputs["hidden_states"] = ((
self.hf_config.hidden_size, ), None)

Expand Down
7 changes: 3 additions & 4 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,9 +488,8 @@ def add_cli_args(
'--extra-inputs-for-draft-model',
type=nullable_str,
default=EngineArgs.extra_inputs_for_draft_model,
help=
'Extra model inputs used by draft model. These should come as outputs from the target model.'
)
help='Extra model inputs used by draft model.'
'These should come as outputs from the target model.')

parser.add_argument(
'--num-speculative-tokens',
Expand Down Expand Up @@ -595,7 +594,7 @@ def create_engine_config(self, ) -> EngineConfig:

try:
extra_inputs = set(self.extra_inputs_for_draft_model.split(","))
except:
except Exception:
extra_inputs = set()

speculative_config = SpeculativeConfig.maybe_create_spec_config(
Expand Down
6 changes: 2 additions & 4 deletions vllm/executor/gpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,7 @@ def execute_model(
) -> List[Union[SamplerOutput, PoolerOutput]]:
output = self.driver_worker.execute_model(execute_model_req)

if not (isinstance(output[0], SamplerOutput)
or isinstance(output[0], PoolerOutput)):
if not isinstance(output[0], (SamplerOutput, PoolerOutput)):
output = [sampler_output for sampler_output, _ in output]

return output
Expand Down Expand Up @@ -122,8 +121,7 @@ async def execute_model_async(
output = await make_async(self.driver_worker.execute_model
)(execute_model_req=execute_model_req, )

if not (isinstance(output[0], SamplerOutput)
or isinstance(output[0], PoolerOutput)):
if not isinstance(output[0], (SamplerOutput, PoolerOutput)):
output = [sampler_output for sampler_output, _ in output]

return output
7 changes: 5 additions & 2 deletions vllm/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Union

import torch

from vllm.block import LogicalTokenBlock
from vllm.lora.request import LoRARequest
from vllm.pooling_params import PoolingParams
Expand Down Expand Up @@ -96,10 +97,12 @@ def stack(
data: List[Optional["ExtraTensorData"]],
dim: int = 0,
) -> Optional["ExtraTensorData"]:
if len(data) == 0: return None
if len(data) == 0:
return None

for d in data:
if d is None: return None
if d is None:
return None

assert isinstance(data[0], ExtraTensorData)

Expand Down
27 changes: 15 additions & 12 deletions vllm/spec_decode/batch_expansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,15 +84,16 @@ def score_proposals(
assert len(target_sampler_output) == 1, "expected single-step output"
target_sampler_output, _ = target_sampler_output[0]

all_tokens, all_probs, spec_logprobs, all_extra_output_data = self._contract_batch(
contracted_bs=len(execute_model_req.seq_group_metadata_list),
target_sampler_output=target_sampler_output,
proposals=proposals,
num_scoring_tokens=num_scoring_tokens,
non_spec_indices=non_spec_indices,
spec_indices=spec_indices,
k=execute_model_req.num_lookahead_slots,
)
(all_tokens, all_probs, spec_logprobs,
all_extra_output_data) = self._contract_batch(
contracted_bs=len(execute_model_req.seq_group_metadata_list),
target_sampler_output=target_sampler_output,
proposals=proposals,
num_scoring_tokens=num_scoring_tokens,
non_spec_indices=non_spec_indices,
spec_indices=spec_indices,
k=execute_model_req.num_lookahead_slots,
)

return SpeculativeScores(
probs=all_probs,
Expand Down Expand Up @@ -217,7 +218,8 @@ def _contract_batch(
all_probs[non_spec_indices, :1, :] = non_spec_target_probs
all_logprobs[non_spec_indices, :1, :] = non_spec_target_logprobs

if all_extra_output_data and non_spec_target_extra_output_data is not None:
if all_extra_output_data and \
non_spec_target_extra_output_data is not None:
for k in all_extra_output_data:
all_extra_output_data[k][
non_spec_indices, :
Expand Down Expand Up @@ -382,8 +384,9 @@ def _split_scoring_output(
if sampler_output.extra_tensor_data is None:
spec_extra_output_data, no_spec_extra_output_data = (None, None)
else:
spec_extra_output_data, no_spec_extra_output_data = sampler_output.extra_tensor_data.split(
split_sizes)
spec_extra_output_data, no_spec_extra_output_data = sampler_output\
.extra_tensor_data\
.split(split_sizes)

# Convert scores to tensors.
sampler_output.sampled_token_probs = spec_probs
Expand Down
10 changes: 5 additions & 5 deletions vllm/spec_decode/spec_decode_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeScorer, SpeculativeScores)
from vllm.spec_decode.metrics import AsyncMetricsCollector
from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.spec_decode.multi_head_worker import MultiHeadWorker
from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.spec_decode.ngram_worker import NGramWorker
from vllm.spec_decode.util import (create_sequence_group_output,
get_all_num_logprobs, get_all_seq_ids,
Expand Down Expand Up @@ -337,8 +337,8 @@ def _run_no_spec(self, execute_model_req: ExecuteModelRequest,
not called, meaning that the kv-cache in proposer for requests is not
updated, so they cannot enable spec decode in the rest decoding.
"""
execute_model_req.extra_outputs = self.proposer_worker.extra_inputs.copy(
)
execute_model_req.extra_outputs = self.proposer_worker.extra_inputs\
.copy()
model_outputs = self.scorer_worker.execute_model(execute_model_req)
assert len(model_outputs) == 1

Expand Down Expand Up @@ -392,8 +392,8 @@ def _run_speculative_decoding_step(
# Generate proposals using draft worker.
proposals = self.proposer_worker.get_spec_proposals(execute_model_req)

execute_model_req.extra_outputs = self.proposer_worker.extra_inputs.copy(
)
execute_model_req.extra_outputs = self.proposer_worker.extra_inputs\
.copy()
proposal_scores, extra_tensor_data = self.scorer.score_proposals(
execute_model_req,
proposals,
Expand Down
3 changes: 2 additions & 1 deletion vllm/spec_decode/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,8 @@ def sampler_output_to_torch(
sampled_extra_output_data[k] = sampled_extra_output_data[
k].transpose(0, 1)

return sampled_token_ids, sampled_token_probs, sampled_token_logprobs, sampled_extra_output_data
return (sampled_token_ids, sampled_token_probs, sampled_token_logprobs,
sampled_extra_output_data)


def maybe_mock_device_tensors(sampler_output: SamplerOutput, batch_size: int,
Expand Down
11 changes: 5 additions & 6 deletions vllm/transformers_utils/config.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import contextlib
from typing import Dict, Optional

from transformers import AutoConfig, PretrainedConfig

from vllm.logger import init_logger
from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig,
JAISConfig, MPTConfig, RWConfig,
MedusaConfig)
JAISConfig, MedusaConfig,
MPTConfig, RWConfig)

logger = init_logger(__name__)

Expand All @@ -21,11 +22,9 @@
"medusa": MedusaConfig,
}

for name, cls in _CONFIG_REGISTRY.items():
try:
with contextlib.suppress(ValueError):
for name, cls in _CONFIG_REGISTRY.items():
AutoConfig.register(name, cls)
except:
pass


def get_config(model: str,
Expand Down
2 changes: 1 addition & 1 deletion vllm/transformers_utils/configs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
# `FalconConfig` class from the official HuggingFace transformers library.
from vllm.transformers_utils.configs.falcon import RWConfig
from vllm.transformers_utils.configs.jais import JAISConfig
from vllm.transformers_utils.configs.mpt import MPTConfig
from vllm.transformers_utils.configs.medusa import MedusaConfig
from vllm.transformers_utils.configs.mpt import MPTConfig

__all__ = [
"ChatGLMConfig",
Expand Down
3 changes: 2 additions & 1 deletion vllm/worker/embedding_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
from vllm.lora.request import LoRARequest
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.pooling_params import PoolingParams
from vllm.sequence import ExtraTensorData, PoolerOutput, SequenceData, SequenceGroupMetadata
from vllm.sequence import (ExtraTensorData, PoolerOutput, SequenceData,
SequenceGroupMetadata)
from vllm.worker.model_runner import ModelRunner

logger = init_logger(__name__)
Expand Down
8 changes: 4 additions & 4 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,8 +543,8 @@ def _prepare_model_input(
dtype=torch.long,
device=self.device)

extra_inputs_tensor = None if extra_inputs is None else ExtraTensorData.stack(
extra_inputs)
extra_inputs_tensor = None if extra_inputs is None else \
ExtraTensorData.stack(extra_inputs)

if self.attn_backend.get_name() == "flashinfer":
if not hasattr(self, "flashinfer_workspace_buffer"):
Expand Down Expand Up @@ -713,7 +713,7 @@ def execute_model(
seq_group_metadata_list: List[SequenceGroupMetadata],
kv_caches: List[torch.Tensor],
extra_inputs: ExtraTensorData = None,
extra_outputs: Set[str] = set(),
extra_outputs: Optional[Set[str]] = None,
) -> Tuple[Optional[SamplerOutput], ExtraTensorData]:
(input_tokens, input_positions, attn_metadata, sampling_metadata,
lora_requests, lora_mapping, multi_modal_input, prepared_extra_inputs
Expand Down Expand Up @@ -756,7 +756,7 @@ def execute_model(

hidden_states = model_executable(**execute_model_kwargs)

if "hidden_states" in extra_outputs:
if extra_outputs and "hidden_states" in extra_outputs:
extra_tensor_data["hidden_states"] = hidden_states

# Compute the logits.
Expand Down

0 comments on commit d794d13

Please sign in to comment.