Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changes to the DpctlSyclQueue and USMNdArray types. #1064

Merged
merged 3 commits into from
Jun 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion numba_dpex/core/kernel_interface/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,9 +409,14 @@ def __call__(self, *args):
# FIXME: For specialized and ahead of time compiled and cached kernels,
# the CFD check was already done statically. The run-time check is
# redundant. We should avoid these checks for the specialized case.
exec_queue = determine_kernel_launch_queue(
ty_queue = determine_kernel_launch_queue(
args, argtypes, self.kernel_name
)

# FIXME: We need a better way than having to create a queue every time.
device = ty_queue.sycl_device
exec_queue = dpctl.get_device_cached_queue(device)

backend = exec_queue.backend

if exec_queue.backend not in [
Expand Down
6 changes: 5 additions & 1 deletion numba_dpex/core/parfors/kernel_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import sys
import warnings

import dpctl
import dpctl.program as dpctl_prog
from numba.core import ir, types
from numba.core.errors import NumbaParallelSafetyWarning
Expand Down Expand Up @@ -426,7 +427,10 @@ def create_kernel_for_parfor(
for arg in parfor_args:
obj = typemap[arg]
if isinstance(obj, DpnpNdArray):
exec_queue = obj.queue
filter_string = obj.queue.sycl_device
# FIXME: A better design is required so that we do not have to
# create a queue every time.
exec_queue = dpctl.get_device_cached_queue(filter_string)

if not exec_queue:
raise AssertionError(
Expand Down
22 changes: 16 additions & 6 deletions numba_dpex/core/parfors/reduction_kernel_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import warnings

import dpctl
from numba.core import types
from numba.core.errors import NumbaParallelSafetyWarning
from numba.core.ir_utils import (
Expand All @@ -18,6 +19,8 @@
)
from numba.core.typing import signature

from numba_dpex.core.types import DpctlSyclQueue

from ..utils.kernel_templates.reduction_template import (
RemainderReduceIntermediateKernelTemplate,
TreeReduceIntermediateKernelTemplate,
Expand Down Expand Up @@ -134,7 +137,13 @@ def create_reduction_main_kernel_for_parfor(
flags.noalias = True

kernel_sig = signature(types.none, *kernel_param_types)
exec_queue = typemap[reductionKernelVar.parfor_params[0]].queue

# FIXME: A better design is required so that we do not have to create a
# queue every time.
ty_queue: DpctlSyclQueue = typemap[
reductionKernelVar.parfor_params[0]
].queue
exec_queue = dpctl.get_device_cached_queue(ty_queue.sycl_device)

sycl_kernel = _compile_kernel_parfor(
exec_queue,
Expand Down Expand Up @@ -331,11 +340,12 @@ def create_reduction_remainder_kernel_for_parfor(

kernel_sig = signature(types.none, *kernel_param_types)

# FIXME: Enable check after CFD pass has been added
# exec_queue = determine_kernel_launch_queue(
# args=parfor_args, argtypes=kernel_param_types, kernel_name=kernel_name
# )
exec_queue = typemap[reductionKernelVar.parfor_params[0]].queue
# FIXME: A better design is required so that we do not have to create a
# queue every time.
ty_queue: DpctlSyclQueue = typemap[
reductionKernelVar.parfor_params[0]
].queue
exec_queue = dpctl.get_device_cached_queue(ty_queue.sycl_device)

sycl_kernel = _compile_kernel_parfor(
exec_queue,
Expand Down
2 changes: 1 addition & 1 deletion numba_dpex/core/runtime/_dpexrt_python.c
Original file line number Diff line number Diff line change
Expand Up @@ -1180,7 +1180,7 @@ static int DPEXRT_sycl_queue_from_python(PyObject *obj,
PyGILState_STATE gstate;

// Increment the ref count on obj to prevent CPython from garbage
// collecting the array.
// collecting the dpctl.SyclQueue object
Py_IncRef(obj);

// We are unconditionally casting obj to a struct PySyclQueueObject*. If
Expand Down
5 changes: 2 additions & 3 deletions numba_dpex/core/typeconv/array_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@

from numba.np import numpy_support

from numba_dpex.core.types import USMNdArray
from numba_dpex.core.utils import get_info_from_suai
from numba_dpex.core.types import DpctlSyclQueue, USMNdArray
from numba_dpex.utils.constants import address_space


Expand Down Expand Up @@ -37,7 +36,7 @@ def to_usm_ndarray(suai_attrs, addrspace=address_space.GLOBAL):
ndim=suai_attrs.dimensions,
layout=layout,
usm_type=suai_attrs.usm_type,
queue=suai_attrs.queue,
queue=DpctlSyclQueue(suai_attrs.queue),
readonly=not suai_attrs.is_writable,
name=None,
aligned=True,
Expand Down
31 changes: 16 additions & 15 deletions numba_dpex/core/types/dpctl_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,18 @@


class DpctlSyclQueue(types.Type):
"""A Numba type to represent a dpctl.SyclQueue PyObject.

For now, a dpctl.SyclQueue is represented as a Numba opaque type that allows
passing in and using a SyclQueue object as an opaque pointer type inside
Numba.
"""
"""A Numba type to represent a dpctl.SyclQueue PyObject."""

def __init__(self, sycl_queue):
if not isinstance(sycl_queue, SyclQueue):
raise TypeError("The argument sycl_queue is not of type SyclQueue.")

self._sycl_queue = sycl_queue
# XXX: Storing the device filter string is a temporary workaround till
# the compute follows data inference pass is fixed to use SyclQueue
self._device = sycl_queue.sycl_device.filter_string

try:
self._unique_id = hash(self._sycl_queue)
self._unique_id = hash(sycl_queue)
except Exception:
self._unique_id = self.rand_digit_str(16)
super(DpctlSyclQueue, self).__init__(name="DpctlSyclQueue")
Expand All @@ -38,8 +36,14 @@ def rand_digit_str(self, n):
)

@property
def sycl_queue(self):
return self._sycl_queue
def sycl_device(self):
"""Returns the SYCL oneAPI extension filter string associated with the
queue.

Returns:
str: A SYCL oneAPI extension filter string
"""
return self._device

@property
def key(self):
Expand Down Expand Up @@ -69,11 +73,8 @@ def unbox_sycl_queue(typ, obj, c):
qptr = qstruct._getpointer()
ptr = c.builder.bitcast(qptr, c.pyapi.voidptr)

if c.context.enable_nrt:
dpexrtCtx = dpexrt.DpexRTContext(c.context)
errcode = dpexrtCtx.queuestruct_from_python(c.pyapi, obj, ptr)
else:
raise UnreachableError
dpexrtCtx = dpexrt.DpexRTContext(c.context)
errcode = dpexrtCtx.queuestruct_from_python(c.pyapi, obj, ptr)
is_error = cgutils.is_not_null(c.builder, errcode)

# Handle error
Expand Down
46 changes: 26 additions & 20 deletions numba_dpex/core/types/usm_ndarray_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from numba.core.types.npytypes import Array
from numba.np.numpy_support import from_dtype

from numba_dpex.core.types.dpctl_types import DpctlSyclQueue
from numba_dpex.utils import address_space


Expand All @@ -31,22 +32,28 @@ def __init__(
aligned=True,
addrspace=address_space.GLOBAL,
):
if queue and not isinstance(queue, types.misc.Omitted) and device:
if (
queue is not None
and not (
isinstance(queue, types.misc.Omitted)
or isinstance(queue, types.misc.NoneType)
)
and device is not None
):
raise TypeError(
"numba_dpex.core.types.usm_ndarray_type.USMNdArray.__init__(): "
"`device` and `sycl_queue` are exclusive keywords, i.e. use one or other."
"`device` and `sycl_queue` are exclusive keywords, "
"i.e. use one or other."
)

self.usm_type = usm_type
self.addrspace = addrspace

if queue and not isinstance(queue, types.misc.Omitted):
if not isinstance(queue, dpctl.SyclQueue):
if queue is not None and not (
isinstance(queue, types.misc.Omitted)
or isinstance(queue, types.misc.NoneType)
):
if not isinstance(queue, DpctlSyclQueue):
raise TypeError(
"numba_dpex.core.types.usm_ndarray_type.USMNdArray.__init__(): "
"The queue keyword arg should be a dpctl.SyclQueue object or None."
"Found type(queue) ="
+ str(type(queue) + " and queue =" + queue)
"The queue keyword arg should be either DpctlSyclQueue or "
"NoneType. Found type(queue) = " + str(type(queue))
)
self.queue = queue
else:
Expand All @@ -55,24 +62,23 @@ def __init__(
else:
if not isinstance(device, str):
raise TypeError(
"numba_dpex.core.types.usm_ndarray_type.USMNdArray.__init__(): "
"The device keyword arg should be a str object specifying "
"a SYCL filter selector."
"The device keyword arg should be a str object "
"specifying a SYCL filter selector."
)
sycl_device = dpctl.SyclDevice(device)

self.queue = dpctl._sycl_queue_manager.get_device_cached_queue(
sycl_queue = dpctl._sycl_queue_manager.get_device_cached_queue(
sycl_device
)
self.queue = DpctlSyclQueue(sycl_queue=sycl_queue)

self.device = self.queue.sycl_device.filter_string
self.device = self.queue.sycl_device
self.usm_type = usm_type
self.addrspace = addrspace

if not dtype:
dummy_tensor = dpctl.tensor.empty(
1,
order=layout,
usm_type=usm_type,
sycl_queue=self.queue,
1, order=layout, usm_type=usm_type, device=self.device
)
# convert dpnp type to numba/numpy type
_dtype = dummy_tensor.dtype
Expand Down
7 changes: 5 additions & 2 deletions numba_dpex/core/typing/typeof.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,18 @@ def _typeof_helper(val, array_class_type):
"The usm_type for the usm_ndarray could not be inferred"
)

assert val.sycl_queue is not None
if not val.sycl_queue:
raise AssertionError

ty_queue = DpctlSyclQueue(sycl_queue=val.sycl_queue)

return array_class_type(
dtype=dtype,
ndim=val.ndim,
layout=layout,
readonly=readonly,
usm_type=usm_type,
queue=val.sycl_queue,
queue=ty_queue,
addrspace=address_space.GLOBAL,
)

Expand Down
Loading