Skip to content

Commit

Permalink
Update dpnp.ediff1d function (#1983)
Browse files Browse the repository at this point in the history
* Refactor ediff1d function

* Update ediff1d tests

* Apply review comments

---------

Co-authored-by: Anton <[email protected]>
  • Loading branch information
vlad-perevezentsev and antonwolfy authored Aug 15, 2024
1 parent 2dcfd09 commit b20a352
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 43 deletions.
66 changes: 34 additions & 32 deletions dpnp/dpnp_iface_mathematical.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,30 @@ def _gradient_num_diff_edges(
)


def _process_ediff1d_args(arg, arg_name, ary_dtype, ary_sycl_queue, usm_type):
"""Process the argument for ediff1d."""
if not dpnp.is_supported_array_type(arg):
arg = dpnp.asarray(arg, usm_type=usm_type, sycl_queue=ary_sycl_queue)
else:
usm_type = dpu.get_coerced_usm_type([usm_type, arg.usm_type])
# check that arrays have the same allocation queue
if dpu.get_execution_queue([ary_sycl_queue, arg.sycl_queue]) is None:
raise dpu.ExecutionPlacementError(
f"ary and {arg_name} must be allocated on the same SYCL queue"
)

if not dpnp.can_cast(arg, ary_dtype, casting="same_kind"):
raise TypeError(
f"dtype of {arg_name} must be compatible "
"with input ary under the `same_kind` rule."
)

if arg.ndim > 1:
arg = dpnp.ravel(arg)

return arg, usm_type


_ABS_DOCSTRING = """
Calculates the absolute value for each element `x_i` of input array `x`.
Expand Down Expand Up @@ -1332,52 +1356,30 @@ def ediff1d(ary, to_end=None, to_begin=None):
return ary[1:] - ary[:-1]

ary_dtype = ary.dtype
ary_usm_type = ary.usm_type
ary_sycl_queue = ary.sycl_queue
usm_type = ary.usm_type

if to_begin is None:
l_begin = 0
else:
if not dpnp.is_supported_array_type(to_begin):
to_begin = dpnp.asarray(
to_begin, usm_type=ary_usm_type, sycl_queue=ary_sycl_queue
)
if not dpnp.can_cast(to_begin, ary_dtype, casting="same_kind"):
raise TypeError(
"dtype of `to_begin` must be compatible "
"with input `ary` under the `same_kind` rule."
)

to_begin_ndim = to_begin.ndim

if to_begin_ndim > 1:
to_begin = dpnp.ravel(to_begin)

to_begin, usm_type = _process_ediff1d_args(
to_begin, "to_begin", ary_dtype, ary_sycl_queue, usm_type
)
l_begin = to_begin.size

if to_end is None:
l_end = 0
else:
if not dpnp.is_supported_array_type(to_end):
to_end = dpnp.asarray(
to_end, usm_type=ary_usm_type, sycl_queue=ary_sycl_queue
)
if not dpnp.can_cast(to_end, ary_dtype, casting="same_kind"):
raise TypeError(
"dtype of `to_end` must be compatible "
"with input `ary` under the `same_kind` rule."
)

to_end_ndim = to_end.ndim

if to_end_ndim > 1:
to_end = dpnp.ravel(to_end)

to_end, usm_type = _process_ediff1d_args(
to_end, "to_end", ary_dtype, ary_sycl_queue, usm_type
)
l_end = to_end.size

# calculating using in place operation
l_diff = max(len(ary) - 1, 0)
result = dpnp.empty_like(ary, shape=l_diff + l_begin + l_end)
result = dpnp.empty_like(
ary, shape=l_diff + l_begin + l_end, usm_type=usm_type
)

if l_begin > 0:
result[:l_begin] = to_begin
Expand Down
12 changes: 12 additions & 0 deletions tests/test_mathematical.py
Original file line number Diff line number Diff line change
Expand Up @@ -2151,6 +2151,18 @@ def test_ediff1d_errors(self):
to_end = dpnp.array([5], dtype="f4")
assert_raises(TypeError, dpnp.ediff1d, a_dp, to_end=to_end)

# another `to_begin` sycl queue
to_begin = dpnp.array([-20, -15], sycl_queue=dpctl.SyclQueue())
assert_raises(
ExecutionPlacementError, dpnp.ediff1d, a_dp, to_begin=to_begin
)

# another `to_end` sycl queue
to_end = dpnp.array([15, 20], sycl_queue=dpctl.SyclQueue())
assert_raises(
ExecutionPlacementError, dpnp.ediff1d, a_dp, to_end=to_end
)


@pytest.mark.usefixtures("allow_fall_back_on_numpy")
class TestTrapz:
Expand Down
15 changes: 5 additions & 10 deletions tests/test_sycl_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -2407,12 +2407,7 @@ def test_nan_to_num(copy, device):


@pytest.mark.parametrize(
"device_x",
valid_devices,
ids=[device.filter_string for device in valid_devices],
)
@pytest.mark.parametrize(
"device_args",
"device",
valid_devices,
ids=[device.filter_string for device in valid_devices],
)
Expand All @@ -2424,15 +2419,15 @@ def test_nan_to_num(copy, device):
(10, -10),
],
)
def test_ediff1d(device_x, device_args, to_end, to_begin):
def test_ediff1d(device, to_end, to_begin):
data = [1, 3, 5, 7]

x = dpnp.array(data, device=device_x)
x = dpnp.array(data, device=device)
if to_end:
to_end = dpnp.array(to_end, device=device_args)
to_end = dpnp.array(to_end, device=device)

if to_begin:
to_begin = dpnp.array(to_begin, device=device_args)
to_begin = dpnp.array(to_begin, device=device)

res = dpnp.ediff1d(x, to_end=to_end, to_begin=to_begin)

Expand Down
2 changes: 1 addition & 1 deletion tests/test_usm_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -1431,4 +1431,4 @@ def test_ediff1d(usm_type_x, usm_type_args, to_end, to_begin):

res = dp.ediff1d(x, to_end=to_end, to_begin=to_begin)

assert res.usm_type == x.usm_type
assert res.usm_type == du.get_coerced_usm_type([usm_type_x, usm_type_args])

0 comments on commit b20a352

Please sign in to comment.