Skip to content

Commit

Permalink
WIP for DpctlSyclQueue support
Browse files Browse the repository at this point in the history
Adding proper typing for dpctl.SyclQueue

Revert

Tried to follow the interval example from numba

Redo the implementation by following how the ArrayModel was done

Fix arg names

keep driver.py

Added unboxing function for SyclQueueType

Adding minimal test

Testing with different pattern
  • Loading branch information
chudur-budur authored and khaled committed Mar 24, 2023
1 parent 9d3d507 commit 5b176d0
Show file tree
Hide file tree
Showing 9 changed files with 249 additions and 20 deletions.
13 changes: 13 additions & 0 deletions driver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from dpctl import SyclQueue
from numba import njit

from numba_dpex import dpjit

if __name__ == "__main__":

@dpjit
def test(q):
pass

queue = SyclQueue()
test(queue)
22 changes: 19 additions & 3 deletions numba_dpex/core/datamodel/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from numba_dpex.utils import address_space

from ..types import Array, DpctlSyclQueue, DpnpNdArray, USMNdArray
from ..types import Array, DpnpNdArray, SyclQueueType, USMNdArray


class GenericPointerModel(PrimitiveModel):
Expand Down Expand Up @@ -54,6 +54,18 @@ def __init__(self, dmm, fe_type):
super(ArrayModel, self).__init__(dmm, fe_type, members)


class SyclQueueModel(StructModel):
def __init__(self, dmm, fe_type):
members = [
("parent", types.CPointer),
("queue_ref", types.PyObject),
("context", types.PyObject),
("device", types.PyObject),
]
# super(StructModel, self).__init__(dmm, fe_type, members)
StructModel.__init__(self, dmm, fe_type, members)


def _init_data_model_manager():
dmm = datamodel.default_manager.copy()
dmm.register(types.CPointer, GenericPointerModel)
Expand Down Expand Up @@ -84,5 +96,9 @@ def _init_data_model_manager():
dpex_data_model_manager.register(DpnpNdArray, DpnpNdArrayModel)

# Register the DpctlSyclQueue type with Numba's OpaqueModel
register_model(DpctlSyclQueue)(OpaqueModel)
dpex_data_model_manager.register(DpctlSyclQueue, OpaqueModel)
# register_model(DpctlSyclQueue)(OpaqueModel)
# dpex_data_model_manager.register(DpctlSyclQueue, OpaqueModel)

# Register the DpctlSyclQueue type with Numba's OpaqueModel
register_model(SyclQueueType)(SyclQueueModel)
dpex_data_model_manager.register(SyclQueueType, SyclQueueModel)
77 changes: 77 additions & 0 deletions numba_dpex/core/runtime/_dpexrt_python.c
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "_nrt_helper.h"
#include "_nrt_python_helper.h"

#include "_queuestruct.h"
#include "numba/_arraystruct.h"

/* Debugging facilities - enabled at compile-time */
Expand Down Expand Up @@ -66,6 +67,9 @@ DPEXRT_sycl_usm_ndarray_to_python_acqref(arystruct_t *arystruct,
int ndim,
int writeable,
PyArray_Descr *descr);
static struct PySyclQueueObject *to_py_syclqobject(PyObject *obj);
static int DPEXRT_sycl_queue_from_python(PyObject *obj,
queuestruct_t *queue_struct);

