Skip to content

Commit

Permalink
Add pure Python kernel launchers
Browse files Browse the repository at this point in the history
  • Loading branch information
Diptorup Deb committed Feb 10, 2024
1 parent e627589 commit 0e85c6a
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 19 deletions.
2 changes: 2 additions & 0 deletions numba_dpex/kernel_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .atomic_ref import AtomicRef
from .barrier import group_barrier
from .index_space_ids import Group, Item, NdItem
from .launcher import call_kernel
from .memory_enums import AddressSpace, MemoryOrder, MemoryScope
from .ranges import NdRange, Range

Expand All @@ -26,4 +27,5 @@
"NdItem",
"Item",
"group_barrier",
"call_kernel",
]
11 changes: 0 additions & 11 deletions numba_dpex/kernel_api/index_space_ids.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,6 @@ def __init__(
self._global_range = global_range
self._local_range = local_range
self._group_range = group_range
max_local_range = []
for i in self._global_range.ndim:
max_local_range.append(self._global_range[i] / self._local_range[i])
# pylint: disable=no-value-for-parameter
self._max_local_range = Range(*max_local_range)
self._index = index

def get_group_id(self, dim):
Expand Down Expand Up @@ -77,12 +72,6 @@ def get_local_range(self):
"""
return self._local_range

def get_max_local_range(self):
"""Return a range representing the maximum number of work-items in any
work-group in the nd_range.
"""
return self._max_local_range


class Item:
"""Analogue to the ``sycl::item`` type. Identifies an instance of the
Expand Down
83 changes: 75 additions & 8 deletions numba_dpex/kernel_api/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
#
# SPDX-License-Identifier: Apache-2.0

"""Implementation of mock kernel launcher functions
"""

from inspect import signature
from itertools import product

Expand All @@ -12,19 +15,26 @@
def _range_kernel_launcher(kernel_fn, index_range, *kernel_args):
"""Executes a function that mocks a range kernel.
Converts the range into a set of index tuple that represent an element in
the iteration domain over which the kernel will be executed. Then the
kernel is called sequentially over that set of indices after each index
value is used to construct an Item object.
Args:
kernel_fn (_type_): _description_
index_range (_type_): _description_
kernel_fn : A callable function object
index_range (numba_dpex.Range): An instance of a Range object
Raises:
ValueError: _description_
ValueError: If the number of passed in kernel arguments is not the
number of function parameters subtracted by one. The first kernel
argument is expected to be an Item object.
"""

range_sets = [range(ir) for ir in index_range]
index_tuples = list(product(*range_sets))

for idex in index_tuples:
it = Item(extent=index_range, index=idex)
for idx in index_tuples:
it = Item(extent=index_range, index=idx)

if len(signature(kernel_fn).parameters) - len(kernel_args) != 1:
raise ValueError(
Expand All @@ -36,12 +46,69 @@ def _range_kernel_launcher(kernel_fn, index_range, *kernel_args):


def _ndrange_kernel_launcher(kernel_fn, index_range, *kernel_args):
pass
group_range = tuple(
gr // lr
for gr, lr in zip(index_range.global_range, index_range.local_range)
)
local_range_sets = [range(ir) for ir in index_range.local_range]
group_range_sets = [range(gr) for gr in group_range]
local_index_tuples = list(product(*local_range_sets))
group_index_tuples = list(product(*group_range_sets))

# Loop over the groups (parallel loop)
for gidx in group_index_tuples:
# loop over work items in the group (parallel loop)
for lidx in local_index_tuples:
global_id = []
# to calculate global indices
for dim, gidx_val in enumerate(gidx):
global_id.append(
gidx_val * index_range.local_range[dim] + lidx[dim]
)
group = Group(
index_range.global_range,
index_range.local_range,
group_range,
gidx,
)
nditem = NdItem(
global_item=Item(
extent=index_range.global_range, index=global_id
),
local_item=Item(extent=index_range.local_range, index=lidx),
group=group,
)

if len(signature(kernel_fn).parameters) - len(kernel_args) != 1:
raise ValueError(
"Required number of kernel function arguments do not "
"match provided number of kernel args"
)

kernel_fn(nditem, *kernel_args)

def kernel_launcher(kernel_fn, index_range, *kernel_args):

def call_kernel(kernel_fn, index_range, *kernel_args):
"""Mocks the launching of a kernel function over either a Range or NdRange.
Args:
kernel_fn : A callable function object
index_range (numba_dpex.Range): An instance of a Range object
Raises:
ValueError: If the first positional argument is not callable
ValueError: If the second positional argument is not a Range or an
Ndrange object
"""
if not callable(kernel_fn):
raise ValueError(
"Expected the first positional argument to be a function object"
)
if isinstance(index_range, Range):
_range_kernel_launcher(kernel_fn, index_range, *kernel_args)
elif isinstance(index_range, NdRange):
_ndrange_kernel_launcher(kernel_fn, index_range, *kernel_args)
pass
else:
raise ValueError(
"Expected second positional argument to be Range or NdRange object"
)
21 changes: 21 additions & 0 deletions numba_dpex/kernel_api/ranges.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@

from collections.abc import Iterable

from numba_dpex.core.exceptions import (
UnmatchedNumberOfRangeDimsError,
UnsupportedGroupWorkItemSizeError,
)


class Range(tuple):
"""A data structure to encapsulate a single kernel launch parameter.
Expand Down Expand Up @@ -168,6 +173,22 @@ def __init__(self, global_size, local_size):
+ "must be of either type Range or Iterable of int's."
)

if len(self._local_range) != len(self._global_range):
raise UnmatchedNumberOfRangeDimsError(
kernel_name="",
global_ndims=len(self._global_range),
local_ndims=len(self._local_range),
)

for i, _ in enumerate(self._global_range):
if self._global_range[i] % self._local_range[i] != 0:
raise UnsupportedGroupWorkItemSizeError(
kernel_name="",
dim=i,
work_groups=self._global_range[i],
work_items=self._local_range[i],
)

@property
def global_range(self):
"""Accessor for global_range.
Expand Down

0 comments on commit 0e85c6a

Please sign in to comment.