Skip to content

Commit

Permalink
implementations for atomic load, store and exchange; test cases for e…
Browse files Browse the repository at this point in the history
…ach impl
  • Loading branch information
adarshyoga committed Jan 26, 2024
1 parent 6ec9c63 commit 8ead4e8
Show file tree
Hide file tree
Showing 3 changed files with 362 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,157 @@ def codegen(context, builder, sig, args):
)


@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME)
def _intrinsic_load(
ty_context, ty_atomic_ref # pylint: disable=unused-argument
):
sig = ty_atomic_ref.dtype(ty_atomic_ref)

def _intrinsic_load_gen(context, builder, sig, args):
atomic_ref_ty = sig.args[0]
atomic_ref_dtype = atomic_ref_ty.dtype
retty = context.get_value_type(atomic_ref_dtype)

data_attr_pos = context.data_model_manager.lookup(
atomic_ref_ty
).get_field_position("ref")

ptr_type = retty.as_pointer()
ptr_type.addrspace = atomic_ref_ty.address_space

spirv_fn_arg_types = [
ptr_type,
llvmir.IntType(32),
llvmir.IntType(32),
]

mangled_fn_name = ext_itanium_mangler.mangle_ext(
"__spirv_AtomicLoad",
[
types.CPointer(atomic_ref_dtype, addrspace=ptr_type.addrspace),
"__spv.Scope.Flag",
"__spv.MemorySemanticsMask.Flag",
],
)

fn = cgutils.get_or_insert_function(
builder.module,
llvmir.FunctionType(retty, spirv_fn_arg_types),
mangled_fn_name,
)
fn.calling_convention = CC_SPIR_FUNC
spirv_memory_semantics_mask = get_memory_semantics_mask(
atomic_ref_ty.memory_order
)
spirv_scope = get_scope(atomic_ref_ty.memory_scope)

fn_args = [
builder.extract_value(args[0], data_attr_pos),
context.get_constant(types.int32, spirv_scope),
context.get_constant(types.int32, spirv_memory_semantics_mask),
]

return builder.call(fn, fn_args)

return sig, _intrinsic_load_gen


def _store_exchange_intrisic_helper(context, builder, sig, ol_info: dict):
atomic_ref_ty = sig.args[0]
atomic_ref_dtype = atomic_ref_ty.dtype

ptr_type = context.get_value_type(atomic_ref_dtype).as_pointer()
ptr_type.addrspace = atomic_ref_ty.address_space

spirv_fn_arg_types = [
ptr_type,
llvmir.IntType(32),
llvmir.IntType(32),
context.get_value_type(atomic_ref_dtype),
]

mangled_fn_name = ext_itanium_mangler.mangle_ext(
ol_info["name"],
[
types.CPointer(atomic_ref_dtype, addrspace=ptr_type.addrspace),
"__spv.Scope.Flag",
"__spv.MemorySemanticsMask.Flag",
atomic_ref_dtype,
],
)

fn = cgutils.get_or_insert_function(
builder.module,
llvmir.FunctionType(ol_info["retty"], spirv_fn_arg_types),
mangled_fn_name,
)
fn.calling_convention = CC_SPIR_FUNC

fn_args = [
builder.extract_value(
ol_info["args"][0],
context.data_model_manager.lookup(atomic_ref_ty).get_field_position(
"ref"
),
),
context.get_constant(
types.int32, get_scope(atomic_ref_ty.memory_scope)
),
context.get_constant(
types.int32, get_memory_semantics_mask(atomic_ref_ty.memory_order)
),
ol_info["args"][1],
]

return builder.call(fn, fn_args)


@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME)
def _intrinsic_store(
ty_context, ty_atomic_ref, ty_val
): # pylint: disable=unused-argument
sig = types.void(ty_atomic_ref, ty_val)

def _intrinsic_store_gen(context, builder, sig, args):
_store_exchange_intrisic_helper(
context,
builder,
sig,
# dict containing arguments, return type,
# spirv fn name driven by pylint too-many-args
{
"args": args,
"retty": llvmir.VoidType(),
"name": "__spirv_AtomicStore",
},
)

return sig, _intrinsic_store_gen


@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME)
def _intrinsic_exchange(
ty_context, ty_atomic_ref, ty_val # pylint: disable=unused-argument
):
sig = ty_atomic_ref.dtype(ty_atomic_ref, ty_val)

def _intrinsic_exchange_gen(context, builder, sig, args):
return _store_exchange_intrisic_helper(
context,
builder,
sig,
# dict containing arguments, return type,
# spirv fn name driven by pylint too-many-args
{
"args": args,
"retty": context.get_value_type(sig.args[0].dtype),
"name": "__spirv_AtomicExchange",
},
)

return sig, _intrinsic_exchange_gen


def _check_if_supported_ref(ref):
supported = True

Expand Down Expand Up @@ -516,3 +667,72 @@ def ol_fetch_xor_impl(atomic_ref, val):
return _intrinsic_fetch_xor(atomic_ref, val)

return ol_fetch_xor_impl


