Skip to content

Commit

Permalink
Fix lifetime management for sycl queue
Browse files Browse the repository at this point in the history
  • Loading branch information
ZzEeKkAa committed Oct 25, 2023
1 parent af314f5 commit 8df9ea1
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 14 deletions.
4 changes: 2 additions & 2 deletions numba_dpex/core/datamodel/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,8 @@ class SyclQueueModel(StructModel):
def __init__(self, dmm, fe_type):
members = [
(
"parent",
types.CPointer(types.int8),
"meminfo",
types.MemInfoPointer(types.pyobject),
),
(
"queue_ref",
Expand Down
26 changes: 20 additions & 6 deletions numba_dpex/core/runtime/_dpexrt_python.c
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,14 @@ DPEXRT_sycl_usm_ndarray_to_python_acqref(usmarystruct_t *arystruct,
int ndim,
int writeable,
PyArray_Descr *descr);
static int DPEXRT_sycl_queue_from_python(PyObject *obj,
static int DPEXRT_sycl_queue_from_python(NRT_api_functions *nrt,
PyObject *obj,
queuestruct_t *queue_struct);
static int DPEXRT_sycl_event_from_python(NRT_api_functions *nrt,
PyObject *obj,
eventstruct_t *event_struct);
static PyObject *DPEXRT_sycl_queue_to_python(queuestruct_t *queuestruct);
static PyObject *DPEXRT_sycl_queue_to_python(NRT_api_functions *nrt,
queuestruct_t *queuestruct);
static PyObject *DPEXRT_sycl_event_to_python(NRT_api_functions *nrt,
eventstruct_t *eventstruct);

Expand Down Expand Up @@ -1216,7 +1218,8 @@ DPEXRT_sycl_usm_ndarray_to_python_acqref(usmarystruct_t *arystruct,
* represent a dpctl.SyclQueue inside Numba.
* @return {return} Return code indicating success (0) or failure (-1).
*/
static int DPEXRT_sycl_queue_from_python(PyObject *obj,
static int DPEXRT_sycl_queue_from_python(NRT_api_functions *nrt,
PyObject *obj,
queuestruct_t *queue_struct)
{
struct PySyclQueueObject *queue_obj = NULL;
Expand Down Expand Up @@ -1246,7 +1249,13 @@ static int DPEXRT_sycl_queue_from_python(PyObject *obj,
DPCTLDeviceMgr_GetDeviceInfoStr(device_ref));
DPCTLDevice_Delete(device_ref););

queue_struct->parent = obj;
// We are doing incref here to ensure python does not release the object
// while NRT references it. Coresponding decref is called by NRT in
// NRT_MemInfo_pyobject_dtor once there is no reference to this object by
// the code managed by NRT.
Py_INCREF(queue_obj);
queue_struct->meminfo =
nrt->manage_memory(queue_obj, NRT_MemInfo_pyobject_dtor);
queue_struct->queue_ref = queue_ref;

return 0;
Expand Down Expand Up @@ -1275,11 +1284,12 @@ static int DPEXRT_sycl_queue_from_python(PyObject *obj,
* @return {return} A PyObject created from the queuestruct->parent, if
* the PyObject could not be created return NULL.
*/
static PyObject *DPEXRT_sycl_queue_to_python(queuestruct_t *queuestruct)
static PyObject *DPEXRT_sycl_queue_to_python(NRT_api_functions *nrt,
queuestruct_t *queuestruct)
{
PyObject *orig_queue = NULL;

orig_queue = queuestruct->parent;
orig_queue = nrt->get_data(queuestruct->meminfo);
// FIXME: Better error checking is needed to enforce the boxing of the queue
// object. For now, only the minimal is done as the returning of SyclQueue
// from a dpjit function should not be a used often and the dpctl C API for
Expand All @@ -1291,9 +1301,13 @@ static PyObject *DPEXRT_sycl_queue_to_python(queuestruct_t *queuestruct)
return NULL;
}

// TODO: is there any way to release meminfo without calling dtor so we dont
// call incref, decref one after another.
// We need to increase reference count because we are returning new
// reference to the same queue.
Py_INCREF(orig_queue);
// We need to release meminfo since we are taking ownership back.
nrt->release(queuestruct->meminfo);

return orig_queue;
}
Expand Down
4 changes: 2 additions & 2 deletions numba_dpex/core/runtime/_queuestruct.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@

#pragma once

#include <Python.h>
#include "numba/core/runtime/nrt_external.h"

typedef struct
{
PyObject *parent;
NRT_MemInfo *meminfo;
void *queue_ref;
} queuestruct_t;
14 changes: 10 additions & 4 deletions numba_dpex/core/runtime/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,26 +181,32 @@ def arraystruct_from_python(self, pyapi, obj, ptr):
def queuestruct_from_python(self, pyapi, obj, ptr):
"""Calls the c function DPEXRT_sycl_queue_from_python"""
fnty = llvmir.FunctionType(
llvmir.IntType(32), [pyapi.pyobj, pyapi.voidptr]
llvmir.IntType(32), [pyapi.voidptr, pyapi.pyobj, pyapi.voidptr]
)
nrt_api = self._context.nrt.get_nrt_api(pyapi.builder)

fn = pyapi._get_function(fnty, "DPEXRT_sycl_queue_from_python")
fn.args[0].add_attribute("nocapture")
fn.args[1].add_attribute("nocapture")
fn.args[2].add_attribute("nocapture")

self.error = pyapi.builder.call(fn, (obj, ptr))
self.error = pyapi.builder.call(fn, (nrt_api, obj, ptr))
return self.error

def queuestruct_to_python(self, pyapi, val):
"""Calls the c function DPEXRT_sycl_queue_to_python"""

fnty = llvmir.FunctionType(pyapi.pyobj, [pyapi.voidptr])
fnty = llvmir.FunctionType(pyapi.pyobj, [pyapi.voidptr, pyapi.voidptr])
nrt_api = self._context.nrt.get_nrt_api(pyapi.builder)

fn = pyapi._get_function(fnty, "DPEXRT_sycl_queue_to_python")
fn.args[0].add_attribute("nocapture")
fn.args[1].add_attribute("nocapture")

qptr = cgutils.alloca_once_value(pyapi.builder, val)
ptr = pyapi.builder.bitcast(qptr, pyapi.voidptr)
self.error = pyapi.builder.call(fn, [ptr])

self.error = pyapi.builder.call(fn, [nrt_api, ptr])

return self.error

Expand Down

0 comments on commit 8df9ea1

Please sign in to comment.