Skip to content

Commit

Permalink
Merge pull request #1260 from IntelPython/feature/syntax_sugar_call_k…
Browse files Browse the repository at this point in the history
…ernel

Add call_kernel for the old kernel
  • Loading branch information
Diptorup Deb authored Dec 28, 2023
2 parents bac450f + be8cfcb commit 96c3d44
Show file tree
Hide file tree
Showing 19 changed files with 66 additions and 39 deletions.
3 changes: 2 additions & 1 deletion numba_dpex/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from numba import __version__ as numba_version
from numba.np.ufunc.decorators import Vectorize

from numba_dpex.core.kernel_interface.launcher import call_kernel
from numba_dpex.vectorizers import Vectorize as DpexVectorize

from .numba_patches import patch_arrayexpr_tree_to_ir, patch_is_ufunc
Expand Down Expand Up @@ -145,4 +146,4 @@ def parse_sem_version(version_string: str) -> Tuple[int, int, int]:
__version__ = get_versions()["version"]
del get_versions

__all__ = types.__all__ + ["Range", "NdRange"]
__all__ = types.__all__ + ["Range", "NdRange", "call_kernel"]
18 changes: 18 additions & 0 deletions numba_dpex/core/kernel_interface/launcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
"""Launcher package to provide the same way of calling kernel as experimental
one."""


def call_kernel(kernel_fn, index_space, *kernel_args) -> None:
"""Syntax sugar for calling kernel the same way as experimental one.
It is a temporary glue for the experimental kernel migration.
Args:
kernel_fn (numba_dpex.experimental.KernelDispatcher): A
numba_dpex.kernel decorated function that is compiled to a
KernelDispatcher by numba_dpex.
index_space (Range | NdRange): A numba_dpex.Range or numba_dpex.NdRange
type object that specifies the index space for the kernel.
kernel_args : List of objects that are passed to the numba_dpex.kernel
decorated function.
"""
kernel_fn[index_space](*kernel_args)
6 changes: 3 additions & 3 deletions numba_dpex/tests/kernel_tests/test_atomic_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def f(a):
def test_kernel_atomic_simple(input_arrays, kernel_result_pair):
a, dtype = input_arrays()
kernel, expected = kernel_result_pair
kernel[dpex.Range(global_size)](a)
dpex.call_kernel(kernel, dpex.Range(global_size), a)
assert a[0] == expected


Expand Down Expand Up @@ -96,7 +96,7 @@ def test_kernel_atomic_local(input_arrays, return_list_of_op):
op_type, expected = return_list_of_op
f = get_func_local(op_type, dtype)
kernel = dpex.kernel(f)
kernel[dpex.NdRange(dpex.Range(N), dpex.Range(N))](a)
dpex.call_kernel(kernel, dpex.NdRange(dpex.Range(N), dpex.Range(N)), a)
assert a[0] == expected


