Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial set of changes to add a new dpjit decorator to support dpnp arrays #887

Merged
merged 7 commits into from
Jan 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions numba_dpex/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
"""
import numba.testing

import numba_dpex.core.offload_dispatcher

# Re-export types itself
import numba_dpex.core.types as types

Expand Down
File renamed without changes.
217 changes: 11 additions & 206 deletions numba_dpex/core/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,215 +6,15 @@

from numba.core import compiler, ir
from numba.core import types as numba_types
from numba.core.compiler import CompilerBase
from numba.core.compiler_lock import global_compiler_lock
from numba.core.compiler_machinery import PassManager
from numba.core.typed_passes import (
AnnotateTypes,
InlineOverloads,
IRLegalization,
NopythonRewrites,
NoPythonSupportedFeatureValidation,
NopythonTypeInference,
PreLowerStripPhis,
)
from numba.core.untyped_passes import (
DeadBranchPrune,
FindLiterallyCalls,
FixupArgs,
GenericRewrites,
InlineClosureLikes,
InlineInlinables,
IRProcessing,
LiteralPropagationSubPipelinePass,
LiteralUnroll,
MakeFunctionToJitFunction,
ReconstructSSA,
RewriteSemanticConstants,
TranslateByteCode,
WithLifting,
)

from numba_dpex import config
from numba_dpex.core.exceptions import (
KernelHasReturnValueError,
UnreachableError,
UnsupportedCompilationModeError,
)
from numba_dpex.core.passes.passes import (
ConstantSizeStaticLocalMemoryPass,
DpexLowering,
DumpParforDiagnostics,
NoPythonBackend,
ParforPass,
PreParforPass,
)
from numba_dpex.core.passes.rename_numpy_functions_pass import (
RewriteNdarrayFunctionsPass,
RewriteOverloadedNumPyFunctionsPass,
)
from numba_dpex.parfor_diagnostics import ExtendedParforDiagnostics


class PassBuilder(object):
"""
A pass builder to run dpex's code-generation and optimization passes.

Unlike Numba, dpex's pass builder does not offer objectmode and
interpreted passes.
"""

@staticmethod
def define_untyped_pipeline(state, name="dpex_untyped"):
"""Returns an untyped part of the nopython pipeline

