Skip to content

Commit

Permalink
Use SPIRV style kernel code and overload get_id's and range's methods
Browse files Browse the repository at this point in the history
  • Loading branch information
ZzEeKkAa committed Feb 8, 2024
1 parent 8a633c4 commit 9b2ad84
Show file tree
Hide file tree
Showing 3 changed files with 281 additions and 81 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,131 +6,200 @@
Implements the SPIR-V overloads for the kernel_api.items class methods.
"""

import llvmlite.ir as llvmir
from numba.core import cgutils, types
from numba.core.errors import TypingError
from numba.extending import intrinsic, overload_method

from numba_dpex._kernel_api_impl.spirv.target import SPIRVTargetContext
from numba_dpex.experimental.core.types.kernel_api.items import (
GroupType,
ItemType,
NdItemType,
)
from numba_dpex.ocl._declare_function import _declare_function

from ..target import DPEX_KERNEL_EXP_TARGET_NAME


@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME)
def _intrinsic_get_global_id(
ty_context, ty_dim # pylint: disable=unused-argument
def spirv_name(name: str):
"""Converts name to spirv name by adding __spirv_ prefix."""
return "__spirv_" + name


def declare_spirv_const(
builder: llvmir.IRBuilder,
name: str,
):
"""Declares global external spirv constant"""
data = cgutils.add_global_variable(
builder.module,
llvmir.VectorType(llvmir.IntType(64), 3),
spirv_name(name),
addrspace=1,
)
data.linkage = "external"
data.global_constant = True
data.align = 32
data.storage_class = "dso_local local_unnamed_addr"
return data


def _intrinsic_spirv_global_index_const(
ty_context, # pylint: disable=unused-argument# pylint: disable=unused-argument
ty_dim, # pylint: disable=unused-argument
const_name: str,
):
"""Generates instruction for spirv i64 get_global_id(i32) call."""
"""Generates instruction to get spirv index from const_name."""
sig = types.int64(types.int32)

def _intrinsic_get_global_id_gen(context, builder, sig, args):
[dim] = args
get_global_id = _declare_function(
# unsigned int - is what demangler returns from IR instruction
# generated by dpcpp. However int32 is passed as an argument.
# Most likely it does not matter, since argument can be from 0 to 3.
# TODO: https://github.com/IntelPython/numba-dpex/issues/936
# Numba generates unnecessary checks because of the type mismatch.
context,
def _intrinsic_spirv_global_index_const_gen(
context: SPIRVTargetContext,
builder: llvmir.IRBuilder,
sig, # pylint: disable=unused-argument
args,
):
index_const = declare_spirv_const(
builder,
"get_global_id",
sig,
["unsigned int"],
const_name,
)
[dim] = args
# TODO: llvmlite does not support gep on vector. Use this instead once
# supported.
# https://github.com/numba/llvmlite/issues/756
# res = builder.gep( # noqa: E800
# global_invocation_id, # noqa: E800
# [cgutils.int32_t(0), cgutils.int32_t(0)], # noqa: E800
# inbounds=True, # noqa: E800
# ) # noqa: E800
# res = builder.load(res, align=32) # noqa: E800

res = builder.extract_element(
builder.load(index_const),
dim,
)
res = builder.call(get_global_id, [dim])

return context.cast(builder, res, types.uintp, types.intp)

return sig, _intrinsic_get_global_id_gen
return sig, _intrinsic_spirv_global_index_const_gen


@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME)
def _intrinsic_get_group(
ty_context, ty_nd_item: NdItemType # pylint: disable=unused-argument
def _intrinsic_spirv_global_invocation_id(
ty_context, ty_dim # pylint: disable=unused-argument
):
"""Generates group with a dimension of nd_item."""
"""Generates instruction to get index from BuiltInGlobalInvocationId."""
return _intrinsic_spirv_global_index_const(
ty_context, ty_dim, "BuiltInGlobalInvocationId"
)

if not isinstance(ty_nd_item, NdItemType):
raise TypingError(
f"Expected an NdItemType value, but encountered {ty_nd_item}"
)

ty_group = GroupType(ty_nd_item.ndim)
sig = ty_group(ty_nd_item)
@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME)
def _intrinsic_spirv_local_invocation_id(
ty_context, ty_dim # pylint: disable=unused-argument
):
"""Generates instruction to get index from BuiltInLocalInvocationId."""
return _intrinsic_spirv_global_index_const(
ty_context, ty_dim, "BuiltInLocalInvocationId"
)

# pylint: disable=unused-argument
def _intrinsic_get_group_gen(context, builder, sig, args):
group_struct = cgutils.create_struct_proxy(ty_group)(context, builder)
# pylint: disable=protected-access
return group_struct._getvalue()

return sig, _intrinsic_get_group_gen
@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME)
def _intrinsic_spirv_global_size(
ty_context, ty_dim # pylint: disable=unused-argument
):
"""Generates instruction to get index from BuiltInGlobalSize."""
return _intrinsic_spirv_global_index_const(
ty_context, ty_dim, "BuiltInGlobalSize"
)


@overload_method(ItemType, "get_id", target=DPEX_KERNEL_EXP_TARGET_NAME)
def ol_item_get_id(item, dim):
"""SPIR-V overload for :meth:`numba_dpex.kernel_api.Item.get_id`.
@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME)
def _intrinsic_spirv_workgroup_size(
ty_context, ty_dim # pylint: disable=unused-argument
):
"""Generates instruction to get index from BuiltInWorkgroupSize."""
return _intrinsic_spirv_global_index_const(
ty_context, ty_dim, "BuiltInWorkgroupSize"
)

