Skip to content

Commit

Permalink
Merge pull request #1125 from IntelPython/fix_gh_1122
Browse files Browse the repository at this point in the history
Added support for argument axis with value None for dpctl.tensor.concat().
  • Loading branch information
oleksandr-pavlyk authored Mar 17, 2023
2 parents 2fafa76 + cc7f714 commit 2273287
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 13 deletions.
65 changes: 52 additions & 13 deletions dpctl/tensor/_manipulation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,10 +411,10 @@ def roll(X, shift, axis=None):
return res


def _arrays_validation(arrays):
def _arrays_validation(arrays, check_ndim=True):
n = len(arrays)
if n == 0:
raise TypeError("Missing 1 required positional argument: 'arrays'")
raise TypeError("Missing 1 required positional argument: 'arrays'.")

if not isinstance(arrays, (list, tuple)):
raise TypeError(f"Expected tuple or list type, got {type(arrays)}.")
Expand All @@ -425,11 +425,11 @@ def _arrays_validation(arrays):

exec_q = dputils.get_execution_queue([X.sycl_queue for X in arrays])
if exec_q is None:
raise ValueError("All the input arrays must have same sycl queue")
raise ValueError("All the input arrays must have same sycl queue.")

res_usm_type = dputils.get_coerced_usm_type([X.usm_type for X in arrays])
if res_usm_type is None:
raise ValueError("All the input arrays must have usm_type")
raise ValueError("All the input arrays must have usm_type.")

X0 = arrays[0]
_supported_dtype(Xi.dtype for Xi in arrays)
Expand All @@ -438,13 +438,14 @@ def _arrays_validation(arrays):
for i in range(1, n):
res_dtype = np.promote_types(res_dtype, arrays[i])

for i in range(1, n):
if X0.ndim != arrays[i].ndim:
raise ValueError(
"All the input arrays must have same number of dimensions, "
f"but the array at index 0 has {X0.ndim} dimension(s) and the "
f"array at index {i} has {arrays[i].ndim} dimension(s)"
)
if check_ndim:
for i in range(1, n):
if X0.ndim != arrays[i].ndim:
raise ValueError(
"All the input arrays must have same number of dimensions, "
f"but the array at index 0 has {X0.ndim} dimension(s) and "
f"the array at index {i} has {arrays[i].ndim} dimension(s)."
)
return res_dtype, res_usm_type, exec_q


Expand All @@ -457,10 +458,46 @@ def _check_same_shapes(X0_shape, axis, n, arrays):
"All the input array dimensions for the concatenation "
f"axis must match exactly, but along dimension {j}, the "
f"array at index 0 has size {X0j} and the array "
f"at index {i} has size {Xi_shape[j]}"
f"at index {i} has size {Xi_shape[j]}."
)


def _concat_axis_None(arrays):
"Implementation of concat(arrays, axis=None)."
res_dtype, res_usm_type, exec_q = _arrays_validation(
arrays, check_ndim=False
)
res_shape = 0
for array in arrays:
res_shape += array.size
res = dpt.empty(
res_shape, dtype=res_dtype, usm_type=res_usm_type, sycl_queue=exec_q
)

hev_list = []
fill_start = 0
for array in arrays:
fill_end = fill_start + array.size
if array.flags.c_contiguous:
hev, _ = ti._copy_usm_ndarray_into_usm_ndarray(
src=dpt.reshape(array, -1),
dst=res[fill_start:fill_end],
sycl_queue=exec_q,
)
else:
hev, _ = ti._copy_usm_ndarray_for_reshape(
src=array,
dst=res[fill_start:fill_end],
shift=0,
sycl_queue=exec_q,
)
fill_start = fill_end
hev_list.append(hev)

dpctl.SyclEvent.wait_for(hev_list)
return res


def concat(arrays, axis=0):
"""concat(arrays, axis)
Expand All @@ -486,8 +523,10 @@ def concat(arrays, axis=0):
of the output array is determined by USM allocation type promotion
rules.
"""
res_dtype, res_usm_type, exec_q = _arrays_validation(arrays)
if axis is None:
return _concat_axis_None(arrays)

res_dtype, res_usm_type, exec_q = _arrays_validation(arrays)
n = len(arrays)
X0 = arrays[0]

Expand Down
18 changes: 18 additions & 0 deletions dpctl/tests/test_usm_ndarray_manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -839,6 +839,7 @@ def test_concat_1array(data):
[(0, 2), (2, 2), 0],
[(2, 1), (2, 2), -1],
[(2, 2, 2), (2, 1, 2), 1],
[(3, 3, 3), (2, 2), None],
],
)
def test_concat_2arrays(data):
Expand Down Expand Up @@ -892,6 +893,23 @@ def test_concat_3arrays(data):
assert_array_equal(Rnp, dpt.asnumpy(R))


def test_concat_axis_none_strides():
try:
q = dpctl.SyclQueue()
except dpctl.SyclQueueCreationError:
pytest.skip("Queue could not be created")
Xnp = np.arange(0, 18).reshape((6, 3))
X = dpt.asarray(Xnp, sycl_queue=q)

Ynp = np.arange(20, 36).reshape((4, 2, 2))
Y = dpt.asarray(Ynp, sycl_queue=q)

Znp = np.concatenate([Xnp[::2], Ynp[::2]], axis=None)
Z = dpt.concat([X[::2], Y[::2]], axis=None)

assert_array_equal(Znp, dpt.asnumpy(Z))


def test_stack_incorrect_shape():
try:
q = dpctl.SyclQueue()
Expand Down

0 comments on commit 2273287

Please sign in to comment.