The pipeline of untyped passes is duplicated from Numba's compiler. We
are adding couple of passes to the pipeline to change specific numpy
overloads.
"""
pm = PassManager(name)
if state.func_ir is None:
pm.add_pass(TranslateByteCode, "analyzing bytecode")
pm.add_pass(FixupArgs, "fix up args")
pm.add_pass(IRProcessing, "processing IR")
pm.add_pass(WithLifting, "Handle with contexts")

# --- Begin dpex passes added to the untyped pipeline --#

# The RewriteOverloadedNumPyFunctionsPass rewrites the module namespace
# of specific NumPy functions to dpnp, as we overload these functions
# differently.
pm.add_pass(
RewriteOverloadedNumPyFunctionsPass,
"Rewrite name of Numpy functions to overload already overloaded "
+ "function",
)
# Add pass to ensure when users allocate static constant memory the
# size of the allocation is a constant and not specified by a closure
# variable.
pm.add_pass(
ConstantSizeStaticLocalMemoryPass,
"dpex constant size for static local memory",
)

# --- End of dpex passes added to the untyped pipeline --#

# inline closures early in case they are using nonlocal's
# see issue #6585.
pm.add_pass(
InlineClosureLikes, "inline calls to locally defined closures"
)

# pre typing
if not state.flags.no_rewrites:
pm.add_pass(RewriteSemanticConstants, "rewrite semantic constants")
pm.add_pass(DeadBranchPrune, "dead branch pruning")
pm.add_pass(GenericRewrites, "nopython rewrites")

# convert any remaining closures into functions
pm.add_pass(
MakeFunctionToJitFunction,
"convert make_function into JIT functions",
)
# inline functions that have been determined as inlinable and rerun
# branch pruning, this needs to be run after closures are inlined as
# the IR repr of a closure masks call sites if an inlinable is called
# inside a closure
pm.add_pass(InlineInlinables, "inline inlinable functions")
if not state.flags.no_rewrites:
pm.add_pass(DeadBranchPrune, "dead branch pruning")

pm.add_pass(FindLiterallyCalls, "find literally calls")
pm.add_pass(LiteralUnroll, "handles literal_unroll")

if state.flags.enable_ssa:
pm.add_pass(ReconstructSSA, "ssa")

pm.add_pass(LiteralPropagationSubPipelinePass, "Literal propagation")

pm.finalize()
return pm

@staticmethod
def define_typed_pipeline(state, name="dpex_typed"):
"""Returns the typed part of the nopython pipeline"""
pm = PassManager(name)
# typing
pm.add_pass(NopythonTypeInference, "nopython frontend")
# Annotate only once legalized
pm.add_pass(AnnotateTypes, "annotate types")
pm.add_pass(
RewriteNdarrayFunctionsPass,
"Rewrite numpy.ndarray functions to dpnp.ndarray functions",
)

# strip phis
pm.add_pass(PreLowerStripPhis, "remove phis nodes")

# optimization
pm.add_pass(InlineOverloads, "inline overloaded functions")
pm.add_pass(PreParforPass, "Preprocessing for parfors")
if not state.flags.no_rewrites:
pm.add_pass(NopythonRewrites, "nopython rewrites")
pm.add_pass(ParforPass, "convert to parfors")

pm.finalize()
return pm

@staticmethod
def define_nopython_lowering_pipeline(state, name="dpex_nopython_lowering"):
"""Returns an nopython mode pipeline based PassManager"""
pm = PassManager(name)

# legalize
pm.add_pass(
NoPythonSupportedFeatureValidation,
"ensure features that are in use are in a valid form",
)
pm.add_pass(IRLegalization, "ensure IR is legal prior to lowering")

# lower
pm.add_pass(DpexLowering, "Custom Lowerer with auto-offload support")
pm.add_pass(NoPythonBackend, "nopython mode backend")
pm.add_pass(DumpParforDiagnostics, "dump parfor diagnostics")

pm.finalize()
return pm

@staticmethod
def define_nopython_pipeline(state, name="dpex_nopython"):
"""Returns an nopython mode pipeline based PassManager"""
# compose pipeline from untyped, typed and lowering parts
dpb = PassBuilder
pm = PassManager(name)
untyped_passes = dpb.define_untyped_pipeline(state)
pm.passes.extend(untyped_passes.passes)

typed_passes = dpb.define_typed_pipeline(state)
pm.passes.extend(typed_passes.passes)

lowering_passes = dpb.define_nopython_lowering_pipeline(state)
pm.passes.extend(lowering_passes.passes)

pm.finalize()
return pm


class Compiler(CompilerBase):
"""Dpex's compiler pipeline."""

def define_pipelines(self):
# this maintains the objmode fallback behaviour
pms = []
self.state.parfor_diagnostics = ExtendedParforDiagnostics()
self.state.metadata[
"parfor_diagnostics"
] = self.state.parfor_diagnostics
if not self.state.flags.force_pyobject:
pms.append(PassBuilder.define_nopython_pipeline(self.state))
if self.state.status.can_fallback or self.state.flags.force_pyobject:
raise UnsupportedCompilationModeError()
return pms
from numba_dpex.core.pipelines.kernel_compiler import KernelCompiler
from numba_dpex.core.pipelines.offload_compiler import OffloadCompiler


