Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adopt linalg functions from BLAS extension to asynchronous dpctl execution #1919

Merged
merged 2 commits into from
Jul 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion dpnp/dpnp_iface_mathematical.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,6 +673,7 @@ def clip(a, a_min, a_max, *, out=None, order="K", **kwargs):

if kwargs:
raise NotImplementedError(f"kwargs={kwargs} is currently not supported")

if a_min is None and a_max is None:
raise ValueError("One of max or min must be given")

Expand Down Expand Up @@ -923,11 +924,13 @@ def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None):
if not isinstance(axis, int):
raise TypeError(f"axis should be an integer but got, {type(axis)}.")
axisa, axisb, axisc = (axis,) * 3

dpnp.check_supported_arrays_type(a, b)
if a.dtype == dpnp.bool and b.dtype == dpnp.bool:
raise TypeError(
"Input arrays with boolean data type are not supported."
)

# Check axisa and axisb are within bounds
axisa = normalize_axis_index(axisa, a.ndim, msg_prefix="axisa")
axisb = normalize_axis_index(axisb, b.ndim, msg_prefix="axisb")
Expand All @@ -944,6 +947,7 @@ def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None):
# Modify the shape of input arrays if necessary
a_shape = a.shape
b_shape = b.shape

# TODO: replace with dpnp.broadcast_shapes once implemented
res_shape = numpy.broadcast_shapes(a_shape[:-1], b_shape[:-1])
if a_shape[:-1] != res_shape:
Expand All @@ -957,6 +961,7 @@ def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None):
res_shape += (3,)
# Check axisc is within bounds
axisc = normalize_axis_index(axisc, len(res_shape), msg_prefix="axisc")

# Create the output array
dtype = dpnp.result_type(a, b)
res_usm_type, exec_q = get_usm_allocations([a, b])
Expand All @@ -968,7 +973,7 @@ def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None):
a = a.astype(dtype, copy=False)
b = b.astype(dtype, copy=False)

cp = dpnp_cross(a, b, cp, exec_q)
cp = dpnp_cross(a, b, cp)
if a_shape[-1] == 2 and b_shape[-1] == 2:
return cp

Expand Down Expand Up @@ -3184,6 +3189,9 @@ def sum(
sycl_sum = get_sum(input, output)

if sycl_sum:
# TODO: pass dep events into _get_sum_over_axis_0 to remove sync
dpnp.synchronize_array_data(input)

sycl_sum(input, output, []).wait()
result = dpnp_array._create_from_usm_ndarray(output)

Expand Down
Loading
Loading