Skip to content

Commit

Permalink
Fix lifetime management for sycl event
Browse files Browse the repository at this point in the history
  • Loading branch information
ZzEeKkAa committed Oct 24, 2023
1 parent 067cbbb commit 8440510
Show file tree
Hide file tree
Showing 7 changed files with 47 additions and 16 deletions.
4 changes: 2 additions & 2 deletions numba_dpex/core/datamodel/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,8 @@ class SyclEventModel(StructModel):
def __init__(self, dmm, fe_type):
members = [
(
"parent",
types.CPointer(types.int8),
"meminfo",
types.MemInfoPointer(types.pyobject),
),
(
"event_ref",
Expand Down
21 changes: 16 additions & 5 deletions numba_dpex/core/runtime/_dpexrt_python.c
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
#include "_queuestruct.h"
#include "_usmarraystruct.h"

#include "numba/core/runtime/nrt_external.h"

// forward declarations
static struct PyUSMArrayObject *PyUSMNdArray_ARRAYOBJ(PyObject *obj);
static npy_intp product_of_shape(npy_intp *shape, npy_intp ndim);
Expand Down Expand Up @@ -64,9 +66,12 @@ DPEXRT_sycl_usm_ndarray_to_python_acqref(usmarystruct_t *arystruct,
PyArray_Descr *descr);
static int DPEXRT_sycl_queue_from_python(PyObject *obj,
queuestruct_t *queue_struct);
static int DPEXRT_sycl_event_from_python(PyObject *obj,
static int DPEXRT_sycl_event_from_python(NRT_api_functions *nrt,
PyObject *obj,
eventstruct_t *event_struct);
static PyObject *DPEXRT_sycl_queue_to_python(queuestruct_t *queuestruct);
static PyObject *DPEXRT_sycl_event_to_python(NRT_api_functions *nrt,
eventstruct_t *eventstruct);

/** An NRT_external_malloc_func implementation using DPCTLmalloc_device.
*
Expand Down Expand Up @@ -1306,7 +1311,8 @@ static PyObject *DPEXRT_sycl_queue_to_python(queuestruct_t *queuestruct)
* represent a dpctl.SyclEvent inside Numba.
* @return {return} Return code indicating success (0) or failure (-1).
*/
static int DPEXRT_sycl_event_from_python(PyObject *obj,
static int DPEXRT_sycl_event_from_python(NRT_api_functions *nrt,
PyObject *obj,
eventstruct_t *event_struct)
{
struct PySyclEventObject *event_obj = NULL;
Expand All @@ -1328,7 +1334,9 @@ static int DPEXRT_sycl_event_from_python(PyObject *obj,
goto error;
}

event_struct->parent = obj;
Py_INCREF(event_obj);
event_struct->meminfo =
nrt->manage_memory(event_obj, NRT_MemInfo_pyobject_dtor);
event_struct->event_ref = event_ref;

return 0;
Expand All @@ -1355,12 +1363,13 @@ static int DPEXRT_sycl_event_from_python(PyObject *obj,
* @return {return} A PyObject created from the eventstruct->parent, if
* the PyObject could not be created return NULL.
*/
static PyObject *DPEXRT_sycl_event_to_python(eventstruct_t *eventstruct)
static PyObject *DPEXRT_sycl_event_to_python(NRT_api_functions *nrt,
eventstruct_t *eventstruct)
{
PyObject *orig_event = NULL;
PyGILState_STATE gstate;

orig_event = eventstruct->parent;
orig_event = nrt->get_data(eventstruct->meminfo);
// FIXME: Better error checking is needed to enforce the boxing of the event
// object. For now, only the minimal is done as the returning of SyclEvent
// from a dpjit function should not be a used often and the dpctl C API for
Expand All @@ -1378,6 +1387,8 @@ static PyObject *DPEXRT_sycl_event_to_python(eventstruct_t *eventstruct)
// We need to increase reference count because we are returning new
// reference to the same event.
Py_INCREF(orig_event);
// We need to release meminfo since we are taking ownership back.
nrt->release(eventstruct->meminfo);

return orig_event;
}
Expand Down
4 changes: 2 additions & 2 deletions numba_dpex/core/runtime/_eventstruct.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@

#pragma once

#include <Python.h>
#include "numba/core/runtime/nrt_external.h"

typedef struct
{
PyObject *parent;
NRT_MemInfo *meminfo;
void *event_ref;
} eventstruct_t;
12 changes: 12 additions & 0 deletions numba_dpex/core/runtime/_nrt_helper.c
Original file line number Diff line number Diff line change
Expand Up @@ -134,3 +134,15 @@ void NRT_MemInfo_destroy(NRT_MemInfo *mi)
TheMSys.stats.mi_free++;
}
}

void NRT_MemInfo_pyobject_dtor(void *data)
{
PyGILState_STATE gstate;
PyObject *ownerobj = data;

gstate = PyGILState_Ensure(); /* ensure the GIL */
Py_DECREF(data); /* release the python object */
PyGILState_Release(gstate); /* release the GIL */

DPEXRT_DEBUG(drt_debug_print("DPEXRT-DEBUG: pyobject destructor\n"););
}
1 change: 1 addition & 0 deletions numba_dpex/core/runtime/_nrt_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,6 @@ size_t NRT_MemInfo_refcount(NRT_MemInfo *mi);
void NRT_Free(void *ptr);
void NRT_dealloc(NRT_MemInfo *mi);
void NRT_MemInfo_destroy(NRT_MemInfo *mi);
void NRT_MemInfo_pyobject_dtor(void *data);

#endif /* _NRT_HELPER_H_ */
15 changes: 11 additions & 4 deletions numba_dpex/core/runtime/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import functools

import numba.core.unsafe.nrt
from llvmlite import ir as llvmir
from numba.core import cgutils, types

Expand Down Expand Up @@ -206,26 +207,32 @@ def queuestruct_to_python(self, pyapi, val):
def eventstruct_from_python(self, pyapi, obj, ptr):
"""Calls the c function DPEXRT_sycl_event_from_python"""
fnty = llvmir.FunctionType(
llvmir.IntType(32), [pyapi.pyobj, pyapi.voidptr]
llvmir.IntType(32), [pyapi.voidptr, pyapi.pyobj, pyapi.voidptr]
)
nrt_api = self._context.nrt.get_nrt_api(pyapi.builder)

fn = pyapi._get_function(fnty, "DPEXRT_sycl_event_from_python")
fn.args[0].add_attribute("nocapture")
fn.args[1].add_attribute("nocapture")
fn.args[2].add_attribute("nocapture")

self.error = pyapi.builder.call(fn, (obj, ptr))
self.error = pyapi.builder.call(fn, (nrt_api, obj, ptr))
return self.error

def eventstruct_to_python(self, pyapi, val):
"""Calls the c function DPEXRT_sycl_event_to_python"""

fnty = llvmir.FunctionType(pyapi.pyobj, [pyapi.voidptr])
fnty = llvmir.FunctionType(pyapi.pyobj, [pyapi.voidptr, pyapi.voidptr])
nrt_api = self._context.nrt.get_nrt_api(pyapi.builder)

fn = pyapi._get_function(fnty, "DPEXRT_sycl_event_to_python")
fn.args[0].add_attribute("nocapture")
fn.args[1].add_attribute("nocapture")

qptr = cgutils.alloca_once_value(pyapi.builder, val)
ptr = pyapi.builder.bitcast(qptr, pyapi.voidptr)
self.error = pyapi.builder.call(fn, [ptr])

self.error = pyapi.builder.call(fn, [nrt_api, ptr])

return self.error

Expand Down
6 changes: 3 additions & 3 deletions numba_dpex/dpctl_iface/libsyclinterface_bindings.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def dpctl_event_wait(builder: llvmir.IRBuilder, *args):
mod = builder.module
fn = _build_dpctl_function(
llvm_module=mod,
return_ty=cgutils.voidptr_t,
return_ty=llvmir.VoidType(),
arg_list=[cgutils.voidptr_t],
func_name="DPCTLEvent_Wait",
)
Expand All @@ -85,7 +85,7 @@ def dpctl_event_delete(builder: llvmir.IRBuilder, *args):
mod = builder.module
fn = _build_dpctl_function(
llvm_module=mod,
return_ty=cgutils.voidptr_t,
return_ty=llvmir.VoidType(),
arg_list=[cgutils.voidptr_t],
func_name="DPCTLEvent_Delete",
)
Expand All @@ -99,7 +99,7 @@ def dpctl_queue_delete(builder: llvmir.IRBuilder, *args):
mod = builder.module
fn = _build_dpctl_function(
llvm_module=mod,
return_ty=cgutils.voidptr_t,
return_ty=llvmir.VoidType(),
arg_list=[cgutils.voidptr_t],
func_name="DPCTLQueue_Delete",
)
Expand Down

0 comments on commit 8440510

Please sign in to comment.