diff --git a/tests/kernels/attention/flash_infer_kv_fp8_test.py b/tests/kernels/attention/flash_infer_kv_fp8_test.py index 7200dbf0..b64e6bb5 100644 --- a/tests/kernels/attention/flash_infer_kv_fp8_test.py +++ b/tests/kernels/attention/flash_infer_kv_fp8_test.py @@ -7,7 +7,7 @@ import scalellm._C.kernels as kernels # type: ignore -@pytest.mark.parametrize("seq_lens", [[(1, 100), (15, 15), (111, 234), (1000, 10000)]]) +@pytest.mark.parametrize("seq_lens", [[(1, 100)], [(100, 100)], [(1, 100), (15, 15), (111, 234), (1000, 10000)]]) @pytest.mark.parametrize("num_heads", [(8, 8), (8, 4), (8, 2), (8, 1)]) @pytest.mark.parametrize("head_size", [64, 128, 256]) @pytest.mark.parametrize("n_blocks", [100]) diff --git a/tests/kernels/attention/flash_infer_test.py b/tests/kernels/attention/flash_infer_test.py index 8738a3ce..d57495b3 100644 --- a/tests/kernels/attention/flash_infer_test.py +++ b/tests/kernels/attention/flash_infer_test.py @@ -7,7 +7,7 @@ import scalellm._C.kernels as kernels # type: ignore -@pytest.mark.parametrize("seq_lens", [[(1, 100), (15, 15), (111, 234), (1000, 10000)]]) +@pytest.mark.parametrize("seq_lens", [[(1, 100)], [(100, 100)], [(1, 100), (15, 15), (111, 234), (1000, 10000)]]) @pytest.mark.parametrize("num_heads", [(8, 8), (8, 4), (8, 2), (8, 1)]) @pytest.mark.parametrize("head_size", [64, 128, 256]) @pytest.mark.parametrize("n_blocks", [100])