Skip to content

Commit

Permalink
[Intel GPU] Fix xpu decode input (vllm-project#9145)
Browse files Browse the repository at this point in the history
Signed-off-by: Amit Garg <[email protected]>
  • Loading branch information
jikunshang authored and garg-amit committed Oct 28, 2024
1 parent 22ef180 commit 1cd50f1
Showing 1 changed file with 14 additions and 7 deletions.
21 changes: 14 additions & 7 deletions vllm/worker/xpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from vllm.distributed import get_pp_group
from vllm.inputs import INPUT_REGISTRY, InputRegistry
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadataCache
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader import get_model
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
Expand Down Expand Up @@ -136,7 +137,7 @@ def build(self) -> ModelInputForXPU:
(input_tokens, input_positions,
attn_metadata) = self._prepare_decode(
self.seq_group_metadata_list)
seq_lens = []
seq_lens = None
multi_modal_kwargs = None

return self.model_input_cls(
Expand Down Expand Up @@ -406,6 +407,10 @@ def __init__(
# Lazy initialization.
self.model: nn.Module # Set after init_Model

self.sampling_metadata_cache: SamplingMetadataCache = \
SamplingMetadataCache() \
if self.parallel_config.pipeline_parallel_size == 1 else None

def load_model(self) -> None:
with DeviceMemoryProfiler() as m:
self.model = get_model(
Expand Down Expand Up @@ -540,12 +545,14 @@ def prepare_model_input(
seq_group_metadata_list, finished_requests_ids)
# Sampling metadata is only required for the final pp group
generators = self.get_generators(finished_requests_ids)
sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list,
model_input.seq_lens,
model_input.query_lens,
self.device,
pin_memory=False,
generators=generators)
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list,
model_input.seq_lens,
model_input.query_lens,
self.device,
pin_memory=False,
generators=generators,
cache=self.sampling_metadata_cache)

return dataclasses.replace(model_input,
sampling_metadata=sampling_metadata,
Expand Down

0 comments on commit 1cd50f1

Please sign in to comment.