Skip to content

Commit

Permalink
Add overloads for fetch_* functions
Browse files Browse the repository at this point in the history
  • Loading branch information
adarshyoga authored and Diptorup Deb committed Dec 28, 2023
1 parent 6497c26 commit ba7f25c
Showing 1 changed file with 195 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,36 @@ def _intrinsic_fetch_add(ty_context, ty_atomic_ref, ty_val):
return _intrinsic_helper(ty_context, ty_atomic_ref, ty_val, "fetch_add")


@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME)
def _intrinsic_fetch_sub(ty_context, ty_atomic_ref, ty_val):
return _intrinsic_helper(ty_context, ty_atomic_ref, ty_val, "fetch_sub")


@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME)
def _intrinsic_fetch_min(ty_context, ty_atomic_ref, ty_val):
return _intrinsic_helper(ty_context, ty_atomic_ref, ty_val, "fetch_min")


@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME)
def _intrinsic_fetch_max(ty_context, ty_atomic_ref, ty_val):
return _intrinsic_helper(ty_context, ty_atomic_ref, ty_val, "fetch_max")


@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME)
def _intrinsic_fetch_and(ty_context, ty_atomic_ref, ty_val):
return _intrinsic_helper(ty_context, ty_atomic_ref, ty_val, "fetch_and")


@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME)
def _intrinsic_fetch_or(ty_context, ty_atomic_ref, ty_val):
return _intrinsic_helper(ty_context, ty_atomic_ref, ty_val, "fetch_or")


@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME)
def _intrinsic_fetch_xor(ty_context, ty_atomic_ref, ty_val):
return _intrinsic_helper(ty_context, ty_atomic_ref, ty_val, "fetch_xor")


