Skip to content

Commit

Permalink
Merge pull request #912 from adarshyoga/kmeans_perf_improvements
Browse files Browse the repository at this point in the history
Dispatcher/caching rewrite to address performance regression
  • Loading branch information
Diptorup Deb authored Feb 17, 2023
2 parents a3f5604 + 1056a58 commit ae994cd
Show file tree
Hide file tree
Showing 5 changed files with 117 additions and 91 deletions.
74 changes: 7 additions & 67 deletions numba_dpex/core/caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,77 +2,12 @@
#
# SPDX-License-Identifier: Apache-2.0

import hashlib
import sys
from abc import ABCMeta, abstractmethod

from numba.core.caching import CacheImpl, IndexDataCacheFile
from numba.core.serialize import dumps

from numba_dpex import config
from numba_dpex.core.types import USMNdArray


def build_key(
argtypes, pyfunc, codegen, backend=None, device_type=None, exec_queue=None
):
"""Constructs a key from python function, context, backend, the device
type and execution queue.
Compute index key for the given argument types and codegen. It includes a
description of the OS, target architecture and hashes of the bytecode for
the function and, if the function has a __closure__, a hash of the
cell_contents.type
Args:
argtypes : A tuple of numba types corresponding to the arguments to the
compiled function.
pyfunc : The Python function that is to be compiled and cached.
codegen (numba.core.codegen.Codegen):
The codegen object found from the target context.
backend (enum, optional): A 'backend_type' enum.
Defaults to None.
device_type (enum, optional): A 'device_type' enum.
Defaults to None.
exec_queue (dpctl._sycl_queue.SyclQueue', optional): A SYCL queue object.
Returns:
tuple: A tuple of return type, argtpes, magic_tuple of codegen
and another tuple of hashcodes from bytecode and cell_contents.
"""

codebytes = pyfunc.__code__.co_code
if pyfunc.__closure__ is not None:
try:
cvars = tuple([x.cell_contents for x in pyfunc.__closure__])
# Note: cloudpickle serializes a function differently depending
# on how the process is launched; e.g. multiprocessing.Process
cvarbytes = dumps(cvars)
except:
cvarbytes = b"" # a temporary solution for function template
else:
cvarbytes = b""

argtylist = list(argtypes)
for i, argty in enumerate(argtylist):
if isinstance(argty, USMNdArray):
# Convert the USMNdArray to an abridged type that disregards the
# usm_type, device, queue, address space attributes.
argtylist[i] = (argty.ndim, argty.dtype, argty.layout)

argtypes = tuple(argtylist)

return (
argtypes,
codegen.magic_tuple(),
backend,
device_type,
exec_queue,
(
hashlib.sha256(codebytes).hexdigest(),
hashlib.sha256(cvarbytes).hexdigest(),
),
)


class _CacheImpl(CacheImpl):
Expand Down Expand Up @@ -475,8 +410,13 @@ def put(self, key, value):
self._name, len(self._lookup), str(key)
)
)
self._lookup[key].value = value
self.get(key)
node = self._lookup[key]
node.value = value

if node is not self._tail:
self._unlink_node(node)
self._append_tail(node)

return

if key in self._evicted:
Expand Down
41 changes: 23 additions & 18 deletions numba_dpex/core/kernel_interface/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from numba.core.types import void

