Skip to content

Commit

Permalink
resolve gh-1871 (#1872)
Browse files Browse the repository at this point in the history
* update returned result when out is defined with order F

* address comments

* add test for out keyword in einsum

---------

Co-authored-by: Anton <[email protected]>
  • Loading branch information
vtavana and antonwolfy committed Jun 16, 2024
1 parent 77d387d commit 5698c06
Show file tree
Hide file tree
Showing 5 changed files with 191 additions and 35 deletions.
1 change: 0 additions & 1 deletion dpnp/dpnp_iface_linearalgebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -821,7 +821,6 @@ def matmul(
"""

dpnp.check_supported_arrays_type(x1, x2)
if subok is False:
raise NotImplementedError(
"subok keyword argument is only supported by its default value."
Expand Down
101 changes: 73 additions & 28 deletions dpnp/dpnp_utils/dpnp_utils_linearalgebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import dpctl.tensor._tensor_elementwise_impl as tei
import dpctl.tensor._tensor_impl as ti
import numpy
from dpctl.utils import ExecutionPlacementError
from numpy.core.numeric import normalize_axis_tuple

import dpnp
Expand Down Expand Up @@ -218,7 +219,9 @@ def _compute_size(start, shape):
return ret


def _copy_array(x, dep_events, host_events, copy_flag=False, dtype=None):
def _copy_array(
x, dep_events, host_events, copy_flag=False, dtype=None, order="C"
):
"""
Creating a copy of input array if needed.
Expand All @@ -236,7 +239,7 @@ def _copy_array(x, dep_events, host_events, copy_flag=False, dtype=None):
copy = x.dtype != dtype if dtype is not None else False

if copy:
x_copy = dpnp.empty_like(x, dtype=dtype, order="C")
x_copy = dpnp.empty_like(x, dtype=dtype, order=order)
ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
src=dpnp.get_usm_ndarray(x),
dst=x_copy.get_array(),
Expand All @@ -248,7 +251,9 @@ def _copy_array(x, dep_events, host_events, copy_flag=False, dtype=None):
return x


def _create_result_array(x1, x2, out, shape, dtype, usm_type, sycl_queue):
def _create_result_array(
x1, x2, out, shape, dtype, usm_type, sycl_queue, order="C"
):
"""
Create the result array.
Expand All @@ -263,13 +268,12 @@ def _create_result_array(x1, x2, out, shape, dtype, usm_type, sycl_queue):
x1_usm = dpnp.get_usm_ndarray(x1)
x2_usm = dpnp.get_usm_ndarray(x2)
out_usm = dpnp.get_usm_ndarray(out)
contig_flag = _define_contig_flag(out)
contig_flag, _, _ = _define_contig_flag(out)

if (
out.dtype == dtype
and out.shape == shape
and out.usm_type == usm_type
and out.sycl_queue == sycl_queue
and contig_flag
and not ti._array_overlap(x1_usm, out_usm)
and not ti._array_overlap(x2_usm, out_usm)
Expand All @@ -279,6 +283,7 @@ def _create_result_array(x1, x2, out, shape, dtype, usm_type, sycl_queue):
return dpnp.empty(
shape,
dtype=dtype,
order=order,
usm_type=usm_type,
sycl_queue=sycl_queue,
)
Expand All @@ -295,14 +300,14 @@ def _define_contig_flag(x):
x_strides = x.strides
x_shape = x.shape
if x.ndim < 2:
return True
return True, True, True

x_strides = _standardize_strides_to_nonzero(x_strides, x_shape)
x_is_c_contiguous = x_strides[-1] == 1 and x_strides[-2] == x_shape[-1]
x_is_f_contiguous = x_strides[-2] == 1 and x_strides[-1] == x_shape[-2]
if x_is_c_contiguous or x_is_f_contiguous:
flag = True
return flag
return flag, x_is_c_contiguous, x_is_f_contiguous


def _define_dim_flags(x, pos):
Expand Down Expand Up @@ -746,17 +751,26 @@ def _gemm_batch_matmul(exec_q, x1, x2, res, dev_tasks_list):
)
ht_tasks_list.append(ht_blas_ev)
dpctl.SyclEvent.wait_for(ht_tasks_list)

res_shape = res.shape
if not row_major:
res = dpnp.reshape(
res.ravel(), (batch_size, res_shape[2], res_shape[1])
).transpose(0, 2, 1)
_, res_is_c_contig, res_is_f_contig = _define_contig_flag(res)
if row_major:
if res_is_f_contig:
res = dpnp.reshape(
dpnp.ravel(res, order="F"),
(res_shape[1], res_shape[2], batch_size),
).transpose(2, 0, 1)
else:
if res_is_c_contig:
res = dpnp.reshape(
dpnp.ravel(res, order="C"),
(batch_size, res_shape[2], res_shape[1]),
).transpose(0, 2, 1)

if res_shape != orig_shape:
res = res.reshape(orig_shape)

res = dpnp.ascontiguousarray(res)
return res
return dpnp.ascontiguousarray(res)


def _gemm_matmul(exec_q, x1, x2, res, dev_tasks_list):
Expand All @@ -769,14 +783,16 @@ def _gemm_matmul(exec_q, x1, x2, res, dev_tasks_list):
)
ht_blas_ev.wait()

if not row_major:
# TODO: investigate the possibility of defining result
# array with "F" order for this case
res = dpnp.ascontiguousarray(
dpnp.reshape(res.ravel(), res.shape, order="F")
)
if row_major:
if res.flags.f_contiguous is True:
# read data in "F" order and write it in "C" order
res = dpnp.reshape(dpnp.ravel(res, order="F"), res.shape, order="C")
else:
if res.flags.c_contiguous is True:
# read data in "C" order and write it in "F" order
res = dpnp.reshape(dpnp.ravel(res, order="C"), res.shape, order="F")

return res
return dpnp.ascontiguousarray(res)


def _greedy_path(input_sets, output_set, idx_dict, memory_limit):
Expand Down Expand Up @@ -1746,6 +1762,13 @@ def dpnp_dot(a, b, /, out=None, *, conjugate=False):
)

res_usm_type, exec_q = get_usm_allocations([a, b])
if (
out is not None
and dpctl.utils.get_execution_queue((exec_q, out.sycl_queue)) is None
):
raise ExecutionPlacementError(
"Input and output allocation queues are not compatible"
)

# Determine the appropriate data types
dot_dtype, res_dtype = _compute_res_dtype(a, b, sycl_queue=exec_q)
Expand Down Expand Up @@ -1812,6 +1835,12 @@ def dpnp_einsum(
arrays.append(a)

res_usm_type, exec_q = get_usm_allocations(arrays)
if out is not None:
dpnp.check_supported_arrays_type(out)
if dpctl.utils.get_execution_queue((exec_q, out.sycl_queue)) is None:
raise ExecutionPlacementError(
"Input and output allocation queues are not compatible"
)
result_dtype = dpnp.result_type(*arrays) if dtype is None else dtype
for id, a in enumerate(operands):
if dpnp.isscalar(a):
Expand Down Expand Up @@ -2056,10 +2085,17 @@ def dpnp_matmul(
"""

x1_ndim = x1.ndim
x2_ndim = x2.ndim
dpnp.check_supported_arrays_type(x1, x2)
res_usm_type, exec_q = get_usm_allocations([x1, x2])
if out is not None:
dpnp.check_supported_arrays_type(out)
if dpctl.utils.get_execution_queue((exec_q, out.sycl_queue)) is None:
raise ExecutionPlacementError(
"Input and output allocation queues are not compatible"
)

x1_ndim = x1.ndim
x2_ndim = x2.ndim
if axes is not None:
axes = _validate_axes(x1, x2, axes)

Expand All @@ -2072,7 +2108,6 @@ def dpnp_matmul(
x2 = dpnp.moveaxis(x2, axes_x2, (-2, -1)) if x2_ndim != 1 else x2
out_orig = out
if out is not None:
dpnp.check_supported_arrays_type(out)
# out that is passed to the backend should have the correct shape
if len(axes_res) == 2:
out = dpnp.moveaxis(out, axes_res, (-2, -1))
Expand Down Expand Up @@ -2161,8 +2196,18 @@ def dpnp_matmul(
res = dpnp_dot(x1, x2, out=out)
res_shape = res.shape
else:
x1_contig_flag, _, x1_f = _define_contig_flag(x1)
x2_contig_flag, _, x2_f = _define_contig_flag(x2)
res_order = "F" if (x1_f and x2_f and call_flag == "gemm") else "C"
res = _create_result_array(
x1, x2, out, res_shape, compute_dtype, res_usm_type, exec_q
x1,
x2,
out,
res_shape,
compute_dtype,
res_usm_type,
exec_q,
res_order,
)

# calculate result
Expand All @@ -2175,21 +2220,21 @@ def dpnp_matmul(
# their base (last 2-dimensions) to be c-contiguous or f-contiguous
dep_events_list = []
host_tasks_list = []
contig_flag = _define_contig_flag(x1)
x1 = _copy_array(
x1,
dep_events_list,
host_tasks_list,
copy_flag=not contig_flag,
copy_flag=not x1_contig_flag,
dtype=compute_dtype,
order=res_order,
)
contig_flag = _define_contig_flag(x2)
x2 = _copy_array(
x2,
dep_events_list,
host_tasks_list,
copy_flag=not contig_flag,
copy_flag=not x2_contig_flag,
dtype=compute_dtype,
order=res_order,
)

if call_flag == "gemv":
Expand Down
16 changes: 16 additions & 0 deletions tests/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,12 +613,28 @@ def test_einsum_trivial_cases(self):
expected = numpy.einsum("i,i,i", b_np, b_np, b_np, optimize="greedy")
assert_dtype_allclose(result, expected)

def test_einsum_out(self):
a = inp.ones((5, 5))
a_np = a.asnumpy()
out = inp.empty((5,))
out_np = out.asnumpy()
result = inp.einsum("ii->i", a, out=out)
assert result is out
expected = numpy.einsum("ii->i", a_np, out=out_np)
assert_dtype_allclose(result, expected)

def test_einsum_error(self):
a = inp.ones((5, 5))
# unknown keyword argument
with pytest.raises(TypeError):
inp.einsum("ii->i", a, copy=False)

a = inp.ones((5, 5))
out = inp.empty((5,), sycl_queue=dpctl.SyclQueue())
# inconsistent sycl_queue
with pytest.raises(ExecutionPlacementError):
inp.einsum("ii->i", a, out=out)

# unknown value for optimize keyword
with pytest.raises(TypeError):
inp.einsum("ii->i", a, optimize="average")
Expand Down
Loading

0 comments on commit 5698c06

Please sign in to comment.