Skip to content

Commit

Permalink
Add Group and overload NdItem.get_group()
Browse files Browse the repository at this point in the history
  • Loading branch information
ZzEeKkAa committed Feb 3, 2024
1 parent b494492 commit 230f170
Show file tree
Hide file tree
Showing 9 changed files with 124 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@

from llvmlite import ir as llvmir
from numba.core import cgutils, types
from numba.core.errors import TypingError
from numba.extending import intrinsic, overload

from numba_dpex.core import itanium_mangler as ext_itanium_mangler
from numba_dpex.experimental.core.types.kernel_api.items import GroupType
from numba_dpex.experimental.target import DPEX_KERNEL_EXP_TARGET_NAME
from numba_dpex.kernel_api import group_barrier
from numba_dpex.kernel_api.memory_enums import MemoryOrder, MemoryScope
Expand Down Expand Up @@ -105,7 +107,7 @@ def _intrinsic_barrier_codegen(
prefer_literal=True,
target=DPEX_KERNEL_EXP_TARGET_NAME,
)
def ol_group_barrier(fence_scope=MemoryScope.WORK_GROUP):
def ol_group_barrier(group, fence_scope=MemoryScope.WORK_GROUP):
"""SPIR-V overload for
:meth:`numba_dpex.kernel_api.group_barrier`.
Expand All @@ -126,13 +128,17 @@ def ol_group_barrier(fence_scope=MemoryScope.WORK_GROUP):
always.
"""

if not isinstance(group, GroupType):
raise TypingError("Only Group is supported")

mem_scope = _get_memory_scope(fence_scope)
exec_scope = get_scope(MemoryScope.WORK_GROUP.value)
spirv_memory_semantics_mask = get_memory_semantics_mask(
MemoryOrder.SEQ_CST.value
)

def _ol_group_barrier_impl(
group,
fence_scope=MemoryScope.WORK_GROUP,
): # pylint: disable=unused-argument
# pylint: disable=no-value-for-parameter
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@
Implements the SPIR-V overloads for the kernel_api.items class methods.
"""

from numba.core import types
from numba.core import cgutils, types
from numba.core.errors import TypingError
from numba.extending import intrinsic, overload_method

from numba_dpex.experimental.core.types.kernel_api.items import (
GroupType,
ItemType,
NdItemType,
)
Expand Down Expand Up @@ -46,6 +47,27 @@ def _intrinsic_exchange_gen(context, builder, sig, args):
return sig, _intrinsic_exchange_gen


@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME)
def _intrinsic_get_group(
ty_context, ty_nd_item: NdItemType # pylint: disable=unused-argument
):
"""Generates group with a dimension of nd_item."""

if not isinstance(ty_nd_item, NdItemType):
raise TypingError("Only NdItem is supported")

ty_group = GroupType(ty_nd_item.ndim)
sig = ty_group(ty_nd_item)

# pylint: disable=unused-argument
def _intrinsic_exchange_gen(context, builder, sig, args):
group_struct = cgutils.create_struct_proxy(ty_group)(context, builder)
# pylint: disable=protected-access
return group_struct._getvalue()

return sig, _intrinsic_exchange_gen


@overload_method(ItemType, "get_id", target=DPEX_KERNEL_EXP_TARGET_NAME)
def ol_item_get_id(item, dim):
"""SPIR-V overload for
Expand Down Expand Up @@ -96,3 +118,25 @@ def ol_get_global_id(nd_item, dim):
return _intrinsic_get_global_id(dim)

return ol_get_global_id


@overload_method(NdItemType, "get_group", target=DPEX_KERNEL_EXP_TARGET_NAME)
def ol_nd_item_get_group(nd_item):
"""SPIR-V overload for
:meth:`numba_dpex.kernel_api.NdItem.get_global_id`.
Generates the same LLVM IR instruction as dpcpp for the
`sycl::nd_item::get_global_id` function.
Raises:
TypingError: When argument is not an integer.
"""
if not isinstance(nd_item, NdItemType):
raise TypingError("Only NdItem is supported")

# pylint: disable=unused-argument
def ol_get_global_id(nd_item):
# pylint: disable=no-value-for-parameter
return _intrinsic_get_group(nd_item)

return ol_get_global_id
25 changes: 25 additions & 0 deletions numba_dpex/experimental/core/types/kernel_api/items.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,31 @@
from numba.core import errors, types


class GroupType(types.Type):
"""Numba-dpex type corresponding to :class:`numba_dpex.kernel_api.Group`"""

def __init__(self, ndim: int):
self._ndim = ndim
if ndim < 1 or ndim > 3:
raise errors.TypingError(
"ItemType can only have 1, 2 or 3 dimensions"
)
super().__init__(name="Group<" + str(ndim) + ">")

@property
def ndim(self):
"""Returns number of dimensions"""
return self._ndim

