-
Notifications
You must be signed in to change notification settings - Fork 33
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
67 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
# SPDX-FileCopyrightText: 2020 - 2023 Intel Corporation | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from numba import types | ||
from numba.core import cgutils | ||
from numba.core.datamodel import default_manager | ||
from numba.extending import intrinsic, overload_method | ||
|
||
import numba_dpex.dpctl_iface.libsyclinterface_bindings as sycl | ||
from numba_dpex.core import types as dpex_types | ||
|
||
|
||
@intrinsic | ||
def sycl_event_wait(typingctx, ty_event): | ||
# check for accepted types | ||
if not isinstance(ty_event, dpex_types.DpctlSyclEvent): | ||
raise TypeError(f"Expected dpctl.SyclEvent, but got {ty_event}.") | ||
|
||
result_type = types.void | ||
sig = result_type(ty_event) | ||
|
||
# defines the custom code generation | ||
def codegen(context, builder, signature, args): | ||
event_struct_proxy = cgutils.create_struct_proxy(ty_event)( | ||
context, builder | ||
) | ||
event_struct_ptr = event_struct_proxy._getpointer() | ||
|
||
event_struct = builder.load(event_struct_ptr) | ||
sycl_event_dm = default_manager.lookup(ty_event) | ||
event_ref = builder.extract_value( | ||
event_struct, | ||
sycl_event_dm.get_field_position("event_ref"), | ||
) | ||
|
||
sycl.dpctl_event_wait(builder, event_ref) | ||
|
||
return sig, codegen | ||
|
||
|
||
@overload_method(dpex_types.DpctlSyclEvent, "wait") | ||
def ol_dpctl_sycl_event_wait( | ||
event, | ||
): | ||
"""Implementation of an overload to support dpctl.SyclEvent() inside | ||
a dpjit function. | ||
""" | ||
return lambda event: sycl_event_wait(event) |
16 changes: 16 additions & 0 deletions
16
numba_dpex/tests/core/types/DpctlSyclEvent/test_overloads.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
import dpctl | ||
|
||
from numba_dpex import dpjit | ||
|
||
|
||
@dpjit | ||
def wait_call(a): | ||
a.wait() | ||
return None | ||
|
||
|
||
def test_wait_DpctlSyclEvent(): | ||
"""Test the dpctl.SyclEvent.wait() call overload.""" | ||
|
||
e = dpctl.SyclEvent() | ||
wait_call(e) |