Skip to content

Commit

Permalink
Adding cache and specialization to func decorator
Browse files Browse the repository at this point in the history
  • Loading branch information
chudur-budur committed Jan 19, 2023
1 parent 187782d commit d25b525
Show file tree
Hide file tree
Showing 6 changed files with 272 additions and 111 deletions.
12 changes: 12 additions & 0 deletions notes.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
1. Specialized function gets compiled when the file is loaded
2. Dig into:
cres.typing_context.insert_user_function(devfn, _function_template)
libs = [cres.library]
cres.target_context.insert_user_function(devfn, cres.fndesc, libs)
3. why the numba code is downcasting? Look into the lowering logic
%".87" = fptosi double %".86" to i32
store i32 0, i32* %".88"
%".91" = call spir_func i32 @"dpex_py_devfn__5F__5F_main_5F__5F__2E_a_5F_device_5F_function_2E_int32"(i32* %".88", i32 %".87")
4. This is what should happen: if a func is specialized, dpex should throw an error rather than doing an automatic typecasting.
5. look into numba's lowering and intercept the typecasting -- the goal is to stop lowering when type doesn't match in specialization case
6. Most likely 867 will be resolved if these problems are addressed.
28 changes: 15 additions & 13 deletions numba_dpex/core/caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ class LRUCache(AbstractCache):
with a dictionary as a lookup table.
"""

def __init__(self, capacity=10, pyfunc=None):
def __init__(self, name="cache", capacity=10, pyfunc=None):
"""Constructor for LRUCache.
Args:
Expand All @@ -233,6 +233,7 @@ def __init__(self, capacity=10, pyfunc=None):
pyfunc (NoneType, optional): A python function to be cached.
Defaults to None.
"""
self._name = name
self._capacity = capacity
self._lookup = {}
self._evicted = {}
Expand Down Expand Up @@ -432,8 +433,8 @@ def get(self, key):
value = self._cache_file.load(key)
if config.DEBUG_CACHE:
print(
"[cache]: unpickled an evicted artifact, "
"key: {0:s}.".format(str(key))
"[{0:s}]: unpickled an evicted artifact, "
"key: {1:s}.".format(self._name, str(key))
)
else:
value = self._evicted[key]
Expand All @@ -442,8 +443,8 @@ def get(self, key):
else:
if config.DEBUG_CACHE:
print(
"[cache] size: {0:d}, loading artifact, key: {1:s}".format(
len(self._lookup), str(key)
"[{0:s}] size: {1:d}, loading artifact, key: {2:s}".format(
self._name, len(self._lookup), str(key)
)
)
node = self._lookup[key]
Expand All @@ -464,8 +465,8 @@ def put(self, key, value):
if key in self._lookup:
if config.DEBUG_CACHE:
print(
"[cache] size: {0:d}, storing artifact, key: {1:s}".format(
len(self._lookup), str(key)
"[{0:s}] size: {1:d}, storing artifact, key: {2:s}".format(
self._name, len(self._lookup), str(key)
)
)
self._lookup[key].value = value
Expand All @@ -480,8 +481,9 @@ def put(self, key, value):
if self._cache_file:
if config.DEBUG_CACHE:
print(
"[cache] size: {0:d}, pickling the LRU item, "
"key: {1:s}, indexed at {2:s}.".format(
"[{0:s}] size: {1:d}, pickling the LRU item, "
"key: {2:s}, indexed at {3:s}.".format(
self._name,
len(self._lookup),
str(self._head.key),
self._cache_file._index_path,
Expand All @@ -496,8 +498,8 @@ def put(self, key, value):
self._lookup.pop(self._head.key)
if config.DEBUG_CACHE:
print(
"[cache] size: {0:d}, capacity exceeded, evicted".format(
len(self._lookup)
"[{0:s}] size: {1:d}, capacity exceeded, evicted".format(
self._name, len(self._lookup)
),
self._head.key,
)
Expand All @@ -509,7 +511,7 @@ def put(self, key, value):
self._append_tail(new_node)
if config.DEBUG_CACHE:
print(
"[cache] size: {0:d}, saved artifact, key: {1:s}".format(
len(self._lookup), str(key)
"[{0:s}] size: {1:d}, saved artifact, key: {2:s}".format(
self._name, len(self._lookup), str(key)
)
)
8 changes: 6 additions & 2 deletions numba_dpex/core/kernel_interface/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,9 @@ def __init__(
self._cache = NullCache()
elif enable_cache:
self._cache = LRUCache(
capacity=config.CACHE_SIZE, pyfunc=self.pyfunc
name="SPIRVKernelCache",
capacity=config.CACHE_SIZE,
pyfunc=self.pyfunc,
)
else:
self._cache = NullCache()
Expand Down Expand Up @@ -119,7 +121,9 @@ def __init__(
if specialization_sigs:
self._has_specializations = True
self._specialization_cache = LRUCache(
capacity=config.CACHE_SIZE, pyfunc=self.pyfunc
name="SPIRVKernelSpecializationCache",
capacity=config.CACHE_SIZE,
pyfunc=self.pyfunc,
)
for sig in specialization_sigs:
self._specialize(sig)
Expand Down
178 changes: 115 additions & 63 deletions numba_dpex/core/kernel_interface/func.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,61 +6,65 @@
"""


from numba.core import sigutils, types
from numba.core.typing.templates import AbstractTemplate, ConcreteTemplate

from numba_dpex import config
from numba_dpex.core.caching import LRUCache, NullCache, build_key
from numba_dpex.core.compiler import compile_with_dpex
from numba_dpex.core.descriptor import dpex_target
from numba_dpex.utils import npytypes_array_to_dpex_array


