Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Atomic compare_exchange implementation #1312

Merged
merged 1 commit into from
Feb 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
)
from .spv_fn_generator import (
get_or_insert_atomic_load_fn,
get_or_insert_spv_atomic_compare_exchange_fn,
get_or_insert_spv_atomic_exchange_fn,
get_or_insert_spv_atomic_store_fn,
)
Expand Down Expand Up @@ -323,6 +324,108 @@ def _intrinsic_exchange_gen(context, builder, sig, args):
return sig, _intrinsic_exchange_gen


@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME)
def _intrinsic_compare_exchange(
ty_context, # pylint: disable=unused-argument
ty_atomic_ref,
ty_expected_ref,
ty_desired,
ty_expected_idx,
):
sig = types.boolean(
ty_atomic_ref, ty_expected_ref, ty_desired, ty_expected_idx
)

def _intrinsic_compare_exchange_gen(context, builder, sig, args):
# get pointer to expected[expected_idx]
data_attr = builder.extract_value(
args[1],
context.data_model_manager.lookup(sig.args[1]).get_field_position(
"data"
),
)
with builder.goto_entry_block():
ptr_to_data_attr = builder.alloca(data_attr.type)
builder.store(data_attr, ptr_to_data_attr)
expected_ref_ptr = builder.gep(
builder.load(ptr_to_data_attr), [args[3]]
)

expected_arg = builder.load(expected_ref_ptr)
desired_arg = args[2]
atomic_ref_ptr = builder.extract_value(
args[0],
context.data_model_manager.lookup(sig.args[0]).get_field_position(
"ref"
),
)
# add conditional bitcast for atomic_ref pointer,
# expected[expected_idx], and desired
if sig.args[0].dtype == types.float32:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can do it like follows as well to reduce code duplication:

if isinstance(sig.args[0].dtype, types.Float):
    bitwidth = sig.args[0].dtype.bitwidth
    atomic_ref_ptr = builder.bitcast(
                atomic_ref_ptr,
                llvmir.PointerType(
                    llvmir.IntType(bitwidth), addrspace=sig.args[0].address_space
                ),
            )
            expected_arg = builder.bitcast(expected_arg, llvmir.IntType(bitwidth))
            desired_arg = builder.bitcast(desired_arg, llvmir.IntType(bitwidth))

Then we do not need two separate cases for fp32 and fp64.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let us skip as we have the same pattern in multiple places and let us just be consistent.

atomic_ref_ptr = builder.bitcast(
atomic_ref_ptr,
llvmir.PointerType(
llvmir.IntType(32), addrspace=sig.args[0].address_space
),
)
expected_arg = builder.bitcast(expected_arg, llvmir.IntType(32))
desired_arg = builder.bitcast(desired_arg, llvmir.IntType(32))
elif sig.args[0].dtype == types.float64:
atomic_ref_ptr = builder.bitcast(
atomic_ref_ptr,
llvmir.PointerType(
llvmir.IntType(64), addrspace=sig.args[0].address_space
),
)
expected_arg = builder.bitcast(expected_arg, llvmir.IntType(64))
desired_arg = builder.bitcast(desired_arg, llvmir.IntType(64))

atomic_cmpexchg_fn_args = [
atomic_ref_ptr,
context.get_constant(
types.int32, get_scope(sig.args[0].memory_scope)
),
context.get_constant(
types.int32,
get_memory_semantics_mask(sig.args[0].memory_order),
),
context.get_constant(
types.int32,
get_memory_semantics_mask(sig.args[0].memory_order),
),
desired_arg,
expected_arg,
]

ret_val = builder.call(
get_or_insert_spv_atomic_compare_exchange_fn(
context, builder.module, sig.args[0]
),
atomic_cmpexchg_fn_args,
)

# compare_exchange returns the old value stored in AtomicRef object.
# If the return value is same as expected, then compare_exchange
# succeeded in replacing AtomicRef object with desired.
# If the return value is not same as expected, then store return
# value in expected.
# In either case, return result of cmp instruction.
is_cmp_exchg_success = builder.icmp_signed("==", ret_val, expected_arg)

with builder.if_else(is_cmp_exchg_success) as (then, otherwise):
with then:
pass
with otherwise:
if sig.args[0].dtype == types.float32:
diptorupd marked this conversation as resolved.
Show resolved Hide resolved
ret_val = builder.bitcast(ret_val, llvmir.FloatType())
elif sig.args[0].dtype == types.float64:
ret_val = builder.bitcast(ret_val, llvmir.DoubleType())
builder.store(ret_val, expected_ref_ptr)
return is_cmp_exchg_success

return sig, _intrinsic_compare_exchange_gen


def _check_if_supported_ref(ref):
supported = True

Expand Down Expand Up @@ -689,3 +792,50 @@ def ol_exchange_impl(atomic_ref, val):
return _intrinsic_exchange(atomic_ref, val)

return ol_exchange_impl


@overload_method(
AtomicRefType,
"compare_exchange",
target=DPEX_KERNEL_EXP_TARGET_NAME,
)
def ol_compare_exchange(
atomic_ref,
expected_ref,
desired,
expected_idx=0, # pylint: disable=unused-argument
):
"""SPIR-V overload for
:meth:`numba_dpex.experimental.kernel_iface.AtomicRef.compare_exchange`.

Generates the same LLVM IR instruction as dpcpp for the
`atomic_ref::compare_exchange_strong` function.

Raises:
TypingError: When the dtype of the value passed to `compare_exchange`
does not match the dtype of the AtomicRef type.
"""

_check_if_supported_ref(expected_ref)