@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME)
def _intrinsic_atomic_ref_ctor(
ty_context, ref, ty_index, ty_retty_ref # pylint: disable=unused-argument
Expand Down Expand Up @@ -294,3 +324,168 @@ def ol_fetch_add_impl(atomic_ref, val):
return _intrinsic_fetch_add(atomic_ref, val)

return ol_fetch_add_impl


@overload_method(AtomicRefType, "fetch_sub", target=DPEX_KERNEL_EXP_TARGET_NAME)
def ol_fetch_sub(atomic_ref, val):
"""SPIR-V overload for
:meth:`numba_dpex.experimental.kernel_iface.AtomicRef.fetch_sub`.
Generates the same LLVM IR instruction as dpcpp for the
`atomic_ref::fetch_sub` function.
Raises:
TypingError: When the dtype of the aggregator value does not match the
dtype of the AtomicRef type.
"""
if atomic_ref.dtype != val:
raise errors.TypingError(
f"Type of value to sub: {val} does not match the type of the "
f"reference: {atomic_ref.dtype} stored in the atomic ref."
)

def ol_fetch_sub_impl(atomic_ref, val):
# pylint: disable=no-value-for-parameter
return _intrinsic_fetch_sub(atomic_ref, val)

return ol_fetch_sub_impl


@overload_method(AtomicRefType, "fetch_min", target=DPEX_KERNEL_EXP_TARGET_NAME)
def ol_fetch_min(atomic_ref, val):
"""SPIR-V overload for
:meth:`numba_dpex.experimental.kernel_iface.AtomicRef.fetch_min`.
Generates the same LLVM IR instruction as dpcpp for the
`atomic_ref::fetch_min` function.
Raises:
TypingError: When the dtype of the aggregator value does not match the
dtype of the AtomicRef type.
"""
if atomic_ref.dtype != val:
raise errors.TypingError(
f"Type of value to find min: {val} does not match the type of the "
f"reference: {atomic_ref.dtype} stored in the atomic ref."
)

def ol_fetch_min_impl(atomic_ref, val):
# pylint: disable=no-value-for-parameter
return _intrinsic_fetch_min(atomic_ref, val)

return ol_fetch_min_impl


@overload_method(AtomicRefType, "fetch_max", target=DPEX_KERNEL_EXP_TARGET_NAME)
def ol_fetch_max(atomic_ref, val):
"""SPIR-V overload for
:meth:`numba_dpex.experimental.kernel_iface.AtomicRef.fetch_max`.
Generates the same LLVM IR instruction as dpcpp for the
`atomic_ref::fetch_max` function.
Raises:
TypingError: When the dtype of the aggregator value does not match the
dtype of the AtomicRef type.
"""
if atomic_ref.dtype != val:
raise errors.TypingError(
f"Type of value to find max: {val} does not match the type of the "
f"reference: {atomic_ref.dtype} stored in the atomic ref."
)

def ol_fetch_max_impl(atomic_ref, val):
# pylint: disable=no-value-for-parameter
return _intrinsic_fetch_max(atomic_ref, val)

return ol_fetch_max_impl


@overload_method(AtomicRefType, "fetch_and", target=DPEX_KERNEL_EXP_TARGET_NAME)
def ol_fetch_and(atomic_ref, val):
"""SPIR-V overload for
:meth:`numba_dpex.experimental.kernel_iface.AtomicRef.fetch_and`.
Generates the same LLVM IR instruction as dpcpp for the
`atomic_ref::fetch_and` function.
Raises:
TypingError: When the dtype of the aggregator value does not match the
dtype of the AtomicRef type.
"""
if atomic_ref.dtype != val:
raise errors.TypingError(
f"Type of value to and: {val} does not match the type of the "
f"reference: {atomic_ref.dtype} stored in the atomic ref."
)

if atomic_ref.dtype not in (types.int32, types.int64):
raise errors.TypingError(
"fetch_and operation only supported on int32 and int64 dtypes."
)

def ol_fetch_and_impl(atomic_ref, val):
# pylint: disable=no-value-for-parameter
return _intrinsic_fetch_and(atomic_ref, val)

return ol_fetch_and_impl


@overload_method(AtomicRefType, "fetch_or", target=DPEX_KERNEL_EXP_TARGET_NAME)
def ol_fetch_or(atomic_ref, val):
"""SPIR-V overload for
:meth:`numba_dpex.experimental.kernel_iface.AtomicRef.fetch_or`.
Generates the same LLVM IR instruction as dpcpp for the
`atomic_ref::fetch_or` function.
Raises:
TypingError: When the dtype of the aggregator value does not match the
dtype of the AtomicRef type.
"""
if atomic_ref.dtype != val:
raise errors.TypingError(
f"Type of value to or: {val} does not match the type of the "
f"reference: {atomic_ref.dtype} stored in the atomic ref."
)

if atomic_ref.dtype not in (types.int32, types.int64):
raise errors.TypingError(
"fetch_or operation only supported on int32 and int64 dtypes."
)

def ol_fetch_or_impl(atomic_ref, val):
# pylint: disable=no-value-for-parameter
return _intrinsic_fetch_or(atomic_ref, val)

return ol_fetch_or_impl


@overload_method(AtomicRefType, "fetch_xor", target=DPEX_KERNEL_EXP_TARGET_NAME)
def ol_fetch_xor(atomic_ref, val):
"""SPIR-V overload for
:meth:`numba_dpex.experimental.kernel_iface.AtomicRef.fetch_xor`.
Generates the same LLVM IR instruction as dpcpp for the
`atomic_ref::fetch_xor` function.
Raises:
TypingError: When the dtype of the aggregator value does not match the
dtype of the AtomicRef type.
"""
if atomic_ref.dtype != val:
raise errors.TypingError(
f"Type of value to xor: {val} does not match the type of the "
f"reference: {atomic_ref.dtype} stored in the atomic ref."
)

if atomic_ref.dtype not in (types.int32, types.int64):
raise errors.TypingError(
"fetch_xor operation only supported on int32 and int64 dtypes."
)

def ol_fetch_xor_impl(atomic_ref, val):
# pylint: disable=no-value-for-parameter
return _intrinsic_fetch_xor(atomic_ref, val)

return ol_fetch_xor_impl

0 comments on commit ba7f25c

Please sign in to comment.