Skip to content

Commit

Permalink
Use item object for experimental kernels in tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ZzEeKkAa authored and Diptorup Deb committed Feb 23, 2024
1 parent 99c4aa1 commit 90ab8ae
Show file tree
Hide file tree
Showing 10 changed files with 84 additions and 76 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@
import dpctl
from numba.core import types

import numba_dpex as dpex
from numba_dpex import DpctlSyclQueue, DpnpNdArray
from numba_dpex import experimental as dpex_exp
from numba_dpex import int64
from numba_dpex.core.types.kernel_api.index_space_ids import ItemType
from numba_dpex.kernel_api import Item


def kernel_func(a, b, c):
i = dpex.get_global_id(0)
def kernel_func(item: Item, a, b, c):
i = item.get_id(0)
c[i] = a[i] + b[i]


Expand All @@ -36,7 +37,7 @@ def test_codegen_with_max_inline_threshold():

queue_ty = DpctlSyclQueue(dpctl.SyclQueue())
i64arr_ty = DpnpNdArray(ndim=1, dtype=int64, layout="C", queue=queue_ty)
kernel_sig = types.void(i64arr_ty, i64arr_ty, i64arr_ty)
kernel_sig = types.void(ItemType(1), i64arr_ty, i64arr_ty, i64arr_ty)

disp = dpex_exp.kernel(inline_threshold=3)(kernel_func)
disp.compile(kernel_sig)
Expand All @@ -57,7 +58,7 @@ def test_codegen_without_max_inline_threshold():

queue_ty = DpctlSyclQueue(dpctl.SyclQueue())
i64arr_ty = DpnpNdArray(ndim=1, dtype=int64, layout="C", queue=queue_ty)
kernel_sig = types.void(i64arr_ty, i64arr_ty, i64arr_ty)
kernel_sig = types.void(ItemType(1), i64arr_ty, i64arr_ty, i64arr_ty)

disp = dpex_exp.kernel(kernel_func)
disp.compile(kernel_sig)
Expand All @@ -70,4 +71,4 @@ def test_codegen_without_max_inline_threshold():
if not f.is_declaration:
count_of_non_declaration_type_functions += 1

assert count_of_non_declaration_type_functions == 2
assert count_of_non_declaration_type_functions == 3
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@
import pytest
from numba.core.errors import TypingError

import numba_dpex as dpex
import numba_dpex.experimental as dpex_exp
from numba_dpex.kernel_api import AtomicRef
from numba_dpex.kernel_api import AtomicRef, Item, Range
from numba_dpex.tests._helper import get_all_dtypes

