Skip to content

Commit

Permalink
A better implementation for the Range/NdRange class, with examples
Browse files Browse the repository at this point in the history
  • Loading branch information
chudur-budur committed Jan 26, 2023
1 parent e66c0fc commit 719536a
Show file tree
Hide file tree
Showing 6 changed files with 208 additions and 167 deletions.
92 changes: 24 additions & 68 deletions numba_dpex/core/kernel_interface/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,9 @@
)
from numba_dpex.core.kernel_interface.arg_pack_unpacker import Packer
from numba_dpex.core.kernel_interface.spirv_kernel import SpirvKernel
from numba_dpex.core.kernel_interface.utils import Ranges
from numba_dpex.core.kernel_interface.utils import NdRange
from numba_dpex.core.types import USMNdArray

simplefilter("always", DeprecationWarning)


def get_ordered_arg_access_types(pyfunc, access_types):
"""Deprecated and to be removed in next release."""
Expand Down Expand Up @@ -445,56 +443,6 @@ def _determine_kernel_launch_queue(self, args, argtypes):
else:
raise ExecutionQueueInferenceError(self.kernel_name)

def _raise_invalid_kernel_enqueue_args(self):
error_message = (
"Incorrect number of arguments for enqueuing numba_dpex.kernel. "
"Usage: device_env, global size, local size. "
"The local size argument is optional."
)
raise InvalidKernelLaunchArgsError(error_message)

def _ensure_valid_work_item_grid(self, val):
if not isinstance(val, (tuple, list, int)):
error_message = (
"Cannot create work item dimension from provided argument"
)
raise ValueError(error_message)

if isinstance(val, int):
val = [val]

# TODO: we need some way to check the max dimensions
"""
if len(val) > device_env.get_max_work_item_dims():
error_message = ("Unsupported number of work item dimensions ")
raise ValueError(error_message)
"""

return list(
val[::-1]
) # reversing due to sycl and opencl interop kernel range mismatch semantic

def _ensure_valid_work_group_size(self, val, work_item_grid):
if not isinstance(val, (tuple, list, int)):
error_message = (
"Cannot create work item dimension from provided argument"
)
raise ValueError(error_message)

if isinstance(val, int):
val = [val]

if len(val) != len(work_item_grid):
error_message = (
"Unsupported number of work item dimensions, "
+ "dimensions of global and local work items has to be the same "
)
raise IllegalRangeValueError(error_message)

return list(
val[::-1]
) # reversing due to sycl and opencl interop kernel range mismatch semantic

def __getitem__(self, args):
"""Mimic's ``numba.cuda`` square-bracket notation for configuring the
global_range and local_range settings when launching a kernel on a
Expand Down Expand Up @@ -523,9 +471,11 @@ def __getitem__(self, args):
"""

if isinstance(args, Ranges):
if isinstance(args, NdRange):
self._global_range = list(args.global_range)[::-1]
self._local_range = list(args.local_range)[::-1]
self._local_range = (
list(args.local_range)[::-1] if args.local_range else None
)
else:
if (
isinstance(args, tuple)
Expand All @@ -540,38 +490,44 @@ def __getitem__(self, args):
+ "i.e. <function>[(M,N), numba_dpex.DEFAULT_LOCAL_RANGE](<params>), "
+ "otherwise your code might produce erroneous results.",
DeprecationWarning,
stacklevel=2,
)
self._global_range = [args[0]]
self._local_range = [args[1]]
return self

if not isinstance(args, Iterable):
args = [args]

ls = None
args = [args] if not isinstance(args, Iterable) else args
nargs = len(args)

# Check if the kernel enquing arguments are sane
if nargs < 1 or nargs > 2:
self._raise_invalid_kernel_enqueue_args()
raise InvalidKernelLaunchArgsError(kernel_name=self.kernel_name)

gs = self._ensure_valid_work_item_grid(args[0])
g_range = (
[args[0]] if not isinstance(args[0], Iterable) else args[0]
)
# If the optional local size argument is provided
l_range = None
if nargs == 2:
if args[1] != []:
ls = self._ensure_valid_work_group_size(args[1], gs)
l_range = (
[args[1]]
if not isinstance(args[1], Iterable)
else args[1]
)
else:
warn(
"Empty local_range calls will be deprecated in the future.",
DeprecationWarning,
stacklevel=2,
)

self._global_range = list(gs)[::-1]
self._local_range = list(ls)[::-1] if ls else None
if len(g_range) < 1:
raise IllegalRangeValueError(kernel_name=self.kernel_name)

self._global_range = list(g_range)[::-1]
self._local_range = list(l_range)[::-1] if l_range else None

if self._global_range == [] and self._local_range is None:
raise IllegalRangeValueError(
"Illegal range values for kernel launch parameters."
)
return self

def _check_ranges(self, device):
Expand Down
Loading

0 comments on commit 719536a

Please sign in to comment.