Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Register the experimental kernel target as a fully standalone Numba hardware target #1225

Merged
merged 4 commits into from
Nov 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions numba_dpex/experimental/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,13 @@
import inspect

from numba.core import sigutils
from numba.core.target_extension import jit_registry, target_registry
from numba.core.target_extension import (
jit_registry,
resolve_dispatcher_from_str,
target_registry,
)

from .kernel_dispatcher import KernelDispatcher
from .target import DPEX_KERNEL_EXP_TARGET_NAME


def kernel(func_or_sig=None, **options):
Expand All @@ -24,11 +28,14 @@ def kernel(func_or_sig=None, **options):
* All array arguments passed to a kernel should adhere to compute
follows data programming model.
"""

dispatcher = resolve_dispatcher_from_str(DPEX_KERNEL_EXP_TARGET_NAME)

# FIXME: The options need to be evaluated and checked here like it is
# done in numba.core.decorators.jit

def _kernel_dispatcher(pyfunc):
return KernelDispatcher(
return dispatcher(
pyfunc=pyfunc,
targetoptions=options,
)
Expand Down Expand Up @@ -59,9 +66,7 @@ def _kernel_dispatcher(pyfunc):
func_or_sig = [func_or_sig]

def _specialized_kernel_dispatcher(pyfunc):
return KernelDispatcher(
pyfunc=pyfunc,
)
return dispatcher(pyfunc=pyfunc)

return _specialized_kernel_dispatcher
func = func_or_sig
Expand All @@ -75,4 +80,4 @@ def _specialized_kernel_dispatcher(pyfunc):
return _kernel_dispatcher(func)


jit_registry[target_registry["dpex_kernel"]] = kernel
jit_registry[target_registry[DPEX_KERNEL_EXP_TARGET_NAME]] = kernel
68 changes: 45 additions & 23 deletions numba_dpex/experimental/kernel_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
"""
from collections import namedtuple
from contextlib import ExitStack
from typing import Tuple

import numba.core.event as ev
from numba.core import errors, sigutils, types
Expand All @@ -24,13 +25,12 @@
from numba_dpex.core.pipelines import kernel_compiler
from numba_dpex.core.types import DpnpNdArray

from .target import dpex_exp_kernel_target
from .target import DPEX_KERNEL_EXP_TARGET_NAME, dpex_exp_kernel_target

_KernelModule = namedtuple("_KernelModule", ["kernel_name", "kernel_bitcode"])

_KernelCompileResult = namedtuple(
"_KernelCompileResult",
["status", "cres_or_error", "entry_point"],
"_KernelCompileResult", CompileResult._fields + ("kernel_device_ir_module",)
)


Expand Down Expand Up @@ -96,15 +96,15 @@ def _compile_to_spirv(
)

def compile(self, args, return_type):
kcres = self._compile_cached(args, return_type)
if kcres.status:
status, kcres = self._compile_cached(args, return_type)
if status:
return kcres

raise kcres.cres_or_error
raise kcres

def _compile_cached(
self, args, return_type: types.Type
) -> _KernelCompileResult:
) -> Tuple[bool, _KernelCompileResult]:
"""Compiles the kernel function to bitcode and generates a host-callable
wrapper to submit the kernel to a SYCL queue.

Expand Down Expand Up @@ -137,34 +137,45 @@ def _compile_cached(
"""
key = tuple(args), return_type
try:
return _KernelCompileResult(False, self._failed_cache[key], None)
return False, self._failed_cache[key]
except KeyError:
pass

try:
kernel_cres: CompileResult = self._compile_core(args, return_type)
cres: CompileResult = self._compile_core(args, return_type)

kernel_library = kernel_cres.library
kernel_fndesc = kernel_cres.fndesc
kernel_targetctx = kernel_cres.target_context

