Skip to content

Commit

Permalink
Move ocl._declare_function into core.utils.cgutils.extra
Browse files Browse the repository at this point in the history
  • Loading branch information
Diptorup Deb committed Apr 4, 2024
1 parent a455e46 commit 049db98
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 56 deletions.
40 changes: 40 additions & 0 deletions numba_dpex/core/utils/cgutils_extra.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from llvmlite import ir as llvmir
from numba.core import cgutils, types

from numba_dpex.core.utils.itanium_mangler import mangle_c


class LLVMTypes:
"""
Expand All @@ -21,6 +23,44 @@ class LLVMTypes:
void_t = llvmir.VoidType()


def declare_function(context, builder, name, sig, cargs, mangler=mangle_c):
"""Insert declaration for a opencl builtin function.
Uses the Itanium mangler.
Args
----
context: target context
builder: llvm builder
name: str
symbol name
sig: signature
function signature of the symbol being declared
cargs: sequence of str
C type names for the arguments
mangler: a mangler function
function to use to mangle the symbol
"""
mod = builder.module
if sig.return_type == types.void:
llretty = llvmir.VoidType()
else:
llretty = context.get_value_type(sig.return_type)
llargs = [context.get_value_type(t) for t in sig.args]
fnty = llvmir.FunctionType(llretty, llargs)
mangled = mangler(name, cargs)
fn = cgutils.get_or_insert_function(mod, fnty, mangled)
from numba_dpex import spirv_kernel_target

fn.calling_convention = spirv_kernel_target.CC_SPIR_FUNC
return fn


def get_llvm_type(context, type):
"""Returns the LLVM Value corresponding to a Numba type.
Expand Down
48 changes: 0 additions & 48 deletions numba_dpex/ocl/_declare_function.py

This file was deleted.

17 changes: 9 additions & 8 deletions numba_dpex/ocl/mathimpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,11 @@
from numba.core import types
from numba.core.imputils import Registry

from numba_dpex.core.utils.itanium_mangler import mangle

from ._declare_function import _declare_function
from numba_dpex.core.utils import cgutils_extra, itanium_mangler

registry = Registry()
lower = registry.lower

# -----------------------------------------------------------------------------

_unary_b_f = types.int32(types.float32)
_unary_b_d = types.int32(types.float64)
_unary_f_f = types.float32(types.float32)
Expand Down Expand Up @@ -88,16 +84,21 @@


# some functions may be named differently by the underlying math
# library as oposed to the Python name.
# library as opposed to the Python name.
_lib_counterpart = {"gamma": "tgamma"}


def _mk_fn_decl(name, decl_sig):
sym = _lib_counterpart.get(name, name)

def core(context, builder, sig, args):
fn = _declare_function(
context, builder, sym, decl_sig, decl_sig.args, mangler=mangle
fn = cgutils_extra.declare_function(
context,
builder,
sym,
decl_sig,
decl_sig.args,
mangler=itanium_mangler.mangle,
)
res = builder.call(fn, args)
return context.cast(builder, res, decl_sig.return_type, sig.return_type)
Expand Down

0 comments on commit 049db98

Please sign in to comment.