From d25b525319e045348d85c8329e0eebab4a366fbd Mon Sep 17 00:00:00 2001 From: "akmkhale@ansatnuc04" Date: Wed, 11 Jan 2023 17:04:18 -0600 Subject: [PATCH] Adding cache and specialization to func decorator --- notes.md | 12 ++ numba_dpex/core/caching.py | 28 +-- .../core/kernel_interface/dispatcher.py | 8 +- numba_dpex/core/kernel_interface/func.py | 178 +++++++++++------- numba_dpex/decorators.py | 72 +++---- test-func-specialize.py | 85 +++++++++ 6 files changed, 272 insertions(+), 111 deletions(-) create mode 100644 notes.md create mode 100644 test-func-specialize.py diff --git a/notes.md b/notes.md new file mode 100644 index 0000000000..6f5d6d6aef --- /dev/null +++ b/notes.md @@ -0,0 +1,12 @@ +1. Specialized function gets compiled when the file is loaded +2. Dig into: + cres.typing_context.insert_user_function(devfn, _function_template) + libs = [cres.library] + cres.target_context.insert_user_function(devfn, cres.fndesc, libs) +3. why the numba code is downcasting? Look into the lowering logic + %".87" = fptosi double %".86" to i32 + store i32 0, i32* %".88" + %".91" = call spir_func i32 @"dpex_py_devfn__5F__5F_main_5F__5F__2E_a_5F_device_5F_function_2E_int32"(i32* %".88", i32 %".87") +4. This is what should happen: if a func is specialized, dpex should throw an error rather than doing an automatic typecasting. +5. look into numba's lowering and intercept the typecasting -- the goal is to stop lowering when type doesn't match in specialization case +6. Most likely 867 will be resolved if these problems are addressed. diff --git a/numba_dpex/core/caching.py b/numba_dpex/core/caching.py index 1c46485c6c..94b8afe111 100644 --- a/numba_dpex/core/caching.py +++ b/numba_dpex/core/caching.py @@ -224,7 +224,7 @@ class LRUCache(AbstractCache): with a dictionary as a lookup table. """ - def __init__(self, capacity=10, pyfunc=None): + def __init__(self, name="cache", capacity=10, pyfunc=None): """Constructor for LRUCache. Args: @@ -233,6 +233,7 @@ def __init__(self, capacity=10, pyfunc=None): pyfunc (NoneType, optional): A python function to be cached. Defaults to None. """ + self._name = name self._capacity = capacity self._lookup = {} self._evicted = {} @@ -432,8 +433,8 @@ def get(self, key): value = self._cache_file.load(key) if config.DEBUG_CACHE: print( - "[cache]: unpickled an evicted artifact, " - "key: {0:s}.".format(str(key)) + "[{0:s}]: unpickled an evicted artifact, " + "key: {1:s}.".format(self._name, str(key)) ) else: value = self._evicted[key] @@ -442,8 +443,8 @@ def get(self, key): else: if config.DEBUG_CACHE: print( - "[cache] size: {0:d}, loading artifact, key: {1:s}".format( - len(self._lookup), str(key) + "[{0:s}] size: {1:d}, loading artifact, key: {2:s}".format( + self._name, len(self._lookup), str(key) ) ) node = self._lookup[key] @@ -464,8 +465,8 @@ def put(self, key, value): if key in self._lookup: if config.DEBUG_CACHE: print( - "[cache] size: {0:d}, storing artifact, key: {1:s}".format( - len(self._lookup), str(key) + "[{0:s}] size: {1:d}, storing artifact, key: {2:s}".format( + self._name, len(self._lookup), str(key) ) ) self._lookup[key].value = value @@ -480,8 +481,9 @@ def put(self, key, value): if self._cache_file: if config.DEBUG_CACHE: print( - "[cache] size: {0:d}, pickling the LRU item, " - "key: {1:s}, indexed at {2:s}.".format( + "[{0:s}] size: {1:d}, pickling the LRU item, " + "key: {2:s}, indexed at {3:s}.".format( + self._name, len(self._lookup), str(self._head.key), self._cache_file._index_path, @@ -496,8 +498,8 @@ def put(self, key, value): self._lookup.pop(self._head.key) if config.DEBUG_CACHE: print( - "[cache] size: {0:d}, capacity exceeded, evicted".format( - len(self._lookup) + "[{0:s}] size: {1:d}, capacity exceeded, evicted".format( + self._name, len(self._lookup) ), self._head.key, ) @@ -509,7 +511,7 @@ def put(self, key, value): self._append_tail(new_node) if config.DEBUG_CACHE: print( - "[cache] size: {0:d}, saved artifact, key: {1:s}".format( - len(self._lookup), str(key) + "[{0:s}] size: {1:d}, saved artifact, key: {2:s}".format( + self._name, len(self._lookup), str(key) ) ) diff --git a/numba_dpex/core/kernel_interface/dispatcher.py b/numba_dpex/core/kernel_interface/dispatcher.py index ed602199ca..0b9b527ab3 100644 --- a/numba_dpex/core/kernel_interface/dispatcher.py +++ b/numba_dpex/core/kernel_interface/dispatcher.py @@ -90,7 +90,9 @@ def __init__( self._cache = NullCache() elif enable_cache: self._cache = LRUCache( - capacity=config.CACHE_SIZE, pyfunc=self.pyfunc + name="SPIRVKernelCache", + capacity=config.CACHE_SIZE, + pyfunc=self.pyfunc, ) else: self._cache = NullCache() @@ -119,7 +121,9 @@ def __init__( if specialization_sigs: self._has_specializations = True self._specialization_cache = LRUCache( - capacity=config.CACHE_SIZE, pyfunc=self.pyfunc + name="SPIRVKernelSpecializationCache", + capacity=config.CACHE_SIZE, + pyfunc=self.pyfunc, ) for sig in specialization_sigs: self._specialize(sig) diff --git a/numba_dpex/core/kernel_interface/func.py b/numba_dpex/core/kernel_interface/func.py index 42fe2e19c1..7e088adc34 100644 --- a/numba_dpex/core/kernel_interface/func.py +++ b/numba_dpex/core/kernel_interface/func.py @@ -6,61 +6,65 @@ """ +from numba.core import sigutils, types from numba.core.typing.templates import AbstractTemplate, ConcreteTemplate +from numba_dpex import config +from numba_dpex.core.caching import LRUCache, NullCache, build_key from numba_dpex.core.compiler import compile_with_dpex from numba_dpex.core.descriptor import dpex_target +from numba_dpex.utils import npytypes_array_to_dpex_array -def compile_func(pyfunc, return_type, args, debug=None): - cres = compile_with_dpex( - pyfunc=pyfunc, - pyfunc_name=pyfunc.__name__, - return_type=return_type, - target_context=dpex_target.target_context, - typing_context=dpex_target.typing_context, - args=args, - is_kernel=False, - debug=debug, - ) - func = cres.library.get_function(cres.fndesc.llvm_func_name) - cres.target_context.mark_ocl_device(func) - devfn = DpexFunction(cres) - - class _function_template(ConcreteTemplate): - key = devfn - cases = [cres.signature] - - cres.typing_context.insert_user_function(devfn, _function_template) - libs = [cres.library] - cres.target_context.insert_user_function(devfn, cres.fndesc, libs) - return devfn - - -def compile_func_template(pyfunc, debug=None): - """Compile a DpexFunctionTemplate""" - - dft = DpexFunctionTemplate(pyfunc, debug=debug) - - class _function_template(AbstractTemplate): - key = dft - - def generic(self, args, kws): - if kws: - raise AssertionError("No keyword arguments allowed.") - return dft.compile(args) - - dpex_target.typing_context.insert_user_function(dft, _function_template) - return dft +class DpexFunction(object): + def __init__(self, pyfunc, debug=None): + self._pyfunc = pyfunc + self._debug = debug + + def compile(self, arg_types, return_types): + cres = compile_with_dpex( + pyfunc=self._pyfunc, + pyfunc_name=self._pyfunc.__name__, + return_type=return_types, + target_context=dpex_target.target_context, + typing_context=dpex_target.typing_context, + args=arg_types, + is_kernel=False, + debug=self._debug, + ) + func = cres.library.get_function(cres.fndesc.llvm_func_name) + cres.target_context.mark_ocl_device(func) + + return cres class DpexFunctionTemplate(object): """Unmaterialized dpex function""" - def __init__(self, pyfunc, debug=None): - self.py_func = pyfunc - self.debug = debug - self._compileinfos = {} + def __init__(self, pyfunc, debug=None, enable_cache=True): + self._pyfunc = pyfunc + self._debug = debug + self._enable_cache = enable_cache + + if not config.ENABLE_CACHE: + self._cache = NullCache() + elif self._enable_cache: + self._cache = LRUCache( + name="DpexFunctionTemplateCache", + capacity=config.CACHE_SIZE, + pyfunc=self._pyfunc, + ) + else: + self._cache = NullCache() + self._cache_hits = 0 + + @property + def cache(self): + return self._cache + + @property + def cache_hits(self): + return self._cache_hits def compile(self, args): """Compile a dpex.func decorated Python function with the given @@ -69,37 +73,85 @@ def compile(self, args): Each signature is compiled once by caching the compiled function inside this object. """ - if args not in self._compileinfos: + argtypes = [ + dpex_target.typing_context.resolve_argument_type(arg) + for arg in args + ] + key = build_key( + tuple(argtypes), + self._pyfunc, + dpex_target.target_context.codegen(), + ) + cres = self._cache.get(key) + if cres is None: + self._cache_hits += 1 cres = compile_with_dpex( - pyfunc=self.py_func, - pyfunc_name=self.py_func.__name__, + pyfunc=self._pyfunc, + pyfunc_name=self._pyfunc.__name__, return_type=None, target_context=dpex_target.target_context, typing_context=dpex_target.typing_context, args=args, is_kernel=False, - debug=self.debug, + debug=self._debug, ) func = cres.library.get_function(cres.fndesc.llvm_func_name) cres.target_context.mark_ocl_device(func) - first_definition = not self._compileinfos - self._compileinfos[args] = cres libs = [cres.library] - if first_definition: - # First definition - cres.target_context.insert_user_function( - self, cres.fndesc, libs - ) - else: - cres.target_context.add_user_function(self, cres.fndesc, libs) - - else: - cres = self._compileinfos[args] + cres.target_context.insert_user_function(self, cres.fndesc, libs) + # cres.target_context.add_user_function(self, cres.fndesc, libs) + self._cache.put(key, cres) return cres.signature -class DpexFunction(object): - def __init__(self, cres): - self.cres = cres +def compile_func(pyfunc, signature, debug=None): + devfn = DpexFunction(pyfunc, debug=debug) + + cres = [] + for sig in signature: + arg_types, return_types = sigutils.normalize_signature(sig) + arg_types = tuple( + [ + npytypes_array_to_dpex_array(ty) + if isinstance(ty, types.npytypes.Array) + else ty + for ty in arg_types + ] + ) + c = devfn.compile(arg_types, return_types) + cres.append(c) + + class _function_template(ConcreteTemplate): + unsafe_casting = False + exact_match_required = True + key = devfn + cases = [c.signature for c in cres] + + # TODO: If context is only one copy, just make it global + cres[0].typing_context.insert_user_function(devfn, _function_template) + + for c in cres: + c.target_context.insert_user_function(devfn, c.fndesc, [c.library]) + + return devfn + + +def compile_func_template(pyfunc, debug=None): + """Compile a DpexFunctionTemplate""" + + dft = DpexFunctionTemplate(pyfunc, debug=debug) + + class _function_template(AbstractTemplate): + unsafe_casting = False + exact_match_required = True + key = dft + + def generic(self, args, kws): + if kws: + raise AssertionError("No keyword arguments allowed.") + return dft.compile(args) + + dpex_target.typing_context.insert_user_function(dft, _function_template) + return dft diff --git a/numba_dpex/decorators.py b/numba_dpex/decorators.py index ace5354e97..0abb0c37c5 100644 --- a/numba_dpex/decorators.py +++ b/numba_dpex/decorators.py @@ -4,7 +4,7 @@ import inspect -from numba.core import sigutils, types +from numba.core import sigutils from numba_dpex.core.kernel_interface.dispatcher import ( JitKernel, @@ -14,7 +14,6 @@ compile_func, compile_func_template, ) -from numba_dpex.utils import npytypes_array_to_dpex_array def kernel( @@ -69,6 +68,7 @@ def _kernel_dispatcher(pyfunc, sigs=None): if not isinstance(func_or_sig, list): func_or_sig = [func_or_sig] + # TODO: This function is 99% same as _kernel_dispatcher() def _specialized_kernel_dispatcher(pyfunc): ordered_arg_access_types = get_ordered_arg_access_types( pyfunc, access_types @@ -94,39 +94,45 @@ def _specialized_kernel_dispatcher(pyfunc): return _kernel_dispatcher(func) -def func(signature=None, debug=None): - if signature is None: - return _func_autojit_wrapper(debug=debug) - elif not sigutils.is_signature(signature): - func = signature - return _func_autojit(func, debug=debug) - else: - return _func_jit(signature, debug=debug) - - -def _func_jit(signature, debug=None): - argtypes, restype = sigutils.normalize_signature(signature) - argtypes = tuple( - [ - npytypes_array_to_dpex_array(ty) - if isinstance(ty, types.npytypes.Array) - else ty - for ty in argtypes - ] - ) - - def _wrapped(pyfunc): - return compile_func(pyfunc, restype, argtypes, debug=debug) - - return _wrapped - +def func(func_or_sig=None, debug=None, enable_cache=True): + """ + TODO: make similar to kernel version and when the signature is a list + remove the caching for specialization in func + """ -def _func_autojit_wrapper(debug=None): - def _func_autojit(pyfunc, debug=debug): + def _func_autojit(pyfunc, debug=None): return compile_func_template(pyfunc, debug=debug) - return _func_autojit + def _func_autojit_wrapper(debug=None): + return _func_autojit + + if func_or_sig is None: + return _func_autojit_wrapper(debug=debug) + elif isinstance(func_or_sig, str): + raise NotImplementedError( + "Specifying signatures as string is not yet supported by numba-dpex" + ) + elif isinstance(func_or_sig, list) or sigutils.is_signature(func_or_sig): + # String signatures are not supported as passing usm_ndarray type as + # a string is not possible. Numba's sigutils relies on the type being + # available in Numba's types.__dpct__ and dpex types are not registered + # there yet. + if isinstance(func_or_sig, list): + for sig in func_or_sig: + if isinstance(sig, str): + raise NotImplementedError( + "Specifying signatures as string is not yet supported " + "by numba-dpex" + ) + # Specialized signatures can either be a single signature or a list. + # In case only one signature is provided convert it to a list + if not isinstance(func_or_sig, list): + func_or_sig = [func_or_sig] + def _wrapped(pyfunc): + return compile_func(pyfunc, func_or_sig, debug=debug) -def _func_autojit(pyfunc, debug=None): - return compile_func_template(pyfunc, debug=debug) + return _wrapped + else: + func = func_or_sig + return _func_autojit(func, debug=debug) diff --git a/test-func-specialize.py b/test-func-specialize.py new file mode 100644 index 0000000000..cec4c6ffa0 --- /dev/null +++ b/test-func-specialize.py @@ -0,0 +1,85 @@ +# SPDX-FileCopyrightText: 2020 - 2022 Intel Corporation +# +# SPDX-License-Identifier: Apache-2.0 + +import dpnp as np + +import numba_dpex as ndpex +from numba_dpex import float32, int32, int64 + +# Array size +N = 10 + + +# A device callable function that can be invoked from ``kernel`` and other device functions +@ndpex.func # ([int32(int32), int64(int64)]) +def a_device_function(a): + return a + 1 + + +# A device callable function can call another device function +# @ndpex.func +# def another_device_function(a): +# return a_device_function(a * 2) + + +# A kernel function that calls the device function +@ndpex.kernel(enable_cache=True) +def a_kernel_function(a, b): + i = ndpex.get_global_id(0) + b[i] = a_device_function(a[i]) + + +# @ndpex.kernel(enable_cache=True) +# def another_kernel_function(a, b): +# i = ndpex.get_global_id(0) +# b[i] = a_device_function(a[i]) + + +# Utility function for printing +# def driver(a, b, N): +# print("A=", a) +# a_kernel_function[N, ndpex.DEFAULT_LOCAL_SIZE](a, b) +# print("B=", b) + + +# Main function +# def main(): + +a = np.ones(N, dtype=np.int32) +b = np.ones(N, dtype=np.int32) +a_kernel_function[N, ndpex.DEFAULT_LOCAL_SIZE](a, b) +print("int32") +print("a =", a) +print("b =", b) + + +# a = np.ones(N, dtype=np.int64) +# b = np.ones(N, dtype=np.int64) +# a_kernel_function[N, ndpex.DEFAULT_LOCAL_SIZE](a, b) +# print("int64") +# print("a =", a) +# print("b =", b) + + +# a = np.ones(N, dtype=np.float32) +# b = np.ones(N, dtype=np.float32) +# a_kernel_function[N, ndpex.DEFAULT_LOCAL_SIZE](a, b) +# print("float32") +# print("a =", a) +# print("b =", b) + + +# another_kernel_function[N, ndpex.DEFAULT_LOCAL_SIZE](a, b) +# a_kernel_function[N, ndpex.DEFAULT_LOCAL_SIZE](a, b) + +# print("Using device ...") +# print(a.device) +# driver(a, b, N) + +# print("Done...") + + +# main() +# if __name__ == "__main__": +# main()