Skip to content

Commit

Permalink
ut: add more tests for different warp layout (#340)
Browse files Browse the repository at this point in the history
  • Loading branch information
guocuimi authored Oct 18, 2024
1 parent 7838ca5 commit 0876f8a
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion tests/kernels/attention/flash_infer_kv_fp8_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import scalellm._C.kernels as kernels # type: ignore


@pytest.mark.parametrize("seq_lens", [[(1, 100), (15, 15), (111, 234), (1000, 10000)]])
@pytest.mark.parametrize("seq_lens", [[(1, 100)], [(100, 100)], [(1, 100), (15, 15), (111, 234), (1000, 10000)]])
@pytest.mark.parametrize("num_heads", [(8, 8), (8, 4), (8, 2), (8, 1)])
@pytest.mark.parametrize("head_size", [64, 128, 256])
@pytest.mark.parametrize("n_blocks", [100])
Expand Down
2 changes: 1 addition & 1 deletion tests/kernels/attention/flash_infer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import scalellm._C.kernels as kernels # type: ignore


@pytest.mark.parametrize("seq_lens", [[(1, 100), (15, 15), (111, 234), (1000, 10000)]])
@pytest.mark.parametrize("seq_lens", [[(1, 100)], [(100, 100)], [(1, 100), (15, 15), (111, 234), (1000, 10000)]])
@pytest.mark.parametrize("num_heads", [(8, 8), (8, 4), (8, 2), (8, 1)])
@pytest.mark.parametrize("head_size", [64, 128, 256])
@pytest.mark.parametrize("n_blocks", [100])
Expand Down

0 comments on commit 0876f8a

Please sign in to comment.