Skip to content

Commit

Permalink
Add local accessor device func and python simulator tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ZzEeKkAa committed Mar 19, 2024
1 parent 8cad614 commit 2200c3c
Showing 1 changed file with 24 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,54 +9,43 @@

import numba_dpex as dpex
import numba_dpex.experimental as dpex_exp
from numba_dpex.kernel_api import (
LocalAccessor,
MemoryScope,
NdItem,
group_barrier,
)
from numba_dpex.kernel_api import LocalAccessor, NdItem
from numba_dpex.kernel_api import call_kernel as kapi_call_kernel
from numba_dpex.tests._helper import get_all_dtypes

list_of_supported_dtypes = get_all_dtypes(
no_bool=True, no_float16=True, no_none=True, no_complex=True
)


@dpex_exp.kernel
def _kernel1(nd_item: NdItem, a, slm):
i = nd_item.get_global_linear_id()

# TODO: overload nd_item.get_local_id()
j = (nd_item.get_local_id(0),)

slm[j] = 0
group_barrier(nd_item.get_group(), MemoryScope.WORK_GROUP)

for m in range(100):
slm[j] += i * m
group_barrier(nd_item.get_group(), MemoryScope.WORK_GROUP)

a[i] = slm[j]


@dpex_exp.kernel
def _kernel2(nd_item: NdItem, a, slm):
i = nd_item.get_global_linear_id()

# TODO: overload nd_item.get_local_id()
j = (nd_item.get_local_id(0), nd_item.get_local_id(1))

slm[j] = 0
group_barrier(nd_item.get_group(), MemoryScope.WORK_GROUP)

for m in range(100):
slm[j] += i * m
group_barrier(nd_item.get_group(), MemoryScope.WORK_GROUP)

a[i] = slm[j]


@dpex_exp.kernel
def _kernel3(nd_item: NdItem, a, slm):
i = nd_item.get_global_linear_id()

Expand All @@ -68,15 +57,23 @@ def _kernel3(nd_item: NdItem, a, slm):
)

slm[j] = 0
group_barrier(nd_item.get_group(), MemoryScope.WORK_GROUP)

for m in range(100):
slm[j] += i * m
group_barrier(nd_item.get_group(), MemoryScope.WORK_GROUP)

a[i] = slm[j]


def device_func_kernel(func):
_df = dpex_exp.device_func(func)

@dpex_exp.kernel
def _kernel(item, a, slm):
_df(item, a, slm)

return _kernel


@pytest.mark.parametrize("supported_dtype", list_of_supported_dtypes)
@pytest.mark.parametrize(
"nd_range, _kernel",
Expand All @@ -86,7 +83,17 @@ def _kernel3(nd_item: NdItem, a, slm):
(dpex.NdRange((1, 32, 1), (1, 32, 1)), _kernel3),
],
)
def test_local_accessor(supported_dtype, nd_range: dpex.NdRange, _kernel):
@pytest.mark.parametrize(
"call_kernel, kernel",
[
(dpex_exp.call_kernel, dpex_exp.kernel),
(dpex_exp.call_kernel, device_func_kernel),
(kapi_call_kernel, lambda f: f),
],
)
def test_local_accessor(
supported_dtype, nd_range: dpex.NdRange, _kernel, call_kernel, kernel
):
"""A test for passing a LocalAccessor object as a kernel argument."""

N = 32
Expand All @@ -98,7 +105,7 @@ def test_local_accessor(supported_dtype, nd_range: dpex.NdRange, _kernel):
# `4950 * get_global_linear_id` and stores it into the work groups local
# memory. The local memory is of size 32*64 elements of the requested dtype.
# The result is then stored into `a` in global memory
dpex_exp.call_kernel(_kernel, nd_range, a, slm)
call_kernel(kernel(_kernel), nd_range, a, slm)

for idx in range(N):
assert a[idx] == 4950 * idx
Expand Down

0 comments on commit 2200c3c

Please sign in to comment.