-
Notifications
You must be signed in to change notification settings - Fork 33
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Improve dispatcher checks for laych args and add unit test.
- Loading branch information
Diptorup Deb
committed
Jan 18, 2023
1 parent
664d57e
commit c74ea24
Showing
3 changed files
with
148 additions
and
45 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
113 changes: 113 additions & 0 deletions
113
numba_dpex/tests/kernel_tests/test_kernel_launch_params.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
# SPDX-FileCopyrightText: 2020 - 2022 Intel Corporation | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import pytest | ||
|
||
import numba_dpex as dpex | ||
from numba_dpex.core.exceptions import ( | ||
IllegalRangeValueError, | ||
InvalidKernelLaunchArgsError, | ||
) | ||
|
||
|
||
@dpex.kernel | ||
def vecadd(a, b, c): | ||
i = dpex.get_global_id(0) | ||
a[i] = b[i] + c[i] | ||
|
||
|
||
def test_1D_global_range_as_int(): | ||
k = vecadd[10] | ||
assert k._global_range == [10] | ||
assert k._local_range is None | ||
|
||
|
||
def test_1D_global_range_as_one_tuple(): | ||
k = vecadd[ | ||
10, | ||
] | ||
assert k._global_range == [10] | ||
assert k._local_range is None | ||
|
||
|
||
def test_1D_global_range_as_list(): | ||
k = vecadd[[10]] | ||
assert k._global_range == [10] | ||
assert k._local_range is None | ||
|
||
|
||
def test_1D_global_range_and_1D_local_range(): | ||
k = vecadd[10, 10] | ||
assert k._global_range == [10] | ||
assert k._local_range == [10] | ||
|
||
|
||
def test_1D_global_range_and_1D_local_range2(): | ||
k = vecadd[[10, 10]] | ||
assert k._global_range == [10] | ||
assert k._local_range == [10] | ||
|
||
|
||
def test_1D_global_range_and_1D_local_range3(): | ||
k = vecadd[(10,), (10,)] | ||
assert k._global_range == [10] | ||
assert k._local_range == [10] | ||
|
||
|
||
def test_2D_global_range_and_2D_local_range(): | ||
k = vecadd[(10, 10), (10, 10)] | ||
assert k._global_range == [10, 10] | ||
assert k._local_range == [10, 10] | ||
|
||
|
||
def test_2D_global_range_and_2D_local_range2(): | ||
k = vecadd[[10, 10], (10, 10)] | ||
assert k._global_range == [10, 10] | ||
assert k._local_range == [10, 10] | ||
|
||
|
||
def test_2D_global_range_and_2D_local_range3(): | ||
k = vecadd[(10, 10), [10, 10]] | ||
assert k._global_range == [10, 10] | ||
assert k._local_range == [10, 10] | ||
|
||
|
||
def test_2D_global_range_and_2D_local_range4(): | ||
k = vecadd[[10, 10], [10, 10]] | ||
assert k._global_range == [10, 10] | ||
assert k._local_range == [10, 10] | ||
|
||
|
||
def test_deprecation_warning_for_empty_local_range(): | ||
with pytest.deprecated_call(): | ||
k = vecadd[[10, 10], []] | ||
assert k._global_range == [10, 10] | ||
assert k._local_range is None | ||
|
||
|
||
def test_deprecation_warning_for_empty_local_range2(): | ||
with pytest.deprecated_call(): | ||
k = vecadd[10, []] | ||
assert k._global_range == [10] | ||
assert k._local_range is None | ||
|
||
|
||
def test_illegal_kernel_launch_arg(): | ||
with pytest.raises(InvalidKernelLaunchArgsError): | ||
vecadd[10, 10, []] | ||
|
||
|
||
def test_illegal_range_error(): | ||
with pytest.raises(IllegalRangeValueError): | ||
vecadd[[], []] | ||
|
||
|
||
def test_illegal_range_error2(): | ||
with pytest.raises(IllegalRangeValueError): | ||
vecadd[[], 10] | ||
|
||
|
||
def test_illegal_range_error3(): | ||
with pytest.raises(IllegalRangeValueError): | ||
vecadd[(), 10] |