kernel_module = self._compile_to_spirv(
kernel_library, kernel_fndesc, kernel_targetctx
kernel_device_ir_module = self._compile_to_spirv(
cres.library, cres.fndesc, cres.target_context
)

kcres_attrs = []

for cres_field in cres._fields:
cres_attr = getattr(cres, cres_field)
if cres_field == "entry_point":
if cres_attr is not None:
raise AssertionError(
"Compiled kernel and device_func should be "
"compiled with compile_cfunc option turned off"
)
cres_attr = cres.fndesc.qualname
kcres_attrs.append(cres_attr)

kcres_attrs.append(kernel_device_ir_module)

if config.DUMP_KERNEL_LLVM:
with open(
kernel_cres.fndesc.llvm_func_name + ".ll",
cres.fndesc.llvm_func_name + ".ll",
"w",
encoding="UTF-8",
) as f:
f.write(kernel_cres.library.final_module)
f.write(cres.library.final_module)

except errors.TypingError as e:
self._failed_cache[key] = e
return _KernelCompileResult(False, e, None)
return False, e

return _KernelCompileResult(True, kernel_cres, kernel_module)
return True, _KernelCompileResult(*kcres_attrs)
diptorupd marked this conversation as resolved.
Show resolved Hide resolved


class KernelDispatcher(Dispatcher):
Expand Down Expand Up @@ -234,7 +245,14 @@ def typeof_pyval(self, val):

def add_overload(self, cres):
args = tuple(cres.signature.args)
self.overloads[args] = cres.entry_point
self.overloads[args] = cres

def get_overload_device_ir(self, sig):
"""
Return the compiled device bitcode for the given signature.
"""
args, _ = sigutils.normalize_signature(sig)
return self.overloads[tuple(args)].kernel_device_ir_module

def compile(self, sig) -> _KernelCompileResult:
disp = self._get_dispatcher_for_current_target()
Expand Down Expand Up @@ -274,7 +292,7 @@ def cb_llvm(dur):
# Don't recompile if signature already exists
existing = self.overloads.get(tuple(args))
if existing is not None:
return existing
return existing.entry_point

# TODO: Enable caching
# Add code to enable on disk caching of a binary spirv kernel.
Expand All @@ -298,7 +316,11 @@ def folded(args, kws):
)[1]

raise e.bind_fold_arguments(folded)
self.add_overload(kcres.cres_or_error)
self.add_overload(kcres)

kcres.target_context.insert_user_function(
kcres.entry_point, kcres.fndesc, [kcres.library]
)

# TODO: enable caching of kernel_module
# https://github.com/IntelPython/numba-dpex/issues/1197
Expand All @@ -318,5 +340,5 @@ def __call__(self, *args, **kw_args):
raise NotImplementedError


_dpex_target = target_registry["dpex_kernel"]
_dpex_target = target_registry[DPEX_KERNEL_EXP_TARGET_NAME]
dispatcher_registry[_dpex_target] = KernelDispatcher
28 changes: 18 additions & 10 deletions numba_dpex/experimental/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from llvmlite import ir as llvmir
from numba.core import cgutils, cpu, types
from numba.core.datamodel import default_manager as numba_default_dmm
from numba.extending import intrinsic, overload

