Skip to content

Commit

Permalink
Merge branch 'master' into dpnp_norm
Browse files Browse the repository at this point in the history
  • Loading branch information
vtavana authored Apr 2, 2024
2 parents 55d3e68 + 20264e8 commit 3f549a7
Show file tree
Hide file tree
Showing 4 changed files with 2 additions and 37 deletions.
2 changes: 1 addition & 1 deletion conda-recipe/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ requirements:
host:
- python
- setuptools
- numpy <1.27a0
- numpy
- cython
- cmake >=3.21
- ninja
Expand Down
28 changes: 1 addition & 27 deletions dpnp/dpnp_iface_mathematical.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@


import dpctl.tensor as dpt
import dpctl.utils as du
import numpy
from numpy.core.numeric import (
normalize_axis_index,
Expand Down Expand Up @@ -2266,32 +2265,7 @@ def prod(
"""

# Product reduction for complex output are known to fail for Gen9 with 2024.0 compiler
# TODO: get rid of this temporary work around when OneAPI 2024.1 is released
dpnp.check_supported_arrays_type(a)
_dtypes = (a.dtype, dtype)
_any_complex = any(
dpnp.issubdtype(dt, dpnp.complexfloating) for dt in _dtypes
)
device_mask = (
du.intel_device_info(a.sycl_device).get("device_id", 0) & 0xFF00
)
if _any_complex and device_mask in [0x3E00, 0x9B00]:
res = call_origin(
numpy.prod,
a,
axis=axis,
dtype=dtype,
out=out,
keepdims=keepdims,
initial=initial,
where=where,
)
if dpnp.isscalar(res):
# numpy may return a scalar, convert it back to dpnp array
return dpnp.array(res, sycl_queue=a.sycl_queue, usm_type=a.usm_type)
return res
elif initial is not None:
if initial is not None:
raise NotImplementedError(
"initial keyword argument is only supported with its default value."
)
Expand Down
7 changes: 0 additions & 7 deletions tests/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,9 +231,6 @@ def test_cond(arr, p):


class TestDet:
# TODO: Remove the use of fixture for test_det
# when dpnp.prod() will support complex dtypes on Gen9
@pytest.mark.usefixtures("allow_fall_back_on_numpy")
@pytest.mark.parametrize(
"array",
[
Expand Down Expand Up @@ -1379,9 +1376,6 @@ def test_solve_errors(self):


class TestSlogdet:
# TODO: Remove the use of fixture for test_slogdet_2d and test_slogdet_3d
# when dpnp.prod() will support complex dtypes on Gen9
@pytest.mark.usefixtures("allow_fall_back_on_numpy")
@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True))
def test_slogdet_2d(self, dtype):
a_np = numpy.array([[1, 2], [3, 4]], dtype=dtype)
Expand All @@ -1393,7 +1387,6 @@ def test_slogdet_2d(self, dtype):
assert_allclose(sign_expected, sign_result)
assert_allclose(logdet_expected, logdet_result, rtol=1e-3, atol=1e-4)

@pytest.mark.usefixtures("allow_fall_back_on_numpy")
@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True))
def test_slogdet_3d(self, dtype):
a_np = numpy.array(
Expand Down
2 changes: 0 additions & 2 deletions tests/test_mathematical.py
Original file line number Diff line number Diff line change
Expand Up @@ -752,7 +752,6 @@ def test_positive_boolean():


class TestProd:
@pytest.mark.usefixtures("allow_fall_back_on_numpy")
@pytest.mark.parametrize("func", ["prod", "nanprod"])
@pytest.mark.parametrize("axis", [None, 0, 1, -1, 2, -2, (1, 2), (0, -2)])
@pytest.mark.parametrize("keepdims", [False, True])
Expand Down Expand Up @@ -790,7 +789,6 @@ def test_prod_nanprod_bool(self, func, axis, keepdims):
dpnp_res = getattr(dpnp, func)(ia, axis=axis, keepdims=keepdims)
assert_dtype_allclose(dpnp_res, np_res)

@pytest.mark.usefixtures("allow_fall_back_on_numpy")
@pytest.mark.usefixtures("suppress_complex_warning")
@pytest.mark.usefixtures("suppress_invalid_numpy_warnings")
@pytest.mark.parametrize("func", ["prod", "nanprod"])
Expand Down

0 comments on commit 3f549a7

Please sign in to comment.