Skip to content

Commit

Permalink
Merge pull request #1297 from IntelPython/experimental/ld_str_excg_ols
Browse files Browse the repository at this point in the history
Implementations for atomic load, store and exchange operations
  • Loading branch information
Diptorup Deb authored Jan 31, 2024
2 parents 57190e9 + fb916a2 commit f59d9e8
Show file tree
Hide file tree
Showing 4 changed files with 454 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@
get_memory_semantics_mask,
get_scope,
)
from .spv_fn_generator import (
get_or_insert_atomic_load_fn,
get_or_insert_spv_atomic_exchange_fn,
get_or_insert_spv_atomic_store_fn,
)


def _parse_enum_or_int_literal_(literal_int) -> int:
Expand Down Expand Up @@ -209,6 +214,107 @@ def codegen(context, builder, sig, args):
)


@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME)
def _intrinsic_load(
ty_context, ty_atomic_ref # pylint: disable=unused-argument
):
sig = ty_atomic_ref.dtype(ty_atomic_ref)

def _intrinsic_load_gen(context, builder, sig, args):
atomic_ref_ty = sig.args[0]
fn = get_or_insert_atomic_load_fn(
context, builder.module, atomic_ref_ty
)

spirv_memory_semantics_mask = get_memory_semantics_mask(
atomic_ref_ty.memory_order
)
spirv_scope = get_scope(atomic_ref_ty.memory_scope)

fn_args = [
builder.extract_value(
args[0],
context.data_model_manager.lookup(
atomic_ref_ty
).get_field_position("ref"),
),
context.get_constant(types.int32, spirv_scope),
context.get_constant(types.int32, spirv_memory_semantics_mask),
]

return builder.call(fn, fn_args)

return sig, _intrinsic_load_gen


@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME)
def _intrinsic_store(
ty_context, ty_atomic_ref, ty_val
): # pylint: disable=unused-argument
sig = types.void(ty_atomic_ref, ty_val)

def _intrinsic_store_gen(context, builder, sig, args):
atomic_ref_ty = sig.args[0]
atomic_store_fn = get_or_insert_spv_atomic_store_fn(
context, builder.module, atomic_ref_ty
)

atomic_store_fn_args = [
builder.extract_value(
args[0],
context.data_model_manager.lookup(
atomic_ref_ty
).get_field_position("ref"),
),
context.get_constant(
types.int32, get_scope(atomic_ref_ty.memory_scope)
),
context.get_constant(
types.int32,
get_memory_semantics_mask(atomic_ref_ty.memory_order),
),
args[1],
]

builder.call(atomic_store_fn, atomic_store_fn_args)

return sig, _intrinsic_store_gen


@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME)
def _intrinsic_exchange(
ty_context, ty_atomic_ref, ty_val # pylint: disable=unused-argument
):
sig = ty_atomic_ref.dtype(ty_atomic_ref, ty_val)

def _intrinsic_exchange_gen(context, builder, sig, args):
atomic_ref_ty = sig.args[0]
atomic_exchange_fn = get_or_insert_spv_atomic_exchange_fn(
context, builder.module, atomic_ref_ty
)

atomic_exchange_fn_args = [
builder.extract_value(
args[0],
context.data_model_manager.lookup(
atomic_ref_ty
).get_field_position("ref"),
),
context.get_constant(
types.int32, get_scope(atomic_ref_ty.memory_scope)
),
context.get_constant(
types.int32,
get_memory_semantics_mask(atomic_ref_ty.memory_order),
),
args[1],
]

return builder.call(atomic_exchange_fn, atomic_exchange_fn_args)

return sig, _intrinsic_exchange_gen


def _check_if_supported_ref(ref):
supported = True

Expand Down Expand Up @@ -516,3 +622,72 @@ def ol_fetch_xor_impl(atomic_ref, val):
return _intrinsic_fetch_xor(atomic_ref, val)

return ol_fetch_xor_impl


@overload_method(AtomicRefType, "load", target=DPEX_KERNEL_EXP_TARGET_NAME)
def ol_load(atomic_ref): # pylint: disable=unused-argument
"""SPIR-V overload for
:meth:`numba_dpex.experimental.kernel_iface.AtomicRef.load`.
Generates the same LLVM IR instruction as dpcpp for the
`atomic_ref::load` function.
"""

def ol_load_impl(atomic_ref):
# pylint: disable=no-value-for-parameter
return _intrinsic_load(atomic_ref)

return ol_load_impl


@overload_method(AtomicRefType, "store", target=DPEX_KERNEL_EXP_TARGET_NAME)
def ol_store(atomic_ref, val):
"""SPIR-V overload for
:meth:`numba_dpex.experimental.kernel_iface.AtomicRef.store`.
Generates the same LLVM IR instruction as dpcpp for the
`atomic_ref::store` function.
Raises:
TypingError: When the dtype of the value stored does not match the
dtype of the AtomicRef type.
"""

if atomic_ref.dtype != val:
raise errors.TypingError(
f"Type of value to store: {val} does not match the type of the "
f"reference: {atomic_ref.dtype} stored in the atomic ref."
)

