Skip to content

Commit

Permalink
Comments & minor changes (vllm-project#8)
Browse files Browse the repository at this point in the history
  • Loading branch information
gc-fu authored Oct 24, 2023
1 parent 1ab029d commit a8561b8
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
1 change: 1 addition & 0 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,7 @@ async def generate(
arrival_time = time.monotonic()

try:
# print("In generate-sampling_params" + str(sampling_params))
stream = await self.add_request(request_id,
prompt,
sampling_params,
Expand Down
9 changes: 4 additions & 5 deletions vllm/model_executor/models/bigdl_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from vllm.transformers_utils.tokenizer import (detokenize_incrementally,
get_tokenizer)
from vllm.model_executor.quantization_utils import QuantizationConfig
from vllm.sequence import SamplerOutput, SequenceOutputs
from vllm.sequence import SamplerOutput, SequenceOutputs, SequenceGroupMetadata
import math

import pdb
Expand Down Expand Up @@ -55,16 +55,14 @@ def decode(self, generated_ids: List[int]) -> str:
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
)

# TODO(gc): fix this Optional problem
def forward(
self, seq_group_meta_data_lists, kv_cache: Optional = None
self, seq_group_meta_data_lists: List[SequenceGroupMetadata], kv_cache: Optional = None
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
kv_cache_0 = self.model.config.num_hidden_layers
kv_cache_1 = 2
bigdl_kv_cache = [[torch.tensor([], device=self.device, dtype = self.dtype) for _ in range(kv_cache_1)] for _ in range(kv_cache_0)]
seq_len = len(seq_group_meta_data_lists)
# for i in range(seq_len):
# if kv_cache.get(i) is None:
# kv_cache[i] = bigdl_kv_cache[:]

bigdl_input_ids = []
bigdl_position_ids = []
Expand All @@ -84,6 +82,7 @@ def forward(
bigdl_input_ids.append(cur_seq_input_ids)

bigdl_sampling_params[seq_id] = seq_group_meta_data.sampling_params
# print("sampling params for seq " + str(seq_id) + " is " + str(seq_group_meta_data.sampling_params))

context_len = seq_data.get_len()
bigdl_position_ids.append(range(context_len))
Expand Down

0 comments on commit a8561b8

Please sign in to comment.