Skip to content

Commit

Permalink
Added support for argument axis with value None for arrays with strid…
Browse files Browse the repository at this point in the history
…es for dpctl.tensor.concat().
  • Loading branch information
npolina4 committed Mar 16, 2023
1 parent bccb694 commit cc7f714
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 51 deletions.
116 changes: 65 additions & 51 deletions dpctl/tensor/_manipulation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,72 +372,86 @@ def _check_same_shapes(X0_shape, axis, n, arrays):
)


def _concat_axis_None(arrays):
"Implementation of concat(arrays, axis=None)."
res_dtype, res_usm_type, exec_q = _arrays_validation(
arrays, check_ndim=False
)
res_shape = 0
for array in arrays:
res_shape += array.size
res = dpt.empty(
res_shape, dtype=res_dtype, usm_type=res_usm_type, sycl_queue=exec_q
)

hev_list = []
fill_start = 0
for array in arrays:
fill_end = fill_start + array.size
if array.flags.c_contiguous:
hev, _ = ti._copy_usm_ndarray_into_usm_ndarray(
src=dpt.reshape(array, -1),
dst=res[fill_start:fill_end],
sycl_queue=exec_q,
)
else:
hev, _ = ti._copy_usm_ndarray_for_reshape(
src=array,
dst=res[fill_start:fill_end],
shift=0,
sycl_queue=exec_q,
)
fill_start = fill_end
hev_list.append(hev)

dpctl.SyclEvent.wait_for(hev_list)
return res


def concat(arrays, axis=0):
"""
concat(arrays: tuple or list of usm_ndarrays, axis: int) -> usm_ndarray
Joins a sequence of arrays along an existing axis.
"""
if axis is None:
res_dtype, res_usm_type, exec_q = _arrays_validation(
arrays, check_ndim=False
)
res_shape = 0
for array in arrays:
res_shape += array.size
res = dpt.empty(
res_shape, dtype=res_dtype, usm_type=res_usm_type, sycl_queue=exec_q
)
return _concat_axis_None(arrays)

hev_list = []
fill_start = 0
for array in arrays:
fill_end = fill_start + array.size
hev, _ = ti._copy_usm_ndarray_into_usm_ndarray(
src=dpt.reshape(array, -1),
dst=res[fill_start:fill_end],
sycl_queue=exec_q,
)
fill_start = fill_end
hev_list.append(hev)
res_dtype, res_usm_type, exec_q = _arrays_validation(arrays)
n = len(arrays)
X0 = arrays[0]

dpctl.SyclEvent.wait_for(hev_list)
else:
res_dtype, res_usm_type, exec_q = _arrays_validation(arrays)
n = len(arrays)
X0 = arrays[0]
axis = normalize_axis_index(axis, X0.ndim)
X0_shape = X0.shape
_check_same_shapes(X0_shape, axis, n, arrays)

axis = normalize_axis_index(axis, X0.ndim)
X0_shape = X0.shape
_check_same_shapes(X0_shape, axis, n, arrays)
res_shape_axis = 0
for X in arrays:
res_shape_axis = res_shape_axis + X.shape[axis]

res_shape_axis = 0
for X in arrays:
res_shape_axis = res_shape_axis + X.shape[axis]
res_shape = tuple(
X0_shape[i] if i != axis else res_shape_axis for i in range(X0.ndim)
)

res_shape = tuple(
X0_shape[i] if i != axis else res_shape_axis for i in range(X0.ndim)
)
res = dpt.empty(
res_shape, dtype=res_dtype, usm_type=res_usm_type, sycl_queue=exec_q
)

res = dpt.empty(
res_shape, dtype=res_dtype, usm_type=res_usm_type, sycl_queue=exec_q
hev_list = []
fill_start = 0
for i in range(n):
fill_end = fill_start + arrays[i].shape[axis]
c_shapes_copy = tuple(
np.s_[fill_start:fill_end] if j == axis else np.s_[:]
for j in range(X0.ndim)
)
hev, _ = ti._copy_usm_ndarray_into_usm_ndarray(
src=arrays[i], dst=res[c_shapes_copy], sycl_queue=exec_q
)
fill_start = fill_end
hev_list.append(hev)

hev_list = []
fill_start = 0
for i in range(n):
fill_end = fill_start + arrays[i].shape[axis]
c_shapes_copy = tuple(
np.s_[fill_start:fill_end] if j == axis else np.s_[:]
for j in range(X0.ndim)
)
hev, _ = ti._copy_usm_ndarray_into_usm_ndarray(
src=arrays[i], dst=res[c_shapes_copy], sycl_queue=exec_q
)
fill_start = fill_end
hev_list.append(hev)

dpctl.SyclEvent.wait_for(hev_list)
dpctl.SyclEvent.wait_for(hev_list)

return res

Expand Down
17 changes: 17 additions & 0 deletions dpctl/tests/test_usm_ndarray_manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -893,6 +893,23 @@ def test_concat_3arrays(data):
assert_array_equal(Rnp, dpt.asnumpy(R))


def test_concat_axis_none_strides():
try:
q = dpctl.SyclQueue()
except dpctl.SyclQueueCreationError:
pytest.skip("Queue could not be created")
Xnp = np.arange(0, 18).reshape((6, 3))
X = dpt.asarray(Xnp, sycl_queue=q)

Ynp = np.arange(20, 36).reshape((4, 2, 2))
Y = dpt.asarray(Ynp, sycl_queue=q)

Znp = np.concatenate([Xnp[::2], Ynp[::2]], axis=None)
Z = dpt.concat([X[::2], Y[::2]], axis=None)

assert_array_equal(Znp, dpt.asnumpy(Z))


def test_stack_incorrect_shape():
try:
q = dpctl.SyclQueue()
Expand Down

0 comments on commit cc7f714

Please sign in to comment.