Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A more consistent kernel launch parameter syntax #888

Merged
merged 6 commits into from
Jan 31, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion numba_dpex/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def __getattr__(name):
ENABLE_CACHE = _readenv("NUMBA_DPEX_ENABLE_CACHE", int, 1)
# Capacity of the cache, execute it like:
# NUMBA_DPEX_CACHE_SIZE=20 python <code>
CACHE_SIZE = _readenv("NUMBA_DPEX_CACHE_SIZE", int, 10)
CACHE_SIZE = _readenv("NUMBA_DPEX_CACHE_SIZE", int, 128)
chudur-budur marked this conversation as resolved.
Show resolved Hide resolved

TESTING_SKIP_NO_DPNP = _readenv("NUMBA_DPEX_TESTING_SKIP_NO_DPNP", int, 0)
TESTING_SKIP_NO_DEBUGGING = _readenv(
Expand Down
116 changes: 77 additions & 39 deletions numba_dpex/core/kernel_interface/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# SPDX-License-Identifier: Apache-2.0


from collections.abc import Iterable
from inspect import signature
from warnings import warn

Expand Down Expand Up @@ -32,6 +33,7 @@
)
from numba_dpex.core.kernel_interface.arg_pack_unpacker import Packer
from numba_dpex.core.kernel_interface.spirv_kernel import SpirvKernel
from numba_dpex.core.kernel_interface.utils import NdRange, Range
from numba_dpex.core.types import USMNdArray


Expand Down Expand Up @@ -468,51 +470,87 @@ def __getitem__(self, args):
global_range and local_range attributes initialized.

"""
if isinstance(args, int):
self._global_range = [args]
self._local_range = None
elif isinstance(args, tuple) or isinstance(args, list):
if len(args) == 1 and all(isinstance(v, int) for v in args):
self._global_range = list(args)
self._local_range = None
elif len(args) == 2:
gr = args[0]
lr = args[1]
if isinstance(gr, int):
self._global_range = [gr]
elif len(gr) != 0 and all(isinstance(v, int) for v in gr):
self._global_range = list(gr)
else:
raise IllegalRangeValueError(kernel_name=self.kernel_name)
if isinstance(args, Range):
# we need inversions, see github issue #889
self._global_range = list(args)[::-1]
elif isinstance(args, NdRange):
# we need inversions, see github issue #889
self._global_range = list(args.global_range)[::-1]
self._local_range = list(args.local_range)[::-1]
else:
if (
chudur-budur marked this conversation as resolved.
Show resolved Hide resolved
isinstance(args, tuple)
and len(args) == 2
and isinstance(args[0], int)
and isinstance(args[1], int)
):
warn(
"Ambiguous kernel launch paramters. If your data have "
+ "dimensions > 1, include a default/empty local_range:\n"
+ " <function>[(X,Y), numba_dpex.DEFAULT_LOCAL_RANGE](<params>)\n"
+ "otherwise your code might produce erroneous results.",
DeprecationWarning,
stacklevel=2,
)
self._global_range = [args[0]]
self._local_range = [args[1]]
return self

warn(
"The current syntax for specification of kernel lauch "
+ "parameters is deprecated. Users should set the kernel "
+ "parameters through Range/NdRange classes.\n"
+ "Example:\n"
+ " from numba_dpex.core.kernel_interface.utils import Range,NdRange\n\n"
+ " # for global range only\n"
+ " <function>[Range(X,Y)](<parameters>)\n"
+ " # or,\n"
+ " # for both global and local ranges\n"
+ " <function>[NdRange((X,Y), (P,Q))](<parameters>)",
DeprecationWarning,
stacklevel=2,
)

if isinstance(lr, int):
self._local_range = [lr]
elif isinstance(lr, list) and len(lr) == 0:
# deprecation warning
args = [args] if not isinstance(args, Iterable) else args
nargs = len(args)

# Check if the kernel enquing arguments are sane
if nargs < 1 or nargs > 2:
raise InvalidKernelLaunchArgsError(kernel_name=self.kernel_name)

g_range = (
[args[0]] if not isinstance(args[0], Iterable) else args[0]
)
# If the optional local size argument is provided
l_range = None
if nargs == 2:
if args[1] != []:
l_range = (
[args[1]]
if not isinstance(args[1], Iterable)
else args[1]
)
else:
warn(
"Specifying the local range as an empty list "
"(DEFAULT_LOCAL_SIZE) is deprecated. The kernel will "
"be executed as a basic data-parallel kernel over the "
"global range. Specify a valid local range to execute "
"the kernel as an ND-range kernel.",
"Empty local_range calls are deprecated. Please use Range/NdRange "
+ "to specify the kernel launch parameters:\n"
+ "Example:\n"
+ " from numba_dpex.core.kernel_interface.utils import Range,NdRange\n\n"
+ " # for global range only\n"
+ " <function>[Range(X,Y)](<parameters>)\n"
+ " # or,\n"
+ " # for both global and local ranges\n"
+ " <function>[NdRange((X,Y), (P,Q))](<parameters>)",
DeprecationWarning,
stacklevel=2,
)
self._local_range = None
elif len(lr) != 0 and all(isinstance(v, int) for v in lr):
self._local_range = list(lr)
else:
raise IllegalRangeValueError(kernel_name=self.kernel_name)
else:
raise InvalidKernelLaunchArgsError(kernel_name=self.kernel_name)
else:
raise InvalidKernelLaunchArgsError(kernel_name=self.kernel_name)

# FIXME:[::-1] is done as OpenCL and SYCl have different orders when
# it comes to specifying dimensions.
self._global_range = list(self._global_range)[::-1]
if self._local_range:
self._local_range = list(self._local_range)[::-1]
if len(g_range) < 1:
raise IllegalRangeValueError(kernel_name=self.kernel_name)

# we need inversions, see github issue #889
self._global_range = list(g_range)[::-1]
self._local_range = list(l_range)[::-1] if l_range else None

return self

Expand Down
218 changes: 218 additions & 0 deletions numba_dpex/core/kernel_interface/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
from collections.abc import Iterable


class Range(tuple):
"""A data structure to encapsulate a single kernel lauch parameter.

