Skip to content

Commit

Permalink
Add range methods with overloads and generelize overloads
Browse files Browse the repository at this point in the history
  • Loading branch information
ZzEeKkAa committed Feb 8, 2024
1 parent e09f8e7 commit 7f4ee66
Show file tree
Hide file tree
Showing 3 changed files with 144 additions and 178 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -44,22 +44,23 @@ def declare_spirv_const(
return data


# TODO: call in reverse index once reverse is removed from submission
@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME)
def _intrinsic_spirv_global_invocation_id(
ty_context, ty_dim # pylint: disable=unused-argument
def _intrinsic_spirv_global_index_const(
ty_context, # pylint: disable=unused-argument# pylint: disable=unused-argument
ty_dim, # pylint: disable=unused-argument
const_name: str,
):
"""Generates instruction for spirv BuiltInGlobalInvocationId call."""
"""Generates instruction to get spirv index from const_name."""
sig = types.int64(types.int32)

def _intrinsic_spirv_global_invocation_id_gen(
def _intrinsic_spirv_global_index_const_gen(
context: SPIRVTargetContext,
builder: llvmir.IRBuilder,
sig, # pylint: disable=unused-argument
args,
):
global_invocation_id = declare_spirv_const(
builder, "BuiltInGlobalInvocationId"
index_const = declare_spirv_const(
builder,
const_name,
)
[dim] = args
# TODO: llvmlite does not support gep on vector. Use this instead once
Expand All @@ -73,72 +74,109 @@ def _intrinsic_spirv_global_invocation_id_gen(
# res = builder.load(res, align=32) # noqa: E800

res = builder.extract_element(
builder.load(global_invocation_id),
builder.load(index_const),
dim,
)

return context.cast(builder, res, types.uintp, types.intp)

return sig, _intrinsic_spirv_global_invocation_id_gen
return sig, _intrinsic_spirv_global_index_const_gen


# TODO: call in reverse index once reverse is removed from submission
@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME)
def _intrinsic_spirv_local_invocation_id(
def _intrinsic_spirv_global_invocation_id(
ty_context, ty_dim # pylint: disable=unused-argument
):
"""Generates instruction for spirv BuiltInLocalInvocationId call."""
sig = types.int64(types.int32)
"""Generates instruction to get index from BuiltInGlobalInvocationId."""
return _intrinsic_spirv_global_index_const(
ty_context, ty_dim, "BuiltInGlobalInvocationId"
)

def _intrinsic_spirv_local_invocation_id_gen(
context: SPIRVTargetContext,
builder: llvmir.IRBuilder,
sig, # pylint: disable=unused-argument
args,
):
local_invocation_id = declare_spirv_const(
builder, "BuiltInLocalInvocationId"
)
[dim] = args
# TODO: llvmlite does not support gep on vector. Use this instead once
# supported.
# https://github.com/numba/llvmlite/issues/756

res = builder.extract_element(
builder.load(local_invocation_id),
dim,
)
@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME)
def _intrinsic_spirv_local_invocation_id(
ty_context, ty_dim # pylint: disable=unused-argument
):
"""Generates instruction to get index from BuiltInLocalInvocationId."""
return _intrinsic_spirv_global_index_const(
ty_context, ty_dim, "BuiltInLocalInvocationId"
)

return context.cast(builder, res, types.uintp, types.intp)

return sig, _intrinsic_spirv_local_invocation_id_gen
@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME)
def _intrinsic_spirv_global_size(
ty_context, ty_dim # pylint: disable=unused-argument
):
"""Generates instruction to get index from BuiltInGlobalSize."""
return _intrinsic_spirv_global_index_const(
ty_context, ty_dim, "BuiltInGlobalSize"
)


# TODO: call in reverse index once reverse is removed from submission
@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME)
def _intrinsic_spirv_global_size(
def _intrinsic_spirv_workgroup_size(
ty_context, ty_dim # pylint: disable=unused-argument
):
"""Generates instruction for spirv BuiltInGlobalSize call."""
sig = types.int64(types.int32)
"""Generates instruction to get index from BuiltInWorkgroupSize."""
return _intrinsic_spirv_global_index_const(
ty_context, ty_dim, "BuiltInWorkgroupSize"
)

def _intrinsic_spirv_global_size_gen(
context: SPIRVTargetContext,
builder: llvmir.IRBuilder,
sig, # pylint: disable=unused-argument
args,
):
global_size = declare_spirv_const(builder, "BuiltInGlobalSize")
[dim] = args
# TODO: llvmlite does not support gep on vector. Use this instead once
# supported.
# https://github.com/numba/llvmlite/issues/756

res = builder.extract_element(builder.load(global_size), dim)
def generate_index_overload(_type, _intrinsic):
"""Generates overload for the index method that generates specific IR from
provided intrinsic."""

return context.cast(builder, res, types.uintp, types.intp)
def ol_item_gen_index(item, dim):
"""SPIR-V overload for :meth:`numba_dpex.kernel_api.<_type>.<method>`.
Generates the same LLVM IR instruction as dpcpp for the
`sycl::<type>::<method>` function.
Raises:
TypingError: When argument is not an integer.
"""
if not isinstance(item, _type):
raise TypingError(
f"Expected an item should to be an {_type} value, but "
f"encountered {type(item)}"
)

if not isinstance(dim, types.Integer):
raise TypingError(
f"Expected an {_type}'s dim should to be an Integer value, but "
f"encountered {type(dim)}"
)

# pylint: disable=unused-argument
def ol_item_get_index_impl(item, dim):
# TODO: call in reverse index once index reversing is removed from
# kernel submission
# pylint: disable=no-value-for-parameter
return _intrinsic(dim)

return sig, _intrinsic_spirv_global_size_gen
return ol_item_get_index_impl

return ol_item_gen_index


_index_const_overload_methods = [
(ItemType, "get_id", _intrinsic_spirv_global_invocation_id),
(ItemType, "get_range", _intrinsic_spirv_global_size),
(NdItemType, "get_global_id", _intrinsic_spirv_global_invocation_id),
(NdItemType, "get_local_id", _intrinsic_spirv_local_invocation_id),
(NdItemType, "get_global_range", _intrinsic_spirv_global_size),
(NdItemType, "get_local_range", _intrinsic_spirv_workgroup_size),
]

for index_overload in _index_const_overload_methods:
_type, method, _intrinsic = index_overload

ol_index_func = generate_index_overload(_type, _intrinsic)

overload_method(_type, method, target=DPEX_KERNEL_EXP_TARGET_NAME)(
ol_index_func
)


@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME)
Expand All @@ -164,130 +202,6 @@ def _intrinsic_get_group_gen(context, builder, sig, args):
return sig, _intrinsic_get_group_gen


@overload_method(ItemType, "get_id", target=DPEX_KERNEL_EXP_TARGET_NAME)
def ol_item_get_id(item, dim):
"""SPIR-V overload for :meth:`numba_dpex.kernel_api.Item.get_id`.
Generates the same LLVM IR instruction as dpcpp for the
`sycl::item::get_id` function.
Raises:
TypingError: When argument is not an integer.
"""
if not isinstance(item, ItemType):
raise TypingError(
"Expected an item should to be an Item value, but "
f"encountered {type(item)}"
)

if not isinstance(dim, types.Integer):
raise TypingError(
"Expected an Item's dim should to be an Integer value, but "
f"encountered {type(dim)}"
)

# pylint: disable=unused-argument
def ol_item_get_id_impl(item, dim):
# pylint: disable=no-value-for-parameter
return _intrinsic_spirv_global_invocation_id(dim)

return ol_item_get_id_impl


@overload_method(ItemType, "get_range", target=DPEX_KERNEL_EXP_TARGET_NAME)
def ol_item_get_range(item, dim):
"""SPIR-V overload for :meth:`numba_dpex.kernel_api.Item.get_range`.
Generates the same LLVM IR instruction as dpcpp for the
`sycl::item::get_id` function.
Raises:
TypingError: When argument is not an integer.
"""
if not isinstance(item, ItemType):
raise TypingError(
"Expected an item should to be an Item value, but "
f"encountered {type(item)}"
)

if not isinstance(dim, types.Integer):
raise TypingError(
"Expected an Item's dim should to be an Integer value, but "
f"encountered {type(dim)}"
)

# pylint: disable=unused-argument
def ol_item_get_range_impl(item, dim):
# pylint: disable=no-value-for-parameter
return _intrinsic_spirv_global_size(dim)

return ol_item_get_range_impl


@overload_method(
NdItemType, "get_global_id", target=DPEX_KERNEL_EXP_TARGET_NAME
)
def ol_nd_item_get_global_id(nd_item, dim):
"""SPIR-V overload for :meth:`numba_dpex.kernel_api.NdItem.get_global_id`.
Generates the same LLVM IR instruction as dpcpp for the
`sycl::nd_item::get_global_id` function.
Raises:
TypingError: When argument is not an integer.
"""
if not isinstance(nd_item, NdItemType):
# since it is a method overload, this error should not be reached
raise TypingError(
"Expected a nd_item should to be a NdItem value, but "
f"encountered {type(nd_item)}"
)

if not isinstance(dim, types.Integer):
raise TypingError(
"Expected a NdItem's dim should to be an Integer value, but "
f"encountered {type(dim)}"
)

# pylint: disable=unused-argument
def ol_nd_item_get_global_id_impl(nd_item, dim):
# pylint: disable=no-value-for-parameter
return _intrinsic_spirv_global_invocation_id(dim)

return ol_nd_item_get_global_id_impl


@overload_method(NdItemType, "get_local_id", target=DPEX_KERNEL_EXP_TARGET_NAME)
def ol_nd_item_get_local_id(nd_item, dim):
"""SPIR-V overload for :meth:`numba_dpex.kernel_api.NdItem.get_local_id`.
Generates the same LLVM IR instruction as dpcpp for the
`sycl::nd_item::get_local_id` function.
Raises:
TypingError: When argument is not an integer.
"""
if not isinstance(nd_item, NdItemType):
# since it is a method overload, this error should not be reached
raise TypingError(
"Expected a nd_item should to be a NdItem value, but "
f"encountered {type(nd_item)}"
)

if not isinstance(dim, types.Integer):
raise TypingError(
"Expected a NdItem's dim should to be an Integer value, but "
f"encountered {type(dim)}"
)

# pylint: disable=unused-argument
def ol_nd_item_get_local_id_impl(nd_item, dim):
# pylint: disable=no-value-for-parameter
return _intrinsic_spirv_local_invocation_id(dim)

return ol_nd_item_get_local_id_impl


@overload_method(NdItemType, "get_group", target=DPEX_KERNEL_EXP_TARGET_NAME)
def ol_nd_item_get_group(nd_item):
"""SPIR-V overload for :meth:`numba_dpex.kernel_api.NdItem.get_group`.
Expand Down
16 changes: 16 additions & 0 deletions numba_dpex/kernel_api/index_space_ids.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,22 @@ def get_local_linear_id(self):
"""
return self._local_item.get_linear_id()

def get_global_range(self, idx):
"""Get the global range size for a specific dimension.
Returns:
int: The size
"""
return self._global_item.get_range(idx)

def get_local_range(self, idx):
"""Get the local range size for a specific dimension.
Returns:
int: The size
"""
return self._local_item.get_range(idx)

def get_group(self):
"""Returns the group.
Expand Down
Loading

0 comments on commit 7f4ee66

Please sign in to comment.