/*
* Debugging printf function used internally
Expand Down Expand Up @@ -645,6 +649,21 @@ static struct PyUSMArrayObject *PyUSMNdArray_ARRAYOBJ(PyObject *obj)
return pyusmarrayobj;
}

static struct PySyclQueueObject *to_py_syclqobject(PyObject *obj)
{
if (!obj)
return NULL;
if (!PyObject_TypeCheck(obj, &PySyclQueueType))
return NULL;

struct PySyclQueueObject *pysyclqobj = (struct PySyclQueueObject *)(obj);
// struct Py_SyclQueueObject py_syclqobj = pysyclqobj->__pyx_base;

// return &py_syclqobj;

return pysyclqobj;
}

/*!
* @brief Returns the product of the elements in an array of a given
* length.
Expand Down Expand Up @@ -785,6 +804,62 @@ static int DPEXRT_sycl_usm_ndarray_from_python(PyObject *obj,
return -1;
}

static int DPEXRT_sycl_queue_from_python(PyObject *obj,
queuestruct_t *queue_struct)
{

struct PySyclQueueObject *queue_obj = NULL;
// DPCTLSyclQueueRef queue_ref = NULL;
PyGILState_STATE gstate;

// Increment the ref count on obj to prevent CPython from garbage
// collecting the array.
Py_IncRef(obj);

DPEXRT_DEBUG(
nrt_debug_print("DPEXRT-DEBUG: In DPEXRT_sycl_queue_from_python.\n"));

// Check if the PyObject obj has an _array_obj attribute that is of
// dpctl.tensor.usm_ndarray type.
if (!(queue_obj = to_py_syclqobject(obj))) {
DPEXRT_DEBUG(nrt_debug_print(
"DPEXRT-ERROR: to_py_syclqobject() check failed %d\n", __FILE__,
__LINE__));
goto error;
}

// if (!(queue_ref = SyclQueue_GetQueueRef(queue_obj))) {
// DPEXRT_DEBUG(nrt_debug_print(
// "DPEXRT-ERROR: SyclQueue_GetQueueRef returned NULL at "
// "%s, line %d.\n",
// __FILE__, __LINE__));
// goto error;
// }

queue_struct->parent = obj;
// queue_struct->queue_ref = queue_ref;
queue_struct->queue_ref = (PyObject *)queue_obj->__pyx_base._queue_ref;
queue_struct->cotext = (PyObject *)queue_obj->__pyx_base._context;
queue_struct->device = (PyObject *)queue_obj->__pyx_base._device;

error:
// If the check failed then decrement the refcount and return an error
// code of -1.
// Decref the Pyobject of the array
// ensure the GIL
DPEXRT_DEBUG(nrt_debug_print(
"DPEXRT-ERROR: Failed to unbox dpctl SyclQueue into a Numba "
"queuestruct at %s, line %d\n",
__FILE__, __LINE__));
gstate = PyGILState_Ensure();
// decref the python object
Py_DECREF(obj);
// release the GIL
PyGILState_Release(gstate);

return -1;
}

/*!
* @brief A helper function that boxes a Numba arystruct_t object into a
* dpnp.ndarray PyObject using the arystruct_t's parent attribute.
Expand Down Expand Up @@ -1082,6 +1157,8 @@ static PyObject *build_c_helpers_dict(void)
_declpointer("DPEXRT_MemInfo_fill", &DPEXRT_MemInfo_fill);
_declpointer("NRT_ExternalAllocator_new_for_usm",
&NRT_ExternalAllocator_new_for_usm);
_declpointer("DPEXRT_sycl_queue_from_python",
&DPEXRT_sycl_queue_from_python);

#undef _declpointer
return dct;
Expand Down
20 changes: 20 additions & 0 deletions numba_dpex/core/runtime/_queuestruct.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#ifndef NUMBA_DPEX_QUEUESTRUCT_H_
#define NUMBA_DPEX_QUEUESTRUCT_H_
/*
* Fill in the *queuestruct* with information from the Numpy array *obj*.
* *queuestruct*'s layout is defined in numba.targets.arrayobj (look
* for the ArrayTemplate class).
*/

#include "numpy/npy_common.h"
#include <Python.h>

typedef struct
{
PyObject *parent;
PyObject *queue_ref;
PyObject *cotext;
PyObject *device;
} queuestruct_t;

#endif /* NUMBA_DPEX_QUEUESTRUCT_H_ */
15 changes: 15 additions & 0 deletions numba_dpex/core/runtime/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,21 @@ def arraystruct_from_python(self, pyapi, obj, ptr):

return self.error

def queuestruct_from_python(self, pyapi, obj, ptr):
# call the c function DPEXRT_sycl_queue_from_python

fnty = llvmir.FunctionType(
llvmir.IntType(32), [pyapi.pyobj, pyapi.voidptr]
)

fn = pyapi._get_function(fnty, "DPEXRT_sycl_queue_from_python")
fn.args[0].add_attribute("nocapture")
fn.args[1].add_attribute("nocapture")

self.error = pyapi.builder.call(fn, (obj, ptr))

return self.error

