diff --git a/numba_dpex/core/datamodel/models.py b/numba_dpex/core/datamodel/models.py index 471c28d4c7..9149fd6307 100644 --- a/numba_dpex/core/datamodel/models.py +++ b/numba_dpex/core/datamodel/models.py @@ -145,8 +145,8 @@ class SyclQueueModel(StructModel): def __init__(self, dmm, fe_type): members = [ ( - "parent", - types.CPointer(types.int8), + "meminfo", + types.MemInfoPointer(types.pyobject), ), ( "queue_ref", diff --git a/numba_dpex/core/runtime/_dpexrt_python.c b/numba_dpex/core/runtime/_dpexrt_python.c index 1f7d2dda6a..a45012e340 100644 --- a/numba_dpex/core/runtime/_dpexrt_python.c +++ b/numba_dpex/core/runtime/_dpexrt_python.c @@ -64,12 +64,14 @@ DPEXRT_sycl_usm_ndarray_to_python_acqref(usmarystruct_t *arystruct, int ndim, int writeable, PyArray_Descr *descr); -static int DPEXRT_sycl_queue_from_python(PyObject *obj, +static int DPEXRT_sycl_queue_from_python(NRT_api_functions *nrt, + PyObject *obj, queuestruct_t *queue_struct); static int DPEXRT_sycl_event_from_python(NRT_api_functions *nrt, PyObject *obj, eventstruct_t *event_struct); -static PyObject *DPEXRT_sycl_queue_to_python(queuestruct_t *queuestruct); +static PyObject *DPEXRT_sycl_queue_to_python(NRT_api_functions *nrt, + queuestruct_t *queuestruct); static PyObject *DPEXRT_sycl_event_to_python(NRT_api_functions *nrt, eventstruct_t *eventstruct); @@ -1216,7 +1218,8 @@ DPEXRT_sycl_usm_ndarray_to_python_acqref(usmarystruct_t *arystruct, * represent a dpctl.SyclQueue inside Numba. * @return {return} Return code indicating success (0) or failure (-1). */ -static int DPEXRT_sycl_queue_from_python(PyObject *obj, +static int DPEXRT_sycl_queue_from_python(NRT_api_functions *nrt, + PyObject *obj, queuestruct_t *queue_struct) { struct PySyclQueueObject *queue_obj = NULL; @@ -1246,7 +1249,13 @@ static int DPEXRT_sycl_queue_from_python(PyObject *obj, DPCTLDeviceMgr_GetDeviceInfoStr(device_ref)); DPCTLDevice_Delete(device_ref);); - queue_struct->parent = obj; + // We are doing incref here to ensure python does not release the object + // while NRT references it. Coresponding decref is called by NRT in + // NRT_MemInfo_pyobject_dtor once there is no reference to this object by + // the code managed by NRT. + Py_INCREF(queue_obj); + queue_struct->meminfo = + nrt->manage_memory(queue_obj, NRT_MemInfo_pyobject_dtor); queue_struct->queue_ref = queue_ref; return 0; @@ -1275,11 +1284,12 @@ static int DPEXRT_sycl_queue_from_python(PyObject *obj, * @return {return} A PyObject created from the queuestruct->parent, if * the PyObject could not be created return NULL. */ -static PyObject *DPEXRT_sycl_queue_to_python(queuestruct_t *queuestruct) +static PyObject *DPEXRT_sycl_queue_to_python(NRT_api_functions *nrt, + queuestruct_t *queuestruct) { PyObject *orig_queue = NULL; - orig_queue = queuestruct->parent; + orig_queue = nrt->get_data(queuestruct->meminfo); // FIXME: Better error checking is needed to enforce the boxing of the queue // object. For now, only the minimal is done as the returning of SyclQueue // from a dpjit function should not be a used often and the dpctl C API for @@ -1291,9 +1301,13 @@ static PyObject *DPEXRT_sycl_queue_to_python(queuestruct_t *queuestruct) return NULL; } + // TODO: is there any way to release meminfo without calling dtor so we dont + // call incref, decref one after another. // We need to increase reference count because we are returning new // reference to the same queue. Py_INCREF(orig_queue); + // We need to release meminfo since we are taking ownership back. + nrt->release(queuestruct->meminfo); return orig_queue; } diff --git a/numba_dpex/core/runtime/_queuestruct.h b/numba_dpex/core/runtime/_queuestruct.h index ebb088d65a..2488494384 100644 --- a/numba_dpex/core/runtime/_queuestruct.h +++ b/numba_dpex/core/runtime/_queuestruct.h @@ -11,10 +11,10 @@ #pragma once -#include +#include "numba/core/runtime/nrt_external.h" typedef struct { - PyObject *parent; + NRT_MemInfo *meminfo; void *queue_ref; } queuestruct_t; diff --git a/numba_dpex/core/runtime/context.py b/numba_dpex/core/runtime/context.py index d67cb9c04b..5ceb4da5ad 100644 --- a/numba_dpex/core/runtime/context.py +++ b/numba_dpex/core/runtime/context.py @@ -181,26 +181,32 @@ def arraystruct_from_python(self, pyapi, obj, ptr): def queuestruct_from_python(self, pyapi, obj, ptr): """Calls the c function DPEXRT_sycl_queue_from_python""" fnty = llvmir.FunctionType( - llvmir.IntType(32), [pyapi.pyobj, pyapi.voidptr] + llvmir.IntType(32), [pyapi.voidptr, pyapi.pyobj, pyapi.voidptr] ) + nrt_api = self._context.nrt.get_nrt_api(pyapi.builder) fn = pyapi._get_function(fnty, "DPEXRT_sycl_queue_from_python") fn.args[0].add_attribute("nocapture") fn.args[1].add_attribute("nocapture") + fn.args[2].add_attribute("nocapture") - self.error = pyapi.builder.call(fn, (obj, ptr)) + self.error = pyapi.builder.call(fn, (nrt_api, obj, ptr)) return self.error def queuestruct_to_python(self, pyapi, val): """Calls the c function DPEXRT_sycl_queue_to_python""" - fnty = llvmir.FunctionType(pyapi.pyobj, [pyapi.voidptr]) + fnty = llvmir.FunctionType(pyapi.pyobj, [pyapi.voidptr, pyapi.voidptr]) + nrt_api = self._context.nrt.get_nrt_api(pyapi.builder) fn = pyapi._get_function(fnty, "DPEXRT_sycl_queue_to_python") fn.args[0].add_attribute("nocapture") + fn.args[1].add_attribute("nocapture") + qptr = cgutils.alloca_once_value(pyapi.builder, val) ptr = pyapi.builder.bitcast(qptr, pyapi.voidptr) - self.error = pyapi.builder.call(fn, [ptr]) + + self.error = pyapi.builder.call(fn, [nrt_api, ptr]) return self.error