def compile_func(pyfunc, return_type, args, debug=None):
cres = compile_with_dpex(
pyfunc=pyfunc,
pyfunc_name=pyfunc.__name__,
return_type=return_type,
target_context=dpex_target.target_context,
typing_context=dpex_target.typing_context,
args=args,
is_kernel=False,
debug=debug,
)
func = cres.library.get_function(cres.fndesc.llvm_func_name)
cres.target_context.mark_ocl_device(func)
devfn = DpexFunction(cres)

class _function_template(ConcreteTemplate):
key = devfn
cases = [cres.signature]

cres.typing_context.insert_user_function(devfn, _function_template)
libs = [cres.library]
cres.target_context.insert_user_function(devfn, cres.fndesc, libs)
return devfn


def compile_func_template(pyfunc, debug=None):
"""Compile a DpexFunctionTemplate"""

dft = DpexFunctionTemplate(pyfunc, debug=debug)

class _function_template(AbstractTemplate):
key = dft

def generic(self, args, kws):
if kws:
raise AssertionError("No keyword arguments allowed.")
return dft.compile(args)

dpex_target.typing_context.insert_user_function(dft, _function_template)
return dft
class DpexFunction(object):
def __init__(self, pyfunc, debug=None):
self._pyfunc = pyfunc
self._debug = debug

def compile(self, arg_types, return_types):
cres = compile_with_dpex(
pyfunc=self._pyfunc,
pyfunc_name=self._pyfunc.__name__,
return_type=return_types,
target_context=dpex_target.target_context,
typing_context=dpex_target.typing_context,
args=arg_types,
is_kernel=False,
debug=self._debug,
)
func = cres.library.get_function(cres.fndesc.llvm_func_name)
cres.target_context.mark_ocl_device(func)

return cres


class DpexFunctionTemplate(object):
"""Unmaterialized dpex function"""

def __init__(self, pyfunc, debug=None):
self.py_func = pyfunc
self.debug = debug
self._compileinfos = {}
def __init__(self, pyfunc, debug=None, enable_cache=True):
self._pyfunc = pyfunc
self._debug = debug
self._enable_cache = enable_cache

if not config.ENABLE_CACHE:
self._cache = NullCache()
elif self._enable_cache:
self._cache = LRUCache(
name="DpexFunctionTemplateCache",
capacity=config.CACHE_SIZE,
pyfunc=self._pyfunc,
)
else:
self._cache = NullCache()
self._cache_hits = 0

@property
def cache(self):
return self._cache

@property
def cache_hits(self):
return self._cache_hits

def compile(self, args):
"""Compile a dpex.func decorated Python function with the given
Expand All @@ -69,37 +73,85 @@ def compile(self, args):
Each signature is compiled once by caching the compiled function inside
this object.
"""
if args not in self._compileinfos:
argtypes = [
dpex_target.typing_context.resolve_argument_type(arg)
for arg in args
]
key = build_key(
tuple(argtypes),
self._pyfunc,
dpex_target.target_context.codegen(),
)
cres = self._cache.get(key)
if cres is None:
self._cache_hits += 1
cres = compile_with_dpex(
pyfunc=self.py_func,
pyfunc_name=self.py_func.__name__,
pyfunc=self._pyfunc,
pyfunc_name=self._pyfunc.__name__,
return_type=None,
target_context=dpex_target.target_context,
typing_context=dpex_target.typing_context,
args=args,
is_kernel=False,
debug=self.debug,
debug=self._debug,
)
func = cres.library.get_function(cres.fndesc.llvm_func_name)
cres.target_context.mark_ocl_device(func)
first_definition = not self._compileinfos
self._compileinfos[args] = cres
libs = [cres.library]

if first_definition:
# First definition
cres.target_context.insert_user_function(
self, cres.fndesc, libs
)
else:
cres.target_context.add_user_function(self, cres.fndesc, libs)

else:
cres = self._compileinfos[args]
cres.target_context.insert_user_function(self, cres.fndesc, libs)
# cres.target_context.add_user_function(self, cres.fndesc, libs)
self._cache.put(key, cres)

return cres.signature


class DpexFunction(object):
def __init__(self, cres):
self.cres = cres
def compile_func(pyfunc, signature, debug=None):
devfn = DpexFunction(pyfunc, debug=debug)

cres = []
for sig in signature:
arg_types, return_types = sigutils.normalize_signature(sig)
arg_types = tuple(
[
npytypes_array_to_dpex_array(ty)
if isinstance(ty, types.npytypes.Array)
else ty
for ty in arg_types
]
)
c = devfn.compile(arg_types, return_types)
cres.append(c)

class _function_template(ConcreteTemplate):
unsafe_casting = False
exact_match_required = True
key = devfn
cases = [c.signature for c in cres]

# TODO: If context is only one copy, just make it global
cres[0].typing_context.insert_user_function(devfn, _function_template)

for c in cres:
c.target_context.insert_user_function(devfn, c.fndesc, [c.library])

return devfn


def compile_func_template(pyfunc, debug=None):
"""Compile a DpexFunctionTemplate"""

dft = DpexFunctionTemplate(pyfunc, debug=debug)

class _function_template(AbstractTemplate):
unsafe_casting = False
exact_match_required = True
key = dft

def generic(self, args, kws):
if kws:
raise AssertionError("No keyword arguments allowed.")
return dft.compile(args)

dpex_target.typing_context.insert_user_function(dft, _function_template)
return dft
Loading

0 comments on commit d25b525

Please sign in to comment.