Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/index function overloads #1323

Merged
merged 1 commit into from
Feb 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,131 +6,200 @@
Implements the SPIR-V overloads for the kernel_api.items class methods.
"""

import llvmlite.ir as llvmir
from numba.core import cgutils, types
from numba.core.errors import TypingError
from numba.extending import intrinsic, overload_method

from numba_dpex._kernel_api_impl.spirv.target import SPIRVTargetContext
from numba_dpex.experimental.core.types.kernel_api.items import (
GroupType,
ItemType,
NdItemType,
)
from numba_dpex.ocl._declare_function import _declare_function

from ..target import DPEX_KERNEL_EXP_TARGET_NAME


@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME)
def _intrinsic_get_global_id(
ty_context, ty_dim # pylint: disable=unused-argument
def spirv_name(name: str):
"""Converts name to spirv name by adding __spirv_ prefix."""
return "__spirv_" + name


def declare_spirv_const(
builder: llvmir.IRBuilder,
name: str,
):
"""Declares global external spirv constant"""
data = cgutils.add_global_variable(
builder.module,
llvmir.VectorType(llvmir.IntType(64), 3),
spirv_name(name),
addrspace=1,
)
data.linkage = "external"
data.global_constant = True
data.align = 32
data.storage_class = "dso_local local_unnamed_addr"
return data


def _intrinsic_spirv_global_index_const(
ty_context, # pylint: disable=unused-argument
ty_dim, # pylint: disable=unused-argument
const_name: str,
):
"""Generates instruction for spirv i64 get_global_id(i32) call."""
"""Generates instruction to get spirv index from const_name."""
sig = types.int64(types.int32)

def _intrinsic_get_global_id_gen(context, builder, sig, args):
[dim] = args
get_global_id = _declare_function(
# unsigned int - is what demangler returns from IR instruction
# generated by dpcpp. However int32 is passed as an argument.
# Most likely it does not matter, since argument can be from 0 to 3.
# TODO: https://github.com/IntelPython/numba-dpex/issues/936
# Numba generates unnecessary checks because of the type mismatch.
context,
def _intrinsic_spirv_global_index_const_gen(
context: SPIRVTargetContext,
builder: llvmir.IRBuilder,
sig, # pylint: disable=unused-argument
args,
):
index_const = declare_spirv_const(
builder,
"get_global_id",
sig,
["unsigned int"],
const_name,
)
[dim] = args
# TODO: llvmlite does not support gep on vector. Use this instead once
# supported.
# https://github.com/numba/llvmlite/issues/756
# res = builder.gep( # noqa: E800
# global_invocation_id, # noqa: E800
# [cgutils.int32_t(0), cgutils.int32_t(0)], # noqa: E800
# inbounds=True, # noqa: E800
# ) # noqa: E800
# res = builder.load(res, align=32) # noqa: E800

res = builder.extract_element(
builder.load(index_const),
dim,
)
res = builder.call(get_global_id, [dim])

return context.cast(builder, res, types.uintp, types.intp)

return sig, _intrinsic_get_global_id_gen
return sig, _intrinsic_spirv_global_index_const_gen


@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME)
def _intrinsic_get_group(
ty_context, ty_nd_item: NdItemType # pylint: disable=unused-argument
def _intrinsic_spirv_global_invocation_id(
ty_context, ty_dim # pylint: disable=unused-argument
):
"""Generates group with a dimension of nd_item."""
"""Generates instruction to get index from BuiltInGlobalInvocationId."""
return _intrinsic_spirv_global_index_const(
ty_context, ty_dim, "BuiltInGlobalInvocationId"
)

if not isinstance(ty_nd_item, NdItemType):
raise TypingError(
f"Expected an NdItemType value, but encountered {ty_nd_item}"
)

ty_group = GroupType(ty_nd_item.ndim)
sig = ty_group(ty_nd_item)
@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME)
def _intrinsic_spirv_local_invocation_id(
ty_context, ty_dim # pylint: disable=unused-argument
):
"""Generates instruction to get index from BuiltInLocalInvocationId."""
return _intrinsic_spirv_global_index_const(
ty_context, ty_dim, "BuiltInLocalInvocationId"
)

# pylint: disable=unused-argument
def _intrinsic_get_group_gen(context, builder, sig, args):
group_struct = cgutils.create_struct_proxy(ty_group)(context, builder)
# pylint: disable=protected-access
return group_struct._getvalue()

return sig, _intrinsic_get_group_gen
@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME)
def _intrinsic_spirv_global_size(
ty_context, ty_dim # pylint: disable=unused-argument
):
"""Generates instruction to get index from BuiltInGlobalSize."""
return _intrinsic_spirv_global_index_const(
ty_context, ty_dim, "BuiltInGlobalSize"
)


@overload_method(ItemType, "get_id", target=DPEX_KERNEL_EXP_TARGET_NAME)
def ol_item_get_id(item, dim):
"""SPIR-V overload for :meth:`numba_dpex.kernel_api.Item.get_id`.
@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME)
def _intrinsic_spirv_workgroup_size(
ty_context, ty_dim # pylint: disable=unused-argument
):
"""Generates instruction to get index from BuiltInWorkgroupSize."""
return _intrinsic_spirv_global_index_const(
ty_context, ty_dim, "BuiltInWorkgroupSize"
)