The range is an abstraction that describes the number of elements
in each dimension of buffers and index spaces. It can contain
1, 2, or 3 numbers, dependending on the dimensionality of the
object it describes.

This is just a wrapper class on top of a 3-tuple. The kernel launch
parameter is consisted of three int's. This class basically mimics
the behavior of `sycl::range`.
"""

def __new__(cls, dim0, dim1=None, dim2=None):
"""Constructs a 1, 2, or 3 dimensional range.

Args:
dim0 (int): The range of the first dimension.
dim1 (int, optional): The range of second dimension.
Defaults to None.
dim2 (int, optional): The range of the third dimension.
Defaults to None.

Raises:
TypeError: If dim0 is not an int.
TypeError: If dim1 is not an int.
TypeError: If dim2 is not an int.
"""
if not isinstance(dim0, int):
raise TypeError("dim0 of a Range must be an int.")
_values = [dim0]
if dim1:
if not isinstance(dim1, int):
raise TypeError("dim1 of a Range must be an int.")
_values.append(dim1)
if dim2:
if not isinstance(dim2, int):
raise TypeError("dim2 of a Range must be an int.")
_values.append(dim2)
return super(Range, cls).__new__(cls, tuple(_values))

def get(self, index):
"""Returns the range of a single dimension.

Args:
index (int): The index of the dimension, i.e. [0,2]

Returns:
int: The range of the dimension indexed by `index`.
"""
return self[index]

def size(self):
"""Returns the size of a range.

Returns the size of a range by multiplying
the range of the individual dimensions.

Returns:
int: The size of a range.
"""
n = len(self)
if n > 2:
return self[0] * self[1] * self[2]
elif n > 1:
return self[0] * self[1]
else:
return self[0]


class NdRange:
"""A class to encapsulate all kernel launch parameters.

The NdRange defines the index space for a work group as well as
the global index space. It is passed to parallel_for to execute
a kernel on a set of work items.

This class basically contains two Range object, one for the global_range
and the other for the local_range. The global_range parameter contains
the global index space and the local_range parameter contains the index
space of a work group. This class mimics the behavior of `sycl::nd_range`
class.
"""

def __init__(self, global_size, local_size):
"""Constructor for NdRange class.

Args:
global_size (Range or tuple of int's): The values for
the global_range.
local_size (Range or tuple of int's, optional): The values for
the local_range. Defaults to None.
"""
if isinstance(global_size, Range):
self._global_range = global_size
elif isinstance(global_size, Iterable):
self._global_range = Range(*global_size)
else:
TypeError("Unknwon argument type for NdRange global_size.")

if isinstance(local_size, Range):
self._local_range = local_size
elif isinstance(local_size, Iterable):
self._local_range = Range(*local_size)
else:
TypeError("Unknwon argument type for NdRange local_size.")

@property
def global_range(self):
"""Accessor for global_range.

Returns:
Range: The `global_range` `Range` object.
"""
return self._global_range

@property
def local_range(self):
"""Accessor for local_range.

Returns:
Range: The `local_range` `Range` object.
"""
return self._local_range

def get_global_range(self):
"""Returns a Range defining the index space.

Returns:
Range: A `Range` object defining the index space.
"""
return self._global_range

def get_local_range(self):
"""Returns a Range defining the index space of a work group.

Returns:
Range: A `Range` object to specify index space of a work group.
"""
return self._local_range

def __str__(self):
"""str() function for NdRange class.

Returns:
str: str representation for NdRange class.
"""
return (
"(" + str(self._global_range) + ", " + str(self._local_range) + ")"
)

def __repr__(self):
"""repr() function for NdRange class.

Returns:
str: str representation for NdRange class.
"""
return self.__str__()


if __name__ == "__main__":
r1 = Range(1)
print("r1 =", r1)

r2 = Range(1, 2)
print("r2 =", r2)

r3 = Range(1, 2, 3)
print("r3 =", r3, ", len(r3) =", len(r3))

r3 = Range(*(1, 2, 3))
print("r3 =", r3, ", len(r3) =", len(r3))

r3 = Range(*[1, 2, 3])
print("r3 =", r3, ", len(r3) =", len(r3))

print("r1.get(0) =", r1.get(0))
try:
print("r2.get(2) =", r2.get(2))
except Exception as e:
print(e)

print("r3.get(0) =", r3.get(0))
print("r3.get(1) =", r3.get(1))

print("r1[0] =", r1[0])
try:
print("r2[2] =", r2[2])
except Exception as e:
print(e)

print("r3[0] =", r3[0])
print("r3[1] =", r3[1])

try:
r4 = Range(1, 2, 3, 4)
except Exception as e:
print(e)

try:
r5 = Range(*(1, 2, 3, 4))
except Exception as e:
print(e)

ndr1 = NdRange(Range(1, 2))
print("ndr1 =", ndr1)

ndr2 = NdRange(Range(1, 2), Range(1, 1, 1))
print("ndr2 =", ndr2)

ndr3 = NdRange((1, 2))
print("ndr3 =", ndr3)

ndr4 = NdRange((1, 2), (1, 1, 1))
print("ndr4 =", ndr4)
2 changes: 2 additions & 0 deletions numba_dpex/examples/kernel/kernel_specialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,3 +128,5 @@ def data_parallel_sum2(a, b, c):
"strings."
)
print(e)

print("Done...")
Loading