Skip to content

Commit

Permalink
Merge pull request #1314 from IntelPython/refactor/kernel_api_impl
Browse files Browse the repository at this point in the history
A new _kernel_api_impl module
  • Loading branch information
Diptorup Deb authored Feb 3, 2024
2 parents b494492 + 7d8c135 commit df32f71
Show file tree
Hide file tree
Showing 18 changed files with 94 additions and 76 deletions.
3 changes: 2 additions & 1 deletion numba_dpex/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from numba_dpex.core.kernel_interface.launcher import call_kernel

from ._kernel_api_impl.spirv import target as spirv_kernel_target
from .numba_patches import patch_arrayexpr_tree_to_ir, patch_is_ufunc


Expand Down Expand Up @@ -107,7 +108,7 @@ def parse_sem_version(version_string: str) -> Tuple[int, int, int]:
# backward compatibility
from numba_dpex.kernel_api import NdRange, Range # noqa E402

from .core.targets import dpjit_target, kernel_target # noqa E402
from .core.targets import dpjit_target # noqa E402
from .decorators import dpjit, func, kernel # noqa E402
from .ocl.stubs import ( # noqa E402
GLOBAL_MEM_FENCE,
Expand Down
7 changes: 7 additions & 0 deletions numba_dpex/_kernel_api_impl/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# SPDX-FileCopyrightText: 2023 - 2024 Intel Corporation
#
# SPDX-License-Identifier: Apache-2.0

"""The module stores the numba_dpex backends implementing the target-specific
code generation for the kernel_api Python functions.
"""
6 changes: 6 additions & 0 deletions numba_dpex/_kernel_api_impl/spirv/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# SPDX-FileCopyrightText: 2023 - 2024 Intel Corporation
#
# SPDX-License-Identifier: Apache-2.0

"""A SPIR-V backend to compile the numba_dpex.kernel_api functions to SPIR-V.
"""
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -26,29 +26,31 @@
from numba.core.typing.typeof import Purpose, typeof

from numba_dpex import config, numba_sem_version, spirv_generator
from numba_dpex.core.codegen import SPIRVCodeLibrary
from numba_dpex._kernel_api_impl.spirv.codegen import SPIRVCodeLibrary
from numba_dpex._kernel_api_impl.spirv.target import (
CompilationMode,
SPIRVTargetContext,
)
from numba_dpex.core.exceptions import (
ExecutionQueueInferenceError,
InvalidKernelSpecializationError,
KernelHasReturnValueError,
UnsupportedKernelArgumentError,
)
from numba_dpex.core.pipelines import kernel_compiler
from numba_dpex.core.targets.kernel_target import (
CompilationMode,
DpexKernelTargetContext,
)
from numba_dpex.core.types import USMNdArray
from numba_dpex.core.utils import kernel_launcher as kl
from numba_dpex.experimental.target import (
DPEX_KERNEL_EXP_TARGET_NAME,
dpex_exp_kernel_target,
)

from .target import DPEX_KERNEL_EXP_TARGET_NAME, dpex_exp_kernel_target

_KernelCompileResult = namedtuple(
_SPIRVKernelCompileResult = namedtuple(
"_KernelCompileResult", CompileResult._fields + ("kernel_device_ir_module",)
)


class _KernelCompiler(_FunctionCompiler):
class _SPIRVKernelCompiler(_FunctionCompiler):
"""A special compiler class used to compile numba_dpex.kernel decorated
functions.
"""
Expand Down Expand Up @@ -155,7 +157,7 @@ def _compile_to_spirv(
self,
kernel_library: SPIRVCodeLibrary,
kernel_fndesc: PythonFunctionDescriptor,
kernel_targetctx: DpexKernelTargetContext,
kernel_targetctx: SPIRVTargetContext,
):
kernel_func: ValueRef = kernel_library.get_function(
kernel_fndesc.llvm_func_name
Expand Down Expand Up @@ -188,7 +190,7 @@ def _compile_to_spirv(
kernel_name=kernel_fn.name, kernel_bitcode=kernel_spirv_module
)

def compile(self, args, return_type) -> _KernelCompileResult:
def compile(self, args, return_type) -> _SPIRVKernelCompileResult:
status, kcres = self._compile_cached(args, return_type)
if status:
return kcres
Expand All @@ -197,7 +199,7 @@ def compile(self, args, return_type) -> _KernelCompileResult:

def _compile_cached(
self, args, return_type: types.Type
) -> Tuple[bool, _KernelCompileResult]:
) -> Tuple[bool, _SPIRVKernelCompileResult]:
"""Compiles the kernel function to bitcode and generates a host-callable
wrapper to submit the kernel to a SYCL queue.
Expand Down Expand Up @@ -277,10 +279,10 @@ def _compile_cached(
self._failed_cache[key] = err
return False, err

return True, _KernelCompileResult(*kcres_attrs)
return True, _SPIRVKernelCompileResult(*kcres_attrs)


class KernelDispatcher(Dispatcher):
class SPIRVKernelDispatcher(Dispatcher):
"""Dispatcher class designed to compile kernel decorated functions. The
dispatcher inherits the Numba Dispatcher class, but has a different
compilation strategy. Instead of compiling a kernel decorated function to
Expand Down Expand Up @@ -325,7 +327,7 @@ def __init__(
targetoptions=targetoptions,
pipeline_class=pipeline_class,
)
self._compiler = _KernelCompiler(
self._compiler = _SPIRVKernelCompiler(
pyfunc,
self.targetdescr,
targetoptions,
Expand Down Expand Up @@ -426,8 +428,8 @@ def cb_llvm(dur):
},
):
try:
compiler: _KernelCompiler = self._compiler
kcres: _KernelCompileResult = compiler.compile(
compiler: _SPIRVKernelCompiler = self._compiler
kcres: _SPIRVKernelCompileResult = compiler.compile(
args, return_type
)
except errors.ForceLiteralArg as err:
Expand Down Expand Up @@ -463,4 +465,4 @@ def __call__(self, *args, **kw_args):


_dpex_target = target_registry[DPEX_KERNEL_EXP_TARGET_NAME]
dispatcher_registry[_dpex_target] = KernelDispatcher
dispatcher_registry[_dpex_target] = SPIRVKernelDispatcher
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from numba_dpex.core.utils import get_info_from_suai
from numba_dpex.utils import address_space, calling_conv

from .. import codegen
from . import codegen

CC_SPIR_KERNEL = "spir_kernel"
CC_SPIR_FUNC = "spir_func"
Expand Down Expand Up @@ -52,7 +52,7 @@ class CompilationMode(IntEnum):
DEVICE_FUNC = 2


class DpexKernelTypingContext(typing.BaseContext):
class SPIRVTypingContext(typing.BaseContext):
"""Custom typing context to support kernel compilation.
The customized typing context provides two features required to compile
Expand Down Expand Up @@ -124,20 +124,20 @@ def load_additional_registries(self):
self.install_registry(enumdecl.registry)


class SyclDevice(GPU):
"""Mark the hardware target as SYCL Device."""
class SPIRVDevice(GPU):
"""Mark the hardware target as device that supports SPIR-V bitcode."""

pass


DPEX_KERNEL_TARGET_NAME = "dpex_kernel"
SPIRV_TARGET_NAME = "spirv"

target_registry[DPEX_KERNEL_TARGET_NAME] = SyclDevice
target_registry[SPIRV_TARGET_NAME] = SPIRVDevice


class DpexKernelTargetContext(BaseContext):
class SPIRVTargetContext(BaseContext):
"""A target context inheriting Numba's ``BaseContext`` that is customized
for generating SYCL kernels.
for generating SPIR-V kernels.
A customized target context for generating SPIR-V kernels. The class defines
helper functions to generates SPIR-V kernels as LLVM IR using the required
Expand Down Expand Up @@ -243,7 +243,7 @@ def _generate_spir_kernel_wrapper(self, func, argtypes):
module.get_function(func.name).linkage = "internal"
return wrapper

def __init__(self, typingctx, target=DPEX_KERNEL_TARGET_NAME):
def __init__(self, typingctx, target=SPIRV_TARGET_NAME):
super().__init__(typingctx, target)

def init(self):
Expand Down Expand Up @@ -338,7 +338,7 @@ def load_additional_registries(self):

@cached_property
def call_conv(self):
return DpexCallConv(self)
return SPIRVCallConv(self)

def codegen(self):
return self._internal_codegen
Expand Down Expand Up @@ -385,9 +385,7 @@ def declare_function(self, module, fndesc):
)
if not self.enable_debuginfo:
fn.attributes.add("alwaysinline")
ret = super(DpexKernelTargetContext, self).declare_function(
module, fndesc
)
ret = super(SPIRVTargetContext, self).declare_function(module, fndesc)
ret.calling_convention = calling_conv.CC_SPIR_FUNC
return ret

Expand Down Expand Up @@ -444,12 +442,12 @@ def populate_array(self, arr, **kwargs):
return arrayobj.populate_array(arr, **kwargs)


class DpexCallConv(MinimalCallConv):
class SPIRVCallConv(MinimalCallConv):
"""Custom calling convention class used by numba-dpex.
numba_dpex's calling convention derives from
:class:`numba.core.callconv import MinimalCallConv`. The
:class:`DpexCallConv` overrides :func:`call_function`.
:class:`SPIRVCallConv` overrides :func:`call_function`.
"""

Expand Down
18 changes: 9 additions & 9 deletions numba_dpex/core/descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,19 @@
from numba.core.cpu import CPUTargetOptions
from numba.core.descriptors import TargetDescriptor

from numba_dpex._kernel_api_impl.spirv.target import (
SPIRV_TARGET_NAME,
CompilationMode,
SPIRVTargetContext,
SPIRVTypingContext,
)
from numba_dpex.core import config

from .targets.dpjit_target import (
DPEX_TARGET_NAME,
DpexTargetContext,
DpexTypingContext,
)
from .targets.kernel_target import (
DPEX_KERNEL_TARGET_NAME,
CompilationMode,
DpexKernelTargetContext,
DpexKernelTypingContext,
)

_option_mapping = options._mapping

Expand Down Expand Up @@ -77,12 +77,12 @@ class DpexKernelTarget(TargetDescriptor):
@cached_property
def _toplevel_target_context(self):
"""Lazily-initialized top-level target context, for all threads."""
return DpexKernelTargetContext(self.typing_context, self._target_name)
return SPIRVTargetContext(self.typing_context, self._target_name)

@cached_property
def _toplevel_typing_context(self):
"""Lazily-initialized top-level typing context, for all threads."""
return DpexKernelTypingContext()
return SPIRVTypingContext()

@property
def target_context(self):
Expand Down Expand Up @@ -132,7 +132,7 @@ def typing_context(self):


# A global instance of the DpexKernelTarget
dpex_kernel_target = DpexKernelTarget(DPEX_KERNEL_TARGET_NAME)
dpex_kernel_target = DpexKernelTarget(SPIRV_TARGET_NAME)

# A global instance of the DpexTarget
dpex_target = DpexTarget(DPEX_TARGET_NAME)
4 changes: 2 additions & 2 deletions numba_dpex/core/kernel_interface/spirv_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
from numba.core import ir

from numba_dpex import spirv_generator
from numba_dpex._kernel_api_impl.spirv.target import SPIRVTargetContext
from numba_dpex.core import config
from numba_dpex.core.compiler import compile_with_dpex
from numba_dpex.core.exceptions import UncompiledKernelError, UnreachableError
from numba_dpex.core.targets.kernel_target import DpexKernelTargetContext

from .kernel_base import KernelInterface

Expand Down Expand Up @@ -135,7 +135,7 @@ def compile(
)

func = cres.library.get_function(cres.fndesc.llvm_func_name)
kernel_targetctx: DpexKernelTargetContext = cres.target_context
kernel_targetctx: SPIRVTargetContext = cres.target_context
kernel = kernel_targetctx.prepare_spir_kernel(func, cres.signature.args)

# XXX: Setting the inline_threshold in the following way is a temporary
Expand Down
4 changes: 2 additions & 2 deletions numba_dpex/dpnp_iface/arrayobj.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@
from numba.np.arrayobj import make_array
from numba.np.numpy_support import is_nonelike

from numba_dpex._kernel_api_impl.spirv.target import SPIRVTargetContext
from numba_dpex.core.kernel_interface.arrayobj import (
_getitem_array_generic as kernel_getitem_array_generic,
)
from numba_dpex.core.targets.kernel_target import DpexKernelTargetContext
from numba_dpex.core.types import DpnpNdArray

from ._intrinsic import (
Expand Down Expand Up @@ -1082,7 +1082,7 @@ def getitem_arraynd_intp(context, builder, sig, args):
that when returning a view of a dpnp.ndarray the sycl::queue pointer
member in the LLVM IR struct gets properly updated.
"""
getitem_call_in_kernel = isinstance(context, DpexKernelTargetContext)
getitem_call_in_kernel = isinstance(context, SPIRVTargetContext)
_getitem_array_generic = np_getitem_array_generic

if getitem_call_in_kernel:
Expand Down
5 changes: 3 additions & 2 deletions numba_dpex/experimental/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

from numba.core.imputils import Registry

from numba_dpex._kernel_api_impl.spirv.dispatcher import SPIRVKernelDispatcher

# Temporary so that Range and NdRange work in experimental call_kernel
from numba_dpex.core.boxing import *

Expand All @@ -17,7 +19,6 @@
_index_space_id_overloads,
)
from .decorators import device_func, kernel
from .kernel_dispatcher import KernelDispatcher
from .launcher import call_kernel, call_kernel_async
from .literal_intenum_type import IntEnumLiteral
from .models import *
Expand All @@ -41,5 +42,5 @@ def dpex_dispatcher_const(context):
"call_kernel",
"call_kernel_async",
"IntEnumLiteral",
"KernelDispatcher",
"SPIRVKernelDispatcher",
]
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,11 @@
from numba.core import cgutils, types
from numba.extending import intrinsic, overload, overload_method

from numba_dpex._kernel_api_impl.spirv.target import (
CC_SPIR_FUNC,
LLVM_SPIRV_ARGS,
)
from numba_dpex.core import itanium_mangler as ext_itanium_mangler
from numba_dpex.core.targets.kernel_target import CC_SPIR_FUNC, LLVM_SPIRV_ARGS
from numba_dpex.core.types import USMNdArray
from numba_dpex.kernel_api import (
AddressSpace,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
from llvmlite import ir as llvmir
from numba.core import cgutils, types

from numba_dpex._kernel_api_impl.spirv.target import CC_SPIR_FUNC
from numba_dpex.core import itanium_mangler as ext_itanium_mangler
from numba_dpex.core.targets.kernel_target import CC_SPIR_FUNC


def get_or_insert_atomic_load_fn(context, module, atomic_ref_ty):
Expand Down
8 changes: 4 additions & 4 deletions numba_dpex/experimental/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
target_registry,
)

from numba_dpex.core.targets.kernel_target import CompilationMode
from numba_dpex.experimental.kernel_dispatcher import KernelDispatcher
from numba_dpex._kernel_api_impl.spirv.dispatcher import SPIRVKernelDispatcher
from numba_dpex._kernel_api_impl.spirv.target import CompilationMode

from .target import DPEX_KERNEL_EXP_TARGET_NAME

Expand Down Expand Up @@ -74,11 +74,11 @@ def kernel(func_or_sig=None, **options):
for sig in sigs:
if isinstance(sig, str):
raise NotImplementedError(
"Specifying signatures as string is not yet supported by numba-dpex"
"Specifying signatures as string is not yet supported"
)

def _kernel_dispatcher(pyfunc):
disp: KernelDispatcher = dispatcher(
disp: SPIRVKernelDispatcher = dispatcher(
pyfunc=pyfunc,
targetoptions=options,
)
Expand Down
Loading

0 comments on commit df32f71

Please sign in to comment.