Skip to content

Commit

Permalink
Disallow LocalAccessor arguments to RangeType kernels
Browse files Browse the repository at this point in the history
  • Loading branch information
Diptorup Deb committed Mar 15, 2024
1 parent eb8d7ac commit ae0a1b9
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 12 deletions.
30 changes: 30 additions & 0 deletions numba_dpex/experimental/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
ItemType,
NdItemType,
)
from numba_dpex.core.types.kernel_api.local_accessor import LocalAccessorType
from numba_dpex.core.utils import kernel_launcher as kl
from numba_dpex.dpctl_iface import libsyclinterface_bindings as sycl
from numba_dpex.dpctl_iface.wrappers import wrap_event_reference
Expand All @@ -42,6 +43,23 @@ class _LLRange(NamedTuple):
local_range_extents: list


def _has_a_local_accessor_argument(args):
"""Checks if there exists at least one LocalAccessorType object in the
input tuple.
Args:
args (_type_): A tuple of numba.core.Type objects
Returns:
bool : True if at least one LocalAccessorType object was found,
otherwise False.
"""
for arg in args:
if isinstance(arg, LocalAccessorType):
return True
return False


def _wrap_event_reference_tuple(ctx, builder, event1, event2):
"""Creates tuple data model from two event data models, so it can be
boxed to Python."""
Expand Down Expand Up @@ -153,6 +171,18 @@ def _submit_kernel( # pylint: disable=too-many-arguments
DeprecationWarning,
)

# Validate local accessor arguments are passed only to a kernel that is
# launched with an NdRange index space. Reference section 4.7.6.11. of the
# SYCL 2020 specification: A local_accessor must not be used in a SYCL
# kernel function that is invoked via single_task or via the simple form of
# parallel_for that takes a range parameter.
if _has_a_local_accessor_argument(ty_kernel_args_tuple) and isinstance(
ty_index_space, RangeType
):
raise TypeError(
"A RangeType kernel cannot have a LocalAccessor argument"
)

# ty_kernel_fn is type specific to exact function, so we can get function
# directly from type and compile it. Thats why we don't need to get it in
# codegen
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import dpnp
import pytest
from numba.core.errors import TypingError

import numba_dpex as dpex
import numba_dpex.experimental as dpex_exp
Expand All @@ -21,23 +22,24 @@
)


@pytest.mark.parametrize("supported_dtype", list_of_supported_dtypes)
def test_local_accessor(supported_dtype):
"""A test for passing a LocalAccessor object as a kernel argument."""
@dpex_exp.kernel
def _kernel(nd_item: NdItem, a, slm):
i = nd_item.get_global_linear_id()
j = nd_item.get_local_linear_id()

@dpex_exp.kernel
def _kernel(nd_item: NdItem, a, slm):
i = nd_item.get_global_linear_id()
j = nd_item.get_local_linear_id()
slm[j] = 0
group_barrier(nd_item.get_group(), MemoryScope.WORK_GROUP)

slm[j] = 0
for m in range(100):
slm[j] += i * m
group_barrier(nd_item.get_group(), MemoryScope.WORK_GROUP)

for m in range(100):
slm[j] += i * m
group_barrier(nd_item.get_group(), MemoryScope.WORK_GROUP)
a[i] = slm[j]

a[i] = slm[j]

@pytest.mark.parametrize("supported_dtype", list_of_supported_dtypes)
def test_local_accessor(supported_dtype):
"""A test for passing a LocalAccessor object as a kernel argument."""

N = 32
a = dpnp.empty(N, dtype=supported_dtype)
Expand All @@ -52,3 +54,18 @@ def _kernel(nd_item: NdItem, a, slm):

for idx in range(N):
assert a[idx] == 4950 * idx


def test_local_accessor_argument_to_range_kernel():
"""Checks if an exception is raised when passing a local accessor to a
RangeType kernel.
"""
N = 32
a = dpnp.empty(N)
slm = LocalAccessor((32 * 64), dtype=a.dtype)

# Passing a local_accessor to a RangeType kernel should raise an exception.
# A TypeError is raised if NUMBA_CAPTURED_ERROR=new_style and a
# numba.TypingError is raised if NUMBA_CAPTURED_ERROR=old_style
with pytest.raises((TypeError, TypingError)):
dpex_exp.call_kernel(_kernel, dpex.Range(N), a, slm)

0 comments on commit ae0a1b9

Please sign in to comment.