Skip to content

Commit

Permalink
implementation of compare exchange with accompanying test cases
Browse files Browse the repository at this point in the history
  • Loading branch information
adarshyoga authored and Diptorup Deb committed Feb 3, 2024
1 parent df32f71 commit 16e2488
Show file tree
Hide file tree
Showing 4 changed files with 346 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
)
from .spv_fn_generator import (
get_or_insert_atomic_load_fn,
get_or_insert_spv_atomic_compare_exchange_fn,
get_or_insert_spv_atomic_exchange_fn,
get_or_insert_spv_atomic_store_fn,
)
Expand Down Expand Up @@ -323,6 +324,108 @@ def _intrinsic_exchange_gen(context, builder, sig, args):
return sig, _intrinsic_exchange_gen


@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME)
def _intrinsic_compare_exchange(
ty_context, # pylint: disable=unused-argument
ty_atomic_ref,
ty_expected_ref,
ty_desired,
ty_expected_idx,
):
sig = types.boolean(
ty_atomic_ref, ty_expected_ref, ty_desired, ty_expected_idx
)

def _intrinsic_compare_exchange_gen(context, builder, sig, args):
# get pointer to expected[expected_idx]
data_attr = builder.extract_value(
args[1],
context.data_model_manager.lookup(sig.args[1]).get_field_position(
"data"
),
)
with builder.goto_entry_block():
ptr_to_data_attr = builder.alloca(data_attr.type)
builder.store(data_attr, ptr_to_data_attr)
expected_ref_ptr = builder.gep(
builder.load(ptr_to_data_attr), [args[3]]
)

expected_arg = builder.load(expected_ref_ptr)
desired_arg = args[2]
atomic_ref_ptr = builder.extract_value(
args[0],
context.data_model_manager.lookup(sig.args[0]).get_field_position(
"ref"
),
)
# add conditional bitcast for atomic_ref pointer,
# expected[expected_idx], and desired
if sig.args[0].dtype == types.float32:
atomic_ref_ptr = builder.bitcast(
atomic_ref_ptr,
llvmir.PointerType(
llvmir.IntType(32), addrspace=sig.args[0].address_space
),
)
expected_arg = builder.bitcast(expected_arg, llvmir.IntType(32))
desired_arg = builder.bitcast(desired_arg, llvmir.IntType(32))
elif sig.args[0].dtype == types.float64:
atomic_ref_ptr = builder.bitcast(
atomic_ref_ptr,
llvmir.PointerType(
llvmir.IntType(64), addrspace=sig.args[0].address_space
),
)
expected_arg = builder.bitcast(expected_arg, llvmir.IntType(64))
desired_arg = builder.bitcast(desired_arg, llvmir.IntType(64))

atomic_cmpexchg_fn_args = [
atomic_ref_ptr,
context.get_constant(
types.int32, get_scope(sig.args[0].memory_scope)
),
context.get_constant(
types.int32,
get_memory_semantics_mask(sig.args[0].memory_order),
),
context.get_constant(
types.int32,
get_memory_semantics_mask(sig.args[0].memory_order),
),
desired_arg,
expected_arg,
]

ret_val = builder.call(
get_or_insert_spv_atomic_compare_exchange_fn(
context, builder.module, sig.args[0]
),
atomic_cmpexchg_fn_args,
)

# compare_exchange returns the old value stored in AtomicRef object.
# If the return value is same as expected, then compare_exchange
# succeeded in replacing AtomicRef object with desired.
# If the return value is not same as expected, then store return
# value in expected.
# In either case, return result of cmp instruction.
is_cmp_exchg_success = builder.icmp_signed("==", ret_val, expected_arg)

with builder.if_else(is_cmp_exchg_success) as (then, otherwise):
with then:
pass
with otherwise:
if sig.args[0].dtype == types.float32:
ret_val = builder.bitcast(ret_val, llvmir.FloatType())
elif sig.args[0].dtype == types.float64:
ret_val = builder.bitcast(ret_val, llvmir.DoubleType())
builder.store(ret_val, expected_ref_ptr)
return is_cmp_exchg_success

return sig, _intrinsic_compare_exchange_gen


def _check_if_supported_ref(ref):
supported = True

Expand Down Expand Up @@ -689,3 +792,94 @@ def ol_exchange_impl(atomic_ref, val):
return _intrinsic_exchange(atomic_ref, val)

return ol_exchange_impl


@overload_method(
AtomicRefType,
"compare_exchange_weak",
target=DPEX_KERNEL_EXP_TARGET_NAME,
)
def ol_compare_exchange_weak(
atomic_ref, expected_ref, desired, expected_idx=0
): # pylint: disable=unused-argument
"""SPIR-V overload for
:meth:`numba_dpex.experimental.kernel_iface.AtomicRef.compare_exchange_weak`.
Generates the same LLVM IR instruction as dpcpp for the
`atomic_ref::compare_exchange_weak` function.
Raises:
TypingError: When the dtype of the value passed to `compare_exchange_weak`
does not match the dtype of the AtomicRef type.
"""

_check_if_supported_ref(expected_ref)

if atomic_ref.dtype != expected_ref.dtype:
raise errors.TypingError(
f"Type of value to compare_exchange_weak: {expected_ref} does not match the "
f"type of the reference: {atomic_ref.dtype} stored in the atomic ref."
)

if atomic_ref.dtype != desired:
raise errors.TypingError(
f"Type of value to compare_exchange_weak: {desired} does not match the "
f"type of the reference: {atomic_ref.dtype} stored in the atomic ref."
)

def ol_compare_exchange_weak_impl(
atomic_ref, expected_ref, desired, expected_idx=0
):
# pylint: disable=no-value-for-parameter
return _intrinsic_compare_exchange(
atomic_ref, expected_ref, desired, expected_idx
)

