Skip to content

Commit

Permalink
Update dpnp.linalg.qr() function (#1673)
Browse files Browse the repository at this point in the history
* Impl dpnp.linalg.qr for 2d array

* Add cupy tests for dpnp.linalg.qr

* Add batch implementation of dpnp.linalg.qr

* Remove an old impl of dpnp_qr

* Update test_qr in test_sycl_queue

* Add test_qr in test_usm_type

* Use _real_type for _orgqr

* Use _real_type for _orgqr_batch

* Update dpnp tests for dpnp.linalg.qr

* Pass scratchpad_size to the error message test

* Add additional checks

* Extend error handler for mkl batch funcs

* Add ungqr mkl extension to support complex dtype

* Update tau array size check for orgqr

* Add ungqr_batch mkl extension to support complex dtype

* Add arrays type check

* Fix test_det_singular_matrix

* Expand tests for dpnp.linalg.qr with complex types

* Update examples

* Remove astype for output arrays

* Use empty_like instead of empty

* Use ht_list_ev with dpctl.SyclEvent.wait_for

* Add _triu_inplace func

* Use copy_usm for a_t array overwritten by geqrf/geqrf_batch


---------

Co-authored-by: Anton <[email protected]>
  • Loading branch information
vlad-perevezentsev and antonwolfy authored Feb 8, 2024
1 parent d45bb24 commit 1e86753
Show file tree
Hide file tree
Showing 22 changed files with 2,767 additions and 246 deletions.
6 changes: 6 additions & 0 deletions dpnp/backend/extensions/lapack/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,21 @@
set(python_module_name _lapack_impl)
set(_module_src
${CMAKE_CURRENT_SOURCE_DIR}/lapack_py.cpp
${CMAKE_CURRENT_SOURCE_DIR}/geqrf.cpp
${CMAKE_CURRENT_SOURCE_DIR}/geqrf_batch.cpp
${CMAKE_CURRENT_SOURCE_DIR}/gesv.cpp
${CMAKE_CURRENT_SOURCE_DIR}/gesvd.cpp
${CMAKE_CURRENT_SOURCE_DIR}/getrf.cpp
${CMAKE_CURRENT_SOURCE_DIR}/getrf_batch.cpp
${CMAKE_CURRENT_SOURCE_DIR}/getri_batch.cpp
${CMAKE_CURRENT_SOURCE_DIR}/heevd.cpp
${CMAKE_CURRENT_SOURCE_DIR}/orgqr.cpp
${CMAKE_CURRENT_SOURCE_DIR}/orgqr_batch.cpp
${CMAKE_CURRENT_SOURCE_DIR}/potrf.cpp
${CMAKE_CURRENT_SOURCE_DIR}/potrf_batch.cpp
${CMAKE_CURRENT_SOURCE_DIR}/syevd.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ungqr.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ungqr_batch.cpp
)

pybind11_add_module(${python_module_name} MODULE ${_module_src})
Expand Down
262 changes: 262 additions & 0 deletions dpnp/backend/extensions/lapack/geqrf.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,262 @@
//*****************************************************************************
// Copyright (c) 2024, Intel Corporation
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
// - Redistributions of source code must retain the above copyright notice,
// this list of conditions and the following disclaimer.
// - Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
// THE POSSIBILITY OF SUCH DAMAGE.
//*****************************************************************************

#include <pybind11/pybind11.h>

// dpctl tensor headers
#include "utils/memory_overlap.hpp"
#include "utils/type_utils.hpp"

#include "geqrf.hpp"
#include "types_matrix.hpp"

#include "dpnp_utils.hpp"

namespace dpnp
{
namespace backend
{
namespace ext
{
namespace lapack
{
namespace mkl_lapack = oneapi::mkl::lapack;
namespace py = pybind11;
namespace type_utils = dpctl::tensor::type_utils;

typedef sycl::event (*geqrf_impl_fn_ptr_t)(sycl::queue,
const std::int64_t,
const std::int64_t,
char *,
std::int64_t,
char *,
std::vector<sycl::event> &,
const std::vector<sycl::event> &);

static geqrf_impl_fn_ptr_t geqrf_dispatch_vector[dpctl_td_ns::num_types];

template <typename T>
static sycl::event geqrf_impl(sycl::queue exec_q,
const std::int64_t m,
const std::int64_t n,
char *in_a,
std::int64_t lda,
char *in_tau,
std::vector<sycl::event> &host_task_events,
const std::vector<sycl::event> &depends)
{
type_utils::validate_type_for_device<T>(exec_q);

T *a = reinterpret_cast<T *>(in_a);
T *tau = reinterpret_cast<T *>(in_tau);

const std::int64_t scratchpad_size =
mkl_lapack::geqrf_scratchpad_size<T>(exec_q, m, n, lda);
T *scratchpad = nullptr;

std::stringstream error_msg;
std::int64_t info = 0;
bool is_exception_caught = false;

sycl::event geqrf_event;
try {
scratchpad = sycl::malloc_device<T>(scratchpad_size, exec_q);

geqrf_event = mkl_lapack::geqrf(
exec_q,
m, // The number of rows in the matrix; (0 ≤ m).
n, // The number of columns in the matrix; (0 ≤ n).
a, // Pointer to the m-by-n matrix.
lda, // The leading dimension of `a`; (1 ≤ m).
tau, // Pointer to the array of scalar factors of the
// elementary reflectors.
scratchpad, // Pointer to scratchpad memory to be used by MKL
// routine for storing intermediate results.
scratchpad_size, depends);
} catch (mkl_lapack::exception const &e) {
is_exception_caught = true;
info = e.info();

if (info < 0) {
error_msg << "Parameter number " << -info
<< " had an illegal value.";
}
else if (info == scratchpad_size && e.detail() != 0) {
error_msg
<< "Insufficient scratchpad size. Required size is at least "
<< e.detail() << ", but current size is " << scratchpad_size
<< ".";
}
else {
error_msg << "Unexpected MKL exception caught during geqrf() "
"call:\nreason: "
<< e.what() << "\ninfo: " << info;
}
} catch (sycl::exception const &e) {
is_exception_caught = true;
error_msg << "Unexpected SYCL exception caught during geqrf() call:\n"
<< e.what();
}

if (is_exception_caught) // an unexpected error occurs
{
if (scratchpad != nullptr) {
sycl::free(scratchpad, exec_q);
}
throw std::runtime_error(error_msg.str());
}

sycl::event clean_up_event = exec_q.submit([&](sycl::handler &cgh) {
cgh.depends_on(geqrf_event);
auto ctx = exec_q.get_context();
cgh.host_task([ctx, scratchpad]() { sycl::free(scratchpad, ctx); });
});
host_task_events.push_back(clean_up_event);

return geqrf_event;
}

std::pair<sycl::event, sycl::event>
geqrf(sycl::queue q,
dpctl::tensor::usm_ndarray a_array,
dpctl::tensor::usm_ndarray tau_array,
const std::vector<sycl::event> &depends)
{
const int a_array_nd = a_array.get_ndim();
const int tau_array_nd = tau_array.get_ndim();

if (a_array_nd != 2) {
throw py::value_error(
"The input array has ndim=" + std::to_string(a_array_nd) +
", but a 2-dimensional array is expected.");
}

if (tau_array_nd != 1) {
throw py::value_error("The array of Householder scalars has ndim=" +
std::to_string(tau_array_nd) +
", but a 1-dimensional array is expected.");
}

// check compatibility of execution queue and allocation queue
if (!dpctl::utils::queues_are_compatible(q, {a_array, tau_array})) {
throw py::value_error(
"Execution queue is not compatible with allocation queues");
}

auto const &overlap = dpctl::tensor::overlap::MemoryOverlap();
if (overlap(a_array, tau_array)) {
throw py::value_error(
"The input array and the array of Householder scalars "
"are overlapping segments of memory");
}

bool is_a_array_c_contig = a_array.is_c_contiguous();
if (!is_a_array_c_contig) {
throw py::value_error("The input array "
"must be C-contiguous");
}

bool is_tau_array_c_contig = tau_array.is_c_contiguous();
bool is_tau_array_f_contig = tau_array.is_f_contiguous();

if (!is_tau_array_c_contig || !is_tau_array_f_contig) {
throw py::value_error("The array of Householder scalars "
"must be contiguous");
}

auto array_types = dpctl_td_ns::usm_ndarray_types();
int a_array_type_id =
array_types.typenum_to_lookup_id(a_array.get_typenum());
int tau_array_type_id =
array_types.typenum_to_lookup_id(tau_array.get_typenum());

if (a_array_type_id != tau_array_type_id) {
throw py::value_error(
"The types of the input array and "
"the array of Householder scalars are mismatched");
}

geqrf_impl_fn_ptr_t geqrf_fn = geqrf_dispatch_vector[a_array_type_id];
if (geqrf_fn == nullptr) {
throw py::value_error(
"No geqrf implementation defined for the provided type "
"of the input matrix.");
}

char *a_array_data = a_array.get_data();
char *tau_array_data = tau_array.get_data();

const py::ssize_t *a_array_shape = a_array.get_shape_raw();

// The input array is transponded
// Change the order of getting m, n
const std::int64_t m = a_array_shape[1];
const std::int64_t n = a_array_shape[0];
const std::int64_t lda = std::max<size_t>(1UL, m);

const size_t tau_array_size = tau_array.get_size();
const size_t min_m_n = std::max<size_t>(1UL, std::min<size_t>(m, n));

if (tau_array_size != min_m_n) {
throw py::value_error("The array of Householder scalars has size=" +
std::to_string(tau_array_size) + ", but a size=" +
std::to_string(min_m_n) + " array is expected.");
}

std::vector<sycl::event> host_task_events;
sycl::event geqrf_ev = geqrf_fn(q, m, n, a_array_data, lda, tau_array_data,
host_task_events, depends);

sycl::event args_ev = dpctl::utils::keep_args_alive(q, {a_array, tau_array},
host_task_events);

return std::make_pair(args_ev, geqrf_ev);
}

template <typename fnT, typename T>
struct GeqrfContigFactory
{
fnT get()
{
if constexpr (types::GeqrfTypePairSupportFactory<T>::is_defined) {
return geqrf_impl<T>;
}
else {
return nullptr;
}
}
};

void init_geqrf_dispatch_vector(void)
{
dpctl_td_ns::DispatchVectorBuilder<geqrf_impl_fn_ptr_t, GeqrfContigFactory,
dpctl_td_ns::num_types>
contig;
contig.populate_dispatch_vector(geqrf_dispatch_vector);
}
} // namespace lapack
} // namespace ext
} // namespace backend
} // namespace dpnp
63 changes: 63 additions & 0 deletions dpnp/backend/extensions/lapack/geqrf.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
//*****************************************************************************
// Copyright (c) 2024, Intel Corporation
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
// - Redistributions of source code must retain the above copyright notice,
// this list of conditions and the following disclaimer.
// - Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
// THE POSSIBILITY OF SUCH DAMAGE.
//*****************************************************************************

#pragma once

#include <CL/sycl.hpp>
#include <oneapi/mkl.hpp>

#include <dpctl4pybind11.hpp>

namespace dpnp
{
namespace backend
{
namespace ext
{
namespace lapack
{
extern std::pair<sycl::event, sycl::event>
geqrf(sycl::queue exec_q,
dpctl::tensor::usm_ndarray a_array,
dpctl::tensor::usm_ndarray tau_array,
const std::vector<sycl::event> &depends = {});

extern std::pair<sycl::event, sycl::event>
geqrf_batch(sycl::queue exec_q,
dpctl::tensor::usm_ndarray a_array,
dpctl::tensor::usm_ndarray tau_array,
std::int64_t m,
std::int64_t n,
std::int64_t stride_a,
std::int64_t stride_tau,
std::int64_t batch_size,
const std::vector<sycl::event> &depends = {});

extern void init_geqrf_batch_dispatch_vector(void);
extern void init_geqrf_dispatch_vector(void);
} // namespace lapack
} // namespace ext
} // namespace backend
} // namespace dpnp
Loading

0 comments on commit 1e86753

Please sign in to comment.