Skip to content

Commit

Permalink
Port func decorator to new API.
Browse files Browse the repository at this point in the history
  • Loading branch information
Diptorup Deb committed Nov 22, 2022
1 parent a601d3b commit b43201c
Show file tree
Hide file tree
Showing 5 changed files with 227 additions and 110 deletions.
99 changes: 99 additions & 0 deletions numba_dpex/core/_compile_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# SPDX-FileCopyrightText: 2022 Intel Corporation
#
# SPDX-License-Identifier: Apache-2.0

from types import FunctionType

from numba.core import compiler, ir
from numba.core import types as numba_types
from numba.core.compiler_lock import global_compiler_lock

from numba_dpex import config
from numba_dpex.core import compiler as dpex_compiler
from numba_dpex.core.descriptor import dpex_target
from numba_dpex.core.exceptions import (
KernelHasReturnValueError,
UnreachableError,
)


@global_compiler_lock
def compile_with_dpex(
pyfunc,
pyfunc_name,
args,
return_type,
debug=None,
is_kernel=True,
extra_compile_flags=None,
):
"""
Compiles the function using the dpex compiler pipeline and returns the
compiled result.
Args:
args: The list of arguments passed to the kernel.
debug (bool): Optional flag to turn on debug mode compilation.
extra_compile_flags: Extra flags passed to the compiler.
Returns:
cres: Compiled result.
Raises:
KernelHasReturnValueError: If the compiled function returns a
non-void value.
"""
# First compilation will trigger the initialization of the backend.
typingctx = dpex_target.typing_context
targetctx = dpex_target.target_context

flags = compiler.Flags()
# Do not compile the function to a binary, just lower to LLVM
flags.debuginfo = config.DEBUGINFO_DEFAULT
flags.no_compile = True
flags.no_cpython_wrapper = True
flags.nrt = False

if debug is not None:
flags.debuginfo = debug

# Run compilation pipeline
if isinstance(pyfunc, FunctionType):
cres = compiler.compile_extra(
typingctx=typingctx,
targetctx=targetctx,
func=pyfunc,
args=args,
return_type=return_type,
flags=flags,
locals={},
pipeline_class=dpex_compiler.Compiler,
)
elif isinstance(pyfunc, ir.FunctionIR):
cres = compiler.compile_ir(
typingctx=typingctx,
targetctx=targetctx,
func_ir=pyfunc,
args=args,
return_type=return_type,
flags=flags,
locals={},
pipeline_class=dpex_compiler.Compiler,
)
else:
raise UnreachableError()

if (
is_kernel
and cres.signature.return_type is not None
and cres.signature.return_type != numba_types.void
):
raise KernelHasReturnValueError(
kernel_name=pyfunc_name,
return_type=cres.signature.return_type,
)
# Linking depending libraries
library = cres.library
library.finalize()

return cres
23 changes: 6 additions & 17 deletions numba_dpex/core/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,8 @@ def define_typed_pipeline(state, name="dpex_typed"):
pm = PassManager(name)
# typing
pm.add_pass(NopythonTypeInference, "nopython frontend")

# Annotate only once legalized
pm.add_pass(AnnotateTypes, "annotate types")
pm.add_pass(
RewriteNdarrayFunctionsPass,
"Rewrite numpy.ndarray functions to dpnp.ndarray functions",
Expand Down Expand Up @@ -162,9 +163,6 @@ def define_nopython_lowering_pipeline(state, name="dpex_nopython_lowering"):
)
pm.add_pass(IRLegalization, "ensure IR is legal prior to lowering")

# Annotate only once legalized
pm.add_pass(AnnotateTypes, "annotate types")

# lower
pm.add_pass(DpexLowering, "Custom Lowerer with auto-offload support")
pm.add_pass(NoPythonBackend, "nopython mode backend")
Expand Down Expand Up @@ -196,23 +194,14 @@ class Compiler(CompilerBase):
"""Dpex's compiler pipeline."""

def define_pipelines(self):
dpb = PassBuilder
pm = PassManager("dpex")

# this maintains the objmode fallback behaviour
pms = []
self.state.parfor_diagnostics = ExtendedParforDiagnostics()
self.state.metadata[
"parfor_diagnostics"
] = self.state.parfor_diagnostics

passes = dpb.define_nopython_pipeline(self.state)
pm.passes.extend(passes.passes)

if not self.state.flags.force_pyobject:
pm.extend(PassBuilder.define_nopython_pipeline(self.state))

pms.append(PassBuilder.define_nopython_pipeline(self.state))
if self.state.status.can_fallback or self.state.flags.force_pyobject:
raise UnsupportedCompilationModeError()

pm.finalize()

return [pm]
return pms
102 changes: 102 additions & 0 deletions numba_dpex/core/kernel_interface/func.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# SPDX-FileCopyrightText: 2022 Intel Corporation
#
# SPDX-License-Identifier: Apache-2.0

"""_summary_
"""


from numba.core.typing.templates import AbstractTemplate, ConcreteTemplate

from numba_dpex.core._compile_helper import compile_with_dpex


def compile_func(pyfunc, return_type, args, debug=None):
cres = compile_with_dpex(
pyfunc=pyfunc,
pyfunc_name=pyfunc.__name__,
return_type=return_type,
args=args,
is_kernel=False,
debug=debug,
)
func = cres.library.get_function(cres.fndesc.llvm_func_name)
cres.target_context.mark_ocl_device(func)
devfn = DpexFunction(cres)

