Skip to content

Commit

Permalink
Enable overload caching in kernel_dispatcher.
Browse files Browse the repository at this point in the history
    - Use function qualname as the entry point (key) for
      caching compiled function in the target's _defns dict.
    - KernelCompileResult now extends numba's CompileResult class
      and only stores the spirv bitcode as an extra field.
    - Added a new function to return just the device ir for a
      cached overload.
    - Updates to launcher to accomodate changes to kernel_dispatcher.
  • Loading branch information
Diptorup Deb committed Nov 21, 2023
1 parent 185a61b commit 2ab2608
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 24 deletions.
61 changes: 41 additions & 20 deletions numba_dpex/experimental/kernel_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@
_KernelModule = namedtuple("_KernelModule", ["kernel_name", "kernel_bitcode"])

_KernelCompileResult = namedtuple(
"_KernelCompileResult",
["status", "cres_or_error", "entry_point"],
"_KernelCompileResult", CompileResult._fields + ("kernel_device_ir_module",)
)


Expand Down Expand Up @@ -96,11 +95,11 @@ def _compile_to_spirv(
)

def compile(self, args, return_type):
kcres = self._compile_cached(args, return_type)
if kcres.status:
status, kcres = self._compile_cached(args, return_type)
if status:
return kcres

raise kcres.cres_or_error
raise kcres

def _compile_cached(
self, args, return_type: types.Type
Expand Down Expand Up @@ -137,34 +136,45 @@ def _compile_cached(
"""
key = tuple(args), return_type
try:
return _KernelCompileResult(False, self._failed_cache[key], None)
return False, self._failed_cache[key]
except KeyError:
pass

try:
kernel_cres: CompileResult = self._compile_core(args, return_type)
cres: CompileResult = self._compile_core(args, return_type)

kernel_library = kernel_cres.library
kernel_fndesc = kernel_cres.fndesc
kernel_targetctx = kernel_cres.target_context

kernel_module = self._compile_to_spirv(
kernel_library, kernel_fndesc, kernel_targetctx
kernel_device_ir_module = self._compile_to_spirv(
cres.library, cres.fndesc, cres.target_context
)

kcres_attrs = []

for cres_field in cres._fields:
cres_attr = getattr(cres, cres_field)
if cres_field == "entry_point":
if cres_attr is not None:
raise AssertionError(
"Compiled kernel and device_func should be "
"compiled with compile_cfunc option turned off"
)
cres_attr = cres.fndesc.qualname
kcres_attrs.append(cres_attr)

kcres_attrs.append(kernel_device_ir_module)

if config.DUMP_KERNEL_LLVM:
with open(
kernel_cres.fndesc.llvm_func_name + ".ll",
cres.fndesc.llvm_func_name + ".ll",
"w",
encoding="UTF-8",
) as f:
f.write(kernel_cres.library.final_module)
f.write(cres.library.final_module)

except errors.TypingError as e:
self._failed_cache[key] = e
return _KernelCompileResult(False, e, None)
return False, e

return _KernelCompileResult(True, kernel_cres, kernel_module)
return True, _KernelCompileResult(*kcres_attrs)


class KernelDispatcher(Dispatcher):
Expand Down Expand Up @@ -234,7 +244,14 @@ def typeof_pyval(self, val):

def add_overload(self, cres):
args = tuple(cres.signature.args)
self.overloads[args] = cres.entry_point
self.overloads[args] = cres

def get_overload_device_ir(self, sig):
"""
Return the compiled device bitcode for the given signature.
"""
args, _ = sigutils.normalize_signature(sig)
return self.overloads[tuple(args)].kernel_device_ir_module

def compile(self, sig) -> _KernelCompileResult:
disp = self._get_dispatcher_for_current_target()
Expand Down Expand Up @@ -274,7 +291,7 @@ def cb_llvm(dur):
# Don't recompile if signature already exists
existing = self.overloads.get(tuple(args))
if existing is not None:
return existing
return existing.entry_point

# TODO: Enable caching
# Add code to enable on disk caching of a binary spirv kernel.
Expand All @@ -298,7 +315,11 @@ def folded(args, kws):
)[1]

raise e.bind_fold_arguments(folded)
self.add_overload(kcres.cres_or_error)
self.add_overload(kcres)

kcres.target_context.insert_user_function(
kcres.entry_point, kcres.fndesc, [kcres.library]
)

# TODO: enable caching of kernel_module
# https://github.com/IntelPython/numba-dpex/issues/1197
Expand Down
11 changes: 7 additions & 4 deletions numba_dpex/experimental/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,10 @@ def intrin_launch_trampoline(
sig = types.void(kernel_fn, index_space, kernel_args)
# signature of the kernel_fn
kernel_sig = types.void(*kernel_args_list)
kmodule: _KernelModule = kernel_fn.dispatcher.compile(kernel_sig)
kernel_fn.dispatcher.compile(kernel_sig)
kernel_module: _KernelModule = kernel_fn.dispatcher.get_overload_device_ir(
kernel_sig
)
kernel_targetctx = kernel_fn.dispatcher.targetctx

def codegen(cgctx, builder, sig, llargs):
Expand All @@ -329,7 +332,7 @@ def codegen(cgctx, builder, sig, llargs):
)

kernel_bc_byte_str = fn_body_gen.insert_kernel_bitcode_as_byte_str(
kmodule
kernel_module
)

populated_kernel_args = (
Expand All @@ -346,10 +349,10 @@ def codegen(cgctx, builder, sig, llargs):
kbref = fn_body_gen.create_kernel_bundle_from_spirv(
queue_ref=qref,
kernel_bc=kernel_bc_byte_str,
kernel_bc_size_in_bytes=len(kmodule.kernel_bitcode),
kernel_bc_size_in_bytes=len(kernel_module.kernel_bitcode),
)

kref = fn_body_gen.get_kernel(kmodule, kbref)
kref = fn_body_gen.get_kernel(kernel_module, kbref)

index_space_values = fn_body_gen.create_llvm_values_for_index_space(
indexer_argty=sig.args[1],
Expand Down

0 comments on commit 2ab2608

Please sign in to comment.