Skip to content

Commit

Permalink
Enable device caching for kernels
Browse files Browse the repository at this point in the history
  • Loading branch information
ZzEeKkAa committed Dec 7, 2023
1 parent 50b8a46 commit f441804
Show file tree
Hide file tree
Showing 8 changed files with 284 additions and 48 deletions.
5 changes: 5 additions & 0 deletions numba_dpex/core/runtime/_dpexrt_python.c
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "_queuestruct.h"
#include "_usmarraystruct.h"

#include "experimental/kernel_caching.h"
#include "experimental/nrt_reserve_meminfo.h"
#include "numba/core/runtime/nrt_external.h"

Expand Down Expand Up @@ -1493,6 +1494,7 @@ static PyObject *build_c_helpers_dict(void)
_declpointer("DPEXRT_sycl_event_init", &DPEXRT_sycl_event_init);
_declpointer("DPEXRT_nrt_acquire_meminfo_and_schedule_release",
&DPEXRT_nrt_acquire_meminfo_and_schedule_release);
_declpointer("DPEXRT_build_or_get_kernel", &DPEXRT_build_or_get_kernel);

#undef _declpointer
return dct;
Expand Down Expand Up @@ -1563,6 +1565,9 @@ MOD_INIT(_dpexrt_python)
PyModule_AddObject(
m, "DPEXRT_nrt_acquire_meminfo_and_schedule_release",
PyLong_FromVoidPtr(&DPEXRT_nrt_acquire_meminfo_and_schedule_release));
PyModule_AddObject(m, "DPEXRT_build_or_get_kernel",
PyLong_FromVoidPtr(&DPEXRT_build_or_get_kernel));

