Skip to content

Commit

Permalink
Add dependant event support to async kernel submission
Browse files Browse the repository at this point in the history
  • Loading branch information
ZzEeKkAa committed Dec 18, 2023
1 parent 070df7d commit 34bc454
Show file tree
Hide file tree
Showing 5 changed files with 131 additions and 24 deletions.
2 changes: 1 addition & 1 deletion numba_dpex/core/parfors/parfor_lowerer.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def _submit_parfor_kernel(
kl_builder.set_arguments(
kernel_fn.kernel_arg_types, kernel_args=kernel_args
)
kl_builder.set_dependant_event_list([])
kl_builder.set_dependent_events([])
event_ref = kl_builder.submit()

sycl.dpctl_event_wait(lowerer.builder, event_ref)
Expand Down
46 changes: 33 additions & 13 deletions numba_dpex/core/utils/kernel_launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -668,22 +668,42 @@ def set_arguments_form_tuple(
kernel_args = self._extract_llvm_values_from_tuple(ll_kernel_args_tuple)
self.set_arguments(ty_kernel_args_tuple, kernel_args)

def set_dependant_event_list(self, dep_events: list[llvmir.Instruction]):
"""Sets dependant events to the argument list."""
if self.arguments.dep_events is not None:
return
def set_dependent_events(self, dep_events: list[llvmir.Instruction]):
"""Sets dependent events to the argument list."""
ll_dep_events = self._create_ll_from_py_list(types.voidptr, dep_events)
self.arguments.dep_events = ll_dep_events
self.arguments.dep_events_len = self.context.get_constant(
types.uintp, len(dep_events)
)

if len(dep_events) > 0:
# TODO: implement for non zero input
raise NotImplementedError
def set_dependent_events_from_tuple(
self,
ty_dependent_events: UniTuple,
ll_dependent_events: llvmir.Instruction,
):
"""Set's dependent events from tuple represented by LLVM IR.
self.arguments.dep_events = self.builder.bitcast(
utils.create_null_ptr(builder=self.builder, context=self.context),
utils.get_llvm_type(context=self.context, type=types.voidptr),
)
self.arguments.dep_events_len = self.context.get_constant(
types.uintp, 0
Args:
ll_dependent_events: tuple of numba's data models.
"""
if len(ty_dependent_events) == 0:
self.set_dependent_events([])
return

ty_event = ty_dependent_events[0]
dm_dependent_events = self._extract_llvm_values_from_tuple(
ll_dependent_events
)
dependent_events = []
for dm_dependent_event in dm_dependent_events:
event_struct_proxy = cgutils.create_struct_proxy(ty_event)(
self.context,
self.builder,
value=dm_dependent_event,
)
dependent_events.append(event_struct_proxy.event_ref)

self.set_dependent_events(dependent_events)

def submit(self) -> llvmir.Instruction:
"""Submits kernel by calling sycl.dpctl_queue_submit_range or
Expand Down
4 changes: 2 additions & 2 deletions numba_dpex/dpctl_iface/libsyclinterface_bindings.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def dpctl_queue_submit_range(builder: llvmir.IRBuilder, *args):
llvmir.IntType(64),
llvmir.IntType(64).as_pointer(),
llvmir.IntType(64),
cgutils.voidptr_t,
cgutils.voidptr_t.as_pointer(),
llvmir.IntType(64),
],
func_name="DPCTLQueue_SubmitRange",
Expand Down Expand Up @@ -195,7 +195,7 @@ def dpctl_queue_submit_ndrange(builder: llvmir.IRBuilder, *args):
llvmir.IntType(64).as_pointer(),
llvmir.IntType(64).as_pointer(),
llvmir.IntType(64),
cgutils.voidptr_t,
cgutils.voidptr_t.as_pointer(),
llvmir.IntType(64),
],
func_name="DPCTLQueue_SubmitNDRange",
Expand Down
50 changes: 42 additions & 8 deletions numba_dpex/experimental/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from llvmlite import ir as llvmir
from numba.core import cgutils, types
from numba.core.cpu import CPUContext
from numba.core.types.containers import UniTuple
from numba.core.types.containers import Tuple, UniTuple
from numba.core.types.functions import Dispatcher
from numba.extending import intrinsic

Expand Down Expand Up @@ -51,13 +51,15 @@ def _submit_kernel_async(
typingctx,
ty_kernel_fn: Dispatcher,
ty_index_space: Union[RangeType, NdRangeType],
ty_dependent_events: UniTuple,
ty_kernel_args_tuple: UniTuple,
):
"""Generates IR code for call_kernel_async dpjit function."""
return _submit_kernel(
typingctx,
ty_kernel_fn,
ty_index_space,
ty_dependent_events,
ty_kernel_args_tuple,
sync=False,
)
Expand All @@ -75,15 +77,17 @@ def _submit_kernel_sync(
typingctx,
ty_kernel_fn,
ty_index_space,
None,
ty_kernel_args_tuple,
sync=True,
)


def _submit_kernel(
typingctx, # pylint: disable=W0613
def _submit_kernel( # pylint: disable=too-many-arguments
typingctx, # pylint: disable=unused-argument
ty_kernel_fn: Dispatcher,
ty_index_space: Union[RangeType, NdRangeType],
ty_dependent_events: UniTuple,
ty_kernel_args_tuple: UniTuple,
sync: bool,
):
Expand All @@ -106,7 +110,21 @@ def _submit_kernel(
ty_event = DpctlSyclEvent()
ty_return = types.Tuple([ty_event, ty_event])

sig = ty_return(ty_kernel_fn, ty_index_space, ty_kernel_args_tuple)
if ty_dependent_events is not None:
if not isinstance(ty_dependent_events, UniTuple) and not isinstance(
ty_dependent_events, Tuple
):
raise ValueError("dependent events must be passed as a tuple")

sig = ty_return(
ty_kernel_fn,
ty_index_space,
ty_dependent_events,
ty_kernel_args_tuple,
)
else:
sig = ty_return(ty_kernel_fn, ty_index_space, ty_kernel_args_tuple)

kernel_sig = types.void(*ty_kernel_args_tuple)
# ty_kernel_fn is type specific to exact function, so we can get function
# directly from type and compile it. Thats why we don't need to get it in
Expand All @@ -123,8 +141,14 @@ def codegen(
):
ty_index_space: Union[RangeType, NdRangeType] = sig.args[1]
ll_index_space: llvmir.Instruction = llargs[1]
ty_kernel_args_tuple: UniTuple = sig.args[2]
ll_kernel_args_tuple: llvmir.Instruction = llargs[2]
ty_kernel_args_tuple: UniTuple = sig.args[-1]
ll_kernel_args_tuple: llvmir.Instruction = llargs[-1]

if len(llargs) == 4:
ty_dependent_events: UniTuple = sig.args[2]
ll_dependent_events: llvmir.Instruction = llargs[2]
else:
ty_dependent_events = None

kl_builder = kl.KernelLaunchIRBuilder(
cgctx,
Expand All @@ -140,7 +164,13 @@ def codegen(
)
kl_builder.set_queue_from_arguments()
kl_builder.set_kernel_from_spirv(kernel_module)
kl_builder.set_dependant_event_list([])
if ty_dependent_events is None:
kl_builder.set_dependent_events([])
else:
kl_builder.set_dependent_events_from_tuple(
ty_dependent_events,
ll_dependent_events,
)
device_event_ref = kl_builder.submit()

if not sync:
Expand Down Expand Up @@ -185,7 +215,10 @@ def call_kernel(kernel_fn, index_space, *kernel_args) -> None:

@dpjit
def call_kernel_async(
kernel_fn, index_space, *kernel_args
kernel_fn,
index_space,
dependent_events: list[dpctl.SyclEvent],
*kernel_args
) -> tuple[dpctl.SyclEvent, dpctl.SyclEvent]:
"""Calls a numba_dpex.kernel decorated function from CPython or from another
dpjit function. Kernel execution happens in asyncronous way, so the thread
Expand All @@ -210,5 +243,6 @@ def call_kernel_async(
return _submit_kernel_async( # pylint: disable=E1120
kernel_fn,
index_space,
dependent_events,
kernel_args,
)
53 changes: 53 additions & 0 deletions numba_dpex/tests/experimental/test_async_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

import dpctl
import dpnp
import pytest
from numba.core.errors import TypingError

import numba_dpex as dpex
import numba_dpex.experimental as exp_dpex
Expand Down Expand Up @@ -33,6 +35,7 @@ def test_async_add():
host_ref, event_ref = exp_dpex.call_kernel_async(
add,
r,
(),
a,
b,
c,
Expand All @@ -50,6 +53,56 @@ def test_async_add():
assert dpnp.array_equal(c, d)


def test_async_dependent_add_list_exception():
size = 10

# TODO: should capture ValueError, but numba captures it and generates
# TypingError. ValueError is still readable there.
with pytest.raises(TypingError):
exp_dpex.call_kernel_async(
add,
Range(size),
[dpctl.SyclEvent()],
dpnp.ones(size),
dpnp.ones(size),
dpnp.ones(size),
)


def test_async_dependent_add():
size = 10
a = dpnp.ones(size)
b = dpnp.ones(size)
c = dpnp.zeros(size)

r = Range(size)

host_ref, event_ref = exp_dpex.call_kernel_async(
add,
r,
(),
a,
b,
c,
)

host2_ref, event2_ref = exp_dpex.call_kernel_async(
add,
r,
(event_ref,),
a,
c,
b,
)

event2_ref.wait()
d = dpnp.ones(size) * 3
assert dpnp.array_equal(b, d)

host_ref.wait()
host2_ref.wait()


def test_async_add_from_cache():
test_async_add() # compile
old_size = testing.kernel_cache_size()
Expand Down

0 comments on commit 34bc454

Please sign in to comment.