Skip to content

Commit

Permalink
Avoid creating SyclDevice from filter_string
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexanderKalistratov committed Apr 4, 2023
1 parent 58b7b84 commit 9d8ae9e
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 37 deletions.
77 changes: 41 additions & 36 deletions numba_dpex/core/types/usm_ndarray_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,48 +31,53 @@ def __init__(
aligned=True,
addrspace=address_space.GLOBAL,
):
# Creating SyclDevice from filter_string is expensive. So, USMNdArray should be able to
# accept and SyclDevice from usm_ndarray as device parameter
if not isinstance(device, (str, dpctl.SyclDevice)):
raise TypeError(
"The device keyword arg should be a str object specifying "
"a SYCL filter selector"
)

if not isinstance(queue, dpctl.SyclQueue) and queue is not None:
raise TypeError(
"The queue keyword arg should be a dpctl.SyclQueue object or None"
)

self.usm_type = usm_type
self.addrspace = addrspace

if queue is not None and device != "unknown":
if not isinstance(device, str):
raise TypeError(
"The device keyword arg should be a str object specifying "
"a SYCL filter selector"
)
if not isinstance(queue, dpctl.SyclQueue):
raise TypeError(
"The queue keyword arg should be a dpctl.SyclQueue object"
)
d1 = queue.sycl_device
d2 = dpctl.SyclDevice(device)
if d1 != d2:
raise TypeError(
"The queue keyword arg and the device keyword arg specify "
"different SYCL devices"
)
def to_device(dev):
if isinstance(dev, dpctl.SyclDevice):
return dev

return dpctl.SyclDevice(dev)

def device_as_string(dev):
if isinstance(dev, dpctl.SyclDevice):
return dev.filter_string

return dev

if queue is not None:
if device != "unknown":
if queue.sycl_device != to_device(device):
raise TypeError(
"The queue keyword arg and the device keyword arg specify "
"different SYCL devices"
)

self.queue = queue
self.device = device
elif queue is None and device != "unknown":
if not isinstance(device, str):
raise TypeError(
"The device keyword arg should be a str object specifying "
"a SYCL filter selector"
)
else:
if device == "unknown":
device = None

device_str = device_as_string(device)
self.queue = dpctl.tensor._device.normalize_queue_device(
device=device
device=device_str
)
self.device = device
elif queue is not None and device == "unknown":
if not isinstance(queue, dpctl.SyclQueue):
raise TypeError(
"The queue keyword arg should be a dpctl.SyclQueue object"
)
self.device = self.queue.sycl_device.filter_string
self.queue = queue
else:
self.queue = dpctl.tensor._device.normalize_queue_device()
self.device = self.queue.sycl_device.filter_string

self.device = self.queue.sycl_device.filter_string

if not dtype:
dummy_tensor = dpctl.tensor.empty(
Expand Down
2 changes: 1 addition & 1 deletion numba_dpex/core/typing/typeof.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def _typeof_helper(val, array_class_type):
)

try:
device = val.sycl_device.filter_string
device = val.sycl_device
except AttributeError:
raise ValueError("The device for the usm_ndarray could not be inferred")

Expand Down

0 comments on commit 9d8ae9e

Please sign in to comment.