Skip to content

Commit

Permalink
Getting rid of core.kernel_interface.indexers.NdRange, using core.ker…
Browse files Browse the repository at this point in the history
…nel_interface.utils.Ranges instead.
  • Loading branch information
chudur-budur committed Jan 25, 2023
1 parent f091385 commit 8373c3b
Show file tree
Hide file tree
Showing 5 changed files with 156 additions and 167 deletions.
4 changes: 2 additions & 2 deletions numba_dpex/core/kernel_interface/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@
UnsupportedWorkItemSizeError,
)
from numba_dpex.core.kernel_interface.arg_pack_unpacker import Packer
from numba_dpex.core.kernel_interface.indexers import NdRange
from numba_dpex.core.kernel_interface.spirv_kernel import SpirvKernel
from numba_dpex.core.kernel_interface.utils import Ranges
from numba_dpex.core.types import USMNdArray

simplefilter("always", DeprecationWarning)
Expand Down Expand Up @@ -523,7 +523,7 @@ def __getitem__(self, args):
"""

if isinstance(args, NdRange):
if isinstance(args, Ranges):
self._global_range = list(args.global_range)[::-1]
self._local_range = list(args.local_range)[::-1]
else:
Expand Down
155 changes: 0 additions & 155 deletions numba_dpex/core/kernel_interface/indexers.py

This file was deleted.

146 changes: 146 additions & 0 deletions numba_dpex/core/kernel_interface/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
class Ranges:
"""A data structure to encapsulate kernel lauch parameters.
This is just a wrapper class on top of tuples. The kernel
launch parameter is consisted of two tuples of int. The
first tuple is called `global_range` and the second tuple
is called `local_range`.
The `global_range` is analogous to DPC++'s "global size"
and the `local_range` is analogous to DPC++'s "workgroup size",
respectively.
"""

def __init__(self, global_range, local_range=None):
"""Constructor for Ranges.
Args:
global_range (tuple): A tuple of int to specify DPC++'s
global size.
local_range (tuple, optional): A tuple of int to specify
DPC++'s workgroup size. Defaults to None.
"""
self._global_range = global_range
self._local_range = local_range
self._check_sanity()

def _check_sanity(self):
"""Sanity checks for the global and local range tuples.
Raises:
ValueError: If the length of global_range is more than 3, if tuple.
ValueError: If each of value global_range is not an int, if tuple.
ValueError: If the global_range is not a tuple or an int.
ValueError: If the length of local_range is more than 3, if tuple.
ValueError: If the dimensions of local_range
and global_range are not same, if tuples.
ValueError: If each of value local_range is not an int, if tuple.
ValueError: If the range limits in the global_range is not
divisible by the range limit in the local_range
at the corresponding dimension.
ValueError: If the local_range is not a tuple or an int.
"""
if isinstance(self._global_range, tuple):
if len(self._global_range) > 3:
raise ValueError(
"The maximum allowed dimension for global_range is 3."
)
for i in range(len(self._global_range)):
if not isinstance(self._global_range[i], int):
raise ValueError("The range limit values must be an int.")
elif isinstance(self._global_range, int):
self._global_range = tuple([self._global_range])
else:
raise ValueError("global_range must be a tuple or an int.")
if self._local_range:
if isinstance(self._local_range, tuple):
if len(self._local_range) > 3:
raise ValueError(
"The maximum allowed dimension for local_range is 3."
)
if len(self._global_range) != len(self._local_range):
raise ValueError(
"global_range and local_range must "
+ "have the same dimensions."
)
for i in range(len(self._local_range)):
if not isinstance(self._local_range[i], int):
raise ValueError(
"The range limit values must be an int."
)
if self._global_range[i] % self._local_range[i] != 0:
raise ValueError(
"Each limit in global_range must be divisible "
+ "by each limit in local_range at "
+ " the corresponding dimension."
)
elif isinstance(self._local_range, int):
self._local_range = tuple([self._local_range])
else:
raise ValueError("local_range must be a tuple or an int.")

@property
def global_range(self):
"""global_range accessor.
Returns:
tuple: global_range
"""
return self._global_range

@property
def local_range(self):
"""local_range accessor.
Returns:
tuple: local_range
"""
return self._local_range

def __str__(self) -> str:
"""str() function for this class.
Returns:
str: str representation of a Ranges object.
"""
return (
"(" + str(self._global_range) + ", " + str(self._local_range) + ")"
)

def __repr__(self) -> str:
"""repr() function for this class.
Returns:
str: str representation of a Ranges object.
"""
return self.__str__()


# tester
if __name__ == "__main__":
ranges = Ranges(1)
print(ranges)

ranges = Ranges(1, 1)
print(ranges)

ranges = Ranges((2, 2, 2), (1, 1, 1))
print(ranges)

ranges = Ranges((2, 2, 2))
print(ranges)

try:
ranges = Ranges((1, 1, 1, 1))
except Exception as e:
print(e)

try:
ranges = Ranges((2, 2, 2), (1, 1))
except Exception as e:
print(e)

try:
ranges = Ranges((3, 3, 3), (2, 2, 2))
except Exception as e:
print(e)
2 changes: 1 addition & 1 deletion numba_dpex/examples/kernel/vector_sum2D.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def data_parallel_sum(a, b, c):


def driver(a, b, c, global_size):
data_parallel_sum[global_size](a, b, c)
data_parallel_sum[global_size, dpex.DEFAULT_LOCAL_SIZE](a, b, c)


def main():
Expand Down
16 changes: 7 additions & 9 deletions numba_dpex/tests/kernel_tests/test_ndrange_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,7 @@
import pytest

import numba_dpex as ndpx
from numba_dpex.core.exceptions import (
UnmatchedNumberOfRangeDimsError,
UnsupportedGroupWorkItemSizeError,
)
from numba_dpex.core.kernel_interface.utils import Ranges


# Data parallel kernel implementing vector sum
Expand All @@ -19,13 +16,13 @@ def kernel_vector_sum(a, b, c):


@pytest.mark.parametrize(
"error, ndrange",
"error, ranges",
[
(UnmatchedNumberOfRangeDimsError, ((2, 2), (1, 1, 1))),
(UnsupportedGroupWorkItemSizeError, ((3, 3, 3), (2, 2, 2))),
(ValueError, ((2, 2), (1, 1, 1))),
(ValueError, ((3, 3, 3), (2, 2, 2))),
],
)
def test_ndrange_config_error(error, ndrange):
def test_ndrange_config_error(error, ranges):
"""Test if a exception is raised when calling a
ndrange kernel with unspported arguments.
"""
Expand All @@ -35,4 +32,5 @@ def test_ndrange_config_error(error, ndrange):
c = dpt.zeros(1024, dtype=dpt.int64)

with pytest.raises(error):
kernel_vector_sum[ndrange](a, b, c)
range = Ranges(ranges[0], ranges[1])
kernel_vector_sum[range](a, b, c)

0 comments on commit 8373c3b

Please sign in to comment.