@overload_method(AtomicRefType, "load", target=DPEX_KERNEL_EXP_TARGET_NAME)
def ol_load(atomic_ref): # pylint: disable=unused-argument
"""SPIR-V overload for
:meth:`numba_dpex.experimental.kernel_iface.AtomicRef.load`.
Generates the same LLVM IR instruction as dpcpp for the
`atomic_ref::load` function.
"""

def ol_load_impl(atomic_ref):
# pylint: disable=no-value-for-parameter
return _intrinsic_load(atomic_ref)

return ol_load_impl


@overload_method(AtomicRefType, "store", target=DPEX_KERNEL_EXP_TARGET_NAME)
def ol_store(atomic_ref, val):
"""SPIR-V overload for
:meth:`numba_dpex.experimental.kernel_iface.AtomicRef.store`.
Generates the same LLVM IR instruction as dpcpp for the
`atomic_ref::store` function.
Raises:
TypingError: When the dtype of the value stored does not match the
dtype of the AtomicRef type.
"""

if atomic_ref.dtype != val:
raise errors.TypingError(
f"Type of value to store: {val} does not match the type of the "
f"reference: {atomic_ref.dtype} stored in the atomic ref."
)

def ol_store_impl(atomic_ref, val):
# pylint: disable=no-value-for-parameter
return _intrinsic_store(atomic_ref, val)

return ol_store_impl


@overload_method(AtomicRefType, "exchange", target=DPEX_KERNEL_EXP_TARGET_NAME)
def ol_exchange(atomic_ref, val):
"""SPIR-V overload for
:meth:`numba_dpex.experimental.kernel_iface.AtomicRef.exchange`.
Generates the same LLVM IR instruction as dpcpp for the
`atomic_ref::exchange` function.
Raises:
TypingError: When the dtype of the value passed to `exchange`
does not match the dtype of the AtomicRef type.
"""

if atomic_ref.dtype != val:
raise errors.TypingError(
f"Type of value to exchange: {val} does not match the type of the "
f"reference: {atomic_ref.dtype} stored in the atomic ref."
)

def ol_exchange_impl(atomic_ref, val):
# pylint: disable=no-value-for-parameter
return _intrinsic_exchange(atomic_ref, val)

return ol_exchange_impl
45 changes: 39 additions & 6 deletions numba_dpex/experimental/kernel_iface/atomic_ref.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def fetch_sub(self, val):
Returns the original value of the object.
Args:
val : Value to be subtracted to the object referenced by the
val : Value to be subtracted from the object referenced by the
AtomicRef.
Returns: The original value of the object referenced by the AtomicRef.
Expand All @@ -94,7 +94,7 @@ def fetch_min(self, val):
referenced object. Returns the original value of the object.
Args:
val : Value to be compared against to the object referenced by the
val : Value to be compared against the object referenced by the
AtomicRef.
Returns: The original value of the object referenced by the AtomicRef.
Expand All @@ -110,7 +110,7 @@ def fetch_max(self, val):
referenced object. Returns the original value of the object.
Args:
val : Value to be compared against to the object referenced by the
val : Value to be compared against the object referenced by the
AtomicRef.
Returns: The original value of the object referenced by the AtomicRef.
Expand All @@ -126,7 +126,7 @@ def fetch_and(self, val):
referenced object. Returns the original value of the object.
Args:
val : Value to be bitwise ANDed against to the object referenced by
val : Value to be bitwise ANDed against the object referenced by
the AtomicRef.
Returns: The original value of the object referenced by the AtomicRef.
Expand All @@ -142,7 +142,7 @@ def fetch_or(self, val):
referenced object. Returns the original value of the object.
Args:
val : Value to be bitwise ORed against to the object referenced by
val : Value to be bitwise ORed against the object referenced by
the AtomicRef.
Returns: The original value of the object referenced by the AtomicRef.
Expand All @@ -158,7 +158,7 @@ def fetch_xor(self, val):
referenced object. Returns the original value of the object.
Args:
val : Value to be bitwise XORed against to the object referenced by
val : Value to be bitwise XORed against the object referenced by
the AtomicRef.
Returns: The original value of the object referenced by the AtomicRef.
Expand All @@ -167,3 +167,36 @@ def fetch_xor(self, val):
old = self._ref[self._index]
self._ref[self._index] ^= val
return old

def load(self):
"""Loads the value of the object referenced by the AtomicRef.
Returns: The value of the object referenced by the AtomicRef.
"""
return self._ref[self._index]

def store(self, val):
"""Stores operand ``val`` to the object referenced by the AtomicRef.
Args:
val : Value to be stored in the object referenced by
the AtomicRef.
"""
self._ref[self._index] = val

def exchange(self, val):
"""Replaces the value of the object referenced by the AtomicRef
with value of ``val``. Returns the original value of the referenced object.
Args:
val : Value to be exchanged against the object referenced by
the AtomicRef.
Returns: The original value of the object referenced by the AtomicRef.
"""
old = self._ref[self._index]
self._ref[self._index] = val
return old
Loading

0 comments on commit 8ead4e8

Please sign in to comment.