diff --git a/test/srt/test_engine_token_ids.py b/test/srt/test_engine_token_ids.py index 9dfba0782db..4dee24edc9d 100644 --- a/test/srt/test_engine_token_ids.py +++ b/test/srt/test_engine_token_ids.py @@ -19,20 +19,23 @@ def test_token_ids_in_generate(self): "The capital of France is", "The future of AI is", ] + sampling_params = {"temperature": 0, "top_p": 0.95} outputs = llm.generate(prompts, sampling_params) for prompt, output in zip(prompts, outputs): - # SGLang's input_ids has a start token, so we remove it for comparison. - deocode_input = tokenizer.decode(output["input_ids"][1:]) - assert ( - deocode_input in prompt + deocode_input = tokenizer.decode( + output["input_ids"], skip_special_tokens=True + ) + assert (deocode_input in prompt) or ( + prompt in deocode_input ), f"Decode input: {deocode_input} mismatch for: {prompt}" - # SGLang's output_ids does not have a start token. - deocode_output = tokenizer.decode(output["output_ids"]) - assert ( - deocode_output in output["text"] + deocode_output = tokenizer.decode( + output["output_ids"], skip_special_tokens=True + ) + assert (deocode_output in output["text"]) or ( + output["text"] in deocode_output ), f"Decode output: {deocode_output} mismatch for: {output['text']}" llm.shutdown()