Expand Down Expand Up @@ -134,7 +134,7 @@ def test_kernel_atomic_multi_dim(
dim = return_list_of_dim
kernel = get_kernel_multi_dim(op_type, len(dim))
a = np.zeros(dim, dtype=return_dtype)
kernel[dpex.Range(global_size)](a)
dpex.call_kernel(kernel, dpex.Range(global_size), a)
assert a[0] == expected


Expand Down
8 changes: 5 additions & 3 deletions numba_dpex/tests/kernel_tests/test_barrier.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def twice(A):
orig = dpt.asnumpy(arr)
global_size = (N,)
local_size = (N // 2,)
twice[NdRange(global_size, local_size)](arr)
dpex.call_kernel(twice, NdRange(global_size, local_size), arr)
after = dpt.asnumpy(arr)
# The computation is correct?
np.testing.assert_allclose(orig * 2, after)
Expand All @@ -43,7 +43,7 @@ def twice(A):
N = 256
arr = dpt.arange(N, dtype=dpt.float32)
orig = dpt.asnumpy(arr)
twice[Range(N)](arr)
dpex.call_kernel(twice, Range(N), arr)
after = dpt.asnumpy(arr)
# The computation is correct?
np.testing.assert_allclose(orig * 2, after)
Expand All @@ -66,7 +66,9 @@ def reverse_array(A):

arr = dpt.arange(blocksize, dtype=dpt.float32)
orig = dpt.asnumpy(arr)
reverse_array[NdRange(Range(blocksize), Range(blocksize))](arr)
dpex.call_kernel(
reverse_array, NdRange(Range(blocksize), Range(blocksize)), arr
)
after = dpt.asnumpy(arr)
expected = orig[::-1] + orig
np.testing.assert_allclose(expected, after)
4 changes: 2 additions & 2 deletions numba_dpex/tests/kernel_tests/test_complex_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def test_numeric_kernel_arg_complex_scalar(input_arrays):
a, b, _ = input_arrays
s = a.dtype.type(2 + 1j)

kernel_scalar[dpex.Range(N)](a, b, s)
dpex.call_kernel(kernel_scalar, dpex.Range(N), a, b, s)

nb = dpnp.asnumpy(b)
nexpected = numpy.full_like(nb, fill_value=2 + 1j)
Expand All @@ -65,7 +65,7 @@ def test_numeric_kernel_arg_complex_array(input_arrays):

a, b, c = input_arrays

kernel_array[dpex.Range(N)](a, b, c)
dpex.call_kernel(kernel_array, dpex.Range(N), a, b, c)

nb = dpnp.asnumpy(b)
nexpected = numpy.full_like(nb, fill_value=0 + 0j)
Expand Down
4 changes: 3 additions & 1 deletion numba_dpex/tests/kernel_tests/test_dpnp_ndarray_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,9 @@ def test_setting_private_from_dpnp_ndarray():
global_range = ndpx.Range(N_POINTS // N_POINTS_PER_WORK_ITEM)
local_range = ndpx.Range(LOCAL_SIZE)
try:
_kernel[ndpx.NdRange(global_range, local_range)](COEFFICIENTS)
ndpx.call_kernel(
_kernel, ndpx.NdRange(global_range, local_range), COEFFICIENTS
)
except Exception as e:
assert (
False
Expand Down
2 changes: 1 addition & 1 deletion numba_dpex/tests/kernel_tests/test_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,6 @@ def test_func_call_from_kernel():
a = dpnp.ones(1024)
b = dpnp.ones(1024)

f[dpex.Range(1024)](a, b)
dpex.call_kernel(f, dpex.Range(1024), a, b)
nb = dpnp.asnumpy(b)
assert numpy.all(nb == 2)
12 changes: 6 additions & 6 deletions numba_dpex/tests/kernel_tests/test_func_specialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def kernel_function(a, b):
a = dpnp.ones(N)
b = dpnp.ones(N)

k[dpex.Range(N)](a, b)
dpex.call_kernel(k, dpex.Range(N), a, b)

assert np.array_equal(dpnp.asnumpy(b), dpnp.asnumpy(a) + 1)

Expand All @@ -56,7 +56,7 @@ def kernel_function(a, b):
a = dpnp.ones(N, dtype=dpnp.int32)
b = dpnp.ones(N, dtype=dpnp.int32)

k[dpex.Range(N)](a, b)
dpex.call_kernel(k, dpex.Range(N), a, b)

assert np.array_equal(dpnp.asnumpy(b), dpnp.asnumpy(a) + 1)

Expand All @@ -65,7 +65,7 @@ def kernel_function(a, b):
b = dpnp.ones(N, dtype=dpnp.int64)

with pytest.raises(Exception) as e:
k[dpex.Range(N)](a, b)
dpex.call_kernel(k, dpex.Range(N), a, b)

assert " >>> <unknown function>(int64)" in e.value.args[0]

Expand All @@ -86,15 +86,15 @@ def kernel_function(a, b):
a = dpnp.ones(N, dtype=dpnp.int32)
b = dpnp.ones(N, dtype=dpnp.int32)

k[dpex.Range(N)](a, b)
dpex.call_kernel(k, dpex.Range(N), a, b)

assert np.array_equal(dpnp.asnumpy(b), dpnp.asnumpy(a) + 1)

# Test with float32, should work
a = dpnp.ones(N, dtype=dpnp.float32)
b = dpnp.ones(N, dtype=dpnp.float32)

k[dpex.Range(N)](a, b)
dpex.call_kernel(k, dpex.Range(N), a, b)

assert np.array_equal(dpnp.asnumpy(b), dpnp.asnumpy(a) + 1)

Expand All @@ -103,6 +103,6 @@ def kernel_function(a, b):
b = dpnp.ones(N, dtype=dpnp.int64)

with pytest.raises(Exception) as e:
k[dpex.Range(N)](a, b)
dpex.call_kernel(k, dpex.Range(N), a, b)

assert " >>> <unknown function>(int64)" in e.value.args[0]
2 changes: 1 addition & 1 deletion numba_dpex/tests/kernel_tests/test_invalid_kernel_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,4 @@ def test_passing_numpy_arrays_as_kernel_args():
c = numpy.zeros(N)

with pytest.raises(UnsupportedKernelArgumentError):
vecadd_kernel[dpex.Range(N)](a, b, c)
dpex.call_kernel(vecadd_kernel, dpex.Range(N), a, b, c)
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,4 @@ def test_return(sig):

with pytest.raises(dpex.core.exceptions.KernelHasReturnValueError):
kernel_fn = dpex.kernel(sig)(f)
kernel_fn[dpex.Range(a.size)](a)
dpex.call_kernel(kernel_fn, dpex.Range(a.size), a)
8 changes: 6 additions & 2 deletions numba_dpex/tests/kernel_tests/test_kernel_specialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,9 @@ def test_missing_specialization_error():
c = dpt.zeros(1024, dtype=dpt.int32)

with pytest.raises(MissingSpecializationError):
specialized_kernel1(data_parallel_sum)[Range(1024)](a, b, c)
dpex.call_kernel(
specialized_kernel1(data_parallel_sum), Range(1024), a, b, c
)


def test_execution_of_specialized_kernel():
Expand All @@ -69,7 +71,9 @@ def test_execution_of_specialized_kernel():
b = dpt.ones(1024, dtype=dpt.int64)
c = dpt.zeros(1024, dtype=dpt.int64)

specialized_kernel1(data_parallel_sum)[Range(1024)](a, b, c)
dpex.call_kernel(
specialized_kernel1(data_parallel_sum), Range(1024), a, b, c
)

npc = dpt.asnumpy(c)
import numpy as np
Expand Down
2 changes: 1 addition & 1 deletion numba_dpex/tests/kernel_tests/test_math_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def f(a, b):
i = dpex.get_global_id(0)
b[i] = uop(a[i])

f[dpex.Range(a.size)](a, b)
dpex.call_kernel(f, dpex.Range(a.size), a, b)

expected = dpnp_uop(a)

Expand Down
2 changes: 1 addition & 1 deletion numba_dpex/tests/kernel_tests/test_ndrange_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,4 @@ def test_ndrange_config_error(error, ranges):

with pytest.raises(error):
range = NdRange(ranges[0], ranges[1])
kernel_vector_sum[range](a, b, c)
ndpx.call_kernel(kernel_vector_sum, range, a, b, c)
8 changes: 4 additions & 4 deletions numba_dpex/tests/kernel_tests/test_print.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def print_scalar_val(s):

a = input_arrays

print_scalar_val[dpex.Range(1)](a)
dpex.call_kernel(print_scalar_val, dpex.Range(1), a)
captured = capfd.readouterr()
assert "printing ... 10" in captured.out

Expand All @@ -60,7 +60,7 @@ def print_scalar_val(s):

a = input_arrays

print_scalar_val[dpex.Range(1)](a)
dpex.call_kernel(print_scalar_val, dpex.Range(1), a)
captured = capfd.readouterr()

assert "10" in captured.out
Expand All @@ -85,7 +85,7 @@ def print_string(a):
a = input_arrays

with pytest.raises(LoweringError):
print_string[dpex.Range(1)](a)
dpex.call_kernel(print_string, dpex.Range(1), a)


@skip_on_gpu
Expand All @@ -101,4 +101,4 @@ def print_string(a):
a = input_arrays

with pytest.raises(LoweringError):
print_string[dpex.Range(1)](a)
dpex.call_kernel(print_string, dpex.Range(1), a)
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def kernel_with_private_memory_allocation(A):
def test_private_memory_allocation():
N = 64
arr = dpnp.zeros(N, dtype=dpnp.float32)
kernel_with_private_memory_allocation[dpex.Range(N)](arr)
dpex.call_kernel(kernel_with_private_memory_allocation, dpex.Range(N), arr)

nparr = dpnp.asnumpy(arr)

Expand Down
6 changes: 3 additions & 3 deletions numba_dpex/tests/kernel_tests/test_scalar_arg_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def test_numeric_kernel_arg_types1(input_arrays):
a, b = input_arrays
s = a.dtype.type(2)

scaling_kernel[dpex.Range(N)](a, b, s)
dpex.call_kernel(scaling_kernel, dpex.Range(N), a, b, s)

nb = dpnp.asnumpy(b)
nexpected = numpy.full_like(nb, fill_value=2)
Expand All @@ -65,14 +65,14 @@ def test_bool_kernel_arg_type(input_arrays):
"""
a, b = input_arrays

kernel_with_bool_arg[dpex.Range(a.size)](a, b, True)
dpex.call_kernel(kernel_with_bool_arg, dpex.Range(a.size), a, b, True)

nb = dpnp.asnumpy(b)
nexpected_true = numpy.full_like(nb, fill_value=2)

assert numpy.allclose(nb, nexpected_true)

kernel_with_bool_arg[dpex.Range(a.size)](a, b, False)
dpex.call_kernel(kernel_with_bool_arg, dpex.Range(a.size), a, b, False)

nb = dpnp.asnumpy(b)
nexpected_false = numpy.zeros_like(nb)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def test_kernel_valid_usm_obj(dtype):
C = DuckUSMArray(shape=buffC.shape, dtype=dtype, host_buffer=buffC)

try:
vecadd[dpex.Range(N)](A, B, C)
dpex.call_kernel(vecadd, dpex.Range(N), A, B, C)
except Exception:
pytest.fail(
"Could not pass Python object with sycl_usm_array_interface"
Expand All @@ -112,4 +112,4 @@ def test_kernel_invalid_usm_obj(dtype):
C = PseudoDuckUSMArray()

with pytest.raises(Exception):
vecadd[dpex.Range(N)](A, B, C)
dpex.call_kernel(vecadd, dpex.Range(N), A, B, C)
2 changes: 1 addition & 1 deletion numba_dpex/tests/kernel_tests/test_usm_ndarray_interop.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def data_parallel_sum(a, b, c):

c = dpt.empty_like(a)

data_parallel_sum[dpex.Range(N, N)](a, b, c)
dpex.call_kernel(data_parallel_sum, dpex.Range(N, N), a, b, c)

na = dpt.asnumpy(a)
nb = dpt.asnumpy(b)
Expand Down
10 changes: 5 additions & 5 deletions numba_dpex/tests/misc/test_warnings.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def test_opt_warning():
config.DPEX_OPT = 3

with pytest.warns(UserWarning):
foo[dpex.Range(10)](dpnp.arange(10))
dpex.call_kernel(foo, dpex.Range(10), dpnp.arange(10))

config.DPEX_OPT = bkp

Expand All @@ -27,7 +27,7 @@ def test_inline_threshold_eq_3_warning():
config.INLINE_THRESHOLD = 3

with pytest.warns(UserWarning):
foo[dpex.Range(10)](dpnp.arange(10))
dpex.call_kernel(foo, dpex.Range(10), dpnp.arange(10))

config.INLINE_THRESHOLD = bkp

Expand All @@ -37,7 +37,7 @@ def test_inline_threshold_negative_val_warning_():
config.INLINE_THRESHOLD = -1

with pytest.warns(UserWarning):
foo[dpex.Range(10)](dpnp.arange(10))
dpex.call_kernel(foo, dpex.Range(10), dpnp.arange(10))

config.INLINE_THRESHOLD = bkp

Expand All @@ -47,12 +47,12 @@ def test_inline_threshold_gt_3_warning():
config.INLINE_THRESHOLD = 4

with pytest.warns(UserWarning):
foo[dpex.Range(10)](dpnp.arange(10))
dpex.call_kernel(foo, dpex.Range(10), dpnp.arange(10))

config.INLINE_THRESHOLD = bkp


def test_no_warning():
with warnings.catch_warnings():
warnings.simplefilter("error")
foo[dpex.Range(10)](dpnp.arange(10))
dpex.call_kernel(foo, dpex.Range(10), dpnp.arange(10))

0 comments on commit 96c3d44

Please sign in to comment.