if atomic_ref.dtype != expected_ref.dtype:
raise errors.TypingError(
f"Type of value to compare_exchange: {expected_ref} does not match the "
f"type of the reference: {atomic_ref.dtype} stored in the atomic ref."
)

if atomic_ref.dtype != desired:
raise errors.TypingError(
f"Type of value to compare_exchange: {desired} does not match the "
f"type of the reference: {atomic_ref.dtype} stored in the atomic ref."
)

def ol_compare_exchange_impl(
atomic_ref, expected_ref, desired, expected_idx=0
):
# pylint: disable=no-value-for-parameter
return _intrinsic_compare_exchange(
atomic_ref, expected_ref, desired, expected_idx
)

return ol_compare_exchange_impl
Original file line number Diff line number Diff line change
Expand Up @@ -119,3 +119,59 @@ def get_or_insert_spv_atomic_exchange_fn(context, module, atomic_ref_ty):
fn.calling_convention = CC_SPIR_FUNC

return fn


def get_or_insert_spv_atomic_compare_exchange_fn(
context, module, atomic_ref_ty
):
"""
Gets or inserts a declaration for a __spirv_AtomicCompareExchange call into the
specified LLVM IR module.
"""
atomic_ref_dtype = atomic_ref_ty.dtype

# Spirv spec requires arguments and return type to be of integer types.
# That is why the type is changed from float to int
# while maintaining the bit-width.
# During function call, bitcasting is performed
# to adhere to this convention.
if atomic_ref_dtype == types.float32:
atomic_ref_dtype = types.uint32
elif atomic_ref_dtype == types.float64:
atomic_ref_dtype = types.uint64

ptr_type = context.get_value_type(atomic_ref_dtype).as_pointer()
ptr_type.addrspace = atomic_ref_ty.address_space
atomic_cmpexchg_fn_retty = context.get_value_type(atomic_ref_dtype)

atomic_cmpexchg_fn_arg_types = [
ptr_type,
llvmir.IntType(32),
llvmir.IntType(32),
llvmir.IntType(32),
context.get_value_type(atomic_ref_dtype),
context.get_value_type(atomic_ref_dtype),
]

mangled_fn_name = ext_itanium_mangler.mangle_ext(
"__spirv_AtomicCompareExchange",
[
types.CPointer(atomic_ref_dtype, addrspace=ptr_type.addrspace),
"__spv.Scope.Flag",
"__spv.MemorySemanticsMask.Flag",
"__spv.MemorySemanticsMask.Flag",
atomic_ref_dtype,
atomic_ref_dtype,
],
)

fn = cgutils.get_or_insert_function(
module,
llvmir.FunctionType(
atomic_cmpexchg_fn_retty, atomic_cmpexchg_fn_arg_types
),
mangled_fn_name,
)
fn.calling_convention = CC_SPIR_FUNC

return fn
26 changes: 26 additions & 0 deletions numba_dpex/kernel_api/atomic_ref.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,3 +200,29 @@ def exchange(self, val):
old = self._ref[self._index]
self._ref[self._index] = val
return old

def compare_exchange(self, expected, desired, expected_idx=0):
"""Compares the value of the object referenced by the AtomicRef
against the value of ``expected[expected_idx]``.
If the values are equal, replaces the value of the
referenced object with the value of ``desired``.
Otherwise assigns the original value of the
referenced object to ``expected[expected_idx]``.

Args:
expected : Array containing the expected value of the
object referenced by the AtomicRef.
desired : Value that replaces the value of the object
referenced by the AtomicRef.
expected_idx: Offset in `expected` array where the expected
value of the object referenced by the AtomicRef is present.

Returns: Returns ``True`` if the comparison operation and
replacement operation were successful.

"""
if self._ref[self._index] == expected[expected_idx]:
self._ref[self._index] = desired
return True
expected[expected_idx] = self._ref[self._index]
return False
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def _kernel(a, b):
and supported_dtype == dpnp.float64
):
pytest.xfail(
"Atomic load, store, and exchange operations not working "
"Atomic load and store operations not working "
" for fp64 on OpenCL CPU"
)

Expand Down Expand Up @@ -82,8 +82,7 @@ def _kernel(a, b):
and supported_dtype == dpnp.float64
):
pytest.xfail(
"Atomic load, store, and exchange operations not working "
" for fp64 on OpenCL CPU"
"Atomic exchange operation not working " " for fp64 on OpenCL CPU"
)

a_copy = dpnp.copy(a_orig)
Expand All @@ -100,6 +99,43 @@ def _kernel(a, b):
assert b_copy[i] == a_orig[i]


@pytest.mark.parametrize("supported_dtype", list_of_supported_dtypes)
def test_compare_exchange_fns(supported_dtype):
"""A test for compare exchange atomic functions."""

@dpex_exp.kernel
def _kernel(b):
b_ref = AtomicRef(b, index=1)
b[0] = b_ref.compare_exchange(
expected_ref=b, desired=b[3], expected_idx=2
)

b = dpnp.arange(4, dtype=supported_dtype)

dev = b.sycl_device
if (
dev.backend == dpctl.backend_type.opencl
and dev.device_type == dpctl.device_type.cpu
and supported_dtype == dpnp.float64
):
pytest.xfail(
"Atomic compare_exchange operation not working "
" for fp64 on OpenCL CPU"
)

dpex_exp.call_kernel(_kernel, dpex.Range(1), b)

# check for failure
assert b[0] == 0
assert b[2] == b[1]

dpex_exp.call_kernel(_kernel, dpex.Range(1), b)

# check for success
assert b[0] == 1
assert b[1] == b[3]


def test_store_exchange_diff_types(store_exchange_fn):
"""A negative test that verifies that a TypingError is raised if
AtomicRef type and value are of different types.
Expand Down
Loading