def usm_ndarray_to_python_acqref(self, pyapi, aryty, ary, dtypeptr):
"""Boxes a DpnpNdArray native object into a Python dpnp.ndarray.
Expand Down
4 changes: 2 additions & 2 deletions numba_dpex/core/types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# SPDX-License-Identifier: Apache-2.0

from .array_type import Array
from .dpctl_types import DpctlSyclQueue
from .dpctl_types import SyclQueueType
from .dpnp_ndarray_type import DpnpNdArray
from .numba_types_short_names import (
b1,
Expand Down Expand Up @@ -32,7 +32,7 @@

__all__ = [
"Array",
"DpctlSyclQueue",
"SyclQueueType",
"DpnpNdArray",
"USMNdArray",
"none",
Expand Down
109 changes: 97 additions & 12 deletions numba_dpex/core/types/dpctl_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,21 @@

from dpctl import SyclQueue
from numba import types
from numba.extending import NativeValue, box, type_callable, unbox
from numba.core import cgutils
from numba.extending import (
NativeValue,
as_numba_type,
box,
type_callable,
typeof_impl,
unbox,
)

from numba_dpex.core.exceptions import UnreachableError
from numba_dpex.core.runtime import context as dpxrtc

class DpctlSyclQueue(types.Type):

class SyclQueueType(types.Type):
"""A Numba type to represent a dpctl.SyclQueue PyObject.
For now, a dpctl.SyclQueue is represented as a Numba opaque type that allows
Expand All @@ -16,25 +27,99 @@ class DpctlSyclQueue(types.Type):
"""

def __init__(self):
super().__init__(name="DpctlSyclQueueType")
super(SyclQueueType, self).__init__(name="SyclQueue")


# sycl_queue_type = SyclQueueType()


sycl_queue_ty = DpctlSyclQueue()
# @typeof_impl.register(SyclQueue)
# def typeof_index(val, c):
# return sycl_queue_type


# as_numba_type.register(SyclQueue, sycl_queue_type)


@type_callable(SyclQueue)
def type_interval(context):
def typer():
return sycl_queue_ty
def type_sycl_queue(context):
def typer(args):
if isinstance(args, types.Tuple):
if len(args) > 0:
if (
isinstance(args[0], types.PyObject)
and isinstance(args[1], types.StringLiteral)
and isinstance(args[2], types.PyObject)
):
return SyclQueueType()
else:
return SyclQueueType()
elif isinstance(args, types.NoneType):
return SyclQueueType()
else:
raise ValueError("Couldn't do type inference for 'SycleQueue'.")

return typer


@unbox(DpctlSyclQueue)
# @lower_builtin(SyclQueue, types.PyObject, types.StringLiteral, types.PyObject)
# def impl_interval(context, builder, sig, args):
# typ = sig.return_type
# if len(args) > 0:
# ctx, dev, property = args
# sycl_queue = cgutils.create_struct_proxy(typ)(context, builder)
# sycl_queue.ctx = ctx
# sycl_queue.dev = dev
# sycl_queue.property = property
# else:
# sycl_queue = cgutils.create_struct_proxy(typ)(context, builder)
# return sycl_queue._getvalue()


@unbox(SyclQueueType)
def unbox_sycl_queue(typ, obj, c):
return NativeValue(obj)
"""
Convert a SyclQueue object to a native structure.
"""
qstruct = cgutils.create_struct_proxy(typ)(c.context, c.builder)
qptr = qstruct._getpointer()
ptr = c.builder.bitcast(qptr, c.pyapi.voidptr)
if c.context.enable_nrt:
dpexrtCtx = dpxrtc.DpexRTContext(c.context)
errcode = dpexrtCtx.queuestruct_from_python(c.pyapi, obj, ptr)
else:
raise UnreachableError

is_error = cgutils.is_not_null(c.builder, errcode)
# Handle error
with c.builder.if_then(is_error, likely=False):
c.pyapi.err_set_string(
"PyExc_TypeError",
"can't unbox array from PyObject into "
"native value. The object maybe of a "
"different type",
)

return NativeValue(c.builder.load(qptr), is_error=is_error)


@box(DpctlSyclQueue)
def box_pyobject(typ, val, c):
return val
# @box(SyclQueueType)
# def box_sycl_queue_(typ, val, c):
# """
# Convert a native interval structure to an Interval object.
# """
# sycl_queue = cgutils.create_struct_proxy(typ)(
# c.context, c.builder, value=val
# )
# ctx_obj = sycl_queue.ctx
# dev_obj = sycl_queue.dev
# property_obj = sycl_queue.property
# class_obj = c.pyapi.unserialize(c.pyapi.serialize_object(SyclQueue))
# res = c.pyapi.call_function_objargs(
# class_obj, (ctx_obj, dev_obj, property_obj)
# )
# c.pyapi.decref(ctx_obj)
# c.pyapi.decref(dev_obj)
# c.pyapi.decref(property_obj)
# c.pyapi.decref(class_obj)
# return res
3 changes: 2 additions & 1 deletion numba_dpex/core/types/dpnp_ndarray_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,8 @@ def unbox_dpnp_nd_array(typ, obj, c):
# potential memory corruption
#
# --------------- End of Numba comment from @ubox(types.Array)
nativearycls = c.context.make_array(typ)

nativearycls = c.context.make_array(typ) # make_array is in numba.core.base
nativeary = nativearycls(c.context, c.builder)
aryptr = nativeary._getpointer()

Expand Down
6 changes: 4 additions & 2 deletions numba_dpex/core/typing/typeof.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from numba_dpex.utils import address_space

from ..types.dpctl_types import sycl_queue_ty
from ..types.dpctl_types import SyclQueueType
from ..types.dpnp_ndarray_type import DpnpNdArray
from ..types.usm_ndarray_type import USMNdArray

Expand Down Expand Up @@ -107,4 +107,6 @@ def typeof_dpctl_sycl_queue(val, c):
Returns: A numba_dpex.core.types.dpctl_types.DpctlSyclQueue instance.
"""
return sycl_queue_ty
# return sycl_queue_type
# return _typeof_helper(val, SyclQueueType)
return SyclQueueType()

0 comments on commit 5b176d0

Please sign in to comment.