@global_compiler_lock
Expand All @@ -230,8 +30,8 @@ def compile_with_dpex(
extra_compile_flags=None,
):
"""
Compiles a function using the dpex compiler pipeline and returns the
compiled result.
Compiles a function using class:`numba_dpex.core.pipelines.KernelCompiler`
and returns the compiled result.

Args:
args: The list of arguments passed to the kernel.
Expand All @@ -254,6 +54,7 @@ def compile_with_dpex(
flags.debuginfo = config.DEBUGINFO_DEFAULT
flags.no_compile = True
flags.no_cpython_wrapper = True
flags.no_cfunc_wrapper = True
flags.nrt = False

if debug:
Expand All @@ -269,9 +70,13 @@ def compile_with_dpex(
return_type=return_type,
flags=flags,
locals={},
pipeline_class=Compiler,
pipeline_class=KernelCompiler,
)
elif isinstance(pyfunc, ir.FunctionIR):
# FIXME: Kernels in the form of Numba IR need to be compiled
# using the offload compiler due to them retaining parfor
# nodes due to the use of gufuncs. Once the kernel builder is
# ready we should be able to switch to the KernelCompiler.
cres = compiler.compile_ir(
typingctx=typingctx,
targetctx=targetctx,
Expand All @@ -280,7 +85,7 @@ def compile_with_dpex(
return_type=return_type,
flags=flags,
locals={},
pipeline_class=Compiler,
pipeline_class=OffloadCompiler,
)
else:
raise UnreachableError()
Expand Down
56 changes: 50 additions & 6 deletions numba_dpex/core/descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,34 @@
#
# SPDX-License-Identifier: Apache-2.0

from numba.core import utils
from numba.core import typing, utils
from numba.core.cpu import CPUTargetOptions
from numba.core.descriptors import TargetDescriptor

from .target import DPEX_TARGET_NAME, DpexTargetContext, DpexTypingContext
from .targets.dpjit_target import DPEX_TARGET_NAME, DpexTargetContext
from .targets.kernel_target import (
DPEX_KERNEL_TARGET_NAME,
DpexKernelTargetContext,
DpexKernelTypingContext,
)


class DpexTarget(TargetDescriptor):
class DpexKernelTarget(TargetDescriptor):
"""
Implements a target descriptor for numba_dpex.kernel decorated functions.
"""

options = CPUTargetOptions

@utils.cached_property
def _toplevel_target_context(self):
"""Lazily-initialized top-level target context, for all threads."""
return DpexTargetContext(self.typing_context, self._target_name)
return DpexKernelTargetContext(self.typing_context, self._target_name)

@utils.cached_property
def _toplevel_typing_context(self):
"""Lazily-initialized top-level typing context, for all threads."""
return DpexTypingContext()
return DpexKernelTypingContext()

@property
def target_context(self):
Expand All @@ -37,5 +46,40 @@ def typing_context(self):
return self._toplevel_typing_context


# The global Dpex target
class DpexTarget(TargetDescriptor):
"""
Implements a target descriptor for numba_dpex.dpjit decorated functions.
"""

options = CPUTargetOptions

@utils.cached_property
def _toplevel_target_context(self):
# Lazily-initialized top-level target context, for all threads
return DpexTargetContext(self.typing_context, self._target_name)

@utils.cached_property
def _toplevel_typing_context(self):
# Lazily-initialized top-level typing context, for all threads
return typing.Context()

@property
def target_context(self):
"""
The target context for dpex targets.
"""
return self._toplevel_target_context

@property
def typing_context(self):
"""
The typing context for dpex targets.
"""
return self._toplevel_typing_context


# A global instance of the DpexKernelTarget
dpex_kernel_target = DpexKernelTarget(DPEX_KERNEL_TARGET_NAME)

# A global instance of the DpexTarget
dpex_target = DpexTarget(DPEX_TARGET_NAME)
35 changes: 35 additions & 0 deletions numba_dpex/core/dpjit_dispatcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# SPDX-FileCopyrightText: 2020 - 2022 Intel Corporation
#
# SPDX-License-Identifier: Apache-2.0


from numba.core import compiler, dispatcher
from numba.core.target_extension import dispatcher_registry, target_registry

from numba_dpex.core.targets.dpjit_target import DPEX_TARGET_NAME

from .descriptor import dpex_target


class DpjitDispatcher(dispatcher.Dispatcher):
targetdescr = dpex_target

def __init__(
self,
py_func,
locals={},
targetoptions={},
impl_kind="direct",
pipeline_class=compiler.Compiler,
):
dispatcher.Dispatcher.__init__(
self,
py_func,
locals=locals,
targetoptions=targetoptions,
impl_kind=impl_kind,
pipeline_class=pipeline_class,
)


dispatcher_registry[target_registry[DPEX_TARGET_NAME]] = DpjitDispatcher
Loading