From 1e8675368d598faf34d120ef5260a386159cd810 Mon Sep 17 00:00:00 2001 From: vlad-perevezentsev Date: Thu, 8 Feb 2024 13:52:36 +0100 Subject: [PATCH] Update dpnp.linalg.qr() function (#1673) * Impl dpnp.linalg.qr for 2d array * Add cupy tests for dpnp.linalg.qr * Add batch implementation of dpnp.linalg.qr * Remove an old impl of dpnp_qr * Update test_qr in test_sycl_queue * Add test_qr in test_usm_type * Use _real_type for _orgqr * Use _real_type for _orgqr_batch * Update dpnp tests for dpnp.linalg.qr * Pass scratchpad_size to the error message test * Add additional checks * Extend error handler for mkl batch funcs * Add ungqr mkl extension to support complex dtype * Update tau array size check for orgqr * Add ungqr_batch mkl extension to support complex dtype * Add arrays type check * Fix test_det_singular_matrix * Expand tests for dpnp.linalg.qr with complex types * Update examples * Remove astype for output arrays * Use empty_like instead of empty * Use ht_list_ev with dpctl.SyclEvent.wait_for * Add _triu_inplace func * Use copy_usm for a_t array overwritten by geqrf/geqrf_batch --------- Co-authored-by: Anton <100830759+antonwolfy@users.noreply.github.com> --- dpnp/backend/extensions/lapack/CMakeLists.txt | 6 + dpnp/backend/extensions/lapack/geqrf.cpp | 262 +++++++++++ dpnp/backend/extensions/lapack/geqrf.hpp | 63 +++ .../backend/extensions/lapack/geqrf_batch.cpp | 273 +++++++++++ dpnp/backend/extensions/lapack/lapack_py.cpp | 55 +++ dpnp/backend/extensions/lapack/orgqr.cpp | 263 +++++++++++ dpnp/backend/extensions/lapack/orgqr.hpp | 67 +++ .../backend/extensions/lapack/orgqr_batch.cpp | 278 ++++++++++++ .../extensions/lapack/types_matrix.hpp | 147 ++++++ dpnp/backend/extensions/lapack/ungqr.cpp | 263 +++++++++++ dpnp/backend/extensions/lapack/ungqr.hpp | 67 +++ .../backend/extensions/lapack/ungqr_batch.cpp | 278 ++++++++++++ dpnp/backend/include/dpnp_iface_fptr.hpp | 2 - dpnp/backend/kernels/dpnp_krnl_linalg.cpp | 34 -- dpnp/dpnp_algo/dpnp_algo.pxd | 2 - dpnp/linalg/dpnp_algo_linalg.pyx | 56 --- dpnp/linalg/dpnp_iface_linalg.py | 70 ++- dpnp/linalg/dpnp_utils_linalg.py | 427 +++++++++++++++++- tests/test_linalg.py | 210 +++++---- tests/test_sycl_queue.py | 56 ++- tests/test_usm_type.py | 37 ++ .../cupy/linalg_tests/test_decomposition.py | 97 +++- 22 files changed, 2767 insertions(+), 246 deletions(-) create mode 100644 dpnp/backend/extensions/lapack/geqrf.cpp create mode 100644 dpnp/backend/extensions/lapack/geqrf.hpp create mode 100644 dpnp/backend/extensions/lapack/geqrf_batch.cpp create mode 100644 dpnp/backend/extensions/lapack/orgqr.cpp create mode 100644 dpnp/backend/extensions/lapack/orgqr.hpp create mode 100644 dpnp/backend/extensions/lapack/orgqr_batch.cpp create mode 100644 dpnp/backend/extensions/lapack/ungqr.cpp create mode 100644 dpnp/backend/extensions/lapack/ungqr.hpp create mode 100644 dpnp/backend/extensions/lapack/ungqr_batch.cpp diff --git a/dpnp/backend/extensions/lapack/CMakeLists.txt b/dpnp/backend/extensions/lapack/CMakeLists.txt index 28fa2072d7d..8f4b35f20ed 100644 --- a/dpnp/backend/extensions/lapack/CMakeLists.txt +++ b/dpnp/backend/extensions/lapack/CMakeLists.txt @@ -27,15 +27,21 @@ set(python_module_name _lapack_impl) set(_module_src ${CMAKE_CURRENT_SOURCE_DIR}/lapack_py.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/geqrf.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/geqrf_batch.cpp ${CMAKE_CURRENT_SOURCE_DIR}/gesv.cpp ${CMAKE_CURRENT_SOURCE_DIR}/gesvd.cpp ${CMAKE_CURRENT_SOURCE_DIR}/getrf.cpp ${CMAKE_CURRENT_SOURCE_DIR}/getrf_batch.cpp ${CMAKE_CURRENT_SOURCE_DIR}/getri_batch.cpp ${CMAKE_CURRENT_SOURCE_DIR}/heevd.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/orgqr.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/orgqr_batch.cpp ${CMAKE_CURRENT_SOURCE_DIR}/potrf.cpp ${CMAKE_CURRENT_SOURCE_DIR}/potrf_batch.cpp ${CMAKE_CURRENT_SOURCE_DIR}/syevd.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/ungqr.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/ungqr_batch.cpp ) pybind11_add_module(${python_module_name} MODULE ${_module_src}) diff --git a/dpnp/backend/extensions/lapack/geqrf.cpp b/dpnp/backend/extensions/lapack/geqrf.cpp new file mode 100644 index 00000000000..a91f689d503 --- /dev/null +++ b/dpnp/backend/extensions/lapack/geqrf.cpp @@ -0,0 +1,262 @@ +//***************************************************************************** +// Copyright (c) 2024, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** + +#include + +// dpctl tensor headers +#include "utils/memory_overlap.hpp" +#include "utils/type_utils.hpp" + +#include "geqrf.hpp" +#include "types_matrix.hpp" + +#include "dpnp_utils.hpp" + +namespace dpnp +{ +namespace backend +{ +namespace ext +{ +namespace lapack +{ +namespace mkl_lapack = oneapi::mkl::lapack; +namespace py = pybind11; +namespace type_utils = dpctl::tensor::type_utils; + +typedef sycl::event (*geqrf_impl_fn_ptr_t)(sycl::queue, + const std::int64_t, + const std::int64_t, + char *, + std::int64_t, + char *, + std::vector &, + const std::vector &); + +static geqrf_impl_fn_ptr_t geqrf_dispatch_vector[dpctl_td_ns::num_types]; + +template +static sycl::event geqrf_impl(sycl::queue exec_q, + const std::int64_t m, + const std::int64_t n, + char *in_a, + std::int64_t lda, + char *in_tau, + std::vector &host_task_events, + const std::vector &depends) +{ + type_utils::validate_type_for_device(exec_q); + + T *a = reinterpret_cast(in_a); + T *tau = reinterpret_cast(in_tau); + + const std::int64_t scratchpad_size = + mkl_lapack::geqrf_scratchpad_size(exec_q, m, n, lda); + T *scratchpad = nullptr; + + std::stringstream error_msg; + std::int64_t info = 0; + bool is_exception_caught = false; + + sycl::event geqrf_event; + try { + scratchpad = sycl::malloc_device(scratchpad_size, exec_q); + + geqrf_event = mkl_lapack::geqrf( + exec_q, + m, // The number of rows in the matrix; (0 ≤ m). + n, // The number of columns in the matrix; (0 ≤ n). + a, // Pointer to the m-by-n matrix. + lda, // The leading dimension of `a`; (1 ≤ m). + tau, // Pointer to the array of scalar factors of the + // elementary reflectors. + scratchpad, // Pointer to scratchpad memory to be used by MKL + // routine for storing intermediate results. + scratchpad_size, depends); + } catch (mkl_lapack::exception const &e) { + is_exception_caught = true; + info = e.info(); + + if (info < 0) { + error_msg << "Parameter number " << -info + << " had an illegal value."; + } + else if (info == scratchpad_size && e.detail() != 0) { + error_msg + << "Insufficient scratchpad size. Required size is at least " + << e.detail() << ", but current size is " << scratchpad_size + << "."; + } + else { + error_msg << "Unexpected MKL exception caught during geqrf() " + "call:\nreason: " + << e.what() << "\ninfo: " << info; + } + } catch (sycl::exception const &e) { + is_exception_caught = true; + error_msg << "Unexpected SYCL exception caught during geqrf() call:\n" + << e.what(); + } + + if (is_exception_caught) // an unexpected error occurs + { + if (scratchpad != nullptr) { + sycl::free(scratchpad, exec_q); + } + throw std::runtime_error(error_msg.str()); + } + + sycl::event clean_up_event = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(geqrf_event); + auto ctx = exec_q.get_context(); + cgh.host_task([ctx, scratchpad]() { sycl::free(scratchpad, ctx); }); + }); + host_task_events.push_back(clean_up_event); + + return geqrf_event; +} + +std::pair + geqrf(sycl::queue q, + dpctl::tensor::usm_ndarray a_array, + dpctl::tensor::usm_ndarray tau_array, + const std::vector &depends) +{ + const int a_array_nd = a_array.get_ndim(); + const int tau_array_nd = tau_array.get_ndim(); + + if (a_array_nd != 2) { + throw py::value_error( + "The input array has ndim=" + std::to_string(a_array_nd) + + ", but a 2-dimensional array is expected."); + } + + if (tau_array_nd != 1) { + throw py::value_error("The array of Householder scalars has ndim=" + + std::to_string(tau_array_nd) + + ", but a 1-dimensional array is expected."); + } + + // check compatibility of execution queue and allocation queue + if (!dpctl::utils::queues_are_compatible(q, {a_array, tau_array})) { + throw py::value_error( + "Execution queue is not compatible with allocation queues"); + } + + auto const &overlap = dpctl::tensor::overlap::MemoryOverlap(); + if (overlap(a_array, tau_array)) { + throw py::value_error( + "The input array and the array of Householder scalars " + "are overlapping segments of memory"); + } + + bool is_a_array_c_contig = a_array.is_c_contiguous(); + if (!is_a_array_c_contig) { + throw py::value_error("The input array " + "must be C-contiguous"); + } + + bool is_tau_array_c_contig = tau_array.is_c_contiguous(); + bool is_tau_array_f_contig = tau_array.is_f_contiguous(); + + if (!is_tau_array_c_contig || !is_tau_array_f_contig) { + throw py::value_error("The array of Householder scalars " + "must be contiguous"); + } + + auto array_types = dpctl_td_ns::usm_ndarray_types(); + int a_array_type_id = + array_types.typenum_to_lookup_id(a_array.get_typenum()); + int tau_array_type_id = + array_types.typenum_to_lookup_id(tau_array.get_typenum()); + + if (a_array_type_id != tau_array_type_id) { + throw py::value_error( + "The types of the input array and " + "the array of Householder scalars are mismatched"); + } + + geqrf_impl_fn_ptr_t geqrf_fn = geqrf_dispatch_vector[a_array_type_id]; + if (geqrf_fn == nullptr) { + throw py::value_error( + "No geqrf implementation defined for the provided type " + "of the input matrix."); + } + + char *a_array_data = a_array.get_data(); + char *tau_array_data = tau_array.get_data(); + + const py::ssize_t *a_array_shape = a_array.get_shape_raw(); + + // The input array is transponded + // Change the order of getting m, n + const std::int64_t m = a_array_shape[1]; + const std::int64_t n = a_array_shape[0]; + const std::int64_t lda = std::max(1UL, m); + + const size_t tau_array_size = tau_array.get_size(); + const size_t min_m_n = std::max(1UL, std::min(m, n)); + + if (tau_array_size != min_m_n) { + throw py::value_error("The array of Householder scalars has size=" + + std::to_string(tau_array_size) + ", but a size=" + + std::to_string(min_m_n) + " array is expected."); + } + + std::vector host_task_events; + sycl::event geqrf_ev = geqrf_fn(q, m, n, a_array_data, lda, tau_array_data, + host_task_events, depends); + + sycl::event args_ev = dpctl::utils::keep_args_alive(q, {a_array, tau_array}, + host_task_events); + + return std::make_pair(args_ev, geqrf_ev); +} + +template +struct GeqrfContigFactory +{ + fnT get() + { + if constexpr (types::GeqrfTypePairSupportFactory::is_defined) { + return geqrf_impl; + } + else { + return nullptr; + } + } +}; + +void init_geqrf_dispatch_vector(void) +{ + dpctl_td_ns::DispatchVectorBuilder + contig; + contig.populate_dispatch_vector(geqrf_dispatch_vector); +} +} // namespace lapack +} // namespace ext +} // namespace backend +} // namespace dpnp diff --git a/dpnp/backend/extensions/lapack/geqrf.hpp b/dpnp/backend/extensions/lapack/geqrf.hpp new file mode 100644 index 00000000000..4ab65286b29 --- /dev/null +++ b/dpnp/backend/extensions/lapack/geqrf.hpp @@ -0,0 +1,63 @@ +//***************************************************************************** +// Copyright (c) 2024, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** + +#pragma once + +#include +#include + +#include + +namespace dpnp +{ +namespace backend +{ +namespace ext +{ +namespace lapack +{ +extern std::pair + geqrf(sycl::queue exec_q, + dpctl::tensor::usm_ndarray a_array, + dpctl::tensor::usm_ndarray tau_array, + const std::vector &depends = {}); + +extern std::pair + geqrf_batch(sycl::queue exec_q, + dpctl::tensor::usm_ndarray a_array, + dpctl::tensor::usm_ndarray tau_array, + std::int64_t m, + std::int64_t n, + std::int64_t stride_a, + std::int64_t stride_tau, + std::int64_t batch_size, + const std::vector &depends = {}); + +extern void init_geqrf_batch_dispatch_vector(void); +extern void init_geqrf_dispatch_vector(void); +} // namespace lapack +} // namespace ext +} // namespace backend +} // namespace dpnp diff --git a/dpnp/backend/extensions/lapack/geqrf_batch.cpp b/dpnp/backend/extensions/lapack/geqrf_batch.cpp new file mode 100644 index 00000000000..a4fe980a539 --- /dev/null +++ b/dpnp/backend/extensions/lapack/geqrf_batch.cpp @@ -0,0 +1,273 @@ +//***************************************************************************** +// Copyright (c) 2024, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** + +#include + +// dpctl tensor headers +#include "utils/memory_overlap.hpp" +#include "utils/type_utils.hpp" + +#include "geqrf.hpp" +#include "types_matrix.hpp" + +#include "dpnp_utils.hpp" + +namespace dpnp +{ +namespace backend +{ +namespace ext +{ +namespace lapack +{ +namespace mkl_lapack = oneapi::mkl::lapack; +namespace py = pybind11; +namespace type_utils = dpctl::tensor::type_utils; + +typedef sycl::event (*geqrf_batch_impl_fn_ptr_t)( + sycl::queue, + std::int64_t, + std::int64_t, + char *, + std::int64_t, + std::int64_t, + char *, + std::int64_t, + std::int64_t, + std::vector &, + const std::vector &); + +static geqrf_batch_impl_fn_ptr_t + geqrf_batch_dispatch_vector[dpctl_td_ns::num_types]; + +template +static sycl::event geqrf_batch_impl(sycl::queue exec_q, + std::int64_t m, + std::int64_t n, + char *in_a, + std::int64_t lda, + std::int64_t stride_a, + char *in_tau, + std::int64_t stride_tau, + std::int64_t batch_size, + std::vector &host_task_events, + const std::vector &depends) +{ + type_utils::validate_type_for_device(exec_q); + + T *a = reinterpret_cast(in_a); + T *tau = reinterpret_cast(in_tau); + + const std::int64_t scratchpad_size = + mkl_lapack::geqrf_batch_scratchpad_size(exec_q, m, n, lda, stride_a, + stride_tau, batch_size); + T *scratchpad = nullptr; + + std::stringstream error_msg; + std::int64_t info = 0; + bool is_exception_caught = false; + + sycl::event geqrf_batch_event; + try { + scratchpad = sycl::malloc_device(scratchpad_size, exec_q); + + geqrf_batch_event = mkl_lapack::geqrf_batch( + exec_q, + m, // The number of rows in each matrix in the batch; (0 ≤ m). + // It must be a non-negative integer. + n, // The number of columns in each matrix in the batch; (0 ≤ n). + // It must be a non-negative integer. + a, // Pointer to the batch of matrices, each of size (m x n). + lda, // The leading dimension of each matrix in the batch. + // For row major layout, lda ≥ max(1, m). + stride_a, // Stride between consecutive matrices in the batch. + tau, // Pointer to the array of scalar factors of the elementary + // reflectors for each matrix in the batch. + stride_tau, // Stride between arrays of scalar factors in the batch. + batch_size, // The number of matrices in the batch. + scratchpad, // Pointer to scratchpad memory to be used by MKL + // routine for storing intermediate results. + scratchpad_size, depends); + } catch (mkl_lapack::exception const &e) { + is_exception_caught = true; + info = e.info(); + + if (info < 0) { + error_msg << "Parameter number " << -info + << " had an illegal value."; + } + else if (info == scratchpad_size && e.detail() != 0) { + error_msg + << "Insufficient scratchpad size. Required size is at least " + << e.detail() << ", but current size is " << scratchpad_size + << "."; + } + else if (info != 0 && e.detail() == 0) { + error_msg << "Error in batch processing. " + "Number of failed calculations: " + << info; + } + else { + error_msg << "Unexpected MKL exception caught during geqrf_batch() " + "call:\nreason: " + << e.what() << "\ninfo: " << e.info(); + } + } catch (sycl::exception const &e) { + is_exception_caught = true; + error_msg + << "Unexpected SYCL exception caught during geqrf_batch() call:\n" + << e.what(); + } + + if (is_exception_caught) // an unexpected error occurs + { + if (scratchpad != nullptr) { + sycl::free(scratchpad, exec_q); + } + + throw std::runtime_error(error_msg.str()); + } + + sycl::event clean_up_event = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(geqrf_batch_event); + auto ctx = exec_q.get_context(); + cgh.host_task([ctx, scratchpad]() { sycl::free(scratchpad, ctx); }); + }); + host_task_events.push_back(clean_up_event); + return geqrf_batch_event; +} + +std::pair + geqrf_batch(sycl::queue q, + dpctl::tensor::usm_ndarray a_array, + dpctl::tensor::usm_ndarray tau_array, + std::int64_t m, + std::int64_t n, + std::int64_t stride_a, + std::int64_t stride_tau, + std::int64_t batch_size, + const std::vector &depends) +{ + const int a_array_nd = a_array.get_ndim(); + const int tau_array_nd = tau_array.get_ndim(); + + if (a_array_nd < 3) { + throw py::value_error( + "The input array has ndim=" + std::to_string(a_array_nd) + + ", but an array with ndim >= 3 is expected."); + } + + if (tau_array_nd != 2) { + throw py::value_error("The array of Householder scalars has ndim=" + + std::to_string(tau_array_nd) + + ", but a 2-dimensional array is expected."); + } + + // check compatibility of execution queue and allocation queue + if (!dpctl::utils::queues_are_compatible(q, {a_array, tau_array})) { + throw py::value_error( + "Execution queue is not compatible with allocation queues"); + } + + auto const &overlap = dpctl::tensor::overlap::MemoryOverlap(); + if (overlap(a_array, tau_array)) { + throw py::value_error( + "The input array and the array of Householder scalars " + "are overlapping segments of memory"); + } + + bool is_a_array_c_contig = a_array.is_c_contiguous(); + bool is_tau_array_c_contig = tau_array.is_c_contiguous(); + if (!is_a_array_c_contig) { + throw py::value_error("The input array " + "must be C-contiguous"); + } + if (!is_tau_array_c_contig) { + throw py::value_error("The array of Householder scalars " + "must be C-contiguous"); + } + + auto array_types = dpctl_td_ns::usm_ndarray_types(); + int a_array_type_id = + array_types.typenum_to_lookup_id(a_array.get_typenum()); + int tau_array_type_id = + array_types.typenum_to_lookup_id(tau_array.get_typenum()); + + if (a_array_type_id != tau_array_type_id) { + throw py::value_error( + "The types of the input array and " + "the array of Householder scalars are mismatched"); + } + + geqrf_batch_impl_fn_ptr_t geqrf_batch_fn = + geqrf_batch_dispatch_vector[a_array_type_id]; + if (geqrf_batch_fn == nullptr) { + throw py::value_error( + "No geqrf_batch implementation defined for the provided type " + "of the input matrix."); + } + + char *a_array_data = a_array.get_data(); + char *tau_array_data = tau_array.get_data(); + + const std::int64_t lda = std::max(1UL, m); + + std::vector host_task_events; + sycl::event geqrf_batch_ev = + geqrf_batch_fn(q, m, n, a_array_data, lda, stride_a, tau_array_data, + stride_tau, batch_size, host_task_events, depends); + + sycl::event args_ev = dpctl::utils::keep_args_alive(q, {a_array, tau_array}, + host_task_events); + + return std::make_pair(args_ev, geqrf_batch_ev); +} + +template +struct GeqrfBatchContigFactory +{ + fnT get() + { + if constexpr (types::GeqrfBatchTypePairSupportFactory::is_defined) { + return geqrf_batch_impl; + } + else { + return nullptr; + } + } +}; + +void init_geqrf_batch_dispatch_vector(void) +{ + dpctl_td_ns::DispatchVectorBuilder + contig; + contig.populate_dispatch_vector(geqrf_batch_dispatch_vector); +} +} // namespace lapack +} // namespace ext +} // namespace backend +} // namespace dpnp diff --git a/dpnp/backend/extensions/lapack/lapack_py.cpp b/dpnp/backend/extensions/lapack/lapack_py.cpp index 0c76d0fc096..eb815ac9f6b 100644 --- a/dpnp/backend/extensions/lapack/lapack_py.cpp +++ b/dpnp/backend/extensions/lapack/lapack_py.cpp @@ -30,14 +30,17 @@ #include #include +#include "geqrf.hpp" #include "gesv.hpp" #include "gesvd.hpp" #include "getrf.hpp" #include "getri.hpp" #include "heevd.hpp" #include "linalg_exceptions.hpp" +#include "orgqr.hpp" #include "potrf.hpp" #include "syevd.hpp" +#include "ungqr.hpp" namespace lapack_ext = dpnp::backend::ext::lapack; namespace py = pybind11; @@ -45,13 +48,19 @@ namespace py = pybind11; // populate dispatch vectors void init_dispatch_vectors(void) { + lapack_ext::init_geqrf_batch_dispatch_vector(); + lapack_ext::init_geqrf_dispatch_vector(); lapack_ext::init_gesv_dispatch_vector(); lapack_ext::init_getrf_batch_dispatch_vector(); lapack_ext::init_getrf_dispatch_vector(); lapack_ext::init_getri_batch_dispatch_vector(); + lapack_ext::init_orgqr_batch_dispatch_vector(); + lapack_ext::init_orgqr_dispatch_vector(); lapack_ext::init_potrf_batch_dispatch_vector(); lapack_ext::init_potrf_dispatch_vector(); lapack_ext::init_syevd_dispatch_vector(); + lapack_ext::init_ungqr_batch_dispatch_vector(); + lapack_ext::init_ungqr_dispatch_vector(); } // populate dispatch tables @@ -71,6 +80,20 @@ PYBIND11_MODULE(_lapack_impl, m) init_dispatch_vectors(); init_dispatch_tables(); + m.def("_geqrf_batch", &lapack_ext::geqrf_batch, + "Call `geqrf_batch` from OneMKL LAPACK library to return " + "the QR factorization of a batch general matrix ", + py::arg("sycl_queue"), py::arg("a_array"), py::arg("tau_array"), + py::arg("m"), py::arg("n"), py::arg("stride_a"), + py::arg("stride_tau"), py::arg("batch_size"), + py::arg("depends") = py::list()); + + m.def("_geqrf", &lapack_ext::geqrf, + "Call `geqrf` from OneMKL LAPACK library to return " + "the QR factorization of a general m x n matrix ", + py::arg("sycl_queue"), py::arg("a_array"), py::arg("tau_array"), + py::arg("depends") = py::list()); + m.def("_gesv", &lapack_ext::gesv, "Call `gesv` from OneMKL LAPACK library to return " "the solution of a system of linear equations with " @@ -114,6 +137,22 @@ PYBIND11_MODULE(_lapack_impl, m) py::arg("eig_vecs"), py::arg("eig_vals"), py::arg("depends") = py::list()); + m.def("_orgqr_batch", &lapack_ext::orgqr_batch, + "Call `_orgqr_batch` from OneMKL LAPACK library to return " + "the real orthogonal matrix Qi of the QR factorization " + "for a batch of general matrices", + py::arg("sycl_queue"), py::arg("a_array"), py::arg("tau_array"), + py::arg("m"), py::arg("n"), py::arg("k"), py::arg("stride_a"), + py::arg("stride_tau"), py::arg("batch_size"), + py::arg("depends") = py::list()); + + m.def("_orgqr", &lapack_ext::orgqr, + "Call `orgqr` from OneMKL LAPACK library to return " + "the real orthogonal matrix Q of the QR factorization", + py::arg("sycl_queue"), py::arg("m"), py::arg("n"), py::arg("k"), + py::arg("a_array"), py::arg("tau_array"), + py::arg("depends") = py::list()); + m.def("_potrf", &lapack_ext::potrf, "Call `potrf` from OneMKL LAPACK library to return " "the Cholesky factorization of a symmetric positive-definite matrix", @@ -134,4 +173,20 @@ PYBIND11_MODULE(_lapack_impl, m) py::arg("sycl_queue"), py::arg("jobz"), py::arg("upper_lower"), py::arg("eig_vecs"), py::arg("eig_vals"), py::arg("depends") = py::list()); + + m.def("_ungqr_batch", &lapack_ext::ungqr_batch, + "Call `_ungqr_batch` from OneMKL LAPACK library to return " + "the complex unitary matrices matrix Qi of the QR factorization " + "for a batch of general matrices", + py::arg("sycl_queue"), py::arg("a_array"), py::arg("tau_array"), + py::arg("m"), py::arg("n"), py::arg("k"), py::arg("stride_a"), + py::arg("stride_tau"), py::arg("batch_size"), + py::arg("depends") = py::list()); + + m.def("_ungqr", &lapack_ext::ungqr, + "Call `ungqr` from OneMKL LAPACK library to return " + "the complex unitary matrix Q of the QR factorization", + py::arg("sycl_queue"), py::arg("m"), py::arg("n"), py::arg("k"), + py::arg("a_array"), py::arg("tau_array"), + py::arg("depends") = py::list()); } diff --git a/dpnp/backend/extensions/lapack/orgqr.cpp b/dpnp/backend/extensions/lapack/orgqr.cpp new file mode 100644 index 00000000000..22cbbe05bee --- /dev/null +++ b/dpnp/backend/extensions/lapack/orgqr.cpp @@ -0,0 +1,263 @@ +//***************************************************************************** +// Copyright (c) 2024, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** + +#include + +// dpctl tensor headers +#include "utils/memory_overlap.hpp" +#include "utils/type_utils.hpp" + +#include "orgqr.hpp" +#include "types_matrix.hpp" + +#include "dpnp_utils.hpp" + +namespace dpnp +{ +namespace backend +{ +namespace ext +{ +namespace lapack +{ +namespace mkl_lapack = oneapi::mkl::lapack; +namespace py = pybind11; +namespace type_utils = dpctl::tensor::type_utils; + +typedef sycl::event (*orgqr_impl_fn_ptr_t)(sycl::queue, + const std::int64_t, + const std::int64_t, + const std::int64_t, + char *, + std::int64_t, + char *, + std::vector &, + const std::vector &); + +static orgqr_impl_fn_ptr_t orgqr_dispatch_vector[dpctl_td_ns::num_types]; + +template +static sycl::event orgqr_impl(sycl::queue exec_q, + const std::int64_t m, + const std::int64_t n, + const std::int64_t k, + char *in_a, + std::int64_t lda, + char *in_tau, + std::vector &host_task_events, + const std::vector &depends) +{ + type_utils::validate_type_for_device(exec_q); + + T *a = reinterpret_cast(in_a); + T *tau = reinterpret_cast(in_tau); + + const std::int64_t scratchpad_size = + mkl_lapack::orgqr_scratchpad_size(exec_q, m, n, k, lda); + T *scratchpad = nullptr; + + std::stringstream error_msg; + std::int64_t info = 0; + bool is_exception_caught = false; + + sycl::event orgqr_event; + try { + scratchpad = sycl::malloc_device(scratchpad_size, exec_q); + + orgqr_event = mkl_lapack::orgqr( + exec_q, + m, // The number of rows in the matrix; (0 ≤ m). + n, // The number of columns in the matrix; (0 ≤ n). + k, // The number of elementary reflectors + // whose product defines the matrix Q; (0 ≤ k ≤ n). + a, // Pointer to the m-by-n matrix. + lda, // The leading dimension of `a`; (1 ≤ m). + tau, // Pointer to the array of scalar factors of the + // elementary reflectors. + scratchpad, // Pointer to scratchpad memory to be used by MKL + // routine for storing intermediate results. + scratchpad_size, depends); + } catch (mkl_lapack::exception const &e) { + is_exception_caught = true; + info = e.info(); + + if (info < 0) { + error_msg << "Parameter number " << -info + << " had an illegal value."; + } + else if (info == scratchpad_size && e.detail() != 0) { + error_msg + << "Insufficient scratchpad size. Required size is at least " + << e.detail() << ", but current size is " << scratchpad_size + << "."; + } + else { + error_msg << "Unexpected MKL exception caught during orgqr() " + "call:\nreason: " + << e.what() << "\ninfo: " << info; + } + } catch (sycl::exception const &e) { + is_exception_caught = true; + error_msg << "Unexpected SYCL exception caught during orfqr() call:\n" + << e.what(); + } + + if (is_exception_caught) // an unexpected error occurs + { + if (scratchpad != nullptr) { + sycl::free(scratchpad, exec_q); + } + throw std::runtime_error(error_msg.str()); + } + + sycl::event clean_up_event = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(orgqr_event); + auto ctx = exec_q.get_context(); + cgh.host_task([ctx, scratchpad]() { sycl::free(scratchpad, ctx); }); + }); + host_task_events.push_back(clean_up_event); + + return orgqr_event; +} + +std::pair + orgqr(sycl::queue q, + const std::int64_t m, + const std::int64_t n, + const std::int64_t k, + dpctl::tensor::usm_ndarray a_array, + dpctl::tensor::usm_ndarray tau_array, + const std::vector &depends) +{ + const int a_array_nd = a_array.get_ndim(); + const int tau_array_nd = tau_array.get_ndim(); + + if (a_array_nd != 2) { + throw py::value_error( + "The input array has ndim=" + std::to_string(a_array_nd) + + ", but a 2-dimensional array is expected."); + } + + if (tau_array_nd != 1) { + throw py::value_error("The array of Householder scalars has ndim=" + + std::to_string(tau_array_nd) + + ", but a 1-dimensional array is expected."); + } + + // check compatibility of execution queue and allocation queue + if (!dpctl::utils::queues_are_compatible(q, {a_array, tau_array})) { + throw py::value_error( + "Execution queue is not compatible with allocation queues"); + } + + auto const &overlap = dpctl::tensor::overlap::MemoryOverlap(); + if (overlap(a_array, tau_array)) { + throw py::value_error( + "The input array and the array of Householder scalars " + "are overlapping segments of memory"); + } + + bool is_a_array_c_contig = a_array.is_c_contiguous(); + if (!is_a_array_c_contig) { + throw py::value_error("The input array " + "must be C-contiguous"); + } + + bool is_tau_array_c_contig = tau_array.is_c_contiguous(); + bool is_tau_array_f_contig = tau_array.is_f_contiguous(); + + if (!is_tau_array_c_contig || !is_tau_array_f_contig) { + throw py::value_error("The array of Householder scalars " + "must be contiguous"); + } + + const size_t tau_array_size = tau_array.get_size(); + + if (static_cast(tau_array_size) != k) { + throw py::value_error("The array of Householder scalars has size=" + + std::to_string(tau_array_size) + + ", but an array of size=" + std::to_string(k) + + " is expected."); + } + + auto array_types = dpctl_td_ns::usm_ndarray_types(); + int a_array_type_id = + array_types.typenum_to_lookup_id(a_array.get_typenum()); + int tau_array_type_id = + array_types.typenum_to_lookup_id(tau_array.get_typenum()); + + if (a_array_type_id != tau_array_type_id) { + throw py::value_error( + "The types of the input array and " + "the array of Householder scalars are mismatched"); + } + + orgqr_impl_fn_ptr_t orgqr_fn = orgqr_dispatch_vector[a_array_type_id]; + if (orgqr_fn == nullptr) { + throw py::value_error( + "No orgqr implementation defined for the provided type " + "of the input matrix."); + } + + char *a_array_data = a_array.get_data(); + const std::int64_t lda = std::max(1UL, m); + + char *tau_array_data = tau_array.get_data(); + + std::vector host_task_events; + sycl::event orgqr_ev = orgqr_fn(q, m, n, k, a_array_data, lda, + tau_array_data, host_task_events, depends); + + sycl::event args_ev = dpctl::utils::keep_args_alive(q, {a_array, tau_array}, + host_task_events); + + return std::make_pair(args_ev, orgqr_ev); +} + +template +struct OrgqrContigFactory +{ + fnT get() + { + if constexpr (types::OrgqrTypePairSupportFactory::is_defined) { + return orgqr_impl; + } + else { + return nullptr; + } + } +}; + +void init_orgqr_dispatch_vector(void) +{ + dpctl_td_ns::DispatchVectorBuilder + contig; + contig.populate_dispatch_vector(orgqr_dispatch_vector); +} +} // namespace lapack +} // namespace ext +} // namespace backend +} // namespace dpnp diff --git a/dpnp/backend/extensions/lapack/orgqr.hpp b/dpnp/backend/extensions/lapack/orgqr.hpp new file mode 100644 index 00000000000..9cc4f530d03 --- /dev/null +++ b/dpnp/backend/extensions/lapack/orgqr.hpp @@ -0,0 +1,67 @@ +//***************************************************************************** +// Copyright (c) 2024, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** + +#pragma once + +#include +#include + +#include + +namespace dpnp +{ +namespace backend +{ +namespace ext +{ +namespace lapack +{ +extern std::pair + orgqr(sycl::queue exec_q, + const std::int64_t m, + const std::int64_t n, + const std::int64_t k, + dpctl::tensor::usm_ndarray a_array, + dpctl::tensor::usm_ndarray tau_array, + const std::vector &depends = {}); + +extern std::pair + orgqr_batch(sycl::queue exec_q, + dpctl::tensor::usm_ndarray a_array, + dpctl::tensor::usm_ndarray tau_array, + std::int64_t m, + std::int64_t n, + std::int64_t k, + std::int64_t stride_a, + std::int64_t stride_tau, + std::int64_t batch_size, + const std::vector &depends = {}); + +extern void init_orgqr_batch_dispatch_vector(void); +extern void init_orgqr_dispatch_vector(void); +} // namespace lapack +} // namespace ext +} // namespace backend +} // namespace dpnp diff --git a/dpnp/backend/extensions/lapack/orgqr_batch.cpp b/dpnp/backend/extensions/lapack/orgqr_batch.cpp new file mode 100644 index 00000000000..dfa9932a8e0 --- /dev/null +++ b/dpnp/backend/extensions/lapack/orgqr_batch.cpp @@ -0,0 +1,278 @@ +//***************************************************************************** +// Copyright (c) 2024, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** + +#include + +// dpctl tensor headers +#include "utils/memory_overlap.hpp" +#include "utils/type_utils.hpp" + +#include "orgqr.hpp" +#include "types_matrix.hpp" + +#include "dpnp_utils.hpp" + +namespace dpnp +{ +namespace backend +{ +namespace ext +{ +namespace lapack +{ +namespace mkl_lapack = oneapi::mkl::lapack; +namespace py = pybind11; +namespace type_utils = dpctl::tensor::type_utils; + +typedef sycl::event (*orgqr_batch_impl_fn_ptr_t)( + sycl::queue, + std::int64_t, + std::int64_t, + std::int64_t, + char *, + std::int64_t, + std::int64_t, + char *, + std::int64_t, + std::int64_t, + std::vector &, + const std::vector &); + +static orgqr_batch_impl_fn_ptr_t + orgqr_batch_dispatch_vector[dpctl_td_ns::num_types]; + +template +static sycl::event orgqr_batch_impl(sycl::queue exec_q, + std::int64_t m, + std::int64_t n, + std::int64_t k, + char *in_a, + std::int64_t lda, + std::int64_t stride_a, + char *in_tau, + std::int64_t stride_tau, + std::int64_t batch_size, + std::vector &host_task_events, + const std::vector &depends) +{ + type_utils::validate_type_for_device(exec_q); + + T *a = reinterpret_cast(in_a); + T *tau = reinterpret_cast(in_tau); + + const std::int64_t scratchpad_size = + mkl_lapack::orgqr_batch_scratchpad_size( + exec_q, m, n, k, lda, stride_a, stride_tau, batch_size); + T *scratchpad = nullptr; + + std::stringstream error_msg; + std::int64_t info = 0; + bool is_exception_caught = false; + + sycl::event orgqr_batch_event; + try { + scratchpad = sycl::malloc_device(scratchpad_size, exec_q); + + orgqr_batch_event = mkl_lapack::orgqr_batch( + exec_q, + m, // The number of rows in each matrix in the batch; (0 ≤ m). + // It must be a non-negative integer. + n, // The number of columns in each matrix in the batch; (0 ≤ n). + // It must be a non-negative integer. + k, // The number of elementary reflectors + // whose product defines the matrices Qi; (0 ≤ k ≤ n). + a, // Pointer to the batch of matrices, each of size (m x n). + lda, // The leading dimension of each matrix in the batch. + // For row major layout, lda ≥ max(1, m). + stride_a, // Stride between consecutive matrices in the batch. + tau, // Pointer to the array of scalar factors of the elementary + // reflectors for each matrix in the batch. + stride_tau, // Stride between arrays of scalar factors in the batch. + batch_size, // The number of matrices in the batch. + scratchpad, // Pointer to scratchpad memory to be used by MKL + // routine for storing intermediate results. + scratchpad_size, depends); + } catch (mkl_lapack::exception const &e) { + is_exception_caught = true; + info = e.info(); + + if (info < 0) { + error_msg << "Parameter number " << -info + << " had an illegal value."; + } + else if (info == scratchpad_size && e.detail() != 0) { + error_msg + << "Insufficient scratchpad size. Required size is at least " + << e.detail() << ", but current size is " << scratchpad_size + << "."; + } + else if (info != 0 && e.detail() == 0) { + error_msg << "Error in batch processing. " + "Number of failed calculations: " + << info; + } + else { + error_msg << "Unexpected MKL exception caught during orgqr_batch() " + "call:\nreason: " + << e.what() << "\ninfo: " << e.info(); + } + } catch (sycl::exception const &e) { + is_exception_caught = true; + error_msg + << "Unexpected SYCL exception caught during orgqr_batch() call:\n" + << e.what(); + } + + if (is_exception_caught) // an unexpected error occurs + { + if (scratchpad != nullptr) { + sycl::free(scratchpad, exec_q); + } + + throw std::runtime_error(error_msg.str()); + } + + sycl::event clean_up_event = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(orgqr_batch_event); + auto ctx = exec_q.get_context(); + cgh.host_task([ctx, scratchpad]() { sycl::free(scratchpad, ctx); }); + }); + host_task_events.push_back(clean_up_event); + return orgqr_batch_event; +} + +std::pair + orgqr_batch(sycl::queue q, + dpctl::tensor::usm_ndarray a_array, + dpctl::tensor::usm_ndarray tau_array, + std::int64_t m, + std::int64_t n, + std::int64_t k, + std::int64_t stride_a, + std::int64_t stride_tau, + std::int64_t batch_size, + const std::vector &depends) +{ + const int a_array_nd = a_array.get_ndim(); + const int tau_array_nd = tau_array.get_ndim(); + + if (a_array_nd < 3) { + throw py::value_error( + "The input array has ndim=" + std::to_string(a_array_nd) + + ", but an array with ndim >= 3 is expected."); + } + + if (tau_array_nd != 2) { + throw py::value_error("The array of Householder scalars has ndim=" + + std::to_string(tau_array_nd) + + ", but a 2-dimensional array is expected."); + } + + // check compatibility of execution queue and allocation queue + if (!dpctl::utils::queues_are_compatible(q, {a_array, tau_array})) { + throw py::value_error( + "Execution queue is not compatible with allocation queues"); + } + + auto const &overlap = dpctl::tensor::overlap::MemoryOverlap(); + if (overlap(a_array, tau_array)) { + throw py::value_error( + "The input array and the array of Householder scalars " + "are overlapping segments of memory"); + } + + bool is_a_array_c_contig = a_array.is_c_contiguous(); + bool is_tau_array_c_contig = tau_array.is_c_contiguous(); + if (!is_a_array_c_contig) { + throw py::value_error("The input array " + "must be C-contiguous"); + } + if (!is_tau_array_c_contig) { + throw py::value_error("The array of Householder scalars " + "must be C-contiguous"); + } + + auto array_types = dpctl_td_ns::usm_ndarray_types(); + int a_array_type_id = + array_types.typenum_to_lookup_id(a_array.get_typenum()); + int tau_array_type_id = + array_types.typenum_to_lookup_id(tau_array.get_typenum()); + + if (a_array_type_id != tau_array_type_id) { + throw py::value_error( + "The types of the input array and " + "the array of Householder scalars are mismatched"); + } + + orgqr_batch_impl_fn_ptr_t orgqr_batch_fn = + orgqr_batch_dispatch_vector[a_array_type_id]; + if (orgqr_batch_fn == nullptr) { + throw py::value_error( + "No orgqr_batch implementation defined for the provided type " + "of the input matrix."); + } + + char *a_array_data = a_array.get_data(); + char *tau_array_data = tau_array.get_data(); + + const std::int64_t lda = std::max(1UL, m); + + std::vector host_task_events; + sycl::event orgqr_batch_ev = + orgqr_batch_fn(q, m, n, k, a_array_data, lda, stride_a, tau_array_data, + stride_tau, batch_size, host_task_events, depends); + + sycl::event args_ev = dpctl::utils::keep_args_alive(q, {a_array, tau_array}, + host_task_events); + + return std::make_pair(args_ev, orgqr_batch_ev); +} + +template +struct OrgqrBatchContigFactory +{ + fnT get() + { + if constexpr (types::OrgqrBatchTypePairSupportFactory::is_defined) { + return orgqr_batch_impl; + } + else { + return nullptr; + } + } +}; + +void init_orgqr_batch_dispatch_vector(void) +{ + dpctl_td_ns::DispatchVectorBuilder + contig; + contig.populate_dispatch_vector(orgqr_batch_dispatch_vector); +} +} // namespace lapack +} // namespace ext +} // namespace backend +} // namespace dpnp diff --git a/dpnp/backend/extensions/lapack/types_matrix.hpp b/dpnp/backend/extensions/lapack/types_matrix.hpp index 893619e6afb..9a0ab36c8a4 100644 --- a/dpnp/backend/extensions/lapack/types_matrix.hpp +++ b/dpnp/backend/extensions/lapack/types_matrix.hpp @@ -43,6 +43,61 @@ namespace lapack { namespace types { +/** + * @brief A factory to define pairs of supported types for which + * MKL LAPACK library provides support in oneapi::mkl::lapack::geqrf_batch + * function. + * + * @tparam T Type of array containing the input matrices to be QR factorized in + * batch mode. Upon execution, each matrix in the batch is transformed to output + * arrays representing their respective orthogonal matrix Q and upper triangular + * matrix R. + */ +template +struct GeqrfBatchTypePairSupportFactory +{ + static constexpr bool is_defined = std::disjunction< + dpctl_td_ns::TypePairDefinedEntry, + dpctl_td_ns::TypePairDefinedEntry, + dpctl_td_ns::TypePairDefinedEntry, + T, + std::complex>, + dpctl_td_ns::TypePairDefinedEntry, + T, + std::complex>, + // fall-through + dpctl_td_ns::NotDefinedEntry>::is_defined; +}; + +/** + * @brief A factory to define pairs of supported types for which + * MKL LAPACK library provides support in oneapi::mkl::lapack::geqrf + * function. + * + * @tparam T Type of array containing the input matrix to be QR factorized. + * Upon execution, this matrix is transformed to output arrays representing + * the orthogonal matrix Q and the upper triangular matrix R. + */ +template +struct GeqrfTypePairSupportFactory +{ + static constexpr bool is_defined = std::disjunction< + dpctl_td_ns::TypePairDefinedEntry, + dpctl_td_ns::TypePairDefinedEntry, + dpctl_td_ns::TypePairDefinedEntry, + T, + std::complex>, + dpctl_td_ns::TypePairDefinedEntry, + T, + std::complex>, + // fall-through + dpctl_td_ns::NotDefinedEntry>::is_defined; +}; + /** * @brief A factory to define pairs of supported types for which * MKL LAPACK library provides support in oneapi::mkl::lapack::gesv @@ -190,6 +245,46 @@ struct HeevdTypePairSupportFactory dpctl_td_ns::NotDefinedEntry>::is_defined; }; +/** + * @brief A factory to define pairs of supported types for which + * MKL LAPACK library provides support in oneapi::mkl::lapack::orgqr_batch + * function. + * + * @tparam T Type of array containing the matrix A, + * each from a separate instance in the batch, from which the + * elementary reflectors were generated (as in QR factorization). + * Upon execution, each array in the batch is overwritten with + * its respective orthonormal matrix Q. + */ +template +struct OrgqrBatchTypePairSupportFactory +{ + static constexpr bool is_defined = std::disjunction< + dpctl_td_ns::TypePairDefinedEntry, + dpctl_td_ns::TypePairDefinedEntry, + // fall-through + dpctl_td_ns::NotDefinedEntry>::is_defined; +}; + +/** + * @brief A factory to define pairs of supported types for which + * MKL LAPACK library provides support in oneapi::mkl::lapack::orgqr + * function. + * + * @tparam T Type of array containing the matrix A from which the + * elementary reflectors were generated (as in QR factorization). + * Upon execution, the array is overwritten with the orthonormal matrix Q. + */ +template +struct OrgqrTypePairSupportFactory +{ + static constexpr bool is_defined = std::disjunction< + dpctl_td_ns::TypePairDefinedEntry, + dpctl_td_ns::TypePairDefinedEntry, + // fall-through + dpctl_td_ns::NotDefinedEntry>::is_defined; +}; + /** * @brief A factory to define pairs of supported types for which * MKL LAPACK library provides support in oneapi::mkl::lapack::potrf @@ -259,6 +354,58 @@ struct SyevdTypePairSupportFactory // fall-through dpctl_td_ns::NotDefinedEntry>::is_defined; }; + +/** + * @brief A factory to define pairs of supported types for which + * MKL LAPACK library provides support in oneapi::mkl::lapack::ungqr_batch + * function. + * + * @tparam T Type of array containing the matrix A, + * each from a separate instance in the batch, from which the + * elementary reflectors were generated (as in QR factorization). + * Upon execution, each array in the batch is overwritten with + * its respective complex unitary matrix Q. + */ +template +struct UngqrBatchTypePairSupportFactory +{ + static constexpr bool is_defined = std::disjunction< + dpctl_td_ns::TypePairDefinedEntry, + T, + std::complex>, + dpctl_td_ns::TypePairDefinedEntry, + T, + std::complex>, + // fall-through + dpctl_td_ns::NotDefinedEntry>::is_defined; +}; + +/** + * @brief A factory to define pairs of supported types for which + * MKL LAPACK library provides support in oneapi::mkl::lapack::ungqr + * function. + * + * @tparam T Type of array containing the matrix A from which the + * elementary reflectors were generated (as in QR factorization). + * Upon execution, the array is overwritten with the complex unitary matrix Q. + */ +template +struct UngqrTypePairSupportFactory +{ + static constexpr bool is_defined = std::disjunction< + dpctl_td_ns::TypePairDefinedEntry, + T, + std::complex>, + dpctl_td_ns::TypePairDefinedEntry, + T, + std::complex>, + // fall-through + dpctl_td_ns::NotDefinedEntry>::is_defined; +}; } // namespace types } // namespace lapack } // namespace ext diff --git a/dpnp/backend/extensions/lapack/ungqr.cpp b/dpnp/backend/extensions/lapack/ungqr.cpp new file mode 100644 index 00000000000..7c8dea4e950 --- /dev/null +++ b/dpnp/backend/extensions/lapack/ungqr.cpp @@ -0,0 +1,263 @@ +//***************************************************************************** +// Copyright (c) 2024, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** + +#include + +// dpctl tensor headers +#include "utils/memory_overlap.hpp" +#include "utils/type_utils.hpp" + +#include "types_matrix.hpp" +#include "ungqr.hpp" + +#include "dpnp_utils.hpp" + +namespace dpnp +{ +namespace backend +{ +namespace ext +{ +namespace lapack +{ +namespace mkl_lapack = oneapi::mkl::lapack; +namespace py = pybind11; +namespace type_utils = dpctl::tensor::type_utils; + +typedef sycl::event (*ungqr_impl_fn_ptr_t)(sycl::queue, + const std::int64_t, + const std::int64_t, + const std::int64_t, + char *, + std::int64_t, + char *, + std::vector &, + const std::vector &); + +static ungqr_impl_fn_ptr_t ungqr_dispatch_vector[dpctl_td_ns::num_types]; + +template +static sycl::event ungqr_impl(sycl::queue exec_q, + const std::int64_t m, + const std::int64_t n, + const std::int64_t k, + char *in_a, + std::int64_t lda, + char *in_tau, + std::vector &host_task_events, + const std::vector &depends) +{ + type_utils::validate_type_for_device(exec_q); + + T *a = reinterpret_cast(in_a); + T *tau = reinterpret_cast(in_tau); + + const std::int64_t scratchpad_size = + mkl_lapack::ungqr_scratchpad_size(exec_q, m, n, k, lda); + T *scratchpad = nullptr; + + std::stringstream error_msg; + std::int64_t info = 0; + bool is_exception_caught = false; + + sycl::event ungqr_event; + try { + scratchpad = sycl::malloc_device(scratchpad_size, exec_q); + + ungqr_event = mkl_lapack::ungqr( + exec_q, + m, // The number of rows in the matrix; (0 ≤ m). + n, // The number of columns in the matrix; (0 ≤ n). + k, // The number of elementary reflectors + // whose product defines the matrix Q; (0 ≤ k ≤ n). + a, // Pointer to the m-by-n matrix. + lda, // The leading dimension of `a`; (1 ≤ m). + tau, // Pointer to the array of scalar factors of the + // elementary reflectors. + scratchpad, // Pointer to scratchpad memory to be used by MKL + // routine for storing intermediate results. + scratchpad_size, depends); + } catch (mkl_lapack::exception const &e) { + is_exception_caught = true; + info = e.info(); + + if (info < 0) { + error_msg << "Parameter number " << -info + << " had an illegal value."; + } + else if (info == scratchpad_size && e.detail() != 0) { + error_msg + << "Insufficient scratchpad size. Required size is at least " + << e.detail() << ", but current size is " << scratchpad_size + << "."; + } + else { + error_msg << "Unexpected MKL exception caught during ungqr() " + "call:\nreason: " + << e.what() << "\ninfo: " << info; + } + } catch (sycl::exception const &e) { + is_exception_caught = true; + error_msg << "Unexpected SYCL exception caught during orfqr() call:\n" + << e.what(); + } + + if (is_exception_caught) // an unexpected error occurs + { + if (scratchpad != nullptr) { + sycl::free(scratchpad, exec_q); + } + throw std::runtime_error(error_msg.str()); + } + + sycl::event clean_up_event = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(ungqr_event); + auto ctx = exec_q.get_context(); + cgh.host_task([ctx, scratchpad]() { sycl::free(scratchpad, ctx); }); + }); + host_task_events.push_back(clean_up_event); + + return ungqr_event; +} + +std::pair + ungqr(sycl::queue q, + const std::int64_t m, + const std::int64_t n, + const std::int64_t k, + dpctl::tensor::usm_ndarray a_array, + dpctl::tensor::usm_ndarray tau_array, + const std::vector &depends) +{ + const int a_array_nd = a_array.get_ndim(); + const int tau_array_nd = tau_array.get_ndim(); + + if (a_array_nd != 2) { + throw py::value_error( + "The input array has ndim=" + std::to_string(a_array_nd) + + ", but a 2-dimensional array is expected."); + } + + if (tau_array_nd != 1) { + throw py::value_error("The array of Householder scalars has ndim=" + + std::to_string(tau_array_nd) + + ", but a 1-dimensional array is expected."); + } + + // check compatibility of execution queue and allocation queue + if (!dpctl::utils::queues_are_compatible(q, {a_array, tau_array})) { + throw py::value_error( + "Execution queue is not compatible with allocation queues"); + } + + auto const &overlap = dpctl::tensor::overlap::MemoryOverlap(); + if (overlap(a_array, tau_array)) { + throw py::value_error( + "The input array and the array of Householder scalars " + "are overlapping segments of memory"); + } + + bool is_a_array_c_contig = a_array.is_c_contiguous(); + if (!is_a_array_c_contig) { + throw py::value_error("The input array " + "must be C-contiguous"); + } + + bool is_tau_array_c_contig = tau_array.is_c_contiguous(); + bool is_tau_array_f_contig = tau_array.is_f_contiguous(); + + if (!is_tau_array_c_contig || !is_tau_array_f_contig) { + throw py::value_error("The array of Householder scalars " + "must be contiguous"); + } + + const size_t tau_array_size = tau_array.get_size(); + + if (static_cast(tau_array_size) != k) { + throw py::value_error("The array of Householder scalars has size=" + + std::to_string(tau_array_size) + + ", but an array of size=" + std::to_string(k) + + " is expected."); + } + + auto array_types = dpctl_td_ns::usm_ndarray_types(); + int a_array_type_id = + array_types.typenum_to_lookup_id(a_array.get_typenum()); + int tau_array_type_id = + array_types.typenum_to_lookup_id(tau_array.get_typenum()); + + if (a_array_type_id != tau_array_type_id) { + throw py::value_error( + "The types of the input array and " + "the array of Householder scalars are mismatched"); + } + + ungqr_impl_fn_ptr_t ungqr_fn = ungqr_dispatch_vector[a_array_type_id]; + if (ungqr_fn == nullptr) { + throw py::value_error( + "No ungqr implementation defined for the provided type " + "of the input matrix."); + } + + char *a_array_data = a_array.get_data(); + const std::int64_t lda = std::max(1UL, m); + + char *tau_array_data = tau_array.get_data(); + + std::vector host_task_events; + sycl::event ungqr_ev = ungqr_fn(q, m, n, k, a_array_data, lda, + tau_array_data, host_task_events, depends); + + sycl::event args_ev = dpctl::utils::keep_args_alive(q, {a_array, tau_array}, + host_task_events); + + return std::make_pair(args_ev, ungqr_ev); +} + +template +struct UngqrContigFactory +{ + fnT get() + { + if constexpr (types::UngqrTypePairSupportFactory::is_defined) { + return ungqr_impl; + } + else { + return nullptr; + } + } +}; + +void init_ungqr_dispatch_vector(void) +{ + dpctl_td_ns::DispatchVectorBuilder + contig; + contig.populate_dispatch_vector(ungqr_dispatch_vector); +} +} // namespace lapack +} // namespace ext +} // namespace backend +} // namespace dpnp diff --git a/dpnp/backend/extensions/lapack/ungqr.hpp b/dpnp/backend/extensions/lapack/ungqr.hpp new file mode 100644 index 00000000000..1a9b68e94f9 --- /dev/null +++ b/dpnp/backend/extensions/lapack/ungqr.hpp @@ -0,0 +1,67 @@ +//***************************************************************************** +// Copyright (c) 2024, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** + +#pragma once + +#include +#include + +#include + +namespace dpnp +{ +namespace backend +{ +namespace ext +{ +namespace lapack +{ +extern std::pair + ungqr(sycl::queue exec_q, + const std::int64_t m, + const std::int64_t n, + const std::int64_t k, + dpctl::tensor::usm_ndarray a_array, + dpctl::tensor::usm_ndarray tau_array, + const std::vector &depends = {}); + +extern std::pair + ungqr_batch(sycl::queue exec_q, + dpctl::tensor::usm_ndarray a_array, + dpctl::tensor::usm_ndarray tau_array, + std::int64_t m, + std::int64_t n, + std::int64_t k, + std::int64_t stride_a, + std::int64_t stride_tau, + std::int64_t batch_size, + const std::vector &depends = {}); + +extern void init_ungqr_batch_dispatch_vector(void); +extern void init_ungqr_dispatch_vector(void); +} // namespace lapack +} // namespace ext +} // namespace backend +} // namespace dpnp diff --git a/dpnp/backend/extensions/lapack/ungqr_batch.cpp b/dpnp/backend/extensions/lapack/ungqr_batch.cpp new file mode 100644 index 00000000000..c07eaf150fc --- /dev/null +++ b/dpnp/backend/extensions/lapack/ungqr_batch.cpp @@ -0,0 +1,278 @@ +//***************************************************************************** +// Copyright (c) 2024, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** + +#include + +// dpctl tensor headers +#include "utils/memory_overlap.hpp" +#include "utils/type_utils.hpp" + +#include "types_matrix.hpp" +#include "ungqr.hpp" + +#include "dpnp_utils.hpp" + +namespace dpnp +{ +namespace backend +{ +namespace ext +{ +namespace lapack +{ +namespace mkl_lapack = oneapi::mkl::lapack; +namespace py = pybind11; +namespace type_utils = dpctl::tensor::type_utils; + +typedef sycl::event (*ungqr_batch_impl_fn_ptr_t)( + sycl::queue, + std::int64_t, + std::int64_t, + std::int64_t, + char *, + std::int64_t, + std::int64_t, + char *, + std::int64_t, + std::int64_t, + std::vector &, + const std::vector &); + +static ungqr_batch_impl_fn_ptr_t + ungqr_batch_dispatch_vector[dpctl_td_ns::num_types]; + +template +static sycl::event ungqr_batch_impl(sycl::queue exec_q, + std::int64_t m, + std::int64_t n, + std::int64_t k, + char *in_a, + std::int64_t lda, + std::int64_t stride_a, + char *in_tau, + std::int64_t stride_tau, + std::int64_t batch_size, + std::vector &host_task_events, + const std::vector &depends) +{ + type_utils::validate_type_for_device(exec_q); + + T *a = reinterpret_cast(in_a); + T *tau = reinterpret_cast(in_tau); + + const std::int64_t scratchpad_size = + mkl_lapack::ungqr_batch_scratchpad_size( + exec_q, m, n, k, lda, stride_a, stride_tau, batch_size); + T *scratchpad = nullptr; + + std::stringstream error_msg; + std::int64_t info = 0; + bool is_exception_caught = false; + + sycl::event ungqr_batch_event; + try { + scratchpad = sycl::malloc_device(scratchpad_size, exec_q); + + ungqr_batch_event = mkl_lapack::ungqr_batch( + exec_q, + m, // The number of rows in each matrix in the batch; (0 ≤ m). + // It must be a non-negative integer. + n, // The number of columns in each matrix in the batch; (0 ≤ n). + // It must be a non-negative integer. + k, // The number of elementary reflectors + // whose product defines the matrices Qi; (0 ≤ k ≤ n). + a, // Pointer to the batch of matrices, each of size (m x n). + lda, // The leading dimension of each matrix in the batch. + // For row major layout, lda ≥ max(1, m). + stride_a, // Stride between consecutive matrices in the batch. + tau, // Pointer to the array of scalar factors of the elementary + // reflectors for each matrix in the batch. + stride_tau, // Stride between arrays of scalar factors in the batch. + batch_size, // The number of matrices in the batch. + scratchpad, // Pointer to scratchpad memory to be used by MKL + // routine for storing intermediate results. + scratchpad_size, depends); + } catch (mkl_lapack::exception const &e) { + is_exception_caught = true; + info = e.info(); + + if (info < 0) { + error_msg << "Parameter number " << -info + << " had an illegal value."; + } + else if (info == scratchpad_size && e.detail() != 0) { + error_msg + << "Insufficient scratchpad size. Required size is at least " + << e.detail() << ", but current size is " << scratchpad_size + << "."; + } + else if (info != 0 && e.detail() == 0) { + error_msg << "Error in batch processing. " + "Number of failed calculations: " + << info; + } + else { + error_msg << "Unexpected MKL exception caught during ungqr_batch() " + "call:\nreason: " + << e.what() << "\ninfo: " << e.info(); + } + } catch (sycl::exception const &e) { + is_exception_caught = true; + error_msg + << "Unexpected SYCL exception caught during ungqr_batch() call:\n" + << e.what(); + } + + if (is_exception_caught) // an unexpected error occurs + { + if (scratchpad != nullptr) { + sycl::free(scratchpad, exec_q); + } + + throw std::runtime_error(error_msg.str()); + } + + sycl::event clean_up_event = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(ungqr_batch_event); + auto ctx = exec_q.get_context(); + cgh.host_task([ctx, scratchpad]() { sycl::free(scratchpad, ctx); }); + }); + host_task_events.push_back(clean_up_event); + return ungqr_batch_event; +} + +std::pair + ungqr_batch(sycl::queue q, + dpctl::tensor::usm_ndarray a_array, + dpctl::tensor::usm_ndarray tau_array, + std::int64_t m, + std::int64_t n, + std::int64_t k, + std::int64_t stride_a, + std::int64_t stride_tau, + std::int64_t batch_size, + const std::vector &depends) +{ + const int a_array_nd = a_array.get_ndim(); + const int tau_array_nd = tau_array.get_ndim(); + + if (a_array_nd < 3) { + throw py::value_error( + "The input array has ndim=" + std::to_string(a_array_nd) + + ", but an array with ndim >= 3 is expected."); + } + + if (tau_array_nd != 2) { + throw py::value_error("The array of Householder scalars has ndim=" + + std::to_string(tau_array_nd) + + ", but a 2-dimensional array is expected."); + } + + // check compatibility of execution queue and allocation queue + if (!dpctl::utils::queues_are_compatible(q, {a_array, tau_array})) { + throw py::value_error( + "Execution queue is not compatible with allocation queues"); + } + + auto const &overlap = dpctl::tensor::overlap::MemoryOverlap(); + if (overlap(a_array, tau_array)) { + throw py::value_error( + "The input array and the array of Householder scalars " + "are overlapping segments of memory"); + } + + bool is_a_array_c_contig = a_array.is_c_contiguous(); + bool is_tau_array_c_contig = tau_array.is_c_contiguous(); + if (!is_a_array_c_contig) { + throw py::value_error("The input array " + "must be C-contiguous"); + } + if (!is_tau_array_c_contig) { + throw py::value_error("The array of Householder scalars " + "must be C-contiguous"); + } + + auto array_types = dpctl_td_ns::usm_ndarray_types(); + int a_array_type_id = + array_types.typenum_to_lookup_id(a_array.get_typenum()); + int tau_array_type_id = + array_types.typenum_to_lookup_id(tau_array.get_typenum()); + + if (a_array_type_id != tau_array_type_id) { + throw py::value_error( + "The types of the input array and " + "the array of Householder scalars are mismatched"); + } + + ungqr_batch_impl_fn_ptr_t ungqr_batch_fn = + ungqr_batch_dispatch_vector[a_array_type_id]; + if (ungqr_batch_fn == nullptr) { + throw py::value_error( + "No ungqr_batch implementation defined for the provided type " + "of the input matrix."); + } + + char *a_array_data = a_array.get_data(); + char *tau_array_data = tau_array.get_data(); + + const std::int64_t lda = std::max(1UL, m); + + std::vector host_task_events; + sycl::event ungqr_batch_ev = + ungqr_batch_fn(q, m, n, k, a_array_data, lda, stride_a, tau_array_data, + stride_tau, batch_size, host_task_events, depends); + + sycl::event args_ev = dpctl::utils::keep_args_alive(q, {a_array, tau_array}, + host_task_events); + + return std::make_pair(args_ev, ungqr_batch_ev); +} + +template +struct UngqrBatchContigFactory +{ + fnT get() + { + if constexpr (types::UngqrBatchTypePairSupportFactory::is_defined) { + return ungqr_batch_impl; + } + else { + return nullptr; + } + } +}; + +void init_ungqr_batch_dispatch_vector(void) +{ + dpctl_td_ns::DispatchVectorBuilder + contig; + contig.populate_dispatch_vector(ungqr_batch_dispatch_vector); +} +} // namespace lapack +} // namespace ext +} // namespace backend +} // namespace dpnp diff --git a/dpnp/backend/include/dpnp_iface_fptr.hpp b/dpnp/backend/include/dpnp_iface_fptr.hpp index 3061bb01f29..e9a3458f84a 100644 --- a/dpnp/backend/include/dpnp_iface_fptr.hpp +++ b/dpnp/backend/include/dpnp_iface_fptr.hpp @@ -220,8 +220,6 @@ enum class DPNPFuncName : size_t DPNP_FN_PUT, /**< Used in numpy.put() impl */ DPNP_FN_PUT_ALONG_AXIS, /**< Used in numpy.put_along_axis() impl */ DPNP_FN_QR, /**< Used in numpy.linalg.qr() impl */ - DPNP_FN_QR_EXT, /**< Used in numpy.linalg.qr() impl, requires extra - parameters */ DPNP_FN_RADIANS, /**< Used in numpy.radians() impl */ DPNP_FN_RADIANS_EXT, /**< Used in numpy.radians() impl, requires extra parameters */ diff --git a/dpnp/backend/kernels/dpnp_krnl_linalg.cpp b/dpnp/backend/kernels/dpnp_krnl_linalg.cpp index 610da8fda3c..d74c593115e 100644 --- a/dpnp/backend/kernels/dpnp_krnl_linalg.cpp +++ b/dpnp/backend/kernels/dpnp_krnl_linalg.cpp @@ -722,17 +722,6 @@ template void (*dpnp_qr_default_c)(void *, void *, void *, void *, size_t, size_t) = dpnp_qr_c<_InputDT, _ComputeDT>; -template -DPCTLSyclEventRef (*dpnp_qr_ext_c)(DPCTLSyclQueueRef, - void *, - void *, - void *, - void *, - size_t, - size_t, - const DPCTLEventVectorRef) = - dpnp_qr_c<_InputDT, _ComputeDT>; - template DPCTLSyclEventRef dpnp_svd_c(DPCTLSyclQueueRef q_ref, void *array1_in, @@ -1000,29 +989,6 @@ void func_map_init_linalg_func(func_map_t &fmap) // fmap[DPNPFuncName::DPNP_FN_QR][eft_C128][eft_C128] = { // eft_C128, (void*)dpnp_qr_c, std::complex>}; - fmap[DPNPFuncName::DPNP_FN_QR_EXT][eft_INT][eft_INT] = { - get_default_floating_type(), - (void *)dpnp_qr_ext_c< - int32_t, func_type_map_t::find_type>, - get_default_floating_type(), - (void *)dpnp_qr_ext_c< - int32_t, func_type_map_t::find_type< - get_default_floating_type()>>}; - fmap[DPNPFuncName::DPNP_FN_QR_EXT][eft_LNG][eft_LNG] = { - get_default_floating_type(), - (void *)dpnp_qr_ext_c< - int64_t, func_type_map_t::find_type>, - get_default_floating_type(), - (void *)dpnp_qr_ext_c< - int64_t, func_type_map_t::find_type< - get_default_floating_type()>>}; - fmap[DPNPFuncName::DPNP_FN_QR_EXT][eft_FLT][eft_FLT] = { - eft_FLT, (void *)dpnp_qr_ext_c}; - fmap[DPNPFuncName::DPNP_FN_QR_EXT][eft_DBL][eft_DBL] = { - eft_DBL, (void *)dpnp_qr_ext_c}; - // fmap[DPNPFuncName::DPNP_FN_QR_EXT][eft_C128][eft_C128] = { - // eft_C128, (void*)dpnp_qr_c, std::complex>}; - fmap[DPNPFuncName::DPNP_FN_SVD][eft_INT][eft_INT] = { eft_DBL, (void *)dpnp_svd_default_c}; fmap[DPNPFuncName::DPNP_FN_SVD][eft_LNG][eft_LNG] = { diff --git a/dpnp/dpnp_algo/dpnp_algo.pxd b/dpnp/dpnp_algo/dpnp_algo.pxd index 2fc7e1b4a3b..71382d38f26 100644 --- a/dpnp/dpnp_algo/dpnp_algo.pxd +++ b/dpnp/dpnp_algo/dpnp_algo.pxd @@ -94,8 +94,6 @@ cdef extern from "dpnp_iface_fptr.hpp" namespace "DPNPFuncName": # need this na DPNP_FN_PARTITION DPNP_FN_PARTITION_EXT DPNP_FN_PLACE - DPNP_FN_QR - DPNP_FN_QR_EXT DPNP_FN_RADIANS DPNP_FN_RADIANS_EXT DPNP_FN_RNG_BETA diff --git a/dpnp/linalg/dpnp_algo_linalg.pyx b/dpnp/linalg/dpnp_algo_linalg.pyx index 3bf6dad3ee8..67cd5d93034 100644 --- a/dpnp/linalg/dpnp_algo_linalg.pyx +++ b/dpnp/linalg/dpnp_algo_linalg.pyx @@ -50,7 +50,6 @@ __all__ = [ "dpnp_eigvals", "dpnp_matrix_rank", "dpnp_norm", - "dpnp_qr", ] @@ -323,58 +322,3 @@ cpdef object dpnp_norm(object input, ord=None, axis=None): return ret else: raise ValueError("Improper number of dimensions to norm.") - - -cpdef tuple dpnp_qr(utils.dpnp_descriptor x1, str mode): - cdef size_t size_m = x1.shape[0] - cdef size_t size_n = x1.shape[1] - cdef size_t min_m_n = min(size_m, size_n) - cdef size_t size_tau = min_m_n - - cdef DPNPFuncType param1_type = dpnp_dtype_to_DPNPFuncType(x1.dtype) - cdef DPNPFuncData kernel_data = get_dpnp_function_ptr(DPNP_FN_QR_EXT, param1_type, param1_type) - - x1_obj = x1.get_array() - - cdef (DPNPFuncType, void *) ret_type_and_func = utils.get_ret_type_and_func(kernel_data, - x1_obj.sycl_device.has_aspect_fp64) - cdef DPNPFuncType return_type = ret_type_and_func[0] - cdef custom_linalg_1in_3out_shape_t func = < custom_linalg_1in_3out_shape_t > ret_type_and_func[1] - - cdef utils.dpnp_descriptor res_q = utils.create_output_descriptor((size_m, min_m_n), - return_type, - None, - device=x1_obj.sycl_device, - usm_type=x1_obj.usm_type, - sycl_queue=x1_obj.sycl_queue) - cdef utils.dpnp_descriptor res_r = utils.create_output_descriptor((min_m_n, size_n), - return_type, - None, - device=x1_obj.sycl_device, - usm_type=x1_obj.usm_type, - sycl_queue=x1_obj.sycl_queue) - cdef utils.dpnp_descriptor tau = utils.create_output_descriptor((size_tau, ), - return_type, - None, - device=x1_obj.sycl_device, - usm_type=x1_obj.usm_type, - sycl_queue=x1_obj.sycl_queue) - - result_sycl_queue = res_q.get_array().sycl_queue - - cdef c_dpctl.SyclQueue q = result_sycl_queue - cdef c_dpctl.DPCTLSyclQueueRef q_ref = q.get_queue_ref() - - cdef c_dpctl.DPCTLSyclEventRef event_ref = func(q_ref, - x1.get_data(), - res_q.get_data(), - res_r.get_data(), - tau.get_data(), - size_m, - size_n, - NULL) # dep_events_ref - - with nogil: c_dpctl.DPCTLEvent_WaitAndThrow(event_ref) - c_dpctl.DPCTLEvent_Delete(event_ref) - - return (res_q.get_pyobj(), res_r.get_pyobj()) diff --git a/dpnp/linalg/dpnp_iface_linalg.py b/dpnp/linalg/dpnp_iface_linalg.py index 2b8506130ad..88a904b3c3c 100644 --- a/dpnp/linalg/dpnp_iface_linalg.py +++ b/dpnp/linalg/dpnp_iface_linalg.py @@ -51,6 +51,7 @@ dpnp_det, dpnp_eigh, dpnp_inv, + dpnp_qr, dpnp_slogdet, dpnp_solve, dpnp_svd, @@ -529,7 +530,7 @@ def norm(x1, ord=None, axis=None, keepdims=False): return call_origin(numpy.linalg.norm, x1, ord, axis, keepdims) -def qr(x1, mode="reduced"): +def qr(a, mode="reduced"): """ Compute the qr factorization of a matrix. @@ -538,25 +539,64 @@ def qr(x1, mode="reduced"): For full documentation refer to :obj:`numpy.linalg.qr`. - Limitations - ----------- - Input array is supported as :obj:`dpnp.ndarray`. - Parameter mode='reduced' is supported. + Parameters + ---------- + a : {dpnp.ndarray, usm_ndarray} + The input array with the dimensionality of at least 2. + mode : {"reduced", "complete", "r", "raw"}, optional + If K = min(M, N), then + - "reduced" : returns Q, R with dimensions (…, M, K), (…, K, N) + - "complete" : returns Q, R with dimensions (…, M, M), (…, M, N) + - "r" : returns R only with dimensions (…, K, N) + - "raw" : returns h, tau with dimensions (…, N, M), (…, K,) + Default: "reduced". + + Returns + ------- + When mode is "reduced" or "complete", the result will be a namedtuple with + the attributes Q and R. + Q : dpnp.ndarray + A matrix with orthonormal columns. + When mode = "complete" the result is an orthogonal/unitary matrix + depending on whether or not a is real/complex. + The determinant may be either +/- 1 in that case. + In case the number of dimensions in the input array is greater + than 2 then a stack of the matrices with above properties is returned. + R : dpnp.ndarray + The upper-triangular matrix or a stack of upper-triangular matrices + if the number of dimensions in the input array is greater than 2. + (h, tau) : tuple of dpnp.ndarray + The h array contains the Householder reflectors that generate Q along with R. + The tau array contains scaling factors for the reflectors. + + Examples + -------- + >>> import dpnp as np + >>> a = np.random.randn(9, 6) + >>> Q, R = np.linalg.qr(a) + >>> np.allclose(a, np.dot(Q, R)) # a does equal QR + array([ True]) + >>> R2 = np.linalg.qr(a, mode='r') + >>> np.allclose(R, R2) # mode='r' returns the same R as mode='full' + array([ True]) + >>> a = np.random.normal(size=(3, 2, 2)) # Stack of 2 x 2 matrices as input + >>> Q, R = np.linalg.qr(a) + >>> Q.shape + (3, 2, 2) + >>> R.shape + (3, 2, 2) + >>> np.allclose(a, np.matmul(Q, R)) + array([ True]) """ - x1_desc = dpnp.get_dpnp_descriptor(x1, copy_when_nondefault_queue=False) - if x1_desc: - if x1_desc.ndim != 2: - pass - elif mode != "reduced": - pass - else: - result_tup = dpnp_qr(x1_desc, mode) + dpnp.check_supported_arrays_type(a) + check_stacked_2d(a) - return result_tup + if mode not in ("reduced", "complete", "r", "raw"): + raise ValueError(f"Unrecognized mode {mode}") - return call_origin(numpy.linalg.qr, x1, mode) + return dpnp_qr(a, mode) def solve(a, b): diff --git a/dpnp/linalg/dpnp_utils_linalg.py b/dpnp/linalg/dpnp_utils_linalg.py index 93f41883133..a6dcfbf0c2b 100644 --- a/dpnp/linalg/dpnp_utils_linalg.py +++ b/dpnp/linalg/dpnp_utils_linalg.py @@ -39,6 +39,7 @@ "dpnp_det", "dpnp_eigh", "dpnp_inv", + "dpnp_qr", "dpnp_slogdet", "dpnp_solve", "dpnp_svd", @@ -126,29 +127,6 @@ def _check_lapack_dev_info(dev_info, error_msg=None): raise dpnp.linalg.LinAlgError(error_msg) -def _real_type(dtype, device=None): - """ - Returns the real data type corresponding to a given dpnp data type. - - Parameters - ---------- - dtype : dpnp.dtype - The dtype for which to find the corresponding real data type. - device : {None, string, SyclDevice, SyclQueue}, optional - An array API concept of device where an array of default floating type might be created. - - Returns - ------- - out : str - The name of the real data type. - - """ - - default = dpnp.default_float_type(device) - real_type = _real_types_map.get(dtype.name, default) - return dpnp.dtype(real_type) - - def _common_type(*arrays): """ Common type for linear algebra operations. @@ -403,6 +381,29 @@ def _lu_factor(a, res_type): return (a_h, ipiv_h, dev_info_array) +def _real_type(dtype, device=None): + """ + Returns the real data type corresponding to a given dpnp data type. + + Parameters + ---------- + dtype : dpnp.dtype + The dtype for which to find the corresponding real data type. + device : {None, string, SyclDevice, SyclQueue}, optional + An array API concept of device where an array of default floating type might be created. + + Returns + ------- + out : str + The name of the real data type. + + """ + + default = dpnp.default_float_type(device) + real_type = _real_types_map.get(dtype.name, default) + return dpnp.dtype(real_type) + + def _stacked_identity( batch_shape, n, dtype, usm_type="device", sycl_queue=None ): @@ -447,6 +448,48 @@ def _stacked_identity( return x +def _triu_inplace(a, host_tasks, depends=None): + """ + _triu_inplace(a, host_tasks, depends=None) + + Computes the upper triangular part of an array in-place, + but currently allocates extra memory for the result. + + Parameters + ---------- + a : {dpnp.ndarray, usm_ndarray} + Input array from which the upper triangular part is to be extracted. + host_tasks : list + A list to which the function appends the host event corresponding to the computation. + This allows for dependency management and synchronization with other tasks. + depends : list, optional + A list of events that the triangular operation depends on. + These tasks are completed before the triangular computation starts. + If ``None``, defaults to an empty list. + + Returns + ------- + out : dpnp.ndarray + A new array containing the upper triangular part of the input array `a`. + + """ + + # TODO: implement a dedicated kernel for in-place triu instead of + # extra memory allocation for result + if depends is None: + depends = [] + out = dpnp.empty_like(a, order="C") + ht_triu_ev, _ = ti._triu( + src=a.get_array(), + dst=out.get_array(), + k=0, + sycl_queue=a.sycl_queue, + depends=depends, + ) + host_tasks.append(ht_triu_ev) + return out + + def check_stacked_2d(*arrays): """ Return ``True`` if each array in `arrays` has at least two dimensions. @@ -955,6 +998,344 @@ def dpnp_inv(a): return b_f +def dpnp_qr_batch(a, mode="reduced"): + """ + dpnp_qr_batch(a, mode="reduced") + + Return the batched qr factorization of `a` matrix. + + """ + + a_sycl_queue = a.sycl_queue + a_usm_type = a.usm_type + + m, n = a.shape[-2:] + k = min(m, n) + + batch_shape = a.shape[:-2] + batch_size = prod(batch_shape) + + res_type = _common_type(a) + + if batch_size == 0 or k == 0: + if mode == "reduced": + return ( + dpnp.empty_like( + a, + shape=batch_shape + (m, k), + dtype=res_type, + ), + dpnp.empty_like( + a, + shape=batch_shape + (k, n), + dtype=res_type, + ), + ) + elif mode == "complete": + q = _stacked_identity( + batch_shape, + m, + dtype=res_type, + usm_type=a_usm_type, + sycl_queue=a_sycl_queue, + ) + return ( + q, + dpnp.empty_like( + a, + shape=batch_shape + (m, n), + dtype=res_type, + ), + ) + elif mode == "r": + return dpnp.empty_like( + a, + shape=batch_shape + (k, n), + dtype=res_type, + ) + else: # mode=="raw" + return ( + dpnp.empty_like( + a, + shape=batch_shape + (n, m), + dtype=res_type, + ), + dpnp.empty_like( + a, + shape=batch_shape + (k,), + dtype=res_type, + ), + ) + + # get 3d input arrays by reshape + a = a.reshape(-1, m, n) + + a = a.swapaxes(-2, -1) + a_usm_arr = dpnp.get_usm_ndarray(a) + + a_t = dpnp.empty_like(a, order="C", dtype=res_type) + + # use DPCTL tensor function to fill the matrix array + # with content from the input array `a` + a_ht_copy_ev, a_copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=a_usm_arr, dst=a_t.get_array(), sycl_queue=a_sycl_queue + ) + + tau_h = dpnp.empty_like( + a_t, + shape=(batch_size, k), + dtype=res_type, + ) + + a_stride = a_t.strides[0] + tau_stride = tau_h.strides[0] + + # Call the LAPACK extension function _geqrf_batch to compute the QR factorization + # of a general m x n matrix. + ht_geqrf_batch_ev, geqrf_batch_ev = li._geqrf_batch( + a_sycl_queue, + a_t.get_array(), + tau_h.get_array(), + m, + n, + a_stride, + tau_stride, + batch_size, + [a_copy_ev], + ) + + ht_list_ev = [ht_geqrf_batch_ev, a_ht_copy_ev] + + if mode in ["r", "raw"]: + if mode == "r": + r = a_t[..., :k].swapaxes(-2, -1) + r = _triu_inplace(r, ht_list_ev, [geqrf_batch_ev]) + dpctl.SyclEvent.wait_for(ht_list_ev) + return r.reshape(batch_shape + r.shape[-2:]) + + # mode=="raw" + dpctl.SyclEvent.wait_for(ht_list_ev) + q = a_t.reshape(batch_shape + a_t.shape[-2:]) + r = tau_h.reshape(batch_shape + tau_h.shape[-1:]) + return (q, r) + + if mode == "complete" and m > n: + mc = m + q = dpnp.empty_like( + a_t, + shape=(batch_size, m, m), + dtype=res_type, + ) + else: + mc = k + q = dpnp.empty_like( + a_t, + shape=(batch_size, n, m), + dtype=res_type, + ) + + # use DPCTL tensor function to fill the matrix array `q[..., :n, :]` + # with content from the array `a_t` overwritten by geqrf_batch + a_t_ht_copy_ev, a_t_copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=a_t.get_array(), + dst=q[..., :n, :].get_array(), + sycl_queue=a_sycl_queue, + depends=[geqrf_batch_ev], + ) + + ht_list_ev.append(a_t_ht_copy_ev) + + q_stride = q.strides[0] + tau_stride = tau_h.strides[0] + + # Get LAPACK function (_orgqr_batch for real or _ungqf_batch for complex data types) + # for QR factorization + lapack_func = ( + "_ungqr_batch" + if dpnp.issubdtype(res_type, dpnp.complexfloating) + else "_orgqr_batch" + ) + + # Call the LAPACK extension function _orgqr_batch/ to generate the real orthogonal/ + # complex unitary matrices `Qi` of the QR factorization + # for a batch of general matrices. + ht_lapack_ev, lapack_ev = getattr(li, lapack_func)( + a_sycl_queue, + q.get_array(), + tau_h.get_array(), + m, + mc, + k, + q_stride, + tau_stride, + batch_size, + [a_t_copy_ev], + ) + + ht_list_ev.append(ht_lapack_ev) + + q = q[..., :mc, :].swapaxes(-2, -1) + r = a_t[..., :mc].swapaxes(-2, -1) + + ht_list_ev.append(ht_lapack_ev) + + r = _triu_inplace(r, ht_list_ev, [lapack_ev]) + dpctl.SyclEvent.wait_for(ht_list_ev) + + return ( + q.reshape(batch_shape + q.shape[-2:]), + r.reshape(batch_shape + r.shape[-2:]), + ) + + +def dpnp_qr(a, mode="reduced"): + """ + dpnp_qr(a, mode="reduced") + + Return the qr factorization of `a` matrix. + + """ + + if a.ndim > 2: + return dpnp_qr_batch(a, mode=mode) + + a_usm_arr = dpnp.get_usm_ndarray(a) + a_sycl_queue = a.sycl_queue + a_usm_type = a.usm_type + + res_type = _common_type(a) + + m, n = a.shape + k = min(m, n) + if k == 0: + if mode == "reduced": + return dpnp.empty_like( + a, + shape=(m, 0), + dtype=res_type, + ), dpnp.empty_like( + a, + shape=(0, n), + dtype=res_type, + ) + elif mode == "complete": + return dpnp.identity( + m, dtype=res_type, sycl_queue=a_sycl_queue, usm_type=a_usm_type + ), dpnp.empty_like( + a, + shape=(m, n), + dtype=res_type, + ) + elif mode == "r": + return dpnp.empty_like( + a, + shape=(0, n), + dtype=res_type, + ) + else: # mode == "raw" + return dpnp.empty_like( + a, + shape=(n, m), + dtype=res_type, + ), dpnp.empty_like( + a, + shape=(0,), + dtype=res_type, + ) + + # Transpose the input matrix to convert from row-major to column-major order. + # This adjustment is necessary for compatibility with OneMKL LAPACK routines, + # which expect matrices in column-major format. + # This allows data to be handled efficiently without the need for additional conversion. + a = a.T + a_usm_arr = dpnp.get_usm_ndarray(a) + a_t = dpnp.empty_like(a, order="C", dtype=res_type) + + # use DPCTL tensor function to fill the matrix array + # with content from the input array `a` + a_ht_copy_ev, a_copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=a_usm_arr, dst=a_t.get_array(), sycl_queue=a_sycl_queue + ) + + tau_h = dpnp.empty_like( + a, + shape=(k,), + dtype=res_type, + ) + + # Call the LAPACK extension function _geqrf to compute the QR factorization + # of a general m x n matrix. + ht_geqrf_ev, geqrf_ev = li._geqrf( + a_sycl_queue, a_t.get_array(), tau_h.get_array(), [a_copy_ev] + ) + + ht_list_ev = [ht_geqrf_ev, a_ht_copy_ev] + + if mode in ["r", "raw"]: + if mode == "r": + r = a_t[:, :k].transpose() + r = _triu_inplace(r, ht_list_ev, [geqrf_ev]) + dpctl.SyclEvent.wait_for(ht_list_ev) + return r + + # mode == "raw": + dpctl.SyclEvent.wait_for(ht_list_ev) + return (a_t, tau_h) + + # mc is the total number of columns in the q matrix. + # In `complete` mode, mc equals the number of rows. + # In `reduced` mode, mc is the lesser of the row count or column count. + if mode == "complete" and m > n: + mc = m + q = dpnp.empty_like( + a_t, + shape=(m, m), + dtype=res_type, + ) + else: + mc = k + q = dpnp.empty_like( + a_t, + shape=(n, m), + dtype=res_type, + ) + + # use DPCTL tensor function to fill the matrix array `q[:n]` + # with content from the array `a_t` overwritten by geqrf + a_t_ht_copy_ev, a_t_copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=a_t.get_array(), + dst=q[:n].get_array(), + sycl_queue=a_sycl_queue, + depends=[geqrf_ev], + ) + + ht_list_ev.append(a_t_ht_copy_ev) + + # Get LAPACK function (_orgqr for real or _ungqf for complex data types) + # for QR factorization + lapack_func = ( + "_ungqr" + if dpnp.issubdtype(res_type, dpnp.complexfloating) + else "_orgqr" + ) + + # Call the LAPACK extension function _orgqr/_ungqf to generate the real orthogonal/ + # complex unitary matrix `Q` of the QR factorization + ht_lapack_ev, lapack_ev = getattr(li, lapack_func)( + a_sycl_queue, m, mc, k, q.get_array(), tau_h.get_array(), [a_t_copy_ev] + ) + + q = q[:mc].transpose() + r = a_t[:, :mc].transpose() + + ht_list_ev.append(ht_lapack_ev) + + r = _triu_inplace(r, ht_list_ev, [lapack_ev]) + dpctl.SyclEvent.wait_for(ht_list_ev) + + return (q, r) + + def dpnp_solve(a, b): """ dpnp_solve(a, b) diff --git a/tests/test_linalg.py b/tests/test_linalg.py index 85206bad5ba..8e32b867b85 100644 --- a/tests/test_linalg.py +++ b/tests/test_linalg.py @@ -1,7 +1,12 @@ import dpctl import numpy import pytest -from numpy.testing import assert_allclose, assert_array_equal, assert_raises +from numpy.testing import ( + assert_allclose, + assert_almost_equal, + assert_array_equal, + assert_raises, +) import dpnp as inp from tests.third_party.cupy import testing @@ -308,8 +313,8 @@ def test_det_singular_matrix(self, matrix): a_np = numpy.array(matrix, dtype="float32") a_dp = inp.array(a_np) - expected = numpy.linalg.slogdet(a_np) - result = inp.linalg.slogdet(a_dp) + expected = numpy.linalg.det(a_np) + result = inp.linalg.det(a_dp) assert_allclose(expected, result, rtol=1e-3, atol=1e-4) @@ -672,88 +677,141 @@ def test_norm3(array, ord, axis): assert_allclose(expected, result) -@pytest.mark.usefixtures("allow_fall_back_on_numpy") -@pytest.mark.parametrize("type", get_all_dtypes(no_bool=True, no_complex=True)) -@pytest.mark.parametrize( - "shape", - [(2, 2), (3, 4), (5, 3), (16, 16), (0, 0), (0, 2), (2, 0)], - ids=["(2,2)", "(3,4)", "(5,3)", "(16,16)", "(0,0)", "(0,2)", "(2,0)"], -) -@pytest.mark.parametrize( - "mode", ["complete", "reduced"], ids=["complete", "reduced"] -) -def test_qr(type, shape, mode): - a = numpy.arange(shape[0] * shape[1], dtype=type).reshape(shape) - ia = inp.array(a) +class TestQr: + # TODO: New packages that fix issue CMPLRLLVM-53771 are only available in internal CI. + # Skip the tests on cpu until these packages are available for the external CI. + # Specifically dpcpp_linux-64>=2024.1.0 + @pytest.mark.skipif(is_cpu_device(), reason="CMPLRLLVM-53771") + @pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True)) + @pytest.mark.parametrize( + "shape", + [(2, 2), (3, 4), (5, 3), (16, 16), (2, 2, 2), (2, 4, 2), (2, 2, 4)], + ids=[ + "(2, 2)", + "(3, 4)", + "(5, 3)", + "(16, 16)", + "(2, 2, 2)", + "(2, 4, 2)", + "(2, 2, 4)", + ], + ) + @pytest.mark.parametrize( + "mode", + ["r", "raw", "complete", "reduced"], + ids=["r", "raw", "complete", "reduced"], + ) + def test_qr(self, dtype, shape, mode): + a = numpy.random.rand(*shape).astype(dtype) + ia = inp.array(a) + + if mode == "r": + np_r = numpy.linalg.qr(a, mode) + dpnp_r = inp.linalg.qr(ia, mode) + else: + np_q, np_r = numpy.linalg.qr(a, mode) + dpnp_q, dpnp_r = inp.linalg.qr(ia, mode) + + # check decomposition + if mode in ("complete", "reduced"): + if a.ndim == 2: + assert_almost_equal( + inp.dot(dpnp_q, dpnp_r), + a, + decimal=5, + ) + else: # a.ndim > 2 + assert_almost_equal( + inp.matmul(dpnp_q, dpnp_r), + a, + decimal=5, + ) + else: # mode=="raw" + assert_dtype_allclose(dpnp_q, np_q) - np_q, np_r = numpy.linalg.qr(a, mode) - dpnp_q, dpnp_r = inp.linalg.qr(ia, mode) - - support_aspect64 = has_support_aspect64() - - if support_aspect64: - assert dpnp_q.dtype == np_q.dtype - assert dpnp_r.dtype == np_r.dtype - assert dpnp_q.shape == np_q.shape - assert dpnp_r.shape == np_r.shape - - tol = 1e-6 - if type == inp.float32: - tol = 1e-02 - elif not support_aspect64 and type in (inp.int32, inp.int64, None): - tol = 1e-02 - - # check decomposition - assert_allclose( - ia, - inp.dot(dpnp_q, dpnp_r), - rtol=tol, - atol=tol, + if mode in ("raw", "r"): + assert_dtype_allclose(dpnp_r, np_r) + + @pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True)) + @pytest.mark.parametrize( + "shape", + [(0, 0), (0, 2), (2, 0), (2, 0, 3), (2, 3, 0), (0, 2, 3)], + ids=[ + "(0, 0)", + "(0, 2)", + "(2 ,0)", + "(2, 0, 3)", + "(2, 3, 0)", + "(0, 2, 3)", + ], + ) + @pytest.mark.parametrize( + "mode", + ["r", "raw", "complete", "reduced"], + ids=["r", "raw", "complete", "reduced"], ) + def test_qr_empty(self, dtype, shape, mode): + a = numpy.empty(shape, dtype=dtype) + ia = inp.array(a) - # NP change sign for comparison - ncols = min(a.shape[0], a.shape[1]) - for i in range(ncols): - j = numpy.where(numpy.abs(np_q[:, i]) > tol)[0][0] - if np_q[j, i] * dpnp_q[j, i] < 0: - np_q[:, i] = -np_q[:, i] - np_r[i, :] = -np_r[i, :] - - if numpy.any(numpy.abs(np_r[i, :]) > tol): - assert_allclose( - inp.asnumpy(dpnp_q)[:, i], np_q[:, i], rtol=tol, atol=tol - ) + if mode == "r": + np_r = numpy.linalg.qr(a, mode) + dpnp_r = inp.linalg.qr(ia, mode) + else: + np_q, np_r = numpy.linalg.qr(a, mode) + dpnp_q, dpnp_r = inp.linalg.qr(ia, mode) - assert_allclose(dpnp_r, np_r, rtol=tol, atol=tol) + assert_dtype_allclose(dpnp_q, np_q) + assert_dtype_allclose(dpnp_r, np_r) -@pytest.mark.usefixtures("allow_fall_back_on_numpy") -def test_qr_not_2D(): - a = numpy.arange(12, dtype=numpy.float32).reshape((3, 2, 2)) - ia = inp.array(a) + @pytest.mark.skipif(is_cpu_device(), reason="CMPLRLLVM-53771") + @pytest.mark.parametrize( + "mode", + ["r", "raw", "complete", "reduced"], + ids=["r", "raw", "complete", "reduced"], + ) + def test_qr_strides(self, mode): + a = numpy.random.rand(5, 5) + ia = inp.array(a) - np_q, np_r = numpy.linalg.qr(a) - dpnp_q, dpnp_r = inp.linalg.qr(ia) + # positive strides + if mode == "r": + np_r = numpy.linalg.qr(a[::2, ::2], mode) + dpnp_r = inp.linalg.qr(ia[::2, ::2], mode) + else: + np_q, np_r = numpy.linalg.qr(a[::2, ::2], mode) + dpnp_q, dpnp_r = inp.linalg.qr(ia[::2, ::2], mode) - assert dpnp_q.dtype == np_q.dtype - assert dpnp_r.dtype == np_r.dtype - assert dpnp_q.shape == np_q.shape - assert dpnp_r.shape == np_r.shape + assert_dtype_allclose(dpnp_q, np_q) - assert_allclose(ia, inp.matmul(dpnp_q, dpnp_r)) + assert_dtype_allclose(dpnp_r, np_r) - a = numpy.empty((0, 3, 2), dtype=numpy.float32) - ia = inp.array(a) + # negative strides + if mode == "r": + np_r = numpy.linalg.qr(a[::-2, ::-2], mode) + dpnp_r = inp.linalg.qr(ia[::-2, ::-2], mode) + else: + np_q, np_r = numpy.linalg.qr(a[::-2, ::-2], mode) + dpnp_q, dpnp_r = inp.linalg.qr(ia[::-2, ::-2], mode) - np_q, np_r = numpy.linalg.qr(a) - dpnp_q, dpnp_r = inp.linalg.qr(ia) + assert_dtype_allclose(dpnp_q, np_q) - assert dpnp_q.dtype == np_q.dtype - assert dpnp_r.dtype == np_r.dtype - assert dpnp_q.shape == np_q.shape - assert dpnp_r.shape == np_r.shape + assert_dtype_allclose(dpnp_r, np_r) - assert_allclose(ia, inp.matmul(dpnp_q, dpnp_r)) + def test_qr_errors(self): + a_dp = inp.array([[1, 2], [3, 5]], dtype="float32") + + # unsupported type + a_np = inp.asnumpy(a_dp) + assert_raises(TypeError, inp.linalg.qr, a_np) + + # a.ndim < 2 + a_dp_ndim_1 = a_dp.flatten() + assert_raises(inp.linalg.LinAlgError, inp.linalg.qr, a_dp_ndim_1) + + # invalid mode + assert_raises(ValueError, inp.linalg.qr, a_dp, "c") class TestSolve: @@ -1018,14 +1076,6 @@ def check_decomposition( dpnp_diag_s = inp.zeros_like(dp_a, dtype=dp_s.dtype) for i in range(min(dp_a.shape[-2], dp_a.shape[-1])): dpnp_diag_s[..., i, i] = dp_s[..., i] - # TODO: remove it when dpnp.dot is updated - # dpnp.dot does not support complex type - if inp.issubdtype(dp_a.dtype, inp.complexfloating): - reconstructed = numpy.dot( - inp.asnumpy(dp_u), - numpy.dot(inp.asnumpy(dpnp_diag_s), inp.asnumpy(dp_vt)), - ) - else: reconstructed = inp.dot(dp_u, inp.dot(dpnp_diag_s, dp_vt)) # TODO: use assert dpnp.allclose() inside check_decomposition() # when it will support complex dtypes diff --git a/tests/test_sycl_queue.py b/tests/test_sycl_queue.py index f6329d8f216..de243744403 100644 --- a/tests/test_sycl_queue.py +++ b/tests/test_sycl_queue.py @@ -1202,34 +1202,52 @@ def test_matrix_rank(device): assert_array_equal(expected, result) +@pytest.mark.parametrize( + "shape", + [ + (4, 4), + (2, 0), + (2, 2, 3), + (0, 2, 3), + (1, 0, 3), + ], + ids=[ + "(4, 4)", + "(2, 0)", + "(2, 2, 3)", + "(0, 2, 3)", + "(1, 0, 3)", + ], +) +@pytest.mark.parametrize( + "mode", + ["r", "raw", "complete", "reduced"], + ids=["r", "raw", "complete", "reduced"], +) @pytest.mark.parametrize( "device", valid_devices, ids=[device.filter_string for device in valid_devices], ) -def test_qr(device): - data = [[1.0, 2.0, 3.0], [1.0, 2.0, 3.0]] - dpnp_data = dpnp.array(data, device=device) - numpy_data = numpy.array(data, dtype=dpnp_data.dtype) - - np_q, np_r = numpy.linalg.qr(numpy_data, "reduced") - dpnp_q, dpnp_r = dpnp.linalg.qr(dpnp_data, "reduced") +def test_qr(shape, mode, device): + dtype = dpnp.default_float_type(device) + count_elems = numpy.prod(shape) + a = dpnp.arange(count_elems, dtype=dtype, device=device).reshape(shape) - assert dpnp_q.dtype == np_q.dtype - assert dpnp_r.dtype == np_r.dtype - assert dpnp_q.shape == np_q.shape - assert dpnp_r.shape == np_r.shape + expected_queue = a.get_array().sycl_queue - assert_dtype_allclose(dpnp_q, np_q) - assert_dtype_allclose(dpnp_r, np_r) + if mode == "r": + dp_r = dpnp.linalg.qr(a, mode=mode) + dp_r_queue = dp_r.get_array().sycl_queue + assert_sycl_queue_equal(dp_r_queue, expected_queue) + else: + dp_q, dp_r = dpnp.linalg.qr(a, mode=mode) - expected_queue = dpnp_data.get_array().sycl_queue - dpnp_q_queue = dpnp_q.get_array().sycl_queue - dpnp_r_queue = dpnp_r.get_array().sycl_queue + dp_q_queue = dp_q.get_array().sycl_queue + dp_r_queue = dp_r.get_array().sycl_queue - # compare queue and device - assert_sycl_queue_equal(dpnp_q_queue, expected_queue) - assert_sycl_queue_equal(dpnp_r_queue, expected_queue) + assert_sycl_queue_equal(dp_q_queue, expected_queue) + assert_sycl_queue_equal(dp_r_queue, expected_queue) @pytest.mark.parametrize( diff --git a/tests/test_usm_type.py b/tests/test_usm_type.py index 29101cf9f48..56e2a68756a 100644 --- a/tests/test_usm_type.py +++ b/tests/test_usm_type.py @@ -796,3 +796,40 @@ def test_svd(usm_type, shape, full_matrices_param, compute_uv_param): ) assert x.usm_type == s.usm_type + + +@pytest.mark.parametrize("usm_type", list_of_usm_types, ids=list_of_usm_types) +@pytest.mark.parametrize( + "shape", + [ + (4, 4), + (2, 0), + (2, 2, 3), + (0, 2, 3), + (1, 0, 3), + ], + ids=[ + "(4, 4)", + "(2, 0)", + "(2, 2, 3)", + "(0, 2, 3)", + "(1, 0, 3)", + ], +) +@pytest.mark.parametrize( + "mode", + ["r", "raw", "complete", "reduced"], + ids=["r", "raw", "complete", "reduced"], +) +def test_qr(shape, mode, usm_type): + count_elems = numpy.prod(shape) + a = dp.arange(count_elems, usm_type=usm_type).reshape(shape) + + if mode == "r": + dp_r = dp.linalg.qr(a, mode=mode) + assert a.usm_type == dp_r.usm_type + else: + dp_q, dp_r = dp.linalg.qr(a, mode=mode) + + assert a.usm_type == dp_q.usm_type + assert a.usm_type == dp_r.usm_type diff --git a/tests/third_party/cupy/linalg_tests/test_decomposition.py b/tests/third_party/cupy/linalg_tests/test_decomposition.py index fd887c16e6c..234a2e0e381 100644 --- a/tests/third_party/cupy/linalg_tests/test_decomposition.py +++ b/tests/third_party/cupy/linalg_tests/test_decomposition.py @@ -201,38 +201,31 @@ def check_usv(self, shape, dtype): # reconstruct the matrix k = s_cpu.shape[-1] - # dpnp.dot/matmul does not support complex type and unstable on cpu - # TODO: remove it and use xp.dot/matmul when dpnp.dot/matmul is updated - u_gpu = u_gpu.asnumpy() - vh_gpu = vh_gpu.asnumpy() - s_gpu = s_gpu.asnumpy() - xp = numpy - if len(shape) == 2: if self.full_matrices: - a_gpu_usv = numpy.dot(u_gpu[:, :k] * s_gpu, vh_gpu[:k, :]) + a_gpu_usv = cupy.dot(u_gpu[:, :k] * s_gpu, vh_gpu[:k, :]) else: - a_gpu_usv = numpy.dot(u_gpu * s_gpu, vh_gpu) + a_gpu_usv = cupy.dot(u_gpu * s_gpu, vh_gpu) else: if self.full_matrices: - a_gpu_usv = numpy.matmul( + a_gpu_usv = cupy.matmul( u_gpu[..., :k] * s_gpu[..., None, :], vh_gpu[..., :k, :] ) else: - a_gpu_usv = numpy.matmul(u_gpu * s_gpu[..., None, :], vh_gpu) + a_gpu_usv = cupy.matmul(u_gpu * s_gpu[..., None, :], vh_gpu) testing.assert_allclose(a_gpu, a_gpu_usv, rtol=1e-4, atol=1e-4) # assert unitary u_len = u_gpu.shape[-1] vh_len = vh_gpu.shape[-2] testing.assert_allclose( - xp.matmul(u_gpu.swapaxes(-1, -2).conj(), u_gpu), - stacked_identity(xp, shape[:-2], u_len, dtype), + cupy.matmul(u_gpu.swapaxes(-1, -2).conj(), u_gpu), + stacked_identity(cupy, shape[:-2], u_len, dtype), atol=1e-4, ) testing.assert_allclose( - xp.matmul(vh_gpu, vh_gpu.swapaxes(-1, -2).conj()), - stacked_identity(xp, shape[:-2], vh_len, dtype), + cupy.matmul(vh_gpu, vh_gpu.swapaxes(-1, -2).conj()), + stacked_identity(cupy, shape[:-2], vh_len, dtype), atol=1e-4, ) @@ -385,3 +378,77 @@ def test_svd_rank4_empty_array(self): self.check_usv((0, 2, 3, 4)) self.check_usv((1, 2, 0, 4)) self.check_usv((1, 2, 3, 0)) + + +@testing.parameterize( + *testing.product( + { + "mode": ["r", "raw", "complete", "reduced"], + } + ) +) +class TestQRDecomposition(unittest.TestCase): + @testing.for_dtypes("fdFD") + def check_mode(self, array, mode, dtype): + a_cpu = numpy.asarray(array, dtype=dtype) + a_gpu = cupy.asarray(array, dtype=dtype) + result_gpu = cupy.linalg.qr(a_gpu, mode=mode) + if ( + mode != "raw" + or numpy.lib.NumpyVersion(numpy.__version__) >= "1.22.0rc1" + ): + result_cpu = numpy.linalg.qr(a_cpu, mode=mode) + self._check_result(result_cpu, result_gpu) + + def _check_result(self, result_cpu, result_gpu): + if isinstance(result_cpu, tuple): + for b_cpu, b_gpu in zip(result_cpu, result_gpu): + assert b_cpu.dtype == b_gpu.dtype + testing.assert_allclose(b_cpu, b_gpu, atol=1e-4) + else: + assert result_cpu.dtype == result_gpu.dtype + testing.assert_allclose(result_cpu, result_gpu, atol=1e-4) + + # TODO: New packages that fix issue CMPLRLLVM-53771 are only available in internal CI. + # Skip the tests on cpu until these packages are available for the external CI. + # Specifically dpcpp_linux-64>=2024.1.0 + @pytest.mark.skipif(is_cpu_device(), reason="CMPLRLLVM-53771") + @testing.fix_random() + @_condition.repeat(3, 10) + def test_mode(self): + self.check_mode(numpy.random.randn(2, 4), mode=self.mode) + self.check_mode(numpy.random.randn(3, 3), mode=self.mode) + self.check_mode(numpy.random.randn(5, 4), mode=self.mode) + + @pytest.mark.skipif(is_cpu_device(), reason="CMPLRLLVM-53771") + @testing.with_requires("numpy>=1.22") + @testing.fix_random() + def test_mode_rank3(self): + self.check_mode(numpy.random.randn(3, 2, 4), mode=self.mode) + self.check_mode(numpy.random.randn(4, 3, 3), mode=self.mode) + self.check_mode(numpy.random.randn(2, 5, 4), mode=self.mode) + + @pytest.mark.skipif(is_cpu_device(), reason="CMPLRLLVM-53771") + @testing.with_requires("numpy>=1.22") + @testing.fix_random() + def test_mode_rank4(self): + self.check_mode(numpy.random.randn(2, 3, 2, 4), mode=self.mode) + self.check_mode(numpy.random.randn(2, 4, 3, 3), mode=self.mode) + self.check_mode(numpy.random.randn(2, 2, 5, 4), mode=self.mode) + + @testing.with_requires("numpy>=1.16") + def test_empty_array(self): + self.check_mode(numpy.empty((0, 3)), mode=self.mode) + self.check_mode(numpy.empty((3, 0)), mode=self.mode) + + @testing.with_requires("numpy>=1.22") + def test_empty_array_rank3(self): + self.check_mode(numpy.empty((0, 3, 2)), mode=self.mode) + self.check_mode(numpy.empty((3, 0, 2)), mode=self.mode) + self.check_mode(numpy.empty((3, 2, 0)), mode=self.mode) + self.check_mode(numpy.empty((0, 3, 3)), mode=self.mode) + self.check_mode(numpy.empty((3, 0, 3)), mode=self.mode) + self.check_mode(numpy.empty((3, 3, 0)), mode=self.mode) + self.check_mode(numpy.empty((0, 2, 3)), mode=self.mode) + self.check_mode(numpy.empty((2, 0, 3)), mode=self.mode) + self.check_mode(numpy.empty((2, 3, 0)), mode=self.mode)