Generates the same LLVM IR instruction as dpcpp for the
`sycl::item::get_id` function.

Raises:
TypingError: When argument is not an integer.
"""
if not isinstance(item, ItemType):
raise TypingError(
"Expected an item should to be an Item value, but "
f"encountered {type(item)}"
)
def generate_index_overload(_type, _intrinsic):
"""Generates overload for the index method that generates specific IR from
provided intrinsic."""

if not isinstance(dim, types.Integer):
raise TypingError(
"Expected an Item's dim should to be an Integer value, but "
f"encountered {type(dim)}"
)
def ol_item_gen_index(item, dim):
"""SPIR-V overload for :meth:`numba_dpex.kernel_api.<_type>.<method>`.

# pylint: disable=unused-argument
def ol_item_get_id_impl(item, dim):
# pylint: disable=no-value-for-parameter
return _intrinsic_get_global_id(dim)
Generates the same LLVM IR instruction as dpcpp for the
`sycl::<type>::<method>` function.

return ol_item_get_id_impl
Raises:
TypingError: When argument is not an integer.
"""
if not isinstance(item, _type):
raise TypingError(
f"Expected an item should to be an {_type} value, but "
f"encountered {type(item)}"
)

if not isinstance(dim, types.Integer):
raise TypingError(
f"Expected an {_type}'s dim should to be an Integer value, but "
f"encountered {type(dim)}"
)

@overload_method(
NdItemType, "get_global_id", target=DPEX_KERNEL_EXP_TARGET_NAME
)
def ol_nd_item_get_global_id(nd_item, dim):
"""SPIR-V overload for :meth:`numba_dpex.kernel_api.NdItem.get_global_id`.
# pylint: disable=unused-argument
def ol_item_get_index_impl(item, dim):
# TODO: call in reverse index once index reversing is removed from
# kernel submission
# pylint: disable=no-value-for-parameter
return _intrinsic(dim)

Generates the same LLVM IR instruction as dpcpp for the
`sycl::nd_item::get_global_id` function.
return ol_item_get_index_impl

Raises:
TypingError: When argument is not an integer.
"""
if not isinstance(nd_item, NdItemType):
# since it is a method overload, this error should not be reached
raise TypingError(
"Expected a nd_item should to be a NdItem value, but "
f"encountered {type(nd_item)}"
)
return ol_item_gen_index


_index_const_overload_methods = [
(ItemType, "get_id", _intrinsic_spirv_global_invocation_id),
(ItemType, "get_range", _intrinsic_spirv_global_size),
(NdItemType, "get_global_id", _intrinsic_spirv_global_invocation_id),
(NdItemType, "get_local_id", _intrinsic_spirv_local_invocation_id),
(NdItemType, "get_global_range", _intrinsic_spirv_global_size),
(NdItemType, "get_local_range", _intrinsic_spirv_workgroup_size),
]

for index_overload in _index_const_overload_methods:
_type, method, _intrinsic = index_overload

ol_index_func = generate_index_overload(_type, _intrinsic)

if not isinstance(dim, types.Integer):
overload_method(_type, method, target=DPEX_KERNEL_EXP_TARGET_NAME)(
ol_index_func
)


@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME)
def _intrinsic_get_group(
ty_context, ty_nd_item: NdItemType # pylint: disable=unused-argument
):
"""Generates group with a dimension of nd_item."""

if not isinstance(ty_nd_item, NdItemType):
raise TypingError(
"Expected a NdItem's dim should to be an Integer value, but "
f"encountered {type(dim)}"
f"Expected an NdItemType value, but encountered {ty_nd_item}"
)

ty_group = GroupType(ty_nd_item.ndim)
sig = ty_group(ty_nd_item)

# pylint: disable=unused-argument
def ol_nd_item_get_global_id_impl(nd_item, dim):
# pylint: disable=no-value-for-parameter
return _intrinsic_get_global_id(dim)
def _intrinsic_get_group_gen(context, builder, sig, args):
group_struct = cgutils.create_struct_proxy(ty_group)(context, builder)
# pylint: disable=protected-access
return group_struct._getvalue()

return ol_nd_item_get_global_id_impl
return sig, _intrinsic_get_group_gen


@overload_method(NdItemType, "get_group", target=DPEX_KERNEL_EXP_TARGET_NAME)
Expand Down
24 changes: 24 additions & 0 deletions numba_dpex/kernel_api/index_space_ids.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,14 @@ def get_id(self, idx):
"""
return self._index[idx]

def get_range(self, idx):
"""Get the range size for a specific dimension.

Returns:
int: The size
"""
return self._extent[idx]

@property
def ndim(self) -> int:
"""Returns the rank of a Item object.
Expand Down Expand Up @@ -117,6 +125,22 @@ def get_local_linear_id(self):
"""
return self._local_item.get_linear_id()

def get_global_range(self, idx):
"""Get the global range size for a specific dimension.

Returns:
int: The size
"""
return self._global_item.get_range(idx)

def get_local_range(self, idx):
"""Get the local range size for a specific dimension.

Returns:
int: The size
"""
return self._local_item.get_range(idx)

def get_group(self):
"""Returns the group.

Expand Down
Loading
Loading