Skip to content

Commit

Permalink
Register index range methods as jitable
Browse files Browse the repository at this point in the history
  • Loading branch information
ZzEeKkAa committed Feb 28, 2024
1 parent 5a73d0b commit a425666
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,11 @@ def register_jitable_method(type_, method):


register_jitable_method(ItemType, Item.get_linear_id)
register_jitable_method(ItemType, Item.get_linear_range)
register_jitable_method(NdItemType, NdItem.get_global_linear_id)
register_jitable_method(NdItemType, NdItem.get_global_linear_range)
register_jitable_method(NdItemType, NdItem.get_local_linear_range)
register_jitable_method(NdItemType, NdItem.get_local_linear_id)
register_jitable_method(GroupType, Group.get_group_linear_id)
register_jitable_method(GroupType, Group.get_group_linear_range)
register_jitable_method(GroupType, Group.get_local_linear_range)
127 changes: 107 additions & 20 deletions numba_dpex/tests/experimental/test_index_space_ids.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,19 @@ def set_last_one_item(item: Item, a):
a[i] = 1


@dpex_exp.kernel
def set_last_one_linear_item(item: Item, a):
i = item.get_linear_range() - 1
a[i] = 1


@dpex_exp.kernel
def set_last_one_linear_nd_item(nd_item: NdItem, a):
i = nd_item.get_global_linear_range() - 1
a[0] = i
a[i] = 1


@dpex_exp.kernel
def set_last_one_nd_item(item: NdItem, a):
if item.get_global_id(0) == 0:
Expand All @@ -43,6 +56,20 @@ def set_last_one_nd_item(item: NdItem, a):
a[i] = 1


@dpex_exp.kernel
def set_last_group_one_linear_nd_item(nd_item: NdItem, a):
i = nd_item.get_local_linear_range() - 1
a[0] = i
a[i] = 1


@dpex_exp.kernel
def set_last_group_one_group_linear_nd_item(nd_item: NdItem, a):
i = nd_item.get_group().get_local_linear_range() - 1
a[0] = i
a[i] = 1


@dpex_exp.kernel
def set_last_group_one_nd_item(item: NdItem, a):
if item.get_global_id(0) == 0:
Expand Down Expand Up @@ -99,6 +126,12 @@ def _get_group_range_driver(nditem: NdItem, a):
a[i] = g.get_group_range(0)


def _get_group_linear_range_driver(nditem: NdItem, a):
i = nditem.get_global_linear_id()
g = nditem.get_group()
a[i] = g.get_group_linear_range()


def _get_group_local_range_driver(nditem: NdItem, a):
i = nditem.get_global_id(0)
g = nditem.get_group()
Expand All @@ -122,11 +155,34 @@ def test_item_get_range():
assert np.array_equal(a.asnumpy(), want)


def test_nd_item_get_global_range():
@pytest.mark.parametrize(
"rng",
[dpex.Range(_SIZE), dpex.Range(1, _GROUP_SIZE, int(_SIZE / _GROUP_SIZE))],
)
def test_item_get_linear_range(rng):
a = dpnp.zeros(_SIZE, dtype=dpnp.float32)
dpex_exp.call_kernel(
set_last_one_nd_item, dpex.NdRange((a.size,), (_GROUP_SIZE,)), a
)
dpex_exp.call_kernel(set_last_one_linear_item, rng, a)

want = np.zeros(a.size, dtype=np.float32)
want[-1] = 1

assert np.array_equal(a.asnumpy(), want)


@pytest.mark.parametrize(
"kernel,rng",
[
(set_last_one_nd_item, dpex.NdRange((_SIZE,), (_GROUP_SIZE,))),
(set_last_one_linear_nd_item, dpex.NdRange((_SIZE,), (_GROUP_SIZE,))),
(
set_last_one_linear_nd_item,
dpex.NdRange((1, 1, _SIZE), (1, 1, _GROUP_SIZE)),
),
],
)
def test_nd_item_get_global_range(kernel, rng):
a = dpnp.zeros(_SIZE, dtype=dpnp.float32)
dpex_exp.call_kernel(kernel, rng, a)

want = np.zeros(a.size, dtype=np.float32)
want[-1] = 1
Expand All @@ -135,11 +191,31 @@ def test_nd_item_get_global_range():
assert np.array_equal(a.asnumpy(), want)


def test_nd_item_get_local_range():
@pytest.mark.parametrize(
"kernel,rng",
[
(set_last_group_one_nd_item, dpex.NdRange((_SIZE,), (_GROUP_SIZE,))),
(
set_last_group_one_linear_nd_item,
dpex.NdRange((_SIZE,), (_GROUP_SIZE,)),
),
(
set_last_group_one_linear_nd_item,
dpex.NdRange((1, 1, _SIZE), (1, 1, _GROUP_SIZE)),
),
(
set_last_group_one_group_linear_nd_item,
dpex.NdRange((_SIZE,), (_GROUP_SIZE,)),
),
(
set_last_group_one_group_linear_nd_item,
dpex.NdRange((1, 1, _SIZE), (1, 1, _GROUP_SIZE)),
),
],
)
def test_nd_item_get_local_range(kernel, rng):
a = dpnp.zeros(_SIZE, dtype=dpnp.float32)
dpex_exp.call_kernel(
set_last_group_one_nd_item, dpex.NdRange((a.size,), (_GROUP_SIZE,)), a
)
dpex_exp.call_kernel(kernel, rng, a)

want = np.zeros(a.size, dtype=np.float32)
want[_GROUP_SIZE - 1] = 1
Expand Down Expand Up @@ -240,21 +316,32 @@ def test_get_group_id(driver, rng):
assert np.array_equal(ka.asnumpy(), expected)


def test_get_group_range():
global_size = 100
group_size = 20
num_groups = global_size // group_size
@pytest.mark.parametrize(
"driver,rng",
[
(_get_group_range_driver, dpex.NdRange((_SIZE,), (_GROUP_SIZE,))),
(
_get_group_linear_range_driver,
dpex.NdRange((_SIZE,), (_GROUP_SIZE,)),
),
(
_get_group_linear_range_driver,
dpex.NdRange((1, 1, _SIZE), (1, 1, _GROUP_SIZE)),
),
],
)
def test_get_group_range(driver, rng):
num_groups = _SIZE // _GROUP_SIZE

a = dpnp.empty(global_size, dtype=dpnp.int32)
ka = dpnp.empty(global_size, dtype=dpnp.int32)
expected = np.empty(global_size, dtype=np.int32)
ndrange = NdRange((global_size,), (group_size,))
dpex_exp.call_kernel(dpex_exp.kernel(_get_group_range_driver), ndrange, a)
kapi_call_kernel(_get_group_range_driver, ndrange, ka)
a = dpnp.empty(_SIZE, dtype=dpnp.int32)
ka = dpnp.empty(_SIZE, dtype=dpnp.int32)
expected = np.empty(_SIZE, dtype=np.int32)
dpex_exp.call_kernel(dpex_exp.kernel(driver), rng, a)
kapi_call_kernel(driver, rng, ka)

for gid in range(num_groups):
for lid in range(group_size):
expected[gid * group_size + lid] = num_groups
for lid in range(_GROUP_SIZE):
expected[gid * _GROUP_SIZE + lid] = num_groups

assert np.array_equal(a.asnumpy(), expected)
assert np.array_equal(ka.asnumpy(), expected)
Expand Down

0 comments on commit a425666

Please sign in to comment.