@property
def key(self):
"""Numba type specific overload"""
return self._ndim

def cast_python_value(self, args):
raise NotImplementedError


class ItemType(types.Type):
"""Numba-dpex type corresponding to :class:`numba_dpex.kernel_api.Item`"""

Expand Down
7 changes: 7 additions & 0 deletions numba_dpex/experimental/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import numba_dpex.core.datamodel.models as dpex_core_models
from numba_dpex.experimental.core.types.kernel_api.items import (
GroupType,
ItemType,
NdItemType,
)
Expand Down Expand Up @@ -73,6 +74,9 @@ def _init_exp_data_model_manager() -> DataModelManager:
dmm.register(IntEnumLiteral, IntEnumLiteralModel)
dmm.register(AtomicRefType, AtomicRefModel)

# Register the GroupType type
dmm.register(GroupType, EmptyStructModel)

# Register the ItemType type
dmm.register(ItemType, EmptyStructModel)

Expand All @@ -87,6 +91,9 @@ def _init_exp_data_model_manager() -> DataModelManager:
# Register any new type that should go into numba.core.datamodel.default_manager
register_model(KernelDispatcherType)(models.OpaqueModel)

# Register the GroupType type
register_model(GroupType)(EmptyStructModel)

# Register the ItemType type
register_model(ItemType)(EmptyStructModel)

Expand Down
18 changes: 17 additions & 1 deletion numba_dpex/experimental/typeof.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@
from numba.extending import typeof_impl

from numba_dpex.experimental.core.types.kernel_api.items import (
GroupType,
ItemType,
NdItemType,
)
from numba_dpex.kernel_api import AtomicRef, Item, NdItem
from numba_dpex.kernel_api import AtomicRef, Group, Item, NdItem

from .dpcpp_types import AtomicRefType

Expand All @@ -40,6 +41,21 @@ def typeof_atomic_ref(val: AtomicRef, ctx) -> AtomicRefType:
)


@typeof_impl.register(Group)
def typeof_group(val: Group, c):
"""Registers the type inference implementation function for a
numba_dpex.kernel_api.Group PyObject.
Args:
val : An instance of numba_dpex.kernel_api.Group.
c : Unused argument used to be consistent with Numba API.
Returns: A numba_dpex.experimental.core.types.kernel_api.items.GroupType
instance.
"""
return GroupType(val.ndim)


@typeof_impl.register(Item)
def typeof_item(val: Item, c):
"""Registers the type inference implementation function for a
Expand Down
3 changes: 2 additions & 1 deletion numba_dpex/kernel_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from .atomic_ref import AtomicRef
from .barrier import group_barrier
from .index_space_ids import Item, NdItem
from .index_space_ids import Group, Item, NdItem
from .memory_enums import AddressSpace, MemoryOrder, MemoryScope
from .ranges import NdRange, Range

Expand All @@ -22,6 +22,7 @@
"MemoryScope",
"NdRange",
"Range",
"Group",
"NdItem",
"Item",
"group_barrier",
Expand Down
3 changes: 2 additions & 1 deletion numba_dpex/kernel_api/barrier.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
"""Python functions that simulate SYCL's barrier primitives.
"""

from .index_space_ids import Group
from .memory_enums import MemoryScope


def group_barrier(fence_scope=MemoryScope.WORK_GROUP):
def group_barrier(group: Group, fence_scope=MemoryScope.WORK_GROUP):
"""Performs a barrier operation across all work-items in a work group.
The function is modeled after the ``sycl::group_barrier`` function. It
Expand Down
19 changes: 18 additions & 1 deletion numba_dpex/kernel_api/index_space_ids.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,23 @@
from .ranges import Range


# pylint: disable=too-few-public-methods
class Group:
"""Analogue to the ``sycl::group`` type."""

def __init__(
self,
global_range: Range,
local_range: Range,
group_range: Range,
index: list,
):
self._global_range = global_range
self._local_range = local_range
self._group_range = group_range
self._index = index


class Item:
"""Analogue to the ``sycl::item`` type. Identifies an instance of the
function object executing at each point in an Range.
Expand Down Expand Up @@ -60,7 +77,7 @@ class NdItem:
"""

# TODO: define group type
def __init__(self, global_item: Item, local_item: Item, group: any):
def __init__(self, global_item: Item, local_item: Item, group: Group):
# TODO: assert offset and dimensions
self._global_item = global_item
self._local_item = local_item
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def _kernel(nd_item: NdItem, a):
i = nd_item.get_global_id(0)

a[i] += 1
group_barrier(MemoryScope.DEVICE)
group_barrier(nd_item.get_group(), MemoryScope.DEVICE)

if i == 0:
for idx in range(1, a.size):
Expand Down

0 comments on commit 230f170

Please sign in to comment.