diff --git a/dpctl/tensor/_slicing.pxi b/dpctl/tensor/_slicing.pxi index 353616b942..81a696328f 100644 --- a/dpctl/tensor/_slicing.pxi +++ b/dpctl/tensor/_slicing.pxi @@ -120,9 +120,11 @@ def _basic_slice_meta(ind, shape : tuple, strides : tuple, offset : int): sh0 = _slice_len(sl_start, sl_stop, sl_step) str0 = sl_step * strides[0] new_strides = strides if (sl_step == 1 or sh0 == 0) else (str0,) + strides[1:] - new_offset = offset if sh0 == 0 else offset + sl_start * strides[0] + new_shape = (sh0, ) + shape[1:] + is_empty = any(sh_i == 0 for sh_i in new_shape) + new_offset = offset if is_empty else offset + sl_start * strides[0] return ( - (sh0, ) + shape[1:], + new_shape, new_strides, new_offset, _no_advanced_ind, @@ -135,11 +137,15 @@ def _basic_slice_meta(ind, shape : tuple, strides : tuple, offset : int): return ((0,) + shape, (0,) + strides, offset, _no_advanced_ind, _no_advanced_pos) elif _is_integral(ind): ind = ind.__index__() + new_shape = shape[1:] + new_strides = strides[1:] + is_empty = any(sh_i == 0 for sh_i in new_shape) if 0 <= ind < shape[0]: - return (shape[1:], strides[1:], offset + ind * strides[0], _no_advanced_ind, _no_advanced_pos) + new_offset = offset if is_empty else offset + ind * strides[0] + return (new_shape, new_strides, new_offset, _no_advanced_ind, _no_advanced_pos) elif -shape[0] <= ind < 0: - return (shape[1:], strides[1:], - offset + (shape[0] + ind) * strides[0], _no_advanced_ind, _no_advanced_pos) + new_offset = offset if is_empty else offset + (shape[0] + ind) * strides[0] + return (new_shape, new_strides, new_offset, _no_advanced_ind, _no_advanced_pos) else: raise IndexError( "Index {0} is out of range for axes 0 with " diff --git a/dpctl/tests/test_usm_ndarray_ctor.py b/dpctl/tests/test_usm_ndarray_ctor.py index a0f2414fce..4117c7a967 100644 --- a/dpctl/tests/test_usm_ndarray_ctor.py +++ b/dpctl/tests/test_usm_ndarray_ctor.py @@ -452,6 +452,30 @@ def test_slicing_basic(): assert np.array_equal(Xh, Xnp[Xnp[2] : Xnp[5]]) +def test_slicing_empty(): + try: + X = dpt.usm_ndarray((0, 10), dtype="i4") + except dpctl.SyclDeviceCreationError: + pytest.skip("No SYCL devices available") + x = dpt.moveaxis(X, 1, 0) + # this used to raise ValueError + y = x[1] + assert y.ndim == 1 + assert y.shape == (0,) + assert y.dtype == X.dtype + assert y.usm_type == X.usm_type + assert y.sycl_queue == X.sycl_queue + w = x[1:3] + assert w.ndim == 2 + assert w.shape == ( + 2, + 0, + ) + assert w.dtype == X.dtype + assert w.usm_type == X.usm_type + assert w.sycl_queue == X.sycl_queue + + def test_ctor_invalid_shape(): with pytest.raises(TypeError): dpt.usm_ndarray(dict())