diff --git a/dpnp/dpnp_algo/dpnp_elementwise_common.py b/dpnp/dpnp_algo/dpnp_elementwise_common.py index 5115f39a8e7..7a54a194795 100644 --- a/dpnp/dpnp_algo/dpnp_elementwise_common.py +++ b/dpnp/dpnp_algo/dpnp_elementwise_common.py @@ -133,7 +133,7 @@ def _call_divide_inplace(lhs, rhs, sycl_queue, depends=[]): """In place workaround until dpctl.tensor provides the functionality.""" # allocate temporary memory for out array - out = dpt.empty_like(lhs, dtype=numpy.result_type((lhs.dtype, rhs.dtype))) + out = dpt.empty_like(lhs, dtype=dpnp.result_type(lhs.dtype, rhs.dtype)) # call a general callback div_ht_, div_ev_ = _call_divide(lhs, rhs, out, sycl_queue, depends) diff --git a/dpnp/dpnp_iface_manipulation.py b/dpnp/dpnp_iface_manipulation.py index cadf63af236..1ca879abe1a 100644 --- a/dpnp/dpnp_iface_manipulation.py +++ b/dpnp/dpnp_iface_manipulation.py @@ -67,6 +67,7 @@ "ravel", "repeat", "reshape", + "result_type", "rollaxis", "shape", "squeeze", @@ -579,6 +580,49 @@ def reshape(x, /, newshape, order='C', copy=None): return dpnp_array._create_from_usm_ndarray(usm_arr) +def result_type(*arrays_and_dtypes): + """ + Returns the type that results from applying the NumPy + type promotion rules to the arguments. + + For full documentation refer to :obj:`numpy.result_type`. + + Parameters + ---------- + arrays_and_dtypes : list of arrays and dtypes + An arbitrary length sequence of arrays or dtypes. + + Returns + ------- + out : dtype + The result type. + + Limitations + ----------- + An array in the input list is supported as either :class:`dpnp.ndarray` + or :class:`dpctl.tensor.usm_ndarray`. + + Examples + -------- + >>> import dpnp as dp + >>> dp.result_type(dp.arange(3, dtype=dp.int64), dp.arange(7, dtype=dp.int32)) + dtype('int64') + + >>> dp.result_type(dp.int64, dp.complex128) + dtype('complex128') + + >>> dp.result_type(dp.ones(10, dtype=dp.float32), dp.float64) + dtype('float64') + + """ + + usm_arrays_and_dtypes = [ + X.dtype if isinstance(X, (dpnp_array, dpt.usm_ndarray)) else X + for X in arrays_and_dtypes + ] + return dpt.result_type(*usm_arrays_and_dtypes) + + def rollaxis(x1, axis, start=0): """ Roll the specified axis backwards, until it lies in a given position. diff --git a/tests/test_manipulation.py b/tests/test_manipulation.py index 5c535d60f8e..b8ee2cfaa97 100644 --- a/tests/test_manipulation.py +++ b/tests/test_manipulation.py @@ -40,6 +40,25 @@ def test_repeat(arr): assert_array_equal(expected, result) +def test_result_type(): + X = [dpnp.ones((2), dtype=dpnp.int64), dpnp.int32, "float16"] + X_np = [numpy.ones((2), dtype=numpy.int64), numpy.int32, "float16"] + + assert dpnp.result_type(*X) == numpy.result_type(*X_np) + +def test_result_type_only_dtypes(): + X = [dpnp.int64, dpnp.int32, dpnp.bool, dpnp.float32] + X_np = [numpy.int64, numpy.int32, numpy.bool_, numpy.float32] + + assert dpnp.result_type(*X) == numpy.result_type(*X_np) + +def test_result_type_only_arrays(): + X = [dpnp.ones((2), dtype=dpnp.int64), dpnp.ones((7, 4), dtype=dpnp.int32)] + X_np = [numpy.ones((2), dtype=numpy.int64), numpy.ones((7, 4), dtype=numpy.int32)] + + assert dpnp.result_type(*X) == numpy.result_type(*X_np) + + @pytest.mark.usefixtures("allow_fall_back_on_numpy") @pytest.mark.parametrize("array", [[1, 2, 3],