Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Group and overload NdItem.get_group() #1310

Merged
merged 1 commit into from
Feb 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@

from llvmlite import ir as llvmir
from numba.core import cgutils, types
from numba.core.errors import TypingError
from numba.extending import intrinsic, overload

from numba_dpex.core import itanium_mangler as ext_itanium_mangler
from numba_dpex.experimental.core.types.kernel_api.items import GroupType
from numba_dpex.experimental.target import DPEX_KERNEL_EXP_TARGET_NAME
from numba_dpex.kernel_api import group_barrier
from numba_dpex.kernel_api.memory_enums import MemoryOrder, MemoryScope
Expand All @@ -26,7 +28,7 @@
except ValueError:
warnings.warn(
"convergent attribute is supported only starting llvmlite "
+ "0.42. Not setting this attribute may result into unexpected behavior"
+ "0.42. Not setting this attribute may result in unexpected behavior"
+ "when using group_barrier"
)
_SUPPORT_CONVERGENT = False
Expand Down Expand Up @@ -76,7 +78,6 @@ def _intrinsic_barrier_codegen(
llvmir.IntType(32),
]

# TODO: split the function declaration from call
fn = cgutils.get_or_insert_function(
builder.module,
llvmir.FunctionType(llvmir.VoidType(), spirv_fn_arg_types),
Expand Down Expand Up @@ -105,34 +106,39 @@ def _intrinsic_barrier_codegen(
prefer_literal=True,
target=DPEX_KERNEL_EXP_TARGET_NAME,
)
def ol_group_barrier(fence_scope=MemoryScope.WORK_GROUP):
def ol_group_barrier(group, fence_scope=MemoryScope.WORK_GROUP):
"""SPIR-V overload for
:meth:`numba_dpex.kernel_api.group_barrier`.

Generates the same LLVM IR instruction as dpcpp for the
Generates the same LLVM IR instruction as DPC++ for the SYCL
`group_barrier` function.

Per SYCL spec, group_barrier must perform both control barrier and memory
fence operations. Hence, group_barrier requires two scopes and memory
consistency specification as three arguments.
fence operations. Hence, group_barrier requires two scopes and one memory
consistency specification as its three arguments.

mem_scope - scope of any memory consistency operations that are performed by
the barrier. By default, mem_scope is set to `work_group`.
exec_scope - scope that determines the set of work-items that synchronize at
barrier. Set to `work_group` for group_barrier always.
spirv_memory_semantics_mask - Based on sycl implementation.

Mask that is set to use sequential consistency memory order semantics
always.
spirv_memory_semantics_mask - Based on SYCL implementation. Always set to
use sequential consistency memory order.
"""

if not isinstance(group, GroupType):
raise TypingError(
"Expected a group should to be a GroupType value, but "
f"encountered {type(group)}"
)

mem_scope = _get_memory_scope(fence_scope)
exec_scope = get_scope(MemoryScope.WORK_GROUP.value)
spirv_memory_semantics_mask = get_memory_semantics_mask(
MemoryOrder.SEQ_CST.value
)

def _ol_group_barrier_impl(
group,
fence_scope=MemoryScope.WORK_GROUP,
): # pylint: disable=unused-argument
# pylint: disable=no-value-for-parameter
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@
Implements the SPIR-V overloads for the kernel_api.items class methods.
"""

from numba.core import types
from numba.core import cgutils, types
from numba.core.errors import TypingError
from numba.extending import intrinsic, overload_method

from numba_dpex.experimental.core.types.kernel_api.items import (
GroupType,
ItemType,
NdItemType,
)
Expand All @@ -26,7 +27,7 @@ def _intrinsic_get_global_id(
"""Generates instruction for spirv i64 get_global_id(i32) call."""
sig = types.int64(types.int32)

def _intrinsic_exchange_gen(context, builder, sig, args):
def _intrinsic_get_global_id_gen(context, builder, sig, args):
[dim] = args
get_global_id = _declare_function(
# unsigned int - is what demangler returns from IR instruction
Expand All @@ -43,13 +44,35 @@ def _intrinsic_exchange_gen(context, builder, sig, args):
res = builder.call(get_global_id, [dim])
return context.cast(builder, res, types.uintp, types.intp)

return sig, _intrinsic_exchange_gen
return sig, _intrinsic_get_global_id_gen


@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME)
def _intrinsic_get_group(
ty_context, ty_nd_item: NdItemType # pylint: disable=unused-argument
):
"""Generates group with a dimension of nd_item."""

if not isinstance(ty_nd_item, NdItemType):
raise TypingError(
f"Expected an NdItemType value, but encountered {ty_nd_item}"
)

ty_group = GroupType(ty_nd_item.ndim)
sig = ty_group(ty_nd_item)

# pylint: disable=unused-argument
def _intrinsic_get_group_gen(context, builder, sig, args):
group_struct = cgutils.create_struct_proxy(ty_group)(context, builder)
# pylint: disable=protected-access
return group_struct._getvalue()

return sig, _intrinsic_get_group_gen


@overload_method(ItemType, "get_id", target=DPEX_KERNEL_EXP_TARGET_NAME)
def ol_item_get_id(item, dim):
"""SPIR-V overload for
:meth:`numba_dpex.kernel_api.Item.get_id`.
"""SPIR-V overload for :meth:`numba_dpex.kernel_api.Item.get_id`.

Generates the same LLVM IR instruction as dpcpp for the
`sycl::item::get_id` function.
Expand All @@ -58,25 +81,30 @@ def ol_item_get_id(item, dim):
TypingError: When argument is not an integer.
"""
if not isinstance(item, ItemType):
raise TypingError("Only Item is supported")
raise TypingError(
"Expected an item should to be an Item value, but "
f"encountered {type(item)}"
)

if not isinstance(dim, types.Integer):
raise TypingError("Only integers supported")
raise TypingError(
diptorupd marked this conversation as resolved.
Show resolved Hide resolved
"Expected an Item's dim should to be an Integer value, but "
f"encountered {type(dim)}"
)

# pylint: disable=unused-argument
def ol_get_id(item, dim):
def ol_item_get_id_impl(item, dim):
# pylint: disable=no-value-for-parameter
return _intrinsic_get_global_id(dim)

return ol_get_id
return ol_item_get_id_impl


@overload_method(
NdItemType, "get_global_id", target=DPEX_KERNEL_EXP_TARGET_NAME
)
def ol_nd_item_get_global_id(nd_item, dim):
"""SPIR-V overload for
:meth:`numba_dpex.kernel_api.NdItem.get_global_id`.
"""SPIR-V overload for :meth:`numba_dpex.kernel_api.NdItem.get_global_id`.

Generates the same LLVM IR instruction as dpcpp for the
`sycl::nd_item::get_global_id` function.
Expand All @@ -85,14 +113,46 @@ def ol_nd_item_get_global_id(nd_item, dim):
TypingError: When argument is not an integer.
"""
if not isinstance(nd_item, NdItemType):
raise TypingError("Only NdItem is supported")
# since it is a method overload, this error should not be reached
raise TypingError(
"Expected a nd_item should to be a NdItem value, but "
f"encountered {type(nd_item)}"
)

if not isinstance(dim, types.Integer):
raise TypingError("Only integers supported")
raise TypingError(
"Expected a NdItem's dim should to be an Integer value, but "
f"encountered {type(dim)}"
)

# pylint: disable=unused-argument
def ol_get_global_id(nd_item, dim):
def ol_nd_item_get_global_id_impl(nd_item, dim):
# pylint: disable=no-value-for-parameter
return _intrinsic_get_global_id(dim)

return ol_get_global_id
return ol_nd_item_get_global_id_impl


@overload_method(NdItemType, "get_group", target=DPEX_KERNEL_EXP_TARGET_NAME)
def ol_nd_item_get_group(nd_item):
"""SPIR-V overload for :meth:`numba_dpex.kernel_api.NdItem.get_group`.

Generates the same LLVM IR instruction as dpcpp for the
`sycl::nd_item::get_group` function.

Raises:
TypingError: When argument is not NdItem.
"""
if not isinstance(nd_item, NdItemType):
# since it is a method overload, this error should not be reached
raise TypingError(
"Expected a nd_item should to be a NdItem value, but "
f"encountered {type(nd_item)}"
)

# pylint: disable=unused-argument
def ol_nd_item_get_group_impl(nd_item):
# pylint: disable=no-value-for-parameter
return _intrinsic_get_group(nd_item)

return ol_nd_item_get_group_impl
25 changes: 25 additions & 0 deletions numba_dpex/experimental/core/types/kernel_api/items.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,31 @@
from numba.core import errors, types


class GroupType(types.Type):
"""Numba-dpex type corresponding to :class:`numba_dpex.kernel_api.Group`"""

def __init__(self, ndim: int):
diptorupd marked this conversation as resolved.
Show resolved Hide resolved
self._ndim = ndim
if ndim < 1 or ndim > 3:
raise errors.TypingError(
"ItemType can only have 1, 2 or 3 dimensions"
)
super().__init__(name="Group<" + str(ndim) + ">")

@property
def ndim(self):
"""Returns number of dimensions"""
return self._ndim

@property
def key(self):
"""Numba type specific overload"""
return self._ndim

def cast_python_value(self, args):
raise NotImplementedError


class ItemType(types.Type):
"""Numba-dpex type corresponding to :class:`numba_dpex.kernel_api.Item`"""

Expand Down
7 changes: 7 additions & 0 deletions numba_dpex/experimental/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import numba_dpex.core.datamodel.models as dpex_core_models
from numba_dpex.experimental.core.types.kernel_api.items import (
GroupType,
ItemType,
NdItemType,
)
Expand Down Expand Up @@ -73,6 +74,9 @@ def _init_exp_data_model_manager() -> DataModelManager:
dmm.register(IntEnumLiteral, IntEnumLiteralModel)
dmm.register(AtomicRefType, AtomicRefModel)

# Register the GroupType type
dmm.register(GroupType, EmptyStructModel)

# Register the ItemType type
dmm.register(ItemType, EmptyStructModel)

Expand All @@ -87,6 +91,9 @@ def _init_exp_data_model_manager() -> DataModelManager:
# Register any new type that should go into numba.core.datamodel.default_manager
register_model(KernelDispatcherType)(models.OpaqueModel)

# Register the GroupType type
register_model(GroupType)(EmptyStructModel)

# Register the ItemType type
register_model(ItemType)(EmptyStructModel)

Expand Down
18 changes: 17 additions & 1 deletion numba_dpex/experimental/typeof.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@
from numba.extending import typeof_impl

from numba_dpex.experimental.core.types.kernel_api.items import (
GroupType,
ItemType,
NdItemType,
)
from numba_dpex.kernel_api import AtomicRef, Item, NdItem
from numba_dpex.kernel_api import AtomicRef, Group, Item, NdItem

from .dpcpp_types import AtomicRefType

Expand All @@ -40,6 +41,21 @@ def typeof_atomic_ref(val: AtomicRef, ctx) -> AtomicRefType:
)


@typeof_impl.register(Group)
def typeof_group(val: Group, c):
"""Registers the type inference implementation function for a
numba_dpex.kernel_api.Group PyObject.

Args:
val : An instance of numba_dpex.kernel_api.Group.
c : Unused argument used to be consistent with Numba API.

Returns: A numba_dpex.experimental.core.types.kernel_api.items.GroupType
instance.
"""
return GroupType(val.ndim)


@typeof_impl.register(Item)
def typeof_item(val: Item, c):
"""Registers the type inference implementation function for a
Expand Down
3 changes: 2 additions & 1 deletion numba_dpex/kernel_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from .atomic_ref import AtomicRef
from .barrier import group_barrier
from .index_space_ids import Item, NdItem
from .index_space_ids import Group, Item, NdItem
from .memory_enums import AddressSpace, MemoryOrder, MemoryScope
from .ranges import NdRange, Range

Expand All @@ -22,6 +22,7 @@
"MemoryScope",
"NdRange",
"Range",
"Group",
"NdItem",
"Item",
"group_barrier",
Expand Down
3 changes: 2 additions & 1 deletion numba_dpex/kernel_api/barrier.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
"""Python functions that simulate SYCL's barrier primitives.
"""

from .index_space_ids import Group
from .memory_enums import MemoryScope


def group_barrier(fence_scope=MemoryScope.WORK_GROUP):
def group_barrier(group: Group, fence_scope=MemoryScope.WORK_GROUP):
"""Performs a barrier operation across all work-items in a work group.

The function is modeled after the ``sycl::group_barrier`` function. It
Expand Down
19 changes: 18 additions & 1 deletion numba_dpex/kernel_api/index_space_ids.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,23 @@
from .ranges import Range


# pylint: disable=too-few-public-methods
class Group:
"""Analogue to the ``sycl::group`` type."""

def __init__(
self,
global_range: Range,
local_range: Range,
group_range: Range,
index: list,
):
self._global_range = global_range
self._local_range = local_range
self._group_range = group_range
self._index = index


class Item:
"""Analogue to the ``sycl::item`` type. Identifies an instance of the
function object executing at each point in an Range.
Expand Down Expand Up @@ -60,7 +77,7 @@ class NdItem:
"""

# TODO: define group type
def __init__(self, global_item: Item, local_item: Item, group: any):
def __init__(self, global_item: Item, local_item: Item, group: Group):
# TODO: assert offset and dimensions
self._global_item = global_item
self._local_item = local_item
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def _kernel(nd_item: NdItem, a):
i = nd_item.get_global_id(0)

a[i] += 1
group_barrier(MemoryScope.DEVICE)
group_barrier(nd_item.get_group(), MemoryScope.DEVICE)

if i == 0:
for idx in range(1, a.size):
Expand Down
Loading