Skip to content

Commit

Permalink
A better implementation for the Range/NdRange class, with examples
Browse files Browse the repository at this point in the history
Better deprecation warnings added

Following exact sycl::range/nd_range specification for kernel lauch parameters

Default cache size is set to 128, like numba
  • Loading branch information
chudur-budur committed Jan 27, 2023
1 parent e66c0fc commit f536bd1
Show file tree
Hide file tree
Showing 7 changed files with 237 additions and 172 deletions.
2 changes: 1 addition & 1 deletion numba_dpex/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def __getattr__(name):
ENABLE_CACHE = _readenv("NUMBA_DPEX_ENABLE_CACHE", int, 1)
# Capacity of the cache, execute it like:
# NUMBA_DPEX_CACHE_SIZE=20 python <code>
CACHE_SIZE = _readenv("NUMBA_DPEX_CACHE_SIZE", int, 10)
CACHE_SIZE = _readenv("NUMBA_DPEX_CACHE_SIZE", int, 128)

TESTING_SKIP_NO_DPNP = _readenv("NUMBA_DPEX_TESTING_SKIP_NO_DPNP", int, 0)
TESTING_SKIP_NO_DEBUGGING = _readenv(
Expand Down
124 changes: 52 additions & 72 deletions numba_dpex/core/kernel_interface/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,9 @@
)
from numba_dpex.core.kernel_interface.arg_pack_unpacker import Packer
from numba_dpex.core.kernel_interface.spirv_kernel import SpirvKernel
from numba_dpex.core.kernel_interface.utils import Ranges
from numba_dpex.core.kernel_interface.utils import NdRange, Range
from numba_dpex.core.types import USMNdArray

simplefilter("always", DeprecationWarning)


def get_ordered_arg_access_types(pyfunc, access_types):
"""Deprecated and to be removed in next release."""
Expand Down Expand Up @@ -445,56 +443,6 @@ def _determine_kernel_launch_queue(self, args, argtypes):
else:
raise ExecutionQueueInferenceError(self.kernel_name)

def _raise_invalid_kernel_enqueue_args(self):
error_message = (
"Incorrect number of arguments for enqueuing numba_dpex.kernel. "
"Usage: device_env, global size, local size. "
"The local size argument is optional."
)
raise InvalidKernelLaunchArgsError(error_message)

def _ensure_valid_work_item_grid(self, val):
if not isinstance(val, (tuple, list, int)):
error_message = (
"Cannot create work item dimension from provided argument"
)
raise ValueError(error_message)

if isinstance(val, int):
val = [val]

# TODO: we need some way to check the max dimensions
"""
if len(val) > device_env.get_max_work_item_dims():
error_message = ("Unsupported number of work item dimensions ")
raise ValueError(error_message)
"""

return list(
val[::-1]
) # reversing due to sycl and opencl interop kernel range mismatch semantic

def _ensure_valid_work_group_size(self, val, work_item_grid):
if not isinstance(val, (tuple, list, int)):
error_message = (
"Cannot create work item dimension from provided argument"
)
raise ValueError(error_message)

if isinstance(val, int):
val = [val]

if len(val) != len(work_item_grid):
error_message = (
"Unsupported number of work item dimensions, "
+ "dimensions of global and local work items has to be the same "
)
raise IllegalRangeValueError(error_message)

return list(
val[::-1]
) # reversing due to sycl and opencl interop kernel range mismatch semantic

def __getitem__(self, args):
"""Mimic's ``numba.cuda`` square-bracket notation for configuring the
global_range and local_range settings when launching a kernel on a
Expand Down Expand Up @@ -522,8 +470,11 @@ def __getitem__(self, args):
global_range and local_range attributes initialized.
"""

if isinstance(args, Ranges):
if isinstance(args, Range):
# we need inversions, see github issue #889
self._global_range = list(args)[::-1]
elif isinstance(args, NdRange):
# we need inversions, see github issue #889
self._global_range = list(args.global_range)[::-1]
self._local_range = list(args.local_range)[::-1]
else:
Expand All @@ -534,44 +485,73 @@ def __getitem__(self, args):
and isinstance(args[1], int)
):
warn(
"Ambiguous kernel launch paramters. "
+ "If your data have dimensions > 1, "
+ "include a default/empty local_range. "
+ "i.e. <function>[(M,N), numba_dpex.DEFAULT_LOCAL_RANGE](<params>), "
"Ambiguous kernel launch paramters. If your data have "
+ "dimensions > 1, include a default/empty local_range:\n"
+ " <function>[(X,Y), numba_dpex.DEFAULT_LOCAL_RANGE](<params>)\n"
+ "otherwise your code might produce erroneous results.",
DeprecationWarning,
stacklevel=2,
)
self._global_range = [args[0]]
self._local_range = [args[1]]
return self

if not isinstance(args, Iterable):
args = [args]
warn(
"The current syntax for specification of kernel lauch "
+ "parameters is deprecated. Users should set the kernel "
+ "parameters through Range/NdRange classes.\n"
+ "Example:\n"
+ " from numba_dpex.core.kernel_interface.utils import Range,NdRange\n\n"
+ " # for global range only\n"
+ " <function>[Range(X,Y)](<parameters>)\n"
+ " # or,\n"
+ " # for both global and local ranges\n"
+ " <function>[NdRange((X,Y), (P,Q))](<parameters>)",
DeprecationWarning,
stacklevel=2,
)

ls = None
args = [args] if not isinstance(args, Iterable) else args
nargs = len(args)

# Check if the kernel enquing arguments are sane
if nargs < 1 or nargs > 2:
self._raise_invalid_kernel_enqueue_args()
raise InvalidKernelLaunchArgsError(kernel_name=self.kernel_name)

gs = self._ensure_valid_work_item_grid(args[0])
g_range = (
[args[0]] if not isinstance(args[0], Iterable) else args[0]
)
# If the optional local size argument is provided
l_range = None
if nargs == 2:
if args[1] != []:
ls = self._ensure_valid_work_group_size(args[1], gs)
l_range = (
[args[1]]
if not isinstance(args[1], Iterable)
else args[1]
)
else:
warn(
"Empty local_range calls will be deprecated in the future.",
"Empty local_range calls are deprecated. Please use Range/NdRange "
+ "to specify the kernel launch parameters:\n"
+ "Example:\n"
+ " from numba_dpex.core.kernel_interface.utils import Range,NdRange\n\n"
+ " # for global range only\n"
+ " <function>[Range(X,Y)](<parameters>)\n"
+ " # or,\n"
+ " # for both global and local ranges\n"
+ " <function>[NdRange((X,Y), (P,Q))](<parameters>)",
DeprecationWarning,
stacklevel=2,
)

self._global_range = list(gs)[::-1]
self._local_range = list(ls)[::-1] if ls else None
if len(g_range) < 1:
raise IllegalRangeValueError(kernel_name=self.kernel_name)

# we need inversions, see github issue #889
self._global_range = list(g_range)[::-1]
self._local_range = list(l_range)[::-1] if l_range else None

if self._global_range == [] and self._local_range is None:
raise IllegalRangeValueError(
"Illegal range values for kernel launch parameters."
)
return self

def _check_ranges(self, device):
Expand Down
Loading

0 comments on commit f536bd1

Please sign in to comment.