PyModule_AddObject(m, "c_helpers", build_c_helpers_dict());
return MOD_SUCCESS_VAL(m);
}
36 changes: 36 additions & 0 deletions numba_dpex/core/runtime/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,3 +471,39 @@ def acquire_meminfo_and_schedule_release(
ret = builder.call(fn, args)

return ret

def build_or_get_kernel(self, builder: llvmir.IRBuilder, args):
"""Inserts LLVM IR to call build_or_get_kernel.
DPCTLSyclKernelRef
DPEXRT_build_or_get_kernel(
const DPCTLSyclContextRef ctx,
const DPCTLSyclDeviceRef dev,
size_t il_hash,
const char *il,
size_t il_length,
const char *compile_opts,
const char *kernel_name,
);
"""
mod = builder.module

func_ty = llvmir.FunctionType(
cgutils.voidptr_t,
[
cgutils.voidptr_t,
cgutils.voidptr_t,
llvmir.IntType(64),
cgutils.voidptr_t,
llvmir.IntType(64),
cgutils.voidptr_t,
cgutils.voidptr_t,
],
)
fn = cgutils.get_or_insert_function(
mod, func_ty, "DPEXRT_build_or_get_kernel"
)
ret = builder.call(fn, args)

return ret
106 changes: 106 additions & 0 deletions numba_dpex/core/runtime/experimental/kernel_caching.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
// SPDX-FileCopyrightText: 2020 - 2023 Intel Corporation
//
// SPDX-License-Identifier: Apache-2.0

//===----------------------------------------------------------------------===//
///
/// \file
/// A Python module that pprovides constructors to create a Numba MemInfo
/// PyObject using a sycl USM allocator as the external memory allocator.
/// The Module also provides the Numba box and unbox implementations for a
/// dpnp.ndarray object.
///
//===----------------------------------------------------------------------===//

#include "kernel_caching.h"
#include <unordered_map>

extern "C"
{
#include "dpctl_capi.h"
#include "dpctl_sycl_interface.h"

#include "_dbg_printer.h"

#include "numba/core/runtime/nrt_external.h"
}

#include "syclinterface/dpctl_sycl_type_casters.hpp"
#include "tools/dpctl.hpp"
#include "tools/hash_tuple.hpp"

using CacheKey = std::tuple<DPCTLSyclContextRef, DPCTLSyclDeviceRef, size_t>;

class CacheKeysAreEqual
{
public:
bool operator()(CacheKey const &lhs, CacheKey const &rhs) const
{
// TODO: implement full comparison
return DPCTLDevice_AreEq(std::get<DPCTLSyclDeviceRef>(lhs),
std::get<DPCTLSyclDeviceRef>(rhs)) &&
DPCTLContext_AreEq(std::get<DPCTLSyclContextRef>(lhs),
std::get<DPCTLSyclContextRef>(rhs)) &&
std::get<size_t>(lhs) == std::get<size_t>(rhs);
}
};

// TODO: add cache cleaning
std::unordered_map<CacheKey,
DPCTLSyclKernelRef,
std::hash<CacheKey>,
CacheKeysAreEqual>
sycl_kernel_cache = std::unordered_map<CacheKey,
DPCTLSyclKernelRef,
std::hash<CacheKey>,
CacheKeysAreEqual>();

template <class M, class Key, class F>
typename M::mapped_type &get_else_compute(M &m, Key const &k, F f)
{
typedef typename M::mapped_type V;
std::pair<typename M::iterator, bool> r =
m.insert(typename M::value_type(k, V()));
V &v = r.first->second;
if (r.second) {
DPEXRT_DEBUG(drt_debug_print("DPEXRT-DEBUG: building kernel.\n"););
f(v);
}
else {
DPEXRT_DEBUG(drt_debug_print("DPEXRT-DEBUG: using cached kernel.\n"););
}
return v;
}

extern "C"
{
DPCTLSyclKernelRef DPEXRT_build_or_get_kernel(const DPCTLSyclContextRef ctx,
const DPCTLSyclDeviceRef dev,
size_t il_hash,
const char *il,
size_t il_length,
const char *compile_opts,
const char *kernel_name)
{
DPEXRT_DEBUG(
drt_debug_print("DPEXRT-DEBUG: in build or get kernel.\n"););

CacheKey key = std::make_tuple(ctx, dev, il_hash);

DPEXRT_DEBUG(auto ctx_hash = std::hash<DPCTLSyclContextRef>{}(ctx);
auto dev_hash = std::hash<DPCTLSyclDeviceRef>{}(dev);
drt_debug_print("DPEXRT-DEBUG: key hashes: %d %d %d.\n",
ctx_hash, dev_hash, il_hash););

auto k_ref = get_else_compute(
sycl_kernel_cache, key,
[ctx, dev, il, il_length, compile_opts,
kernel_name](DPCTLSyclKernelRef &k_ref) {
auto kb_ref = DPCTLKernelBundle_CreateFromSpirv(
ctx, dev, il, il_length, compile_opts);
k_ref = DPCTLKernelBundle_GetKernel(kb_ref, kernel_name);
DPCTLKernelBundle_Delete(kb_ref);
});
return DPCTLKernel_Copy(k_ref);
}
}
32 changes: 32 additions & 0 deletions numba_dpex/core/runtime/experimental/kernel_caching.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// SPDX-FileCopyrightText: 2020 - 2023 Intel Corporation
//
// SPDX-License-Identifier: Apache-2.0

//===----------------------------------------------------------------------===//
///
/// \file
/// A Python module that pprovides constructors to create a Numba MemInfo
/// PyObject using a sycl USM allocator as the external memory allocator.
/// The Module also provides the Numba box and unbox implementations for a
/// dpnp.ndarray object.
///
//===----------------------------------------------------------------------===//

#ifdef __cplusplus
extern "C"
{
#endif

#include "dpctl_capi.h"
#include "dpctl_sycl_interface.h"

DPCTLSyclKernelRef DPEXRT_build_or_get_kernel(const DPCTLSyclContextRef ctx,
const DPCTLSyclDeviceRef dev,
size_t il_hash,
const char *il,
size_t il_length,
const char *compile_opts,
const char *kernel_name);
#ifdef __cplusplus
}
#endif
25 changes: 25 additions & 0 deletions numba_dpex/core/runtime/experimental/tools/dpctl.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#pragma once

#include "syclinterface/dpctl_sycl_type_casters.hpp"
#include <CL/sycl.hpp>

namespace std
{
template <> struct hash<DPCTLSyclDeviceRef>
{
size_t operator()(const DPCTLSyclDeviceRef &DRef) const
{
using dpctl::syclinterface::unwrap;
return hash<sycl::device>()(*unwrap<sycl::device>(DRef));
}
};

template <> struct hash<DPCTLSyclContextRef>
{
size_t operator()(const DPCTLSyclContextRef &CRef) const
{
using dpctl::syclinterface::unwrap;
return hash<sycl::context>()(*unwrap<sycl::context>(CRef));
}
};
} // namespace std
49 changes: 49 additions & 0 deletions numba_dpex/core/runtime/experimental/tools/hash_tuple.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
#pragma once

#include <tuple>
namespace std
{
namespace
{

// Code from boost
// Reciprocal of the golden ratio helps spread entropy
// and handles duplicates.
// See Mike Seymour in magic-numbers-in-boosthash-combine:
// http://stackoverflow.com/questions/4948780

template <class T> inline void hash_combine(std::size_t &seed, T const &v)
{
seed ^= std::hash<T>()(v) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
}

// Recursive template code derived from Matthieu M.
template <class Tuple, size_t Index = std::tuple_size<Tuple>::value - 1>
struct HashValueImpl
{
static void apply(size_t &seed, Tuple const &tuple)
{
HashValueImpl<Tuple, Index - 1>::apply(seed, tuple);
hash_combine(seed, std::get<Index>(tuple));
}
};

template <class Tuple> struct HashValueImpl<Tuple, 0>
{
static void apply(size_t &seed, Tuple const &tuple)
{
hash_combine(seed, std::get<0>(tuple));
}
};
} // namespace

template <typename... TT> struct hash<std::tuple<TT...>>
{
size_t operator()(std::tuple<TT...> const &tt) const
{
size_t seed = 0;
HashValueImpl<std::tuple<TT...>>::apply(seed, tt);
return seed;
}
};
} // namespace std
74 changes: 26 additions & 48 deletions numba_dpex/experimental/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ def __init__(
):
self.context = codegen_targetctx
self.builder = builder
# TODO: get dpex RT from cached property once the PR is merged
# https://github.com/IntelPython/numba-dpex/pull/1027
# and get rid of the global variable. Use self.context.dpexrt instead.
self.dpexrt = DpexRTContext(self.context)

if config.DEBUG_KERNEL_LAUNCHER:
cgutils.printf(
Expand Down Expand Up @@ -139,7 +143,7 @@ def get_queue_ref_val(

return ptr_to_queue_ref

def get_kernel(self, qref, kernel_module: _KernelModule):
def get_kernel(self, queue_ref, kernel_module: _KernelModule):
"""Returns the pointer to the sycl::kernel object in a passed in
sycl::kernel_bundle wrapper object.
"""
Expand All @@ -150,23 +154,32 @@ def get_kernel(self, qref, kernel_module: _KernelModule):
bytes=kernel_module.kernel_bitcode,
)

# Create a sycl::kernel_bundle object and return it as an opaque pointer
# using dpctl's libsyclinterface.
kbref = self.create_kernel_bundle_from_spirv(
queue_ref=qref,
kernel_bc=kernel_bc_byte_str,
kernel_bc_size_in_bytes=len(kernel_module.kernel_bitcode),
)

kernel_name = self.context.insert_const_string(
self.builder.module, kernel_module.kernel_name
)

kernel_ref = sycl.dpctl_kernel_bundle_get_kernel(
self.builder, kbref, kernel_name
context_ref = sycl.dpctl_queue_get_context(self.builder, queue_ref)
device_ref = sycl.dpctl_queue_get_device(self.builder, queue_ref)

kernel_ref = self.dpexrt.build_or_get_kernel(
self.builder,
[
context_ref,
device_ref,
llvmir.Constant(
llvmir.IntType(64), hash(kernel_module.kernel_bitcode)
),
kernel_bc_byte_str,
llvmir.Constant(
llvmir.IntType(64), len(kernel_module.kernel_bitcode)
),
self.builder.load(create_null_ptr(self.builder, self.context)),
kernel_name,
],
)

sycl.dpctl_kernel_bundle_delete(self.builder, kbref)
sycl.dpctl_context_delete(self.builder, context_ref)
sycl.dpctl_device_delete(self.builder, device_ref)

return kernel_ref

Expand Down Expand Up @@ -210,36 +223,6 @@ def create_llvm_values_for_index_space(

return LLRange(global_range_extents, local_range_extents)

def create_kernel_bundle_from_spirv(
self,
queue_ref: llvmir.PointerType,
kernel_bc: llvmir.Constant,
kernel_bc_size_in_bytes: int,
) -> llvmir.CallInstr:
"""Calls DPCTLKernelBundle_CreateFromSpirv to create an opaque pointer
to a sycl::kernel_bundle from the SPIR-V generated for a kernel.
"""
device_ref = sycl.dpctl_queue_get_device(self.builder, queue_ref)
context_ref = sycl.dpctl_queue_get_context(self.builder, queue_ref)
args = [
context_ref,
device_ref,
kernel_bc,
llvmir.Constant(llvmir.IntType(64), kernel_bc_size_in_bytes),
self.builder.load(create_null_ptr(self.builder, self.context)),
]
kb_ref = sycl.dpctl_kernel_bundle_create_from_spirv(self.builder, *args)
sycl.dpctl_context_delete(self.builder, context_ref)
sycl.dpctl_device_delete(self.builder, device_ref)

if config.DEBUG_KERNEL_LAUNCHER:
cgutils.printf(
self.builder,
"DPEX-DEBUG: Generated kernel_bundle from SPIR-V.\n",
)

return kb_ref

def acquire_meminfo_and_schedule_release(
self,
queue_ref,
Expand All @@ -259,12 +242,7 @@ def acquire_meminfo_and_schedule_release(
status_ptr = cgutils.alloca_once(
self.builder, self.context.get_value_type(types.uint64)
)
# TODO: get dpex RT from cached property once the PR is merged
# https://github.com/IntelPython/numba-dpex/pull/1027
# host_eref = ctx.dpexrt.acquire_meminfo_and_schedule_release( # noqa: W0621
host_eref = DpexRTContext(
self.context
).acquire_meminfo_and_schedule_release(
host_eref = self.dpexrt.acquire_meminfo_and_schedule_release(
self.builder,
[
self.context.nrt.get_nrt_api(self.builder),
Expand Down
Loading

0 comments on commit f441804

Please sign in to comment.