class _function_template(ConcreteTemplate):
key = devfn
cases = [cres.signature]

cres.typing_context.insert_user_function(devfn, _function_template)
libs = [cres.library]
cres.target_context.insert_user_function(devfn, cres.fndesc, libs)
return devfn


def compile_func_template(pyfunc, debug=None):
"""Compile a DpexFunctionTemplate"""
from numba_dpex.core.descriptor import dpex_target

dft = DpexFunctionTemplate(pyfunc, debug=debug)

class _function_template(AbstractTemplate):
key = dft

def generic(self, args, kws):
if kws:
raise AssertionError("No keyword arguments allowed.")
return dft.compile(args)

typingctx = dpex_target.typing_context
typingctx.insert_user_function(dft, _function_template)
return dft


class DpexFunctionTemplate(object):
"""Unmaterialized dpex function"""

def __init__(self, pyfunc, debug=None):
self.py_func = pyfunc
self.debug = debug
self._compileinfos = {}

def compile(self, args):
"""Compile a dpex.func decorated Python function with the given
argument types.
Each signature is compiled once by caching the compiled function inside
this object.
"""
if args not in self._compileinfos:
cres = compile_with_dpex(
pyfunc=self.py_func,
pyfunc_name=self.py_func.__name__,
return_type=None,
args=args,
is_kernel=False,
debug=self.debug,
)
func = cres.library.get_function(cres.fndesc.llvm_func_name)
cres.target_context.mark_ocl_device(func)
first_definition = not self._compileinfos
self._compileinfos[args] = cres
libs = [cres.library]

if first_definition:
# First definition
cres.target_context.insert_user_function(
self, cres.fndesc, libs
)
else:
cres.target_context.add_user_function(self, cres.fndesc, libs)

else:
cres = self._compileinfos[args]

return cres.signature


class DpexFunction(object):
def __init__(self, cres):
self.cres = cres
95 changes: 10 additions & 85 deletions numba_dpex/core/kernel_interface/spirv_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,11 @@
import logging
from types import FunctionType

from numba.core import compiler, ir
from numba.core import types as numba_types
from numba.core.compiler_lock import global_compiler_lock

from numba_dpex import compiler as dpex_compiler
from numba_dpex import config, spirv_generator
from numba_dpex.core.descriptor import dpex_target
from numba_dpex.core.exceptions import (
KernelHasReturnValueError,
UncompiledKernelError,
UnreachableError,
)
from numba.core import ir

from numba_dpex import spirv_generator
from numba_dpex.core import _compile_helper
from numba_dpex.core.exceptions import UncompiledKernelError, UnreachableError

from .kernel_base import KernelInterface

Expand Down Expand Up @@ -46,78 +39,6 @@ def __init__(self, func, func_name) -> None:
else:
raise UnreachableError()

@global_compiler_lock
def _compile(self, args, debug=None, extra_compile_flags=None):
"""
Compiles the function using the dpex compiler pipeline and returns the
compiled result.
Args:
args: The list of arguments passed to the kernel.
debug (bool): Optional flag to turn on debug mode compilation.
extra_compile_flags: Extra flags passed to the compiler.
Returns:
cres: Compiled result.
Raises:
KernelHasReturnValueError: If the compiled function returns a
non-void value.
"""
# First compilation will trigger the initialization of the backend.
typingctx = dpex_target.typing_context
targetctx = dpex_target.target_context

flags = compiler.Flags()
# Do not compile the function to a binary, just lower to LLVM
flags.debuginfo = config.DEBUGINFO_DEFAULT
flags.no_compile = True
flags.no_cpython_wrapper = True
flags.nrt = False

if debug is not None:
flags.debuginfo = debug

# Run compilation pipeline
if isinstance(self._func, FunctionType):
cres = compiler.compile_extra(
typingctx=typingctx,
targetctx=targetctx,
func=self._func,
args=args,
return_type=None,
flags=flags,
locals={},
pipeline_class=dpex_compiler.Compiler,
)
elif isinstance(self._func, ir.FunctionIR):
cres = compiler.compile_ir(
typingctx=typingctx,
targetctx=targetctx,
func_ir=self._func,
args=args,
return_type=None,
flags=flags,
locals={},
pipeline_class=dpex_compiler.Compiler,
)
else:
raise UnreachableError()

if (
cres.signature.return_type is not None
and cres.signature.return_type != numba_types.void
):
raise KernelHasReturnValueError(
kernel_name=self._pyfunc_name,
return_type=cres.signature.return_type,
)
# Linking depending libraries
library = cres.library
library.finalize()

return cres

@property
def llvm_module(self):
"""The LLVM IR Module corresponding to the Kernel instance."""
Expand Down Expand Up @@ -158,9 +79,13 @@ def compile(self, arg_types, debug, extra_compile_flags):

logging.debug("compiling SpirvKernel with arg types", arg_types)

cres = self._compile(
cres = _compile_helper.compile_with_dpex(
self._func,
self._pyfunc_name,
args=arg_types,
return_type=None,
debug=debug,
is_kernel=True,
extra_compile_flags=extra_compile_flags,
)

Expand Down
Loading

0 comments on commit b43201c

Please sign in to comment.