Skip to content

Commit

Permalink
Overload generic item's attribute 'dimensions'
Browse files Browse the repository at this point in the history
  • Loading branch information
ZzEeKkAa committed Feb 23, 2024
1 parent aa319c3 commit da1775c
Showing 1 changed file with 22 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import llvmlite.ir as llvmir
from numba.core import cgutils, types
from numba.core.errors import TypingError
from numba.extending import intrinsic, overload_method
from numba.extending import intrinsic, overload_attribute, overload_method

from numba_dpex.core.types.kernel_api.index_space_ids import (
GroupType,
Expand Down Expand Up @@ -248,3 +248,24 @@ def ol_nd_item_get_group_impl(nd_item):
return _intrinsic_get_group(nd_item)

return ol_nd_item_get_group_impl


@overload_attribute(GroupType, "dimensions", target=DPEX_KERNEL_EXP_TARGET_NAME)
@overload_attribute(ItemType, "dimensions", target=DPEX_KERNEL_EXP_TARGET_NAME)
@overload_attribute(
NdItemType, "dimensions", target=DPEX_KERNEL_EXP_TARGET_NAME
)
def ol_nd_item_dimensions(item):
"""
SPIR-V overload for :meth:`numba_dpex.kernel_api.<generic_item>.dimensions`.
Generates the same LLVM IR instruction as dpcpp for the
`sycl::<generic_item>::dimensions` attribute.
"""
dimensions = item.ndim

# pylint: disable=unused-argument
def ol_nd_item_get_group_impl(item):
return dimensions

return ol_nd_item_get_group_impl

0 comments on commit da1775c

Please sign in to comment.