Skip to content

Commit

Permalink
Add call_kernel for old kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
ZzEeKkAa committed Dec 27, 2023
1 parent bac450f commit 8d33921
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 4 deletions.
3 changes: 2 additions & 1 deletion numba_dpex/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from numba import __version__ as numba_version
from numba.np.ufunc.decorators import Vectorize

from numba_dpex.core.kernel_interface.launcher import call_kernel
from numba_dpex.vectorizers import Vectorize as DpexVectorize

from .numba_patches import patch_arrayexpr_tree_to_ir, patch_is_ufunc
Expand Down Expand Up @@ -145,4 +146,4 @@ def parse_sem_version(version_string: str) -> Tuple[int, int, int]:
__version__ = get_versions()["version"]
del get_versions

__all__ = types.__all__ + ["Range", "NdRange"]
__all__ = types.__all__ + ["Range", "NdRange", "call_kernel"]
18 changes: 18 additions & 0 deletions numba_dpex/core/kernel_interface/launcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
"""Launcher package to provide the same way of calling kernel as experimental
one."""


def call_kernel(kernel_fn, index_space, *kernel_args) -> None:
"""Syntax sugar for calling kernel the same way as experimental one.
It is a temporary glue for the experimental kernel migration.
Args:
kernel_fn (numba_dpex.experimental.KernelDispatcher): A
numba_dpex.kernel decorated function that is compiled to a
KernelDispatcher by numba_dpex.
index_space (Range | NdRange): A numba_dpex.Range or numba_dpex.NdRange
type object that specifies the index space for the kernel.
kernel_args : List of objects that are passed to the numba_dpex.kernel
decorated function.
"""
kernel_fn[index_space](*kernel_args)
6 changes: 3 additions & 3 deletions numba_dpex/tests/kernel_tests/test_atomic_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def f(a):
def test_kernel_atomic_simple(input_arrays, kernel_result_pair):
a, dtype = input_arrays()
kernel, expected = kernel_result_pair
kernel[dpex.Range(global_size)](a)
dpex.call_kernel(kernel, dpex.Range(global_size), a)
assert a[0] == expected


Expand Down Expand Up @@ -96,7 +96,7 @@ def test_kernel_atomic_local(input_arrays, return_list_of_op):
op_type, expected = return_list_of_op
f = get_func_local(op_type, dtype)
kernel = dpex.kernel(f)
kernel[dpex.NdRange(dpex.Range(N), dpex.Range(N))](a)
dpex.call_kernel(kernel, dpex.NdRange(dpex.Range(N), dpex.Range(N)), a)
assert a[0] == expected


Expand Down Expand Up @@ -134,7 +134,7 @@ def test_kernel_atomic_multi_dim(
dim = return_list_of_dim
kernel = get_kernel_multi_dim(op_type, len(dim))
a = np.zeros(dim, dtype=return_dtype)
kernel[dpex.Range(global_size)](a)
dpex.call_kernel(kernel, dpex.Range(global_size), a)
assert a[0] == expected


Expand Down

0 comments on commit 8d33921

Please sign in to comment.