Skip to content

Commit

Permalink
Reuse dpctl.tensor.where in dpnp.where
Browse files Browse the repository at this point in the history
  • Loading branch information
vlad-perevezentsev committed Apr 17, 2023
1 parent ef32d77 commit 1fb91ed
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 20 deletions.
37 changes: 20 additions & 17 deletions dpnp/dpnp_iface_searching.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,10 @@
from dpnp.dpnp_utils import *

import dpnp
from dpnp.dpnp_array import dpnp_array

import numpy
import dpctl.tensor as dpt


__all__ = [
Expand Down Expand Up @@ -181,7 +184,7 @@ def where(condition, x=None, y=None, /):
Return elements chosen from `x` or `y` depending on `condition`.
When only `condition` is provided, this function is a shorthand for
:obj:`dpnp.nonzero(condition)`.
:obj:`dpnp.nonzero(condition)`.
For full documentation refer to :obj:`numpy.where`.
Expand All @@ -193,12 +196,13 @@ def where(condition, x=None, y=None, /):
Limitations
-----------
Parameters `condition`, `x` and `y` are supported as either scalar, :class:`dpnp.ndarray`
Parameter `condition` is supported as either :class:`dpnp.ndarray`
or :class:`dpctl.tensor.usm_ndarray`.
Parameters `x` and `y` are supported as either scalar, :class:`dpnp.ndarray`
or :class:`dpctl.tensor.usm_ndarray`
Otherwise the function will be executed sequentially on CPU.
Data type of `condition` parameter is limited by :obj:`dpnp.bool`.
Input array data types of `x` and `y` are limited by supported DPNP :ref:`Data types`.
See Also
--------
:obj:`nonzero` : The function that is called when `x` and `y`are omitted.
Expand All @@ -220,18 +224,17 @@ def where(condition, x=None, y=None, /):
elif missing == 2:
return dpnp.nonzero(condition)
elif missing == 0:
# get USM type and queue to copy scalar from the host memory into a USM allocation
usm_type, queue = get_usm_allocations([condition, x, y])

c_desc = dpnp.get_dpnp_descriptor(condition, copy_when_strides=False, copy_when_nondefault_queue=False,
alloc_usm_type=usm_type, alloc_queue=queue)
x_desc = dpnp.get_dpnp_descriptor(x, copy_when_strides=False, copy_when_nondefault_queue=False,
alloc_usm_type=usm_type, alloc_queue=queue)
y_desc = dpnp.get_dpnp_descriptor(y, copy_when_strides=False, copy_when_nondefault_queue=False,
alloc_usm_type=usm_type, alloc_queue=queue)
if c_desc and x_desc and y_desc:
if c_desc.dtype != dpnp.bool:
raise TypeError("condition must be a boolean array")
return dpnp_where(c_desc, x_desc, y_desc).get_pyobj()
check_input_type = lambda x: isinstance(x, (dpnp_array, dpt.usm_ndarray))
if check_input_type(condition):
if numpy.isscalar(x) or numpy.isscalar(y):
# get USM type and queue to copy scalar from the host memory into a USM allocation
usm_type, queue = get_usm_allocations([condition, x, y])
x = dpt.asarray(x, usm_type=usm_type, sycl_queue=queue) if numpy.isscalar(x) else x
y = dpt.asarray(y, usm_type=usm_type, sycl_queue=queue) if numpy.isscalar(y) else y
if check_input_type(x) and check_input_type(y):
dpt_condition = condition.get_array() if isinstance(condition, dpnp_array) else condition
dpt_x = x.get_array() if isinstance(x, dpnp_array) else x
dpt_y = y.get_array() if isinstance(y, dpnp_array) else y
return dpnp_array._create_from_usm_ndarray(dpt.where(dpt_condition, dpt_x, dpt_y))

return call_origin(numpy.where, condition, x, y)
19 changes: 19 additions & 0 deletions tests/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,3 +581,22 @@ def test_triu_indices_from(array, k):
result = dpnp.triu_indices_from(ia, k)
expected = numpy.triu_indices_from(a, k)
assert_array_equal(expected, result)


@pytest.mark.parametrize("cond_dtype", get_all_dtypes())
@pytest.mark.parametrize("scalar_dtype", get_all_dtypes(no_none=True))
def test_where_with_scalars(cond_dtype, scalar_dtype):
a = numpy.array([-1, 0, 1, 0], dtype=cond_dtype)
ia = dpnp.array(a)

result = dpnp.where(ia, scalar_dtype(1), scalar_dtype(0))
expected = numpy.where(a, scalar_dtype(1), scalar_dtype(0))
assert_array_equal(expected, result)

result = dpnp.where(ia, ia*2, scalar_dtype(0))
expected = numpy.where(a, a*2, scalar_dtype(0))
assert_array_equal(expected, result)

result = dpnp.where(ia, scalar_dtype(1), dpnp.array(0))
expected = numpy.where(a, scalar_dtype(1), numpy.array(0))
assert_array_equal(expected, result)
3 changes: 0 additions & 3 deletions tests/third_party/cupy/sorting_tests/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,6 @@ def test_argminmax_dtype(self, in_dtype, result_dtype):
{'cond_shape': (2, 3, 4), 'x_shape': (2, 3, 4), 'y_shape': (3, 4)},
{'cond_shape': (3, 4), 'x_shape': (2, 3, 4), 'y_shape': (4,)},
)
@pytest.mark.usefixtures("allow_fall_back_on_numpy")
@testing.gpu
class TestWhereTwoArrays(unittest.TestCase):

Expand All @@ -274,8 +273,6 @@ def test_where_two_arrays(self, xp, cond_type, x_type, y_type):
# Almost all values of a matrix `shaped_random` makes are not zero.
# To make a sparse matrix, we need multiply `m`.
cond = testing.shaped_random(self.cond_shape, xp, cond_type) * m
if xp is cupy:
cond = cond.astype(cupy.bool)
x = testing.shaped_random(self.x_shape, xp, x_type, seed=0)
y = testing.shaped_random(self.y_shape, xp, y_type, seed=1)
return xp.where(cond, x, y)
Expand Down

0 comments on commit 1fb91ed

Please sign in to comment.