From ef2fba0017b044b55daecc78d464b37920e88ab3 Mon Sep 17 00:00:00 2001 From: Benson Ma Date: Wed, 5 Feb 2025 12:56:36 -0800 Subject: [PATCH] Re-organize SLL ops, pt 8 Summary: - Re-organize the remaining SLL triton ops Differential Revision: D68970862 --- .github/scripts/fbgemm_gpu_test.bash | 11 --- fbgemm_gpu/fbgemm_gpu/sll/__init__.py | 21 +---- fbgemm_gpu/fbgemm_gpu/sll/triton/__init__.py | 16 ++++ fbgemm_gpu/fbgemm_gpu/sll/triton/common.py | 22 +++++ ...agged_dense_elementwise_mul_jagged_out.py} | 82 ------------------- ...ton_jagged_self_substraction_jagged_out.py | 73 +++++++++++++++++ .../sll/array_jagged_bmm_jagged_out_test.py | 6 +- fbgemm_gpu/test/sll/common.py | 3 +- .../sll/dense_jagged_cat_jagged_out_test.py | 73 +++++++++++++++++ fbgemm_gpu/test/sll/jagged_dense_bmm_test.py | 5 +- .../sll/jagged_dense_elementwise_add_test.py | 2 +- ...d_dense_elementwise_mul_jagged_out_test.py | 9 +- .../sll/jagged_dense_flash_attention_test.py | 3 +- .../sll/jagged_flash_attention_basic_test.py | 1 + .../sll/jagged_jagged_bmm_jagged_out_test.py | 7 +- fbgemm_gpu/test/sll/jagged_jagged_bmm_test.py | 5 +- ...gged_self_substraction_jagged_out_test.py} | 57 +------------ fbgemm_gpu/test/sll/jagged_softmax_test.py | 6 +- 18 files changed, 206 insertions(+), 196 deletions(-) rename fbgemm_gpu/fbgemm_gpu/sll/{triton_sll.py => triton/triton_jagged_dense_elementwise_mul_jagged_out.py} (68%) create mode 100644 fbgemm_gpu/fbgemm_gpu/sll/triton/triton_jagged_self_substraction_jagged_out.py create mode 100644 fbgemm_gpu/test/sll/dense_jagged_cat_jagged_out_test.py rename fbgemm_gpu/test/sll/{triton_sll_test.py => jagged_self_substraction_jagged_out_test.py} (60%) diff --git a/.github/scripts/fbgemm_gpu_test.bash b/.github/scripts/fbgemm_gpu_test.bash index 0bd83f73ac..6a13733ce7 100644 --- a/.github/scripts/fbgemm_gpu_test.bash +++ b/.github/scripts/fbgemm_gpu_test.bash @@ -83,17 +83,6 @@ __configure_fbgemm_gpu_test_cpu () { # These tests have non-CPU operators referenced in @given ./uvm/copy_test.py ./uvm/uvm_test.py - ./sll/triton_sll_test.py - ./sll/array_jagged_bmm_jagged_out_test.py - ./sll/jagged_dense_elementwise_add_test.py - ./sll/jagged_flash_attention_basic_test.py - ./sll/jagged_jagged_bmm_jagged_out_test.py - ./sll/jagged_dense_flash_attention_test.py - ./sll/multi_head_jagged_flash_attention_test.py - ./sll/jagged_dense_bmm_test.py - ./sll/jagged_dense_elementwise_mul_jagged_out_test.py - ./sll/jagged_jagged_bmm_test.py - ./sll/jagged_softmax_test.py ) } diff --git a/fbgemm_gpu/fbgemm_gpu/sll/__init__.py b/fbgemm_gpu/fbgemm_gpu/sll/__init__.py index 271f5fa31e..bd89e4ff5f 100644 --- a/fbgemm_gpu/fbgemm_gpu/sll/__init__.py +++ b/fbgemm_gpu/fbgemm_gpu/sll/__init__.py @@ -33,11 +33,6 @@ meta_jagged_self_substraction_jagged_out, ) -from fbgemm_gpu.sll.triton_sll import ( # noqa F401 - jagged_dense_elementwise_mul_jagged_out, - triton_jagged_self_substraction_jagged_out, -) - from fbgemm_gpu.utils import TorchLibraryFragment lib = TorchLibraryFragment("fbgemm") @@ -262,25 +257,11 @@ }, } -# pyre-ignore[5] -sll_gpu_registrations = { - "sll_jagged_self_substraction_jagged_out": { - "CUDA": triton_jagged_self_substraction_jagged_out, - }, - "sll_jagged_dense_elementwise_mul_jagged_out": { - "CUDA": jagged_dense_elementwise_mul_jagged_out, - "AutogradCUDA": jagged_dense_elementwise_mul_jagged_out, - }, -} - for op_name, dispatches in sll_cpu_registrations.items(): lib.register(op_name, dispatches) if torch.cuda.is_available(): - from fbgemm_gpu.sll.triton import op_registrations - - for op_name, dispatches in op_registrations.items(): - lib.register(op_name, dispatches) + from fbgemm_gpu.sll.triton import op_registrations as sll_gpu_registrations for op_name, dispatches in sll_gpu_registrations.items(): lib.register(op_name, dispatches) diff --git a/fbgemm_gpu/fbgemm_gpu/sll/triton/__init__.py b/fbgemm_gpu/fbgemm_gpu/sll/triton/__init__.py index 5c5fa183d9..c3ee692ced 100644 --- a/fbgemm_gpu/fbgemm_gpu/sll/triton/__init__.py +++ b/fbgemm_gpu/fbgemm_gpu/sll/triton/__init__.py @@ -37,6 +37,11 @@ JaggedDenseAdd, # noqa F401 ) +from fbgemm_gpu.sll.triton.triton_jagged_dense_elementwise_mul_jagged_out import ( # noqa F401 + jagged_dense_elementwise_mul_jagged_out, + JaggedDenseElementwiseMul, # noqa F401 +) + from fbgemm_gpu.sll.triton.triton_jagged_dense_flash_attention import ( # noqa F401 jagged_dense_flash_attention, JaggedDenseFlashAttention, # noqa F401 @@ -47,6 +52,10 @@ JaggedFlashAttentionBasic, # noqa F401 ) +from fbgemm_gpu.sll.triton.triton_jagged_self_substraction_jagged_out import ( + triton_jagged_self_substraction_jagged_out, +) + from fbgemm_gpu.sll.triton.triton_jagged_softmax import ( # noqa F401 jagged2_softmax, Jagged2Softmax, # noqa F401 @@ -108,4 +117,11 @@ "CUDA": multi_head_jagged_flash_attention, "AutogradCUDA": multi_head_jagged_flash_attention, }, + "sll_jagged_self_substraction_jagged_out": { + "CUDA": triton_jagged_self_substraction_jagged_out, + }, + "sll_jagged_dense_elementwise_mul_jagged_out": { + "CUDA": jagged_dense_elementwise_mul_jagged_out, + "AutogradCUDA": jagged_dense_elementwise_mul_jagged_out, + }, } diff --git a/fbgemm_gpu/fbgemm_gpu/sll/triton/common.py b/fbgemm_gpu/fbgemm_gpu/sll/triton/common.py index d26c25c6a2..32f0827cc5 100644 --- a/fbgemm_gpu/fbgemm_gpu/sll/triton/common.py +++ b/fbgemm_gpu/fbgemm_gpu/sll/triton/common.py @@ -9,6 +9,28 @@ import torch +def next_power_of_two(N: int) -> int: + if N > 4096: + raise Exception(f"{N} is too large that is not supported yet") + + if N > 2048: + return 4096 + elif N > 1024: + return 2048 + elif N > 512: + return 1024 + elif N > 256: + return 512 + elif N > 128: + return 256 + elif N > 64: + return 128 + elif N > 32: + return 64 + else: + return 32 + + def expect_contiguous(x: torch.Tensor) -> torch.Tensor: if not x.is_contiguous(): return x.contiguous() diff --git a/fbgemm_gpu/fbgemm_gpu/sll/triton_sll.py b/fbgemm_gpu/fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_mul_jagged_out.py similarity index 68% rename from fbgemm_gpu/fbgemm_gpu/sll/triton_sll.py rename to fbgemm_gpu/fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_mul_jagged_out.py index a014c2dfb9..0468944e2f 100644 --- a/fbgemm_gpu/fbgemm_gpu/sll/triton_sll.py +++ b/fbgemm_gpu/fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_mul_jagged_out.py @@ -11,61 +11,6 @@ import triton.language as tl -def next_power_of_two(N: int) -> int: - if N > 4096: - raise Exception(f"{N} is too large that is not supported yet") - - if N > 2048: - return 4096 - elif N > 1024: - return 2048 - elif N > 512: - return 1024 - elif N > 256: - return 512 - elif N > 128: - return 256 - elif N > 64: - return 128 - elif N > 32: - return 64 - else: - return 32 - - -@triton.jit -def jagged_self_substraction_jagged_out_kernel( - a_ptr, # jagged - b_ptr, # jagged - a_offsets_ptr, - b_offsets_ptr, - max_seq_len, - BLOCK_SIZE: tl.constexpr, -): - pid_batch = tl.program_id(0) - pid_index = tl.program_id(1) - - a_offset = tl.load(a_offsets_ptr + pid_batch) - a_length = tl.load(a_offsets_ptr + pid_batch + 1) - a_offset - a_length = tl.minimum(a_length, max_seq_len + 1) - - if a_length <= 1: - return - - N = a_length - 1 - if pid_index >= N: - return - - a_cur = tl.load(a_ptr + a_offset + pid_index) - offs = tl.arange(0, BLOCK_SIZE) - mask = offs < N - a_row = tl.load(a_ptr + a_offset + offs + 1, mask=mask) - b = a_cur - a_row - - b_offset = tl.load(b_offsets_ptr + pid_batch) - tl.store(b_ptr + b_offset + pid_index * N + offs, b, mask=mask) - - @triton.jit def jagged_dense_elementwise_mul_jagged_out_kernel( a_ptr, # 1d jagged @@ -123,33 +68,6 @@ def jagged_dense_elementwise_mul_jagged_out_kernel( c_ptrs += BLOCK_N -def triton_jagged_self_substraction_jagged_out( - jagged_A: torch.Tensor, - offsets_a: torch.Tensor, - offsets_b: torch.Tensor, - max_seq_len, -) -> torch.Tensor: - B = offsets_a.size(0) - 1 - - jagged_B = torch.empty( - (int(offsets_b[-1].item())), device=jagged_A.device, dtype=jagged_A.dtype - ) - - BLOCK_SIZE = max(next_power_of_two(max_seq_len), 16) - grid = (B, max_seq_len) - - jagged_self_substraction_jagged_out_kernel[grid]( - jagged_A, - jagged_B, - offsets_a, - offsets_b, - max_seq_len, - BLOCK_SIZE, # pyre-fixme[6]: For 6th argument expected `constexpr` but got `int`. - ) - - return jagged_B - - def triton_jagged_dense_elementwise_mul_jagged_out( jagged_A, dense_B, diff --git a/fbgemm_gpu/fbgemm_gpu/sll/triton/triton_jagged_self_substraction_jagged_out.py b/fbgemm_gpu/fbgemm_gpu/sll/triton/triton_jagged_self_substraction_jagged_out.py new file mode 100644 index 0000000000..cdd6130507 --- /dev/null +++ b/fbgemm_gpu/fbgemm_gpu/sll/triton/triton_jagged_self_substraction_jagged_out.py @@ -0,0 +1,73 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +import torch +import triton +import triton.language as tl + +from .common import next_power_of_two + + +@triton.jit +def jagged_self_substraction_jagged_out_kernel( + a_ptr, # jagged + b_ptr, # jagged + a_offsets_ptr, + b_offsets_ptr, + max_seq_len, + BLOCK_SIZE: tl.constexpr, +): + pid_batch = tl.program_id(0) + pid_index = tl.program_id(1) + + a_offset = tl.load(a_offsets_ptr + pid_batch) + a_length = tl.load(a_offsets_ptr + pid_batch + 1) - a_offset + a_length = tl.minimum(a_length, max_seq_len + 1) + + if a_length <= 1: + return + + N = a_length - 1 + if pid_index >= N: + return + + a_cur = tl.load(a_ptr + a_offset + pid_index) + offs = tl.arange(0, BLOCK_SIZE) + mask = offs < N + a_row = tl.load(a_ptr + a_offset + offs + 1, mask=mask) + b = a_cur - a_row + + b_offset = tl.load(b_offsets_ptr + pid_batch) + tl.store(b_ptr + b_offset + pid_index * N + offs, b, mask=mask) + + +def triton_jagged_self_substraction_jagged_out( + jagged_A: torch.Tensor, + offsets_a: torch.Tensor, + offsets_b: torch.Tensor, + max_seq_len, +) -> torch.Tensor: + B = offsets_a.size(0) - 1 + + jagged_B = torch.empty( + (int(offsets_b[-1].item())), device=jagged_A.device, dtype=jagged_A.dtype + ) + + BLOCK_SIZE = max(next_power_of_two(max_seq_len), 16) + grid = (B, max_seq_len) + + jagged_self_substraction_jagged_out_kernel[grid]( + jagged_A, + jagged_B, + offsets_a, + offsets_b, + max_seq_len, + BLOCK_SIZE, + ) + + return jagged_B diff --git a/fbgemm_gpu/test/sll/array_jagged_bmm_jagged_out_test.py b/fbgemm_gpu/test/sll/array_jagged_bmm_jagged_out_test.py index e6ea137495..5d701d789a 100644 --- a/fbgemm_gpu/test/sll/array_jagged_bmm_jagged_out_test.py +++ b/fbgemm_gpu/test/sll/array_jagged_bmm_jagged_out_test.py @@ -31,7 +31,7 @@ class ArrayJaggedBmmJaggedTest(unittest.TestCase): ) @unittest.skipIf(*gpu_unavailable) @unittest.skipIf(*running_on_rocm) - @settings(deadline=20000) + @settings(deadline=30000) def test_triton_array_jagged_bmm_jagged_out( self, B: int, @@ -157,7 +157,7 @@ def ref_array_jagged_bmm_jagged_out( ) @unittest.skipIf(*gpu_unavailable) @unittest.skipIf(*running_on_rocm) - @settings(deadline=20000) + @settings(deadline=30000) def test_triton_array_jagged_bmm_jagged_out_with_grad( self, B: int, @@ -244,7 +244,7 @@ def test_triton_array_jagged_bmm_jagged_out_with_grad( ) @unittest.skipIf(*gpu_unavailable) @unittest.skipIf(*running_on_rocm) - @settings(deadline=20000) + @settings(deadline=30000) def test_triton_array_jagged_bmm_jagged_out_meta_backend( self, B: int, diff --git a/fbgemm_gpu/test/sll/common.py b/fbgemm_gpu/test/sll/common.py index 201913d3dc..c00f5e4ed1 100644 --- a/fbgemm_gpu/test/sll/common.py +++ b/fbgemm_gpu/test/sll/common.py @@ -8,8 +8,7 @@ # pyre-ignore-all-errors[56] import fbgemm_gpu -import fbgemm_gpu.sll.cpu_sll -import fbgemm_gpu.sll.triton_sll +import fbgemm_gpu.sll import torch # pyre-fixme[16]: Module `fbgemm_gpu` has no attribute `open_source`. diff --git a/fbgemm_gpu/test/sll/dense_jagged_cat_jagged_out_test.py b/fbgemm_gpu/test/sll/dense_jagged_cat_jagged_out_test.py new file mode 100644 index 0000000000..335589ede2 --- /dev/null +++ b/fbgemm_gpu/test/sll/dense_jagged_cat_jagged_out_test.py @@ -0,0 +1,73 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict +import unittest + +import fbgemm_gpu.sll # noqa F401 +import torch +from hypothesis import given, settings, strategies as st + +from .common import open_source # noqa + +if open_source: + # pyre-ignore[21] + from test_utils import gpu_unavailable, running_on_rocm +else: + from fbgemm_gpu.test.test_utils import gpu_unavailable, running_on_rocm + + +class DenseJaggedCatJaggedOutTest(unittest.TestCase): + @unittest.skipIf(*gpu_unavailable) + @unittest.skipIf(*running_on_rocm) + # pyre-fixme[56]: Pyre was not able to infer the type of argument + @given( + B=st.integers(10, 512), + max_L=st.integers(1, 200), + device_type=st.sampled_from(["cpu", "cuda"]), + enable_pt2=st.sampled_from([True, False]), + ) + @settings(deadline=None) + def test_dense_jagged_cat_jagged_out( + self, + B: int, + max_L: int, + device_type: str, + enable_pt2: bool, + ) -> None: + device = torch.device(device_type) + lengths = torch.randint(0, max_L + 1, (B,), device=device) + offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(lengths) + c_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(lengths + 1) + a = torch.randint(0, 100000000, (B,), device=device) + b = torch.randint(0, 100000000, (int(lengths.sum().item()),), device=device) + + ref = torch.cat( + [ + ( + torch.cat((a[i : i + 1], b[offsets[i] : offsets[i + 1]]), dim=-1) + if lengths[i] > 0 + else a[i : i + 1] + ) + for i in range(B) + ], + dim=-1, + ) + + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. + def model(a, b, offsets, max_L): + return torch.ops.fbgemm.sll_dense_jagged_cat_jagged_out( + a, b, offsets, max_L + ) + + if enable_pt2: + model = torch.compile(model) + + ret, c_offsets_computed = model(a, b, offsets, max_L) + + assert torch.allclose(ref, ret) + assert torch.equal(c_offsets, c_offsets_computed) diff --git a/fbgemm_gpu/test/sll/jagged_dense_bmm_test.py b/fbgemm_gpu/test/sll/jagged_dense_bmm_test.py index 044154507c..030e8e71ed 100644 --- a/fbgemm_gpu/test/sll/jagged_dense_bmm_test.py +++ b/fbgemm_gpu/test/sll/jagged_dense_bmm_test.py @@ -7,10 +7,7 @@ # pyre-strict import unittest -import fbgemm_gpu -import fbgemm_gpu.sll.cpu_sll # noqa F401 -import fbgemm_gpu.sll.triton_sll # noqa F401 - +import fbgemm_gpu.sll # noqa F401 import torch from hypothesis import given, settings, strategies as st diff --git a/fbgemm_gpu/test/sll/jagged_dense_elementwise_add_test.py b/fbgemm_gpu/test/sll/jagged_dense_elementwise_add_test.py index 75850e9170..98b2faa243 100644 --- a/fbgemm_gpu/test/sll/jagged_dense_elementwise_add_test.py +++ b/fbgemm_gpu/test/sll/jagged_dense_elementwise_add_test.py @@ -34,7 +34,7 @@ class JaggedDenseElementwiseAddTest(unittest.TestCase): device_type=st.sampled_from(["cpu", "cuda"]), ) @unittest.skipIf(*gpu_unavailable) - @settings(deadline=20000) + @settings(deadline=30000) def test_triton_jagged_dense_add( self, B: int, D: int, N: int, use_fbgemm_kernel: bool, device_type: str ) -> None: diff --git a/fbgemm_gpu/test/sll/jagged_dense_elementwise_mul_jagged_out_test.py b/fbgemm_gpu/test/sll/jagged_dense_elementwise_mul_jagged_out_test.py index 6f4ece87a6..35093a321e 100644 --- a/fbgemm_gpu/test/sll/jagged_dense_elementwise_mul_jagged_out_test.py +++ b/fbgemm_gpu/test/sll/jagged_dense_elementwise_mul_jagged_out_test.py @@ -7,10 +7,7 @@ # pyre-strict import unittest -import fbgemm_gpu -import fbgemm_gpu.sll.cpu_sll # noqa F401 -import fbgemm_gpu.sll.triton_sll # noqa F401 - +import fbgemm_gpu.sll # noqa F401 import torch from hypothesis import given, settings, strategies as st @@ -31,7 +28,7 @@ class JaggedDenseElementwiseMulJaggedOutTest(unittest.TestCase): B=st.integers(10, 512), L=st.integers(1, 200), ) - @settings(deadline=20000) + @settings(deadline=30000) def test_jagged_dense_elementwise_mul_jagged_out( self, B: int, @@ -164,7 +161,7 @@ def test_jagged_dense_elementwise_mul_jagged_out_with_grad( L=st.integers(1, 200), device_type=st.sampled_from(["meta"]), ) - @settings(deadline=20000) + @settings(deadline=30000) def test_jagged_dense_elementwise_mul_jagged_out_meta_backend( self, B: int, diff --git a/fbgemm_gpu/test/sll/jagged_dense_flash_attention_test.py b/fbgemm_gpu/test/sll/jagged_dense_flash_attention_test.py index ec0a5cbbba..d31b69a979 100644 --- a/fbgemm_gpu/test/sll/jagged_dense_flash_attention_test.py +++ b/fbgemm_gpu/test/sll/jagged_dense_flash_attention_test.py @@ -8,6 +8,7 @@ import unittest +import fbgemm_gpu.sll # noqa F401 import hypothesis.strategies as st import torch from hypothesis import given, settings @@ -33,7 +34,7 @@ class JaggedDenseFlashAttentionTest(unittest.TestCase): ) @unittest.skipIf(*gpu_unavailable) @unittest.skipIf(*running_on_rocm) - @settings(deadline=20000) + @settings(deadline=30000) def test_jagged_dense_flash_attention( self, B: int, diff --git a/fbgemm_gpu/test/sll/jagged_flash_attention_basic_test.py b/fbgemm_gpu/test/sll/jagged_flash_attention_basic_test.py index 8353c5a842..1363aeb4c3 100644 --- a/fbgemm_gpu/test/sll/jagged_flash_attention_basic_test.py +++ b/fbgemm_gpu/test/sll/jagged_flash_attention_basic_test.py @@ -8,6 +8,7 @@ import unittest +import fbgemm_gpu.sll # noqa F401 import hypothesis.strategies as st import torch from hypothesis import given, settings diff --git a/fbgemm_gpu/test/sll/jagged_jagged_bmm_jagged_out_test.py b/fbgemm_gpu/test/sll/jagged_jagged_bmm_jagged_out_test.py index 035e731d9b..a9b447a08f 100644 --- a/fbgemm_gpu/test/sll/jagged_jagged_bmm_jagged_out_test.py +++ b/fbgemm_gpu/test/sll/jagged_jagged_bmm_jagged_out_test.py @@ -8,13 +8,12 @@ import unittest +import fbgemm_gpu.sll # noqa F401 import hypothesis.strategies as st import torch from fbgemm_gpu.sll.triton import triton_jagged_jagged_bmm_jagged_out from hypothesis import given, settings -from .common import open_source # noqa - class JaggedJaggedBmmJaggedOutTest(unittest.TestCase): # pyre-fixme[56]: Pyre was not able to infer the type of argument @@ -23,7 +22,7 @@ class JaggedJaggedBmmJaggedOutTest(unittest.TestCase): max_L=st.integers(1, 200), K=st.integers(1, 100), ) - @settings(deadline=20000) + @settings(deadline=30000) def test_triton_jagged_jagged_bmm_jagged_out( self, B: int, @@ -97,7 +96,7 @@ def ref_jagged_jagged_bmm_jagged_out( K=st.integers(1, 100), device_type=st.sampled_from(["meta"]), ) - @settings(deadline=20000) + @settings(deadline=30000) def test_triton_jagged_jagged_bmm_jagged_out_meta_backend( self, B: int, diff --git a/fbgemm_gpu/test/sll/jagged_jagged_bmm_test.py b/fbgemm_gpu/test/sll/jagged_jagged_bmm_test.py index dd7b365218..6aa89ada85 100644 --- a/fbgemm_gpu/test/sll/jagged_jagged_bmm_test.py +++ b/fbgemm_gpu/test/sll/jagged_jagged_bmm_test.py @@ -7,9 +7,8 @@ # pyre-strict import unittest -import fbgemm_gpu -import fbgemm_gpu.sll.cpu_sll # noqa F401 -import fbgemm_gpu.sll.triton_sll # noqa F401 +import fbgemm_gpu # noqa F401 +import fbgemm_gpu.sll # noqa F401 import torch from hypothesis import given, settings, strategies as st diff --git a/fbgemm_gpu/test/sll/triton_sll_test.py b/fbgemm_gpu/test/sll/jagged_self_substraction_jagged_out_test.py similarity index 60% rename from fbgemm_gpu/test/sll/triton_sll_test.py rename to fbgemm_gpu/test/sll/jagged_self_substraction_jagged_out_test.py index 1b0b1c461c..1b358f0b8d 100644 --- a/fbgemm_gpu/test/sll/triton_sll_test.py +++ b/fbgemm_gpu/test/sll/jagged_self_substraction_jagged_out_test.py @@ -7,9 +7,7 @@ # pyre-strict import unittest -import fbgemm_gpu -import fbgemm_gpu.sll.cpu_sll # noqa F401 -import fbgemm_gpu.sll.triton_sll # noqa F401 +import fbgemm_gpu.sll # noqa F401 import torch from hypothesis import given, settings, strategies as st from torch.testing._internal.optests import opcheck @@ -23,58 +21,7 @@ from fbgemm_gpu.test.test_utils import gpu_unavailable, running_on_rocm -class TritonSLLTest(unittest.TestCase): - @unittest.skipIf(*gpu_unavailable) - @unittest.skipIf(*running_on_rocm) - # pyre-fixme[56]: Pyre was not able to infer the type of argument - @given( - B=st.integers(10, 512), - max_L=st.integers(1, 200), - device_type=st.sampled_from(["cpu", "cuda"]), - enable_pt2=st.sampled_from([True, False]), - ) - @settings(deadline=None) - def test_dense_jagged_cat_jagged_out( - self, - B: int, - max_L: int, - device_type: str, - enable_pt2: bool, - ) -> None: - device = torch.device(device_type) - lengths = torch.randint(0, max_L + 1, (B,), device=device) - offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(lengths) - c_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(lengths + 1) - a = torch.randint(0, 100000000, (B,), device=device) - b = torch.randint(0, 100000000, (int(lengths.sum().item()),), device=device) - - ref = torch.cat( - [ - ( - torch.cat((a[i : i + 1], b[offsets[i] : offsets[i + 1]]), dim=-1) - if lengths[i] > 0 - else a[i : i + 1] - ) - for i in range(B) - ], - dim=-1, - ) - - # pyre-fixme[3]: Return type must be annotated. - # pyre-fixme[2]: Parameter must be annotated. - def model(a, b, offsets, max_L): - return torch.ops.fbgemm.sll_dense_jagged_cat_jagged_out( - a, b, offsets, max_L - ) - - if enable_pt2: - model = torch.compile(model) - - ret, c_offsets_computed = model(a, b, offsets, max_L) - - assert torch.allclose(ref, ret) - assert torch.equal(c_offsets, c_offsets_computed) - +class JaggedSelfSubtractionJaggedOutTest(unittest.TestCase): @unittest.skipIf(*gpu_unavailable) @unittest.skipIf(*running_on_rocm) # pyre-fixme[56]: Pyre was not able to infer the type of argument diff --git a/fbgemm_gpu/test/sll/jagged_softmax_test.py b/fbgemm_gpu/test/sll/jagged_softmax_test.py index 52539c86b2..79ebe24b22 100644 --- a/fbgemm_gpu/test/sll/jagged_softmax_test.py +++ b/fbgemm_gpu/test/sll/jagged_softmax_test.py @@ -7,10 +7,8 @@ # pyre-strict import unittest -import fbgemm_gpu -import fbgemm_gpu.sll.cpu_sll # noqa F401 -import fbgemm_gpu.sll.triton_sll # noqa F401 - +import fbgemm_gpu # noqa F401 +import fbgemm_gpu.sll # noqa F401 import torch from hypothesis import given, settings, strategies as st