Skip to content

Commit

Permalink
Merge pull request #1112 from IntelPython/fix/kernel_func_name
Browse files Browse the repository at this point in the history
Generate proper mangled name for kernel functions
  • Loading branch information
Diptorup Deb authored Aug 18, 2023
2 parents 4d332e8 + d3d7ef6 commit 547382d
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 9 deletions.
1 change: 1 addition & 0 deletions numba_dpex/core/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def _optimize_final_module(self):
pmb.opt_level = config.OPT

pmb.disable_unit_at_a_time = False
pmb.inlining_threshold = 2
pmb.disable_unroll_loops = True
pmb.loop_vectorize = False
pmb.slp_vectorize = False
Expand Down
1 change: 1 addition & 0 deletions numba_dpex/core/kernel_interface/spirv_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def compile(
kernel = cres.target_context.prepare_ocl_kernel(
func, cres.signature.args
)
cres.library._optimize_final_module()
self._llvm_module = kernel.module.__str__()
self._module_name = kernel.name

Expand Down
14 changes: 5 additions & 9 deletions numba_dpex/core/targets/kernel_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from llvmlite import binding as ll
from llvmlite import ir as llvmir
from numba import typeof
from numba.core import cgutils, types, typing, utils
from numba.core import cgutils, funcdesc, types, typing, utils
from numba.core.base import BaseContext
from numba.core.callconv import MinimalCallConv
from numba.core.registry import cpu_target
Expand Down Expand Up @@ -240,7 +240,7 @@ def _generate_kernel_wrapper(self, func, argtypes):
llvmir.VoidType(), arginfo.argument_types
)
wrapper_module = self.create_module("dpex.kernel.wrapper")
wrappername = "dpexPy_{name}".format(name=func.name)
wrappername = func.name.replace("dpex_fn", "dpex_kernel")
argtys = list(arginfo.argument_types)
fnty = llvmir.FunctionType(
llvmir.IntType(32),
Expand Down Expand Up @@ -373,13 +373,9 @@ def target_data(self):
return self._target_data

def mangler(self, name, argtypes, abi_tags=(), uid=None):
def repl(m):
ch = m.group(0)
return "_%X_" % ord(ch)

qualified = name + "." + ".".join(str(a) for a in argtypes)
mangled = VALID_CHARS.sub(repl, qualified)
return "dpex_py_devfn_" + mangled
return funcdesc.default_mangler(
name + "dpex_fn", argtypes, abi_tags=abi_tags, uid=uid
)

def prepare_ocl_kernel(self, func, argtypes):
module = func.module
Expand Down
14 changes: 14 additions & 0 deletions numba_dpex/core/types/usm_ndarray_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,3 +222,17 @@ def as_array(self):
@property
def box_type(self):
return dpctl.tensor.usm_ndarray

@property
def mangling_args(self):
"""Returns a list of parameters used to create a mangled name for a
USMNdArray type.
"""
filter_str_splits = self.device.split(":")
args = [
self.dtype,
self.ndim,
self.layout,
filter_str_splits[0] + "_" + filter_str_splits[1],
]
return self.__class__.__name__, args

0 comments on commit 547382d

Please sign in to comment.