diff --git a/dpctl/tensor/_usmarray.pyx b/dpctl/tensor/_usmarray.pyx index 550b733186..62cbfd0665 100644 --- a/dpctl/tensor/_usmarray.pyx +++ b/dpctl/tensor/_usmarray.pyx @@ -149,6 +149,19 @@ cdef bint _is_host_cpu(object dl_device): return (dl_type == DLDeviceType.kDLCPU) and (dl_id == 0) +cdef void _validate_and_use_stream(object stream, c_dpctl.SyclQueue self_queue) except *: + if (stream is None or stream == self_queue): + pass + else: + if not isinstance(stream, dpctl.SyclQueue): + raise TypeError( + "stream argument type was expected to be dpctl.SyclQueue," + f" got {type(stream)} instead" + ) + ev = self_queue.submit_barrier() + stream.submit_barrier(dependent_events=[ev]) + + cdef class usm_ndarray: """ usm_ndarray(shape, dtype=None, strides=None, buffer="device", \ offset=0, order="C", buffer_ctor_kwargs=dict(), \ @@ -1025,12 +1038,7 @@ cdef class usm_ndarray: cdef c_dpmem._Memory arr_buf d = Device.create_device(target_device) - if (stream is None or not isinstance(stream, dpctl.SyclQueue) or - stream == self.sycl_queue): - pass - else: - ev = self.sycl_queue.submit_barrier() - stream.submit_barrier(dependent_events=[ev]) + _validate_and_use_stream(stream, self.sycl_queue) if (d.sycl_context == self.sycl_context): arr_buf = self.usm_data @@ -1203,12 +1211,7 @@ cdef class usm_ndarray: # legacy path for DLManagedTensor # copy kwarg ignored because copy flag can't be set _caps = c_dlpack.to_dlpack_capsule(self) - if (stream is None or type(stream) is not dpctl.SyclQueue or - stream == self.sycl_queue): - pass - else: - ev = self.sycl_queue.submit_barrier() - stream.submit_barrier(dependent_events=[ev]) + _validate_and_use_stream(stream, self.sycl_queue) return _caps else: if not isinstance(max_version, tuple) or len(max_version) != 2: @@ -1250,12 +1253,7 @@ cdef class usm_ndarray: copy = False # TODO: strategy for handling stream on different device from dl_device if copy: - if (stream is None or type(stream) is not dpctl.SyclQueue or - stream == self.sycl_queue): - pass - else: - ev = self.sycl_queue.submit_barrier() - stream.submit_barrier(dependent_events=[ev]) + _validate_and_use_stream(stream, self.sycl_queue) nbytes = self.usm_data.nbytes copy_buffer = type(self.usm_data)( nbytes, queue=self.sycl_queue @@ -1272,22 +1270,12 @@ cdef class usm_ndarray: _caps = c_dlpack.to_dlpack_versioned_capsule(_copied_arr, copy) else: _caps = c_dlpack.to_dlpack_versioned_capsule(self, copy) - if (stream is None or type(stream) is not dpctl.SyclQueue or - stream == self.sycl_queue): - pass - else: - ev = self.sycl_queue.submit_barrier() - stream.submit_barrier(dependent_events=[ev]) + _validate_and_use_stream(stream, self.sycl_queue) return _caps else: # legacy path for DLManagedTensor _caps = c_dlpack.to_dlpack_capsule(self) - if (stream is None or type(stream) is not dpctl.SyclQueue or - stream == self.sycl_queue): - pass - else: - ev = self.sycl_queue.submit_barrier() - stream.submit_barrier(dependent_events=[ev]) + _validate_and_use_stream(stream, self.sycl_queue) return _caps def __dlpack_device__(self): @@ -1555,17 +1543,17 @@ cdef class usm_ndarray: def __array__(self, dtype=None, /, *, copy=None): """NumPy's array protocol method to disallow implicit conversion. - Without this definition, `numpy.asarray(usm_ar)` converts - usm_ndarray instance into NumPy array with data type `object` - and every element being 0d usm_ndarray. + Without this definition, `numpy.asarray(usm_ar)` converts + usm_ndarray instance into NumPy array with data type `object` + and every element being 0d usm_ndarray. https://github.com/IntelPython/dpctl/pull/1384#issuecomment-1707212972 - """ + """ raise TypeError( "Implicit conversion to a NumPy array is not allowed. " - "Use `dpctl.tensor.asnumpy` to copy data from this " - "`dpctl.tensor.usm_ndarray` instance to NumPy array" - ) + "Use `dpctl.tensor.asnumpy` to copy data from this " + "`dpctl.tensor.usm_ndarray` instance to NumPy array" + ) cdef usm_ndarray _real_view(usm_ndarray ary): diff --git a/dpctl/tests/test_usm_ndarray_ctor.py b/dpctl/tests/test_usm_ndarray_ctor.py index ccfc655bfc..81b398ab9b 100644 --- a/dpctl/tests/test_usm_ndarray_ctor.py +++ b/dpctl/tests/test_usm_ndarray_ctor.py @@ -1380,6 +1380,30 @@ def test_to_device(): assert Y.sycl_device == dev +def test_to_device_stream_validation(): + try: + X = dpt.usm_ndarray(1, "f4") + except dpctl.SyclDeviceCreationError: + pytest.skip("No SYCL devices available") + # invalid type of stream keyword + with pytest.raises(TypeError): + X.to_device(X.sycl_queue, stream=dict()) + # stream is keyword-only arg + with pytest.raises(TypeError): + X.to_device(X.sycl_queue, X.sycl_queue) + + +def test_to_device_stream_use(): + try: + X = dpt.usm_ndarray(1, "f4") + except dpctl.SyclDeviceCreationError: + pytest.skip("No SYCL devices available") + q1 = dpctl.SyclQueue( + X.sycl_context, X.sycl_device, property="enable_profiling" + ) + X.to_device(q1, stream=q1) + + def test_to_device_migration(): q1 = get_queue_or_skip() # two distinct copies of default-constructed queue q2 = get_queue_or_skip()