Skip to content

Commit

Permalink
Merge pull request #946 from AlexanderKalistratov/use_of_cached_queue
Browse files Browse the repository at this point in the history
Using cached queue instead of creating new one on type inference
  • Loading branch information
Diptorup Deb authored Apr 21, 2023
2 parents 3971788 + c81f816 commit 0915170
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 41 deletions.
1 change: 0 additions & 1 deletion numba_dpex/core/typeconv/array_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ def to_usm_ndarray(suai_attrs, addrspace=address_space.GLOBAL):
ndim=suai_attrs.dimensions,
layout=layout,
usm_type=suai_attrs.usm_type,
device=suai_attrs.device,
queue=suai_attrs.queue,
readonly=not suai_attrs.is_writable,
name=None,
Expand Down
61 changes: 26 additions & 35 deletions numba_dpex/core/types/usm_ndarray_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,46 +31,37 @@ def __init__(
aligned=True,
addrspace=address_space.GLOBAL,
):
if not isinstance(device, str):
raise TypeError(
"The device keyword arg should be a str object specifying "
"a SYCL filter selector"
)

if not isinstance(queue, dpctl.SyclQueue) and queue is not None:
raise TypeError(
"The queue keyword arg should be a dpctl.SyclQueue object or None"
)

self.usm_type = usm_type
self.addrspace = addrspace

if queue is not None and device != "unknown":
if not isinstance(device, str):
raise TypeError(
"The device keyword arg should be a str object specifying "
"a SYCL filter selector"
)
if not isinstance(queue, dpctl.SyclQueue):
raise TypeError(
"The queue keyword arg should be a dpctl.SyclQueue object"
)
d1 = queue.sycl_device
d2 = dpctl.SyclDevice(device)
if d1 != d2:
raise TypeError(
"The queue keyword arg and the device keyword arg specify "
"different SYCL devices"
)
self.queue = queue
self.device = device
elif queue is None and device != "unknown":
if not isinstance(device, str):
raise TypeError(
"The device keyword arg should be a str object specifying "
"a SYCL filter selector"
)
self.queue = dpctl.SyclQueue(device)
self.device = self.queue.sycl_device.filter_string
elif queue is not None and device == "unknown":
if not isinstance(queue, dpctl.SyclQueue):
raise TypeError(
"The queue keyword arg should be a dpctl.SyclQueue object"
)
self.device = self.queue.sycl_device.filter_string
if device == "unknown":
device = None

if queue is not None and device is not None:
raise TypeError(
"'queue' and 'device' keywords can not be both specified"
)

if queue is not None:
self.queue = queue
else:
self.queue = dpctl.SyclQueue()
self.device = self.queue.sycl_device.filter_string
if device is None:
device = dpctl.SyclDevice()

self.queue = dpctl.get_device_cached_queue(device)

self.device = self.queue.sycl_device.filter_string

if not dtype:
dummy_tensor = dpctl.tensor.empty(
Expand Down
6 changes: 1 addition & 5 deletions numba_dpex/core/typing/typeof.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,18 +42,14 @@ def _typeof_helper(val, array_class_type):
"The usm_type for the usm_ndarray could not be inferred"
)

try:
device = val.sycl_device.filter_string
except AttributeError:
raise ValueError("The device for the usm_ndarray could not be inferred")
assert val.sycl_queue is not None

return array_class_type(
dtype=dtype,
ndim=val.ndim,
layout=layout,
readonly=readonly,
usm_type=usm_type,
device=device,
queue=val.sycl_queue,
addrspace=address_space.GLOBAL,
)
Expand Down

0 comments on commit 0915170

Please sign in to comment.