from numba_dpex import NdRange, Range, config
from numba_dpex.core.caching import LRUCache, NullCache, build_key
from numba_dpex.core.caching import LRUCache, NullCache
from numba_dpex.core.descriptor import dpex_kernel_target
from numba_dpex.core.exceptions import (
ComputeFollowsDataInferenceError,
Expand All @@ -34,6 +34,11 @@
from numba_dpex.core.kernel_interface.arg_pack_unpacker import Packer
from numba_dpex.core.kernel_interface.spirv_kernel import SpirvKernel
from numba_dpex.core.types import USMNdArray
from numba_dpex.core.utils import (
build_key,
create_func_hash,
strip_usm_metadata,
)


def get_ordered_arg_access_types(pyfunc, access_types):
Expand Down Expand Up @@ -85,6 +90,8 @@ def __init__(
self._global_range = None
self._local_range = None

self._func_hash = create_func_hash(pyfunc)

# caching related attributes
if not config.ENABLE_CACHE:
self._cache = NullCache()
Expand Down Expand Up @@ -151,7 +158,7 @@ def cache(self):
def cache_hits(self):
return self._cache_hits

def _compile_and_cache(self, argtypes, cache):
def _compile_and_cache(self, argtypes, cache, key=None):
"""Helper function to compile the Python function or Numba FunctionIR
object passed to a JitKernel and store it in an internal cache.
"""
Expand All @@ -171,11 +178,13 @@ def _compile_and_cache(self, argtypes, cache):
device_driver_ir_module = kernel.device_driver_ir_module
kernel_module_name = kernel.module_name

key = build_key(
tuple(argtypes),
self.pyfunc,
kernel.target_context.codegen(),
)
if not key:
stripped_argtypes = strip_usm_metadata(argtypes)
codegen_magic_tuple = kernel.target_context.codegen().magic_tuple()
key = build_key(
stripped_argtypes, codegen_magic_tuple, self._func_hash
)

cache.put(key, (device_driver_ir_module, kernel_module_name))

return device_driver_ir_module, kernel_module_name
Expand Down Expand Up @@ -604,12 +613,12 @@ def __call__(self, *args):
self.kernel_name, backend, JitKernel._supported_backends
)

# load the kernel from cache
key = build_key(
tuple(argtypes),
self.pyfunc,
dpex_kernel_target.target_context.codegen(),
# Generate key used for cache lookup
stripped_argtypes = strip_usm_metadata(argtypes)
codegen_magic_tuple = (
dpex_kernel_target.target_context.codegen().magic_tuple()
)
key = build_key(stripped_argtypes, codegen_magic_tuple, self._func_hash)

# If the JitKernel was specialized then raise exception if argtypes
# do not match one of the specialized versions.
Expand All @@ -630,15 +639,11 @@ def __call__(self, *args):
device_driver_ir_module,
kernel_module_name,
) = self._compile_and_cache(
argtypes=argtypes,
cache=self._cache,
argtypes=argtypes, cache=self._cache, key=key
)

kernel_bundle_key = build_key(
tuple(argtypes),
self.pyfunc,
dpex_kernel_target.target_context.codegen(),
exec_queue=exec_queue,
stripped_argtypes, codegen_magic_tuple, exec_queue, self._func_hash
)

artifact = self._kernel_bundle_cache.get(kernel_bundle_key)
Expand Down
21 changes: 15 additions & 6 deletions numba_dpex/core/kernel_interface/func.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,18 @@
"""_summary_
"""


from numba.core import sigutils, types
from numba.core.typing.templates import AbstractTemplate, ConcreteTemplate

from numba_dpex import config
from numba_dpex.core.caching import LRUCache, NullCache, build_key
from numba_dpex.core.caching import LRUCache, NullCache
from numba_dpex.core.compiler import compile_with_dpex
from numba_dpex.core.descriptor import dpex_kernel_target
from numba_dpex.core.utils import (
build_key,
create_func_hash,
strip_usm_metadata,
)
from numba_dpex.utils import npytypes_array_to_dpex_array


Expand Down Expand Up @@ -91,6 +95,8 @@ def __init__(self, pyfunc, debug=False, enable_cache=True):
self._debug = debug
self._enable_cache = enable_cache

self._func_hash = create_func_hash(pyfunc)

if not config.ENABLE_CACHE:
self._cache = NullCache()
elif self._enable_cache:
Expand Down Expand Up @@ -132,11 +138,14 @@ def compile(self, args):
dpex_kernel_target.typing_context.resolve_argument_type(arg)
for arg in args
]
key = build_key(
tuple(argtypes),
self._pyfunc,
dpex_kernel_target.target_context.codegen(),

# Generate key used for cache lookup
stripped_argtypes = strip_usm_metadata(argtypes)
codegen_magic_tuple = (
dpex_kernel_target.target_context.codegen().magic_tuple()
)
key = build_key(stripped_argtypes, codegen_magic_tuple, self._func_hash)

cres = self._cache.get(key)
if cres is None:
self._cache_hits += 1
Expand Down
4 changes: 4 additions & 0 deletions numba_dpex/core/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,13 @@
#
# SPDX-License-Identifier: Apache-2.0

from .caching_utils import build_key, create_func_hash, strip_usm_metadata
from .suai_helper import SyclUSMArrayInterface, get_info_from_suai

__all__ = [
"get_info_from_suai",
"SyclUSMArrayInterface",
"create_func_hash",
"strip_usm_metadata",
"build_key",
]
68 changes: 68 additions & 0 deletions numba_dpex/core/utils/caching_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# SPDX-FileCopyrightText: 2020 - 2022 Intel Corporation
#
# SPDX-License-Identifier: Apache-2.0

import hashlib

from numba.core.serialize import dumps

from numba_dpex.core.types import USMNdArray


def build_key(*args):
"""Constructs key from variable list of args
Args:
*args: List of components to construct key
Return:
Tuple of args
"""
return tuple(args)


def create_func_hash(pyfunc):
"""Creates a tuple of sha256 hashes out of code and
variable bytes extracted from the compiled funtion.
Args:
pyfunc: Python function object
Return:
Tuple of hashes of code and variable bytes
"""
codebytes = pyfunc.__code__.co_code
if pyfunc.__closure__ is not None:
try:
cvars = tuple([x.cell_contents for x in pyfunc.__closure__])
# Note: cloudpickle serializes a function differently depending
# on how the process is launched; e.g. multiprocessing.Process
cvarbytes = dumps(cvars)
except:
cvarbytes = b"" # a temporary solution for function template
else:
cvarbytes = b""

return (
hashlib.sha256(codebytes).hexdigest(),
hashlib.sha256(cvarbytes).hexdigest(),
)


def strip_usm_metadata(argtypes):
"""Convert the USMNdArray to an abridged type that disregards the
usm_type, device, queue, address space attributes.
Args:
argtypes: List of types
Return:
Tuple of types after removing USM metadata from USMNdArray type
"""

stripped_argtypes = []
for argty in argtypes:
if isinstance(argty, USMNdArray):
stripped_argtypes.append((argty.ndim, argty.dtype, argty.layout))
else:
stripped_argtypes.append(argty)

return tuple(stripped_argtypes)

0 comments on commit ae994cd

Please sign in to comment.