Skip to content

Commit

Permalink
Add dpnp.result_type() support (#1435)
Browse files Browse the repository at this point in the history
* Add dpnp.result_type() support

* Update dpnp/dpnp_iface_manipulation.py

Co-authored-by: Natalia Polina <[email protected]>

---------

Co-authored-by: Natalia Polina <[email protected]>
  • Loading branch information
antonwolfy and npolina4 authored Jun 14, 2023
1 parent c51b3ce commit ac25666
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 1 deletion.
2 changes: 1 addition & 1 deletion dpnp/dpnp_algo/dpnp_elementwise_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def _call_divide_inplace(lhs, rhs, sycl_queue, depends=[]):
"""In place workaround until dpctl.tensor provides the functionality."""

# allocate temporary memory for out array
out = dpt.empty_like(lhs, dtype=numpy.result_type((lhs.dtype, rhs.dtype)))
out = dpt.empty_like(lhs, dtype=dpnp.result_type(lhs.dtype, rhs.dtype))

# call a general callback
div_ht_, div_ev_ = _call_divide(lhs, rhs, out, sycl_queue, depends)
Expand Down
44 changes: 44 additions & 0 deletions dpnp/dpnp_iface_manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
"ravel",
"repeat",
"reshape",
"result_type",
"rollaxis",
"shape",
"squeeze",
Expand Down Expand Up @@ -579,6 +580,49 @@ def reshape(x, /, newshape, order='C', copy=None):
return dpnp_array._create_from_usm_ndarray(usm_arr)


def result_type(*arrays_and_dtypes):
"""
Returns the type that results from applying the NumPy
type promotion rules to the arguments.
For full documentation refer to :obj:`numpy.result_type`.
Parameters
----------
arrays_and_dtypes : list of arrays and dtypes
An arbitrary length sequence of arrays or dtypes.
Returns
-------
out : dtype
The result type.
Limitations
-----------
An array in the input list is supported as either :class:`dpnp.ndarray`
or :class:`dpctl.tensor.usm_ndarray`.
Examples
--------
>>> import dpnp as dp
>>> dp.result_type(dp.arange(3, dtype=dp.int64), dp.arange(7, dtype=dp.int32))
dtype('int64')
>>> dp.result_type(dp.int64, dp.complex128)
dtype('complex128')
>>> dp.result_type(dp.ones(10, dtype=dp.float32), dp.float64)
dtype('float64')
"""

usm_arrays_and_dtypes = [
X.dtype if isinstance(X, (dpnp_array, dpt.usm_ndarray)) else X
for X in arrays_and_dtypes
]
return dpt.result_type(*usm_arrays_and_dtypes)


def rollaxis(x1, axis, start=0):
"""
Roll the specified axis backwards, until it lies in a given position.
Expand Down
19 changes: 19 additions & 0 deletions tests/test_manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,25 @@ def test_repeat(arr):
assert_array_equal(expected, result)


def test_result_type():
X = [dpnp.ones((2), dtype=dpnp.int64), dpnp.int32, "float16"]
X_np = [numpy.ones((2), dtype=numpy.int64), numpy.int32, "float16"]

assert dpnp.result_type(*X) == numpy.result_type(*X_np)

def test_result_type_only_dtypes():
X = [dpnp.int64, dpnp.int32, dpnp.bool, dpnp.float32]
X_np = [numpy.int64, numpy.int32, numpy.bool_, numpy.float32]

assert dpnp.result_type(*X) == numpy.result_type(*X_np)

def test_result_type_only_arrays():
X = [dpnp.ones((2), dtype=dpnp.int64), dpnp.ones((7, 4), dtype=dpnp.int32)]
X_np = [numpy.ones((2), dtype=numpy.int64), numpy.ones((7, 4), dtype=numpy.int32)]

assert dpnp.result_type(*X) == numpy.result_type(*X_np)


@pytest.mark.usefixtures("allow_fall_back_on_numpy")
@pytest.mark.parametrize("array",
[[1, 2, 3],
Expand Down

0 comments on commit ac25666

Please sign in to comment.