list_of_supported_dtypes = get_all_dtypes(
Expand Down Expand Up @@ -45,8 +44,8 @@ def test_fetch_phi_fn(input_arrays, ref_index, fetch_phi_fn):
"""A test for all fetch_phi atomic functions."""

@dpex_exp.kernel
def _kernel(a, b, ref_index):
i = dpex.get_global_id(0)
def _kernel(item: Item, a, b, ref_index):
i = item.get_id(0)
v = AtomicRef(b, index=ref_index)
getattr(v, fetch_phi_fn)(a[i])

Expand All @@ -60,9 +59,9 @@ def _kernel(a, b, ref_index):
# fetch_and, fetch_or, fetch_xor accept only int arguments.
# test for TypingError when float arguments are passed.
with pytest.raises(TypingError):
dpex_exp.call_kernel(_kernel, dpex.Range(10), a, b, ref_index)
dpex_exp.call_kernel(_kernel, Range(10), a, b, ref_index)
else:
dpex_exp.call_kernel(_kernel, dpex.Range(10), a, b, ref_index)
dpex_exp.call_kernel(_kernel, Range(10), a, b, ref_index)
# Verify that `a` accumulated at b[ref_index] by kernel
# matches the `a` accumulated at b[ref_index+1] using Python
for i in range(a.size):
Expand All @@ -76,8 +75,8 @@ def test_fetch_phi_retval(fetch_phi_fn):
"""A test for all fetch_phi atomic functions."""

@dpex_exp.kernel
def _kernel(a, b, c):
i = dpex.get_global_id(0)
def _kernel(item: Item, a, b, c):
i = item.get_id(0)
v = AtomicRef(b, index=i)
c[i] = getattr(v, fetch_phi_fn)(a[i])

Expand All @@ -89,7 +88,7 @@ def _kernel(a, b, c):
b_copy = dpnp.copy(b)
c_copy = dpnp.copy(c)

dpex_exp.call_kernel(_kernel, dpex.Range(10), a, b, c)
dpex_exp.call_kernel(_kernel, Range(10), a, b, c)

# Verify if the value returned by fetch_phi kernel
# stored into `c` is same as the value returned
Expand All @@ -108,8 +107,8 @@ def test_fetch_phi_diff_types(fetch_phi_fn):
"""

@dpex_exp.kernel
def _kernel(a, b):
i = dpex.get_global_id(0)
def _kernel(item: Item, a, b):
i = item.get_id(0)
v = AtomicRef(b, index=0)
getattr(v, fetch_phi_fn)(a[i])

Expand All @@ -118,19 +117,19 @@ def _kernel(a, b):
b = dpnp.zeros(N, dtype=dpnp.int32)

with pytest.raises(TypingError):
dpex_exp.call_kernel(_kernel, dpex.Range(10), a, b)
dpex_exp.call_kernel(_kernel, Range(10), a, b)


@dpex_exp.kernel
def atomic_ref_0(a):
i = dpex.get_global_id(0)
def atomic_ref_0(item: Item, a):
i = item.get_id(0)
v = AtomicRef(a, index=0)
v.fetch_add(a[i + 2])


@dpex_exp.kernel
def atomic_ref_1(a):
i = dpex.get_global_id(0)
def atomic_ref_1(item: Item, a):
i = item.get_id(0)
v = AtomicRef(a, index=1)
v.fetch_add(a[i + 2])

Expand All @@ -144,24 +143,24 @@ def test_spirv_compiler_flags_add():
N = 10
a = dpnp.ones(N, dtype=dpnp.float32)

dpex_exp.call_kernel(atomic_ref_0, dpex.Range(N - 2), a)
dpex_exp.call_kernel(atomic_ref_1, dpex.Range(N - 2), a)
dpex_exp.call_kernel(atomic_ref_0, Range(N - 2), a)
dpex_exp.call_kernel(atomic_ref_1, Range(N - 2), a)

assert a[0] == N - 1
assert a[1] == N - 1


@dpex_exp.kernel
def atomic_max_0(a):
i = dpex.get_global_id(0)
def atomic_max_0(item: Item, a):
i = item.get_id(0)
v = AtomicRef(a, index=0)
if i != 0:
v.fetch_max(a[i])


@dpex_exp.kernel
def atomic_max_1(a):
i = dpex.get_global_id(0)
def atomic_max_1(item: Item, a):
i = item.get_id(0)
v = AtomicRef(a, index=0)
if i != 0:
v.fetch_max(a[i])
Expand All @@ -177,8 +176,8 @@ def test_spirv_compiler_flags_max():
a = dpnp.arange(N, dtype=dpnp.float32)
b = dpnp.arange(N, dtype=dpnp.float32)

dpex_exp.call_kernel(atomic_max_0, dpex.Range(N), a)
dpex_exp.call_kernel(atomic_max_1, dpex.Range(N), b)
dpex_exp.call_kernel(atomic_max_0, Range(N), a)
dpex_exp.call_kernel(atomic_max_1, Range(N), b)

assert a[0] == N - 1
assert b[0] == N - 1
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,21 @@
import pytest
from numba.core.errors import TypingError

import numba_dpex as dpex
import numba_dpex.experimental as dpex_exp
from numba_dpex.kernel_api import AddressSpace, AtomicRef
from numba_dpex.kernel_api import AddressSpace, AtomicRef, Item, Range


def test_atomic_ref_compilation():
@dpex_exp.kernel
def atomic_ref_kernel(a, b):
i = dpex.get_global_id(0)
def atomic_ref_kernel(item: Item, a, b):
i = item.get_id(0)
v = AtomicRef(b, index=0, address_space=AddressSpace.GLOBAL)
v.fetch_add(a[i])

a = dpnp.ones(10)
b = dpnp.zeros(10)
try:
dpex_exp.call_kernel(atomic_ref_kernel, dpex.Range(10), a, b)
dpex_exp.call_kernel(atomic_ref_kernel, Range(10), a, b)
except Exception:
pytest.fail("Unexpected execution failure")

Expand All @@ -33,13 +32,13 @@ def test_atomic_ref_compilation_failure():
"""

@dpex_exp.kernel
def atomic_ref_kernel(a, b):
i = dpex.get_global_id(0)
def atomic_ref_kernel(item: Item, a, b):
i = item.get_id(0)
v = AtomicRef(b, index=0, address_space=AddressSpace.LOCAL)
v.fetch_add(a[i])

a = dpnp.ones(10)
b = dpnp.zeros(10)

with pytest.raises(TypingError):
dpex_exp.call_kernel(atomic_ref_kernel, dpex.Range(10), a, b)
dpex_exp.call_kernel(atomic_ref_kernel, Range(10), a, b)
7 changes: 3 additions & 4 deletions numba_dpex/tests/experimental/test_async_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,9 @@
import pytest
from numba.core.errors import TypingError

import numba_dpex as dpex
import numba_dpex.experimental as exp_dpex
from numba_dpex import Range
from numba_dpex.experimental import testing
from numba_dpex.kernel_api import Item, Range


@exp_dpex.kernel(
Expand All @@ -19,8 +18,8 @@
no_cpython_wrapper=True,
no_cfunc_wrapper=True,
)
def add(a, b, c):
i = dpex.get_global_id(0)
def add(item: Item, a, b, c):
i = item.get_id(0)
c[i] = b[i] + a[i]


Expand Down
9 changes: 5 additions & 4 deletions numba_dpex/tests/experimental/test_compiler_warnings.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@
import pytest
from numba.core import types

import numba_dpex as dpex
from numba_dpex import DpctlSyclQueue, DpnpNdArray
from numba_dpex import experimental as dpex_exp
from numba_dpex import int64
from numba_dpex.core.types.kernel_api.index_space_ids import ItemType
from numba_dpex.kernel_api import Item


def _kernel(a, b, c):
i = dpex.get_global_id(0)
def _kernel(item: Item, a, b, c):
i = item.get_id(0)
c[i] = a[i] + b[i]


Expand All @@ -30,5 +31,5 @@ def test_inline_threshold_level_warning():
with pytest.warns(UserWarning):
queue_ty = DpctlSyclQueue(dpctl.SyclQueue())
i64arr_ty = DpnpNdArray(ndim=1, dtype=int64, layout="C", queue=queue_ty)
kernel_sig = types.void(i64arr_ty, i64arr_ty, i64arr_ty)
kernel_sig = types.void(ItemType(1), i64arr_ty, i64arr_ty, i64arr_ty)
dpex_exp.kernel(inline_threshold=3)(_kernel).compile(kernel_sig)
5 changes: 3 additions & 2 deletions numba_dpex/tests/experimental/test_inline_threshold_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@

import numba_dpex as dpex
from numba_dpex import experimental as dpex_exp
from numba_dpex.kernel_api import Item


def kernel_func(a, b, c):
i = dpex.get_global_id(0)
def kernel_func(item: Item, a, b, c):
i = item.get_id(0)
c[i] = a[i] + b[i]


Expand Down
27 changes: 18 additions & 9 deletions numba_dpex/tests/experimental/test_kernel_specialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,26 +7,30 @@
import pytest
from numba.core.errors import TypingError

import numba_dpex as dpex
import numba_dpex.experimental as dpex_exp
from numba_dpex import DpnpNdArray, float32, int64
from numba_dpex.core.exceptions import InvalidKernelSpecializationError
from numba_dpex.kernel_api import Range
from numba_dpex.core.types.kernel_api.index_space_ids import ItemType
from numba_dpex.kernel_api import Item, Range

i64arrty = DpnpNdArray(ndim=1, dtype=int64, layout="C")
f32arrty = DpnpNdArray(ndim=1, dtype=float32, layout="C")
item_ty = ItemType(ndim=1)

specialized_kernel1 = dpex_exp.kernel((i64arrty, i64arrty, i64arrty))
specialized_kernel1 = dpex_exp.kernel((item_ty, i64arrty, i64arrty, i64arrty))
specialized_kernel2 = dpex_exp.kernel(
[(i64arrty, i64arrty, i64arrty), (f32arrty, f32arrty, f32arrty)]
[
(item_ty, i64arrty, i64arrty, i64arrty),
(item_ty, f32arrty, f32arrty, f32arrty),
]
)


def data_parallel_sum(a, b, c):
def data_parallel_sum(item: Item, a, b, c):
"""
Vector addition using the ``kernel`` decorator.
"""
i = dpex.get_global_id(0)
i = item.get_id(0)
c[i] = a[i] + b[i]


Expand All @@ -46,7 +50,9 @@ def test_invalid_specialization_error():
"""Test if an InvalidKernelSpecializationError is raised when attempting to
specialize with NumPy arrays.
"""
specialized_kernel3 = dpex_exp.kernel((int64[::1], int64[::1], int64[::1]))
specialized_kernel3 = dpex_exp.kernel(
(item_ty, int64[::1], int64[::1], int64[::1])
)
with pytest.raises(InvalidKernelSpecializationError):
specialized_kernel3(data_parallel_sum)

Expand Down Expand Up @@ -90,11 +96,14 @@ def test_string_specialization():
"""Test if NotImplementedError is raised when signature is a string"""

with pytest.raises(NotImplementedError):
dpex_exp.kernel("(i64arrty, i64arrty, i64arrty)")
dpex_exp.kernel("(item_ty, i64arrty, i64arrty, i64arrty)")

with pytest.raises(NotImplementedError):
dpex_exp.kernel(
["(i64arrty, i64arrty, i64arrty)", "(f32arrty, f32arrty, f32arrty)"]
[
"(item_ty, i64arrty, i64arrty, i64arrty)",
"(item_ty, f32arrty, f32arrty, f32arrty)",
]
)

with pytest.raises(ValueError):
Expand Down
Loading

0 comments on commit 90ab8ae

Please sign in to comment.