return ol_compare_exchange_weak_impl


@overload_method(
AtomicRefType,
"compare_exchange_strong",
target=DPEX_KERNEL_EXP_TARGET_NAME,
)
def ol_compare_exchange_strong(
atomic_ref,
expected_ref,
desired,
expected_idx=0, # pylint: disable=unused-argument
):
"""SPIR-V overload for
:meth:`numba_dpex.experimental.kernel_iface.AtomicRef.compare_exchange_strong`.
Generates the same LLVM IR instruction as dpcpp for the
`atomic_ref::compare_exchange_strong` function.
Raises:
TypingError: When the dtype of the value passed to `compare_exchange_strong`
does not match the dtype of the AtomicRef type.
"""

_check_if_supported_ref(expected_ref)

if atomic_ref.dtype != expected_ref.dtype:
raise errors.TypingError(
f"Type of value to compare_exchange_strong: {expected_ref} does not match the "
f"type of the reference: {atomic_ref.dtype} stored in the atomic ref."
)

if atomic_ref.dtype != desired:
raise errors.TypingError(
f"Type of value to compare_exchange_strong: {desired} does not match the "
f"type of the reference: {atomic_ref.dtype} stored in the atomic ref."
)

def ol_compare_exchange_strong_impl(
atomic_ref, expected_ref, desired, expected_idx=0
):
# pylint: disable=no-value-for-parameter
return _intrinsic_compare_exchange(
atomic_ref, expected_ref, desired, expected_idx
)

return ol_compare_exchange_strong_impl
Original file line number Diff line number Diff line change
Expand Up @@ -119,3 +119,59 @@ def get_or_insert_spv_atomic_exchange_fn(context, module, atomic_ref_ty):
fn.calling_convention = CC_SPIR_FUNC

return fn


def get_or_insert_spv_atomic_compare_exchange_fn(
context, module, atomic_ref_ty
):
"""
Gets or inserts a declaration for a __spirv_AtomicCompareExchange call into the
specified LLVM IR module.
"""
atomic_ref_dtype = atomic_ref_ty.dtype

# Spirv spec requires arguments and return type to be of integer types.
# That is why the type is changed from float to int
# while maintaining the bit-width.
# During function call, bitcasting is performed
# to adhere to this convention.
if atomic_ref_dtype == types.float32:
atomic_ref_dtype = types.uint32
elif atomic_ref_dtype == types.float64:
atomic_ref_dtype = types.uint64

ptr_type = context.get_value_type(atomic_ref_dtype).as_pointer()
ptr_type.addrspace = atomic_ref_ty.address_space
atomic_cmpexchg_fn_retty = context.get_value_type(atomic_ref_dtype)

atomic_cmpexchg_fn_arg_types = [
ptr_type,
llvmir.IntType(32),
llvmir.IntType(32),
llvmir.IntType(32),
context.get_value_type(atomic_ref_dtype),
context.get_value_type(atomic_ref_dtype),
]

mangled_fn_name = ext_itanium_mangler.mangle_ext(
"__spirv_AtomicCompareExchange",
[
types.CPointer(atomic_ref_dtype, addrspace=ptr_type.addrspace),
"__spv.Scope.Flag",
"__spv.MemorySemanticsMask.Flag",
"__spv.MemorySemanticsMask.Flag",
atomic_ref_dtype,
atomic_ref_dtype,
],
)

fn = cgutils.get_or_insert_function(
module,
llvmir.FunctionType(
atomic_cmpexchg_fn_retty, atomic_cmpexchg_fn_arg_types
),
mangled_fn_name,
)
fn.calling_convention = CC_SPIR_FUNC

return fn
52 changes: 52 additions & 0 deletions numba_dpex/kernel_api/atomic_ref.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,3 +200,55 @@ def exchange(self, val):
old = self._ref[self._index]
self._ref[self._index] = val
return old

def compare_exchange_weak(self, expected, desired, expected_idx=0):
"""Compares the value of the object referenced by the AtomicRef
against the value of ``expected[expected_idx]``.
If the values are equal, attempts to replace the value of the
referenced object with the value of ``desired``.
Otherwise assigns the original value of the
referenced object to ``expected[expected_idx]``.
Args:
expected : Vector containing the expected value of the
object referenced by the AtomicRef.
desired : Value that replaces the value of the object
referenced by the AtomicRef.
expected_idx: Offset in `expected` vector where the expected
value of the object referenced by the AtomicRef is present.
Returns: Returns ``True`` if the comparison operation and
replacement operation were successful.
"""
if self._ref[self._index] == expected[expected_idx]:
self._ref[self._index] = desired
return True
expected[expected_idx] = self._ref[self._index]
return False

def compare_exchange_strong(self, expected, desired, expected_idx=0):
"""Compares the value of the object referenced by the AtomicRef
against the value of ``expected[expected_idx]``.
If the values are equal, replaces the value of the
referenced object with the value of ``desired``.
Otherwise assigns the original value of the
referenced object to ``expected[expected_idx]``.
Args:
expected : Vector containing the expected value of the
object referenced by the AtomicRef.
desired : Value that replaces the value of the object
referenced by the AtomicRef.
expected_idx: Offset in `expected` vector where the expected
value of the object referenced by the AtomicRef is present.
Returns: Returns ``True`` if the comparison operation and
replacement operation were successful.
"""
if self._ref[self._index] == expected[expected_idx]:
self._ref[self._index] = desired
return True
expected[expected_idx] = self._ref[self._index]
return False
Loading

0 comments on commit 16e2488

Please sign in to comment.