def ol_store_impl(atomic_ref, val):
# pylint: disable=no-value-for-parameter
return _intrinsic_store(atomic_ref, val)

return ol_store_impl


@overload_method(AtomicRefType, "exchange", target=DPEX_KERNEL_EXP_TARGET_NAME)
def ol_exchange(atomic_ref, val):
"""SPIR-V overload for
:meth:`numba_dpex.experimental.kernel_iface.AtomicRef.exchange`.
Generates the same LLVM IR instruction as dpcpp for the
`atomic_ref::exchange` function.
Raises:
TypingError: When the dtype of the value passed to `exchange`
does not match the dtype of the AtomicRef type.
"""

if atomic_ref.dtype != val:
raise errors.TypingError(
f"Type of value to exchange: {val} does not match the type of the "
f"reference: {atomic_ref.dtype} stored in the atomic ref."
)

def ol_exchange_impl(atomic_ref, val):
# pylint: disable=no-value-for-parameter
return _intrinsic_exchange(atomic_ref, val)

return ol_exchange_impl
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# SPDX-FileCopyrightText: 2023 - 2024 Intel Corporation
#
# SPDX-License-Identifier: Apache-2.0

"""
Implements a set of helper functions to generate the LLVM IR for SPIR-V
functions and their use inside an LLVM module.
"""

from llvmlite import ir as llvmir
from numba.core import cgutils, types

from numba_dpex.core import itanium_mangler as ext_itanium_mangler
from numba_dpex.core.targets.kernel_target import CC_SPIR_FUNC


def get_or_insert_atomic_load_fn(context, module, atomic_ref_ty):
"""
Gets or inserts a declaration for a __spirv_AtomicLoad call into the
specified LLVM IR module.
"""
atomic_ref_dtype = atomic_ref_ty.dtype
atomic_load_fn_retty = context.get_value_type(atomic_ref_dtype)
ptr_type = atomic_load_fn_retty.as_pointer()
ptr_type.addrspace = atomic_ref_ty.address_space
atomic_load_fn_arg_types = [
ptr_type,
llvmir.IntType(32),
llvmir.IntType(32),
]
mangled_fn_name = ext_itanium_mangler.mangle_ext(
"__spirv_AtomicLoad",
[
types.CPointer(atomic_ref_dtype, addrspace=ptr_type.addrspace),
"__spv.Scope.Flag",
"__spv.MemorySemanticsMask.Flag",
],
)

fn = cgutils.get_or_insert_function(
module,
llvmir.FunctionType(atomic_load_fn_retty, atomic_load_fn_arg_types),
mangled_fn_name,
)
fn.calling_convention = CC_SPIR_FUNC

return fn


def get_or_insert_spv_atomic_store_fn(context, module, atomic_ref_ty):
"""
Gets or inserts a declaration for a __spirv_AtomicStore call into the
specified LLVM IR module.
"""
atomic_ref_dtype = atomic_ref_ty.dtype
ptr_type = context.get_value_type(atomic_ref_dtype).as_pointer()
ptr_type.addrspace = atomic_ref_ty.address_space
atomic_store_fn_retty = llvmir.VoidType()
atomic_store_fn_arg_types = [
ptr_type,
llvmir.IntType(32),
llvmir.IntType(32),
context.get_value_type(atomic_ref_dtype),
]

mangled_fn_name = ext_itanium_mangler.mangle_ext(
"__spirv_AtomicStore",
[
types.CPointer(atomic_ref_dtype, addrspace=ptr_type.addrspace),
"__spv.Scope.Flag",
"__spv.MemorySemanticsMask.Flag",
atomic_ref_dtype,
],
)

fn = cgutils.get_or_insert_function(
module,
llvmir.FunctionType(atomic_store_fn_retty, atomic_store_fn_arg_types),
mangled_fn_name,
)
fn.calling_convention = CC_SPIR_FUNC

return fn


def get_or_insert_spv_atomic_exchange_fn(context, module, atomic_ref_ty):
"""
Gets or inserts a declaration for a __spirv_AtomicExchange call into the
specified LLVM IR module.
"""
atomic_ref_dtype = atomic_ref_ty.dtype
ptr_type = context.get_value_type(atomic_ref_dtype).as_pointer()
ptr_type.addrspace = atomic_ref_ty.address_space
atomic_exchange_fn_retty = context.get_value_type(atomic_ref_ty.dtype)
atomic_exchange_fn_arg_types = [
ptr_type,
llvmir.IntType(32),
llvmir.IntType(32),
context.get_value_type(atomic_ref_dtype),
]

mangled_fn_name = ext_itanium_mangler.mangle_ext(
"__spirv_AtomicExchange",
[
types.CPointer(atomic_ref_dtype, addrspace=ptr_type.addrspace),
"__spv.Scope.Flag",
"__spv.MemorySemanticsMask.Flag",
atomic_ref_dtype,
],
)

fn = cgutils.get_or_insert_function(
module,
llvmir.FunctionType(
atomic_exchange_fn_retty, atomic_exchange_fn_arg_types
),
mangled_fn_name,
)
fn.calling_convention = CC_SPIR_FUNC

return fn
Loading

0 comments on commit f59d9e8

Please sign in to comment.