Skip to content

Commit

Permalink
Adds a kernel launcher function for mock kernels.
Browse files Browse the repository at this point in the history
  • Loading branch information
Diptorup Deb committed Feb 11, 2024
1 parent 38611a6 commit c7f897a
Show file tree
Hide file tree
Showing 3 changed files with 148 additions and 0 deletions.
2 changes: 2 additions & 0 deletions numba_dpex/kernel_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .atomic_ref import AtomicRef
from .barrier import group_barrier
from .index_space_ids import Group, Item, NdItem
from .launcher import call_kernel
from .memory_enums import AddressSpace, MemoryOrder, MemoryScope
from .ranges import NdRange, Range

Expand All @@ -26,4 +27,5 @@
"NdItem",
"Item",
"group_barrier",
"call_kernel",
]
125 changes: 125 additions & 0 deletions numba_dpex/kernel_api/launcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# SPDX-FileCopyrightText: 2024 Intel Corporation
#
# SPDX-License-Identifier: Apache-2.0

"""Implementation of mock kernel launcher functions
"""

from inspect import signature
from itertools import product

from .index_space_ids import Group, Item, NdItem
from .ranges import NdRange, Range


def _range_kernel_launcher(kernel_fn, index_range, *kernel_args):
"""Executes a function that mocks a range kernel.
Converts the range into a set of index tuple that represent an element in
the iteration domain over which the kernel will be executed. Then the
kernel is called sequentially over that set of indices after each index
value is used to construct an Item object.
Args:
kernel_fn : A callable function object
index_range (numba_dpex.Range): An instance of a Range object
Raises:
ValueError: If the number of passed in kernel arguments is not the
number of function parameters subtracted by one. The first kernel
argument is expected to be an Item object.
"""

range_sets = [range(ir) for ir in index_range]
index_tuples = list(product(*range_sets))

for idx in index_tuples:
it = Item(extent=index_range, index=idx)

if len(signature(kernel_fn).parameters) - len(kernel_args) != 1:
raise ValueError(
"Required number of kernel function arguments do not "
"match provided number of kernel args"
)

kernel_fn(it, *kernel_args)


def _ndrange_kernel_launcher(kernel_fn, index_range, *kernel_args):
"""Executes a function that mocks a range kernel.
Args:
kernel_fn : A callable function object
index_range (numba_dpex.NdRange): An instance of a NdRange object
Raises:
ValueError: If the number of passed in kernel arguments is not the
number of function parameters subtracted by one. The first kernel
argument is expected to be an Item object.
"""
group_range = tuple(
gr // lr
for gr, lr in zip(index_range.global_range, index_range.local_range)
)
local_range_sets = [range(ir) for ir in index_range.local_range]
group_range_sets = [range(gr) for gr in group_range]
local_index_tuples = list(product(*local_range_sets))
group_index_tuples = list(product(*group_range_sets))

# Loop over the groups (parallel loop)
for gidx in group_index_tuples:
# loop over work items in the group (parallel loop)
for lidx in local_index_tuples:
global_id = []
# to calculate global indices
for dim, gidx_val in enumerate(gidx):
global_id.append(
gidx_val * index_range.local_range[dim] + lidx[dim]
)
# Every NdItem has its own global Item, local Item and Group
nditem = NdItem(
global_item=Item(
extent=index_range.global_range, index=global_id
),
local_item=Item(extent=index_range.local_range, index=lidx),
group=Group(
index_range.global_range,
index_range.local_range,
group_range,
gidx,
),
)

if len(signature(kernel_fn).parameters) - len(kernel_args) != 1:
raise ValueError(
"Required number of kernel function arguments do not "
"match provided number of kernel args"
)

kernel_fn(nditem, *kernel_args)


def call_kernel(kernel_fn, index_range, *kernel_args):
"""Mocks the launching of a kernel function over either a Range or NdRange.
Args:
kernel_fn : A callable function object
index_range (numba_dpex.Range): An instance of a Range object
Raises:
ValueError: If the first positional argument is not callable
ValueError: If the second positional argument is not a Range or an
Ndrange object
"""
if not callable(kernel_fn):
raise ValueError(
"Expected the first positional argument to be a function object"
)
if isinstance(index_range, Range):
_range_kernel_launcher(kernel_fn, index_range, *kernel_args)
elif isinstance(index_range, NdRange):
_ndrange_kernel_launcher(kernel_fn, index_range, *kernel_args)
else:
raise ValueError(
"Expected second positional argument to be Range or NdRange object"
)
21 changes: 21 additions & 0 deletions numba_dpex/kernel_api/ranges.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@

from collections.abc import Iterable

from numba_dpex.core.exceptions import (
UnmatchedNumberOfRangeDimsError,
UnsupportedGroupWorkItemSizeError,
)


class Range(tuple):
"""A data structure to encapsulate a single kernel launch parameter.
Expand Down Expand Up @@ -168,6 +173,22 @@ def __init__(self, global_size, local_size):
+ "must be of either type Range or Iterable of int's."
)

if len(self._local_range) != len(self._global_range):
raise UnmatchedNumberOfRangeDimsError(
kernel_name="",
global_ndims=len(self._global_range),
local_ndims=len(self._local_range),
)

for i, _ in enumerate(self._global_range):
if self._global_range[i] % self._local_range[i] != 0:
raise UnsupportedGroupWorkItemSizeError(
kernel_name="",
dim=i,
work_groups=self._global_range[i],
work_items=self._local_range[i],
)

@property
def global_range(self):
"""Accessor for global_range.
Expand Down

0 comments on commit c7f897a

Please sign in to comment.