Skip to content

Commit

Permalink
[Intel GPU] Fix xpu decode input (#9145)
Browse files Browse the repository at this point in the history
  • Loading branch information
jikunshang authored Oct 8, 2024
1 parent 04c12f8 commit 80b57f0
Showing 1 changed file with 14 additions and 7 deletions.
21 changes: 14 additions & 7 deletions vllm/worker/xpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from vllm.distributed import get_pp_group
from vllm.inputs import INPUT_REGISTRY, InputRegistry
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadataCache
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader import get_model
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
Expand Down Expand Up @@ -136,7 +137,7 @@ def build(self) -> ModelInputForXPU:
(input_tokens, input_positions,
attn_metadata) = self._prepare_decode(
self.seq_group_metadata_list)
seq_lens = []
seq_lens = None
multi_modal_kwargs = None

return self.model_input_cls(
Expand Down Expand Up @@ -390,6 +391,10 @@ def __init__(
# Lazy initialization.
self.model: nn.Module # Set after init_Model

self.sampling_metadata_cache: SamplingMetadataCache = \
SamplingMetadataCache() \
if self.parallel_config.pipeline_parallel_size == 1 else None

def load_model(self) -> None:
with DeviceMemoryProfiler() as m:
self.model = get_model(
Expand Down Expand Up @@ -524,12 +529,14 @@ def prepare_model_input(
seq_group_metadata_list, finished_requests_ids)
# Sampling metadata is only required for the final pp group
generators = self.get_generators(finished_requests_ids)
sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list,
model_input.seq_lens,
model_input.query_lens,
self.device,
pin_memory=False,
generators=generators)
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list,
model_input.seq_lens,
model_input.query_lens,
self.device,
pin_memory=False,
generators=generators,
cache=self.sampling_metadata_cache)

return dataclasses.replace(model_input,
sampling_metadata=sampling_metadata,
Expand Down

0 comments on commit 80b57f0

Please sign in to comment.