-
Notifications
You must be signed in to change notification settings - Fork 33
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Diptorup Deb
committed
Nov 22, 2022
1 parent
a601d3b
commit b43201c
Showing
5 changed files
with
227 additions
and
110 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
# SPDX-FileCopyrightText: 2022 Intel Corporation | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from types import FunctionType | ||
|
||
from numba.core import compiler, ir | ||
from numba.core import types as numba_types | ||
from numba.core.compiler_lock import global_compiler_lock | ||
|
||
from numba_dpex import config | ||
from numba_dpex.core import compiler as dpex_compiler | ||
from numba_dpex.core.descriptor import dpex_target | ||
from numba_dpex.core.exceptions import ( | ||
KernelHasReturnValueError, | ||
UnreachableError, | ||
) | ||
|
||
|
||
@global_compiler_lock | ||
def compile_with_dpex( | ||
pyfunc, | ||
pyfunc_name, | ||
args, | ||
return_type, | ||
debug=None, | ||
is_kernel=True, | ||
extra_compile_flags=None, | ||
): | ||
""" | ||
Compiles the function using the dpex compiler pipeline and returns the | ||
compiled result. | ||
Args: | ||
args: The list of arguments passed to the kernel. | ||
debug (bool): Optional flag to turn on debug mode compilation. | ||
extra_compile_flags: Extra flags passed to the compiler. | ||
Returns: | ||
cres: Compiled result. | ||
Raises: | ||
KernelHasReturnValueError: If the compiled function returns a | ||
non-void value. | ||
""" | ||
# First compilation will trigger the initialization of the backend. | ||
typingctx = dpex_target.typing_context | ||
targetctx = dpex_target.target_context | ||
|
||
flags = compiler.Flags() | ||
# Do not compile the function to a binary, just lower to LLVM | ||
flags.debuginfo = config.DEBUGINFO_DEFAULT | ||
flags.no_compile = True | ||
flags.no_cpython_wrapper = True | ||
flags.nrt = False | ||
|
||
if debug is not None: | ||
flags.debuginfo = debug | ||
|
||
# Run compilation pipeline | ||
if isinstance(pyfunc, FunctionType): | ||
cres = compiler.compile_extra( | ||
typingctx=typingctx, | ||
targetctx=targetctx, | ||
func=pyfunc, | ||
args=args, | ||
return_type=return_type, | ||
flags=flags, | ||
locals={}, | ||
pipeline_class=dpex_compiler.Compiler, | ||
) | ||
elif isinstance(pyfunc, ir.FunctionIR): | ||
cres = compiler.compile_ir( | ||
typingctx=typingctx, | ||
targetctx=targetctx, | ||
func_ir=pyfunc, | ||
args=args, | ||
return_type=return_type, | ||
flags=flags, | ||
locals={}, | ||
pipeline_class=dpex_compiler.Compiler, | ||
) | ||
else: | ||
raise UnreachableError() | ||
|
||
if ( | ||
is_kernel | ||
and cres.signature.return_type is not None | ||
and cres.signature.return_type != numba_types.void | ||
): | ||
raise KernelHasReturnValueError( | ||
kernel_name=pyfunc_name, | ||
return_type=cres.signature.return_type, | ||
) | ||
# Linking depending libraries | ||
library = cres.library | ||
library.finalize() | ||
|
||
return cres |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
# SPDX-FileCopyrightText: 2022 Intel Corporation | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
"""_summary_ | ||
""" | ||
|
||
|
||
from numba.core.typing.templates import AbstractTemplate, ConcreteTemplate | ||
|
||
from numba_dpex.core._compile_helper import compile_with_dpex | ||
|
||
|
||
def compile_func(pyfunc, return_type, args, debug=None): | ||
cres = compile_with_dpex( | ||
pyfunc=pyfunc, | ||
pyfunc_name=pyfunc.__name__, | ||
return_type=return_type, | ||
args=args, | ||
is_kernel=False, | ||
debug=debug, | ||
) | ||
func = cres.library.get_function(cres.fndesc.llvm_func_name) | ||
cres.target_context.mark_ocl_device(func) | ||
devfn = DpexFunction(cres) | ||
|
||
class _function_template(ConcreteTemplate): | ||
key = devfn | ||
cases = [cres.signature] | ||
|
||
cres.typing_context.insert_user_function(devfn, _function_template) | ||
libs = [cres.library] | ||
cres.target_context.insert_user_function(devfn, cres.fndesc, libs) | ||
return devfn | ||
|
||
|
||
def compile_func_template(pyfunc, debug=None): | ||
"""Compile a DpexFunctionTemplate""" | ||
from numba_dpex.core.descriptor import dpex_target | ||
|
||
dft = DpexFunctionTemplate(pyfunc, debug=debug) | ||
|
||
class _function_template(AbstractTemplate): | ||
key = dft | ||
|
||
def generic(self, args, kws): | ||
if kws: | ||
raise AssertionError("No keyword arguments allowed.") | ||
return dft.compile(args) | ||
|
||
typingctx = dpex_target.typing_context | ||
typingctx.insert_user_function(dft, _function_template) | ||
return dft | ||
|
||
|
||
class DpexFunctionTemplate(object): | ||
"""Unmaterialized dpex function""" | ||
|
||
def __init__(self, pyfunc, debug=None): | ||
self.py_func = pyfunc | ||
self.debug = debug | ||
self._compileinfos = {} | ||
|
||
def compile(self, args): | ||
"""Compile a dpex.func decorated Python function with the given | ||
argument types. | ||
Each signature is compiled once by caching the compiled function inside | ||
this object. | ||
""" | ||
if args not in self._compileinfos: | ||
cres = compile_with_dpex( | ||
pyfunc=self.py_func, | ||
pyfunc_name=self.py_func.__name__, | ||
return_type=None, | ||
args=args, | ||
is_kernel=False, | ||
debug=self.debug, | ||
) | ||
func = cres.library.get_function(cres.fndesc.llvm_func_name) | ||
cres.target_context.mark_ocl_device(func) | ||
first_definition = not self._compileinfos | ||
self._compileinfos[args] = cres | ||
libs = [cres.library] | ||
|
||
if first_definition: | ||
# First definition | ||
cres.target_context.insert_user_function( | ||
self, cres.fndesc, libs | ||
) | ||
else: | ||
cres.target_context.add_user_function(self, cres.fndesc, libs) | ||
|
||
else: | ||
cres = self._compileinfos[args] | ||
|
||
return cres.signature | ||
|
||
|
||
class DpexFunction(object): | ||
def __init__(self, cres): | ||
self.cres = cres |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.