Generates the same LLVM IR instruction as dpcpp for the
`sycl::item::get_id` function.

Raises:
TypingError: When argument is not an integer.
"""
if not isinstance(item, ItemType):
raise TypingError(
"Expected an item should to be an Item value, but "
f"encountered {type(item)}"
)
def generate_index_overload(_type, _intrinsic):
"""Generates overload for the index method that generates specific IR from
provided intrinsic."""

if not isinstance(dim, types.Integer):
raise TypingError(
"Expected an Item's dim should to be an Integer value, but "
f"encountered {type(dim)}"
)
def ol_item_gen_index(item, dim):
"""SPIR-V overload for :meth:`numba_dpex.kernel_api.<_type>.<method>`.
# pylint: disable=unused-argument
def ol_item_get_id_impl(item, dim):
# pylint: disable=no-value-for-parameter
return _intrinsic_get_global_id(dim)
Generates the same LLVM IR instruction as dpcpp for the
`sycl::<type>::<method>` function.
return ol_item_get_id_impl
Raises:
TypingError: When argument is not an integer.
"""
if not isinstance(item, _type):
raise TypingError(
f"Expected an item should to be an {_type} value, but "
f"encountered {type(item)}"
)

if not isinstance(dim, types.Integer):
raise TypingError(
f"Expected an {_type}'s dim should to be an Integer value, but "
f"encountered {type(dim)}"
)

@overload_method(
NdItemType, "get_global_id", target=DPEX_KERNEL_EXP_TARGET_NAME
)
def ol_nd_item_get_global_id(nd_item, dim):
"""SPIR-V overload for :meth:`numba_dpex.kernel_api.NdItem.get_global_id`.
# pylint: disable=unused-argument
def ol_item_get_index_impl(item, dim):
# TODO: call in reverse index once index reversing is removed from
# kernel submission
# pylint: disable=no-value-for-parameter
return _intrinsic(dim)

Generates the same LLVM IR instruction as dpcpp for the
`sycl::nd_item::get_global_id` function.
return ol_item_get_index_impl

Raises:
TypingError: When argument is not an integer.
"""
if not isinstance(nd_item, NdItemType):
# since it is a method overload, this error should not be reached
raise TypingError(
"Expected a nd_item should to be a NdItem value, but "
f"encountered {type(nd_item)}"
)
return ol_item_gen_index


_index_const_overload_methods = [
(ItemType, "get_id", _intrinsic_spirv_global_invocation_id),
(ItemType, "get_range", _intrinsic_spirv_global_size),
(NdItemType, "get_global_id", _intrinsic_spirv_global_invocation_id),
(NdItemType, "get_local_id", _intrinsic_spirv_local_invocation_id),
(NdItemType, "get_global_range", _intrinsic_spirv_global_size),
(NdItemType, "get_local_range", _intrinsic_spirv_workgroup_size),
]

for index_overload in _index_const_overload_methods:
_type, method, _intrinsic = index_overload

ol_index_func = generate_index_overload(_type, _intrinsic)

if not isinstance(dim, types.Integer):
overload_method(_type, method, target=DPEX_KERNEL_EXP_TARGET_NAME)(
ol_index_func
)


@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME)
def _intrinsic_get_group(
ty_context, ty_nd_item: NdItemType # pylint: disable=unused-argument
):
"""Generates group with a dimension of nd_item."""

if not isinstance(ty_nd_item, NdItemType):
raise TypingError(
"Expected a NdItem's dim should to be an Integer value, but "
f"encountered {type(dim)}"
f"Expected an NdItemType value, but encountered {ty_nd_item}"
)

ty_group = GroupType(ty_nd_item.ndim)
sig = ty_group(ty_nd_item)

# pylint: disable=unused-argument
def ol_nd_item_get_global_id_impl(nd_item, dim):
# pylint: disable=no-value-for-parameter
return _intrinsic_get_global_id(dim)
def _intrinsic_get_group_gen(context, builder, sig, args):
group_struct = cgutils.create_struct_proxy(ty_group)(context, builder)
# pylint: disable=protected-access
return group_struct._getvalue()

return ol_nd_item_get_global_id_impl
return sig, _intrinsic_get_group_gen


@overload_method(NdItemType, "get_group", target=DPEX_KERNEL_EXP_TARGET_NAME)
Expand Down
24 changes: 24 additions & 0 deletions numba_dpex/kernel_api/index_space_ids.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,14 @@ def get_id(self, idx):
"""
return self._index[idx]

def get_range(self, idx):
"""Get the range size for a specific dimension.
Returns:
int: The size
"""
return self._extent[idx]

@property
def ndim(self) -> int:
"""Returns the rank of a Item object.
Expand Down Expand Up @@ -117,6 +125,22 @@ def get_local_linear_id(self):
"""
return self._local_item.get_linear_id()

def get_global_range(self, idx):
"""Get the global range size for a specific dimension.
Returns:
int: The size
"""
return self._global_item.get_range(idx)

def get_local_range(self, idx):
"""Get the local range size for a specific dimension.
Returns:
int: The size
"""
return self._local_item.get_range(idx)

def get_group(self):
"""Returns the group.
Expand Down
Loading

0 comments on commit 9b2ad84

Please sign in to comment.