Skip to content

Commit

Permalink
[ci] Fix sampler tests (vllm-project#11922)
Browse files Browse the repository at this point in the history
Signed-off-by: youkaichao <[email protected]>
  • Loading branch information
youkaichao authored and frreiss committed Jan 10, 2025
1 parent 77af093 commit 2e0fc0e
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 2 deletions.
1 change: 1 addition & 0 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ steps:
- vllm/model_executor/layers
- vllm/sampling_metadata.py
- tests/samplers
- tests/conftest.py
commands:
- pytest -v -s samplers
- VLLM_USE_FLASHINFER_SAMPLER=1 pytest -v -s samplers
Expand Down
11 changes: 9 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,13 @@
init_distributed_environment,
initialize_model_parallel)
from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt,
to_enc_dec_tuple_list, zip_enc_dec_prompts)
TokensPrompt, to_enc_dec_tuple_list,
zip_enc_dec_prompts)
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
from vllm.sampling_params import BeamSearchParams
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, cuda_device_count_stateless,
identity)
identity, is_list_of)

logger = init_logger(__name__)

Expand Down Expand Up @@ -886,6 +887,12 @@ def generate_beam_search(
beam_width: int,
max_tokens: int,
) -> List[Tuple[List[List[int]], List[str]]]:
if is_list_of(prompts, str, check="all"):
prompts = [TextPrompt(prompt=prompt) for prompt in prompts]
else:
prompts = [
TokensPrompt(prompt_token_ids=tokens) for tokens in prompts
]
outputs = self.model.beam_search(
prompts,
BeamSearchParams(beam_width=beam_width, max_tokens=max_tokens))
Expand Down

0 comments on commit 2e0fc0e

Please sign in to comment.