Skip to content

Commit

Permalink
Fix in usm_ndarray_types.py
Browse files Browse the repository at this point in the history
  • Loading branch information
khaled committed Apr 21, 2023
1 parent e7ea064 commit 05f14f7
Showing 1 changed file with 22 additions and 16 deletions.
38 changes: 22 additions & 16 deletions numba_dpex/core/types/usm_ndarray_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from numba.core.types.npytypes import Array
from numba.np.numpy_support import from_dtype

from numba_dpex.core.types.dpctl_types import DpctlSyclQueue
from numba_dpex.utils import address_space


Expand All @@ -31,15 +32,10 @@ def __init__(
aligned=True,
addrspace=address_space.GLOBAL,
):
if not isinstance(device, str):
if queue is not None and device != "unknown":
raise TypeError(
"The device keyword arg should be a str object specifying "
"a SYCL filter selector"
)

if not isinstance(queue, dpctl.SyclQueue) and queue is not None:
raise TypeError(
"The queue keyword arg should be a dpctl.SyclQueue object or None"
"numba_dpex.core.types.usm_ndarray_type.USMNdArray.__init__(): "
"`device` and `sycl_queue` are exclusive keywords, i.e. use one or other."
)

self.usm_type = usm_type
Expand All @@ -48,18 +44,28 @@ def __init__(
if device == "unknown":
device = None

if queue is not None and device is not None:
raise TypeError(
"'queue' and 'device' keywords can not be both specified"
)

if queue is not None:
if not isinstance(queue, dpctl.SyclQueue):
raise TypeError(
"numba_dpex.core.types.usm_ndarray_type.USMNdArray.__init__(): "
"The queue keyword arg should be a dpctl.SyclQueue object or None."
)
self.queue = queue
else:
if device is None:
device = dpctl.SyclDevice()

self.queue = dpctl.get_device_cached_queue(device)
sycl_device = dpctl.SyclDevice()
else:
if not isinstance(device, str):
raise TypeError(
"numba_dpex.core.types.usm_ndarray_type.USMNdArray.__init__(): "
"The device keyword arg should be a str object specifying "
"a SYCL filter selector."
)
sycl_device = dpctl.SyclDevice(device)

self.queue = dpctl._sycl_queue_manager.get_device_cached_queue(
sycl_device
)

self.device = self.queue.sycl_device.filter_string

Expand Down

0 comments on commit 05f14f7

Please sign in to comment.