from numba_dpex import config, dpjit
Expand Down Expand Up @@ -192,23 +193,27 @@ def create_llvm_values_for_index_space(
ndim = indexer_argty.ndim
grange_extents = []
lrange_extents = []
datamodel = self._kernel_targetctx.data_model_manager.lookup(
indexer_argty
)
indexer_datamodel = numba_default_dmm.lookup(indexer_argty)

if isinstance(indexer_argty, RangeType):
for dim_num in range(ndim):
dim_pos = datamodel.get_field_position("dim" + str(dim_num))
dim_pos = indexer_datamodel.get_field_position(
"dim" + str(dim_num)
)
grange_extents.append(
self._builder.extract_value(index_space_arg, dim_pos)
)
elif isinstance(indexer_argty, NdRangeType):
for dim_num in range(ndim):
gdim_pos = datamodel.get_field_position("gdim" + str(dim_num))
gdim_pos = indexer_datamodel.get_field_position(
"gdim" + str(dim_num)
)
grange_extents.append(
self._builder.extract_value(index_space_arg, gdim_pos)
)
ldim_pos = datamodel.get_field_position("ldim" + str(dim_num))
ldim_pos = indexer_datamodel.get_field_position(
"ldim" + str(dim_num)
)
lrange_extents.append(
self._builder.extract_value(index_space_arg, ldim_pos)
)
Expand Down Expand Up @@ -308,7 +313,10 @@ def intrin_launch_trampoline(
sig = types.void(kernel_fn, index_space, kernel_args)
# signature of the kernel_fn
kernel_sig = types.void(*kernel_args_list)
kmodule: _KernelModule = kernel_fn.dispatcher.compile(kernel_sig)
kernel_fn.dispatcher.compile(kernel_sig)
kernel_module: _KernelModule = kernel_fn.dispatcher.get_overload_device_ir(
kernel_sig
)
kernel_targetctx = kernel_fn.dispatcher.targetctx

def codegen(cgctx, builder, sig, llargs):
Expand All @@ -324,7 +332,7 @@ def codegen(cgctx, builder, sig, llargs):
)

kernel_bc_byte_str = fn_body_gen.insert_kernel_bitcode_as_byte_str(
kmodule
kernel_module
)

populated_kernel_args = (
Expand All @@ -341,10 +349,10 @@ def codegen(cgctx, builder, sig, llargs):
kbref = fn_body_gen.create_kernel_bundle_from_spirv(
queue_ref=qref,
kernel_bc=kernel_bc_byte_str,
kernel_bc_size_in_bytes=len(kmodule.kernel_bitcode),
kernel_bc_size_in_bytes=len(kernel_module.kernel_bitcode),
)

kref = fn_body_gen.get_kernel(kmodule, kbref)
kref = fn_body_gen.get_kernel(kernel_module, kbref)

index_space_values = fn_body_gen.create_llvm_values_for_index_space(
indexer_argty=sig.args[1],
Expand Down
31 changes: 25 additions & 6 deletions numba_dpex/experimental/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,38 @@
#
# SPDX-License-Identifier: Apache-2.0

"""Provides Numba datamodel for the numba_dpex types introduced in the
"""Provides the Numba data models for the numba_dpex types introduced in the
numba_dpex.experimental module.
"""

from numba.core.datamodel import models
from numba.core.datamodel import DataModelManager, models
from numba.core.extending import register_model

from numba_dpex.core.datamodel.models import dpex_data_model_manager as dmm
import numba_dpex.core.datamodel.models as dpex_core_models

from .types import KernelDispatcherType

# Register the types and datamodel in the DpexKernelTargetContext
dmm.register(KernelDispatcherType, models.OpaqueModel)

# Register the types and datamodel in the DpexTargetContext
def _init_exp_data_model_manager() -> DataModelManager:
"""Initializes a DpexExpKernelTarget-specific data model manager.

Extends the DpexKernelTargetContext's datamodel manager with all
experimental types that are getting added to the kernel API.

Returns:
DataModelManager: A numba-dpex DpexExpKernelTarget-specific data model
manager
"""

dmm = dpex_core_models.dpex_data_model_manager.copy()

# Register the types and data model in the DpexExpTargetContext
# Add here...

return dmm


exp_dmm = _init_exp_data_model_manager()

# Register any new type that should go into numba.core.datamodel.default_manager
register_model(KernelDispatcherType)(models.OpaqueModel)
diptorupd marked this conversation as resolved.
Show resolved Hide resolved
19 changes: 17 additions & 2 deletions numba_dpex/experimental/target.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,24 @@
from functools import cached_property

from numba.core.descriptors import TargetDescriptor
from numba.core.target_extension import GPU, target_registry

from numba_dpex.core.descriptor import DpexTargetOptions
from numba_dpex.core.targets.kernel_target import (
DPEX_KERNEL_TARGET_NAME,
DpexKernelTargetContext,
DpexKernelTypingContext,
)
from numba_dpex.experimental.models import exp_dmm


# pylint: disable=R0903
class SyclDeviceExp(GPU):
"""Mark the hardware target as SYCL Device."""


DPEX_KERNEL_EXP_TARGET_NAME = "dpex_kernel_exp"

target_registry[DPEX_KERNEL_EXP_TARGET_NAME] = SyclDeviceExp


class DpexExpKernelTypingContext(DpexKernelTypingContext):
Expand All @@ -41,6 +52,10 @@ class DpexExpKernelTargetContext(DpexKernelTargetContext):
they are stable enough to be migrated to DpexKernelTargetContext.
"""

def __init__(self, typingctx, target=DPEX_KERNEL_EXP_TARGET_NAME):
super().__init__(typingctx, target)
self.data_model_manager = exp_dmm


class DpexExpKernelTarget(TargetDescriptor):
"""
Expand Down Expand Up @@ -77,4 +92,4 @@ def typing_context(self):


# A global instance of the DpexKernelTarget with the experimental features
dpex_exp_kernel_target = DpexExpKernelTarget(DPEX_KERNEL_TARGET_NAME)
dpex_exp_kernel_target = DpexExpKernelTarget(DPEX_KERNEL_EXP_TARGET_NAME)
Loading