Skip to content

Commit

Permalink
Improve dispatcher checks for laych args and add unit test.
Browse files Browse the repository at this point in the history
  • Loading branch information
Diptorup Deb committed Jan 18, 2023
1 parent 664d57e commit c74ea24
Show file tree
Hide file tree
Showing 3 changed files with 148 additions and 45 deletions.
13 changes: 2 additions & 11 deletions numba_dpex/core/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,18 +55,9 @@ class InvalidKernelLaunchArgsError(Exception):
"""

def __init__(self, kernel_name):
warn(
"The InvalidKernelLaunchArgsError class is deprecated, and will "
+ "be removed once kernel launching using __getitem__ for the "
+ "KernelLauncher class is removed.",
DeprecationWarning,
stacklevel=2,
)
self.message = (
"Invalid launch arguments specified for launching the Kernel "
f'"{kernel_name}". Launch arguments can only be a tuple '
"specifying the global range or a tuple of tuples specifying "
"global and local ranges."
"Invalid global and local range arguments specified for launching "
f' the Kernel "{kernel_name}". Refer documentation for details.'
)
super().__init__(self.message)

Expand Down
67 changes: 33 additions & 34 deletions numba_dpex/core/kernel_interface/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,46 +461,45 @@ def __getitem__(self, args):
global_range and local_range attributes initialized.
"""

if isinstance(args, int):
self._global_range = [args]
self._local_range = None
elif (isinstance(args, tuple) or isinstance(args, list)) and all(
isinstance(v, int) for v in args
):
self._global_range = list(args)
self._local_range = None
elif isinstance(args, tuple) and len(args) == 2:

gr = args[0]
lr = args[1]
if isinstance(gr, int):
self._global_range = [gr]
elif all(isinstance(v, int) for v in gr) and len(gr) != 0:
self._global_range = list(gr)
else:
raise IllegalRangeValueError(kernel_name=self.kernel_name)

if isinstance(lr, int):
self._local_range = [lr]
elif isinstance(lr, list) and len(lr) == 0:
# deprecation warning
warn(
"Specifying the local range as an empty list "
"(DEFAULT_LOCAL_SIZE) is deprecated. The kernel will be "
"executed as a basic data-parallel kernel over the global "
"range. Specify a valid local range to execute the kernel "
"as an ND-range kernel.",
DeprecationWarning,
stacklevel=2,
)
elif isinstance(args, tuple) or isinstance(args, list):
if len(args) == 1 and all(isinstance(v, int) for v in args):
self._global_range = list(args)
self._local_range = None
elif all(isinstance(v, int) for v in lr) and len(lr) != 0:
self._local_range = list(lr)
elif len(args) == 2:
gr = args[0]
lr = args[1]
if isinstance(gr, int):
self._global_range = [gr]
elif len(gr) != 0 and all(isinstance(v, int) for v in gr):
self._global_range = list(gr)
else:
raise IllegalRangeValueError(kernel_name=self.kernel_name)

if isinstance(lr, int):
self._local_range = [lr]
elif isinstance(lr, list) and len(lr) == 0:
# deprecation warning
warn(
"Specifying the local range as an empty list "
"(DEFAULT_LOCAL_SIZE) is deprecated. The kernel will "
"be executed as a basic data-parallel kernel over the "
"global range. Specify a valid local range to execute "
"the kernel as an ND-range kernel.",
DeprecationWarning,
stacklevel=2,
)
self._local_range = None
elif len(lr) != 0 and all(isinstance(v, int) for v in lr):
self._local_range = list(lr)
else:
raise IllegalRangeValueError(kernel_name=self.kernel_name)
else:
raise IllegalRangeValueError(kernel_name=self.kernel_name)
raise InvalidKernelLaunchArgsError(kernel_name=self.kernel_name)
else:
raise IllegalRangeValueError(kernel_name=self.kernel_name)
raise InvalidKernelLaunchArgsError(kernel_name=self.kernel_name)

# FIXME:[::-1] is done as OpenCL and SYCl have different orders when
# it comes to specifying dimensions.
Expand Down
113 changes: 113 additions & 0 deletions numba_dpex/tests/kernel_tests/test_kernel_launch_params.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# SPDX-FileCopyrightText: 2020 - 2022 Intel Corporation
#
# SPDX-License-Identifier: Apache-2.0

import pytest

import numba_dpex as dpex
from numba_dpex.core.exceptions import (
IllegalRangeValueError,
InvalidKernelLaunchArgsError,
)


@dpex.kernel
def vecadd(a, b, c):
i = dpex.get_global_id(0)
a[i] = b[i] + c[i]


def test_1D_global_range_as_int():
k = vecadd[10]
assert k._global_range == [10]
assert k._local_range is None


def test_1D_global_range_as_one_tuple():
k = vecadd[
10,
]
assert k._global_range == [10]
assert k._local_range is None


def test_1D_global_range_as_list():
k = vecadd[[10]]
assert k._global_range == [10]
assert k._local_range is None


def test_1D_global_range_and_1D_local_range():
k = vecadd[10, 10]
assert k._global_range == [10]
assert k._local_range == [10]


def test_1D_global_range_and_1D_local_range2():
k = vecadd[[10, 10]]
assert k._global_range == [10]
assert k._local_range == [10]


def test_1D_global_range_and_1D_local_range3():
k = vecadd[(10,), (10,)]
assert k._global_range == [10]
assert k._local_range == [10]


def test_2D_global_range_and_2D_local_range():
k = vecadd[(10, 10), (10, 10)]
assert k._global_range == [10, 10]
assert k._local_range == [10, 10]


def test_2D_global_range_and_2D_local_range2():
k = vecadd[[10, 10], (10, 10)]
assert k._global_range == [10, 10]
assert k._local_range == [10, 10]


def test_2D_global_range_and_2D_local_range3():
k = vecadd[(10, 10), [10, 10]]
assert k._global_range == [10, 10]
assert k._local_range == [10, 10]


def test_2D_global_range_and_2D_local_range4():
k = vecadd[[10, 10], [10, 10]]
assert k._global_range == [10, 10]
assert k._local_range == [10, 10]


def test_deprecation_warning_for_empty_local_range():
with pytest.deprecated_call():
k = vecadd[[10, 10], []]
assert k._global_range == [10, 10]
assert k._local_range is None


def test_deprecation_warning_for_empty_local_range2():
with pytest.deprecated_call():
k = vecadd[10, []]
assert k._global_range == [10]
assert k._local_range is None


def test_illegal_kernel_launch_arg():
with pytest.raises(InvalidKernelLaunchArgsError):
vecadd[10, 10, []]


def test_illegal_range_error():
with pytest.raises(IllegalRangeValueError):
vecadd[[], []]


def test_illegal_range_error2():
with pytest.raises(IllegalRangeValueError):
vecadd[[], 10]


def test_illegal_range_error3():
with pytest.raises(IllegalRangeValueError):
vecadd[(), 10]

0 comments on commit c74ea24

Please sign in to comment.