Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement gesv_batch via gesv call #1877

Merged
merged 50 commits into from
Aug 8, 2024
Merged
Show file tree
Hide file tree
Changes from 35 commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
4c7c8c2
Init work
vlad-perevezentsev Jun 6, 2024
8e2cb23
First working version with transpose and C contig
vlad-perevezentsev Jun 7, 2024
67fa435
Second working version with moveaxis, transpose and F contig
vlad-perevezentsev Jun 7, 2024
4f5abec
Add more shape checks
vlad-perevezentsev Jun 11, 2024
0cb2808
Pass sycl::queue by reference for gesv/gesv_batch
vlad-perevezentsev Jun 11, 2024
bfa37d4
qwe
vlad-perevezentsev Jun 11, 2024
4a44292
Update _batched_solve implementation
vlad-perevezentsev Jun 12, 2024
df4774e
Remove old impl in _batched_solve
vlad-perevezentsev Jun 12, 2024
8dbe3c4
Merge master into impl_gesv_batch_via_gesv
vlad-perevezentsev Jun 12, 2024
8fb2af3
Use py::gil_scoped_release before gesv call
vlad-perevezentsev Jun 12, 2024
ddcf9fe
Remove junk files
vlad-perevezentsev Jun 12, 2024
262794f
Move gesv_batch to gesv_batch.cpp
vlad-perevezentsev Jun 13, 2024
3a7b8ca
Improve gesv_batch with independent linear streams
vlad-perevezentsev Jun 13, 2024
2016a8c
Extend checks for gesv/gesv_batch
vlad-perevezentsev Jun 13, 2024
2c42290
Update comment
vlad-perevezentsev Jun 13, 2024
e030da8
junk files
vlad-perevezentsev Jun 14, 2024
3f99ae5
Merge master into impl_gesv_batch_via_gesv
vlad-perevezentsev Jun 17, 2024
a0a683b
Merge master into impl_gesv_batch_via_gesv
vlad-perevezentsev Jul 11, 2024
5a48f33
Add common_gesv_checks
vlad-perevezentsev Jul 12, 2024
924fee7
Release GIL in gesv_batch_impl
vlad-perevezentsev Jul 12, 2024
2b15e6c
Remove junk file
vlad-perevezentsev Jul 12, 2024
5a1cab6
Merge master into impl_gesv_batch_via_gesv
vlad-perevezentsev Jul 12, 2024
b5c3062
Merge master into impl_gesv_batch_via_gesv
vlad-perevezentsev Jul 15, 2024
afca803
Remove junk files
vlad-perevezentsev Jul 16, 2024
ed99888
Merge master into impl_gesv_batch_via_gesv
vlad-perevezentsev Jul 16, 2024
0c97aff
Merge master into impl_gesv_batch_via_gesv
vlad-perevezentsev Jul 19, 2024
1b275ea
Merge master into impl_gesv_batch_via_gesv
vlad-perevezentsev Jul 26, 2024
e5b53a1
Remove host_task_events from gesv
vlad-perevezentsev Jul 26, 2024
d5adbd6
Use check_zeros_shape in gesv and gesv_batch
vlad-perevezentsev Jul 26, 2024
5b2780c
Add additional checks for gesv_impl
vlad-perevezentsev Jul 26, 2024
d4547d4
Move alloc_scratchpad to common_helpers.hpp
vlad-perevezentsev Jul 26, 2024
6759164
Use helper::alloc_scratchpad in gesv_batch_impl
vlad-perevezentsev Jul 26, 2024
f37ec43
Remove current_scratch_gesv check
vlad-perevezentsev Jul 26, 2024
adc17ba
Remove lda, ldb pass to gesv_batch_impl, gesv_impl
vlad-perevezentsev Jul 26, 2024
77ba0e2
Use const and constexpr in gesv/gesv_batch
vlad-perevezentsev Jul 26, 2024
9bf94b5
Applied review comments
vlad-perevezentsev Jul 29, 2024
b81893c
Use dpnp.reshape in _batched_solve
vlad-perevezentsev Jul 29, 2024
f8d68ef
Implement alloc_ipiv in common_helpers.hpp
vlad-perevezentsev Jul 29, 2024
fc6c7fa
Add gesv_common_utils.hpp
vlad-perevezentsev Jul 29, 2024
75079d2
Implement handle_lapack_exc function
vlad-perevezentsev Jul 29, 2024
6e82632
Merge master into impl_gesv_batch_via_gesv
vlad-perevezentsev Jul 29, 2024
7e0f384
Use try/catch for scratchpad/ipiv allocation
vlad-perevezentsev Jul 29, 2024
f5ee368
Update alloc_scratchpad/alloc_ipiv
vlad-perevezentsev Jul 29, 2024
eb8c3a0
gesv_scratchpad_size can be 0
vlad-perevezentsev Jul 30, 2024
3c8cda6
Implement help functions alloc_ipiv/alloc_scratchpad
vlad-perevezentsev Jul 30, 2024
3f4d672
Merge master into impl_gesv_batch_via_gesv
vlad-perevezentsev Jul 30, 2024
e56e07e
Merge master into impl_gesv_batch_via_gesv
vlad-perevezentsev Aug 2, 2024
629b97a
Reuse alloc_scratchpad/ipiv in batch versions
vlad-perevezentsev Aug 2, 2024
a9cc253
Merge master into impl_gesv_batch_via_gesv
vlad-perevezentsev Aug 6, 2024
3786ca2
Merge master into impl_gesv_batch_via_gesv
vlad-perevezentsev Aug 7, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions dpnp/backend/extensions/lapack/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ set(_module_src
${CMAKE_CURRENT_SOURCE_DIR}/geqrf.cpp
${CMAKE_CURRENT_SOURCE_DIR}/geqrf_batch.cpp
${CMAKE_CURRENT_SOURCE_DIR}/gesv.cpp
${CMAKE_CURRENT_SOURCE_DIR}/gesv_batch.cpp
${CMAKE_CURRENT_SOURCE_DIR}/gesvd.cpp
${CMAKE_CURRENT_SOURCE_DIR}/getrf.cpp
${CMAKE_CURRENT_SOURCE_DIR}/getrf_batch.cpp
Expand Down
36 changes: 35 additions & 1 deletion dpnp/backend/extensions/lapack/common_helpers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,11 @@
//*****************************************************************************

#pragma once
#include <oneapi/mkl.hpp>
vlad-perevezentsev marked this conversation as resolved.
Show resolved Hide resolved
#include <pybind11/pybind11.h>

#include <complex>
#include <cstring>
#include <pybind11/pybind11.h>
#include <stdexcept>

namespace dpnp::extensions::lapack::helper
Expand Down Expand Up @@ -63,4 +65,36 @@ inline bool check_zeros_shape(int ndim, const py::ssize_t *shape)
}
return src_nelems == 0;
}

// Allocate the total scratchpad memory with proper alignment for batch
// implementations
template <typename T>
inline T *alloc_scratchpad(std::int64_t scratchpad_size,
std::int64_t n_linear_streams,
sycl::queue &exec_q)
{
// Get padding size to ensure memory allocations are aligned to 256 bytes
// for better performance
const std::int64_t padding = 256 / sizeof(T);

if (scratchpad_size <= 0) {
throw std::runtime_error(
vlad-perevezentsev marked this conversation as resolved.
Show resolved Hide resolved
"Invalid scratchpad size: must be greater than zero."
" Calculated scratchpad size: " +
std::to_string(scratchpad_size));
}

// Calculate the total scratchpad memory size needed for all linear
// streams with proper alignment
const size_t alloc_scratch_size =
round_up_mult(n_linear_streams * scratchpad_size, padding);

// Allocate memory for the total scratchpad
T *scratchpad = sycl::malloc_device<T>(alloc_scratch_size, exec_q);
if (!scratchpad) {
throw std::runtime_error("Device allocation for scratchpad failed");
}

return scratchpad;
}
} // namespace dpnp::extensions::lapack::helper
30 changes: 0 additions & 30 deletions dpnp/backend/extensions/lapack/evd_batch_common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,34 +119,4 @@ std::pair<sycl::event, sycl::event>

return std::make_pair(ht_ev, evd_batch_ev);
}

template <typename T>
inline T *alloc_scratchpad(std::int64_t scratchpad_size,
std::int64_t n_linear_streams,
sycl::queue &exec_q)
{
// Get padding size to ensure memory allocations are aligned to 256 bytes
// for better performance
const std::int64_t padding = 256 / sizeof(T);

if (scratchpad_size <= 0) {
throw std::runtime_error(
"Invalid scratchpad size: must be greater than zero."
" Calculated scratchpad size: " +
std::to_string(scratchpad_size));
}

// Calculate the total scratchpad memory size needed for all linear
// streams with proper alignment
const size_t alloc_scratch_size =
helper::round_up_mult(n_linear_streams * scratchpad_size, padding);

// Allocate memory for the total scratchpad
T *scratchpad = sycl::malloc_device<T>(alloc_scratch_size, exec_q);
if (!scratchpad) {
throw std::runtime_error("Device allocation for scratchpad failed");
}

return scratchpad;
}
} // namespace dpnp::extensions::lapack::evd
211 changes: 135 additions & 76 deletions dpnp/backend/extensions/lapack/gesv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

// dpctl tensor headers
#include "utils/memory_overlap.hpp"
#include "utils/output_validation.hpp"
#include "utils/type_utils.hpp"

#include "common_helpers.hpp"
Expand All @@ -42,49 +43,150 @@ namespace mkl_lapack = oneapi::mkl::lapack;
namespace py = pybind11;
namespace type_utils = dpctl::tensor::type_utils;

typedef sycl::event (*gesv_impl_fn_ptr_t)(sycl::queue,
void common_gesv_checks(sycl::queue &exec_q,
vlad-perevezentsev marked this conversation as resolved.
Show resolved Hide resolved
dpctl::tensor::usm_ndarray coeff_matrix,
dpctl::tensor::usm_ndarray dependent_vals,
const py::ssize_t *coeff_matrix_shape,
const py::ssize_t *dependent_vals_shape,
const int expected_coeff_matrix_ndim,
const int min_dependent_vals_ndim,
const int max_dependent_vals_ndim)
{
const int coeff_matrix_nd = coeff_matrix.get_ndim();
const int dependent_vals_nd = dependent_vals.get_ndim();

if (coeff_matrix_nd != expected_coeff_matrix_ndim) {
throw py::value_error("The coefficient matrix has ndim=" +
std::to_string(coeff_matrix_nd) + ", but a " +
std::to_string(expected_coeff_matrix_ndim) +
"-dimensional array is expected.");
}

if (dependent_vals_nd < min_dependent_vals_ndim ||
dependent_vals_nd > max_dependent_vals_ndim)
{
throw py::value_error("The dependent values array has ndim=" +
std::to_string(dependent_vals_nd) + ", but a " +
std::to_string(min_dependent_vals_ndim) +
"-dimensional or a " +
std::to_string(max_dependent_vals_ndim) +
"-dimensional array is expected.");
}

// The coeff_matrix and dependent_vals arrays must be F-contiguous arrays
// for gesv
// with the shapes (n,n) and (n,nrhs) or (n,) respectively;
// for gesv_batch
// with the shapes (n,n,batch_size) and (n,nrhs,batch_size) or
// (n,batch_size) respectively
vlad-perevezentsev marked this conversation as resolved.
Show resolved Hide resolved
if (coeff_matrix_shape[0] != coeff_matrix_shape[1]) {
throw py::value_error("The coefficient matrix must be square,"
" but got a shape of (" +
std::to_string(coeff_matrix_shape[0]) + ", " +
std::to_string(coeff_matrix_shape[1]) + ").");
}
if (coeff_matrix_shape[0] != dependent_vals_shape[0]) {
throw py::value_error("The first dimension (n) of coeff_matrix and"
" dependent_vals must be the same, but got " +
std::to_string(coeff_matrix_shape[0]) + " and " +
std::to_string(dependent_vals_shape[0]) + ".");
}

// check compatibility of execution queue and allocation queue
if (!dpctl::utils::queues_are_compatible(exec_q,
{coeff_matrix, dependent_vals}))
{
throw py::value_error(
"Execution queue is not compatible with allocation queues.");
}

auto const &overlap = dpctl::tensor::overlap::MemoryOverlap();
if (overlap(coeff_matrix, dependent_vals)) {
throw py::value_error(
"The arrays of coefficients and dependent variables "
"are overlapping segments of memory.");
}

dpctl::tensor::validation::CheckWritable::throw_if_not_writable(
dependent_vals);

const bool is_coeff_matrix_f_contig = coeff_matrix.is_f_contiguous();
if (!is_coeff_matrix_f_contig) {
throw py::value_error("The coefficient matrix "
"must be F-contiguous.");
}

const bool is_dependent_vals_f_contig = dependent_vals.is_f_contiguous();
if (!is_dependent_vals_f_contig) {
throw py::value_error("The array of dependent variables "
"must be F-contiguous.");
}

auto array_types = dpctl_td_ns::usm_ndarray_types();
const int coeff_matrix_type_id =
array_types.typenum_to_lookup_id(coeff_matrix.get_typenum());
const int dependent_vals_type_id =
array_types.typenum_to_lookup_id(dependent_vals.get_typenum());

if (coeff_matrix_type_id != dependent_vals_type_id) {
throw py::value_error("The types of the coefficient matrix and "
"dependent variables are mismatched.");
}
}

typedef sycl::event (*gesv_impl_fn_ptr_t)(sycl::queue &,
const std::int64_t,
const std::int64_t,
char *,
std::int64_t,
char *,
std::int64_t,
std::vector<sycl::event> &,
const std::vector<sycl::event> &);

static gesv_impl_fn_ptr_t gesv_dispatch_vector[dpctl_td_ns::num_types];

template <typename T>
static sycl::event gesv_impl(sycl::queue exec_q,
static sycl::event gesv_impl(sycl::queue &exec_q,
const std::int64_t n,
const std::int64_t nrhs,
char *in_a,
std::int64_t lda,
char *in_b,
std::int64_t ldb,
std::vector<sycl::event> &host_task_events,
const std::vector<sycl::event> &depends)
{
type_utils::validate_type_for_device<T>(exec_q);

T *a = reinterpret_cast<T *>(in_a);
T *b = reinterpret_cast<T *>(in_b);

const std::int64_t lda = std::max<size_t>(1UL, n);
const std::int64_t ldb = std::max<size_t>(1UL, n);

const std::int64_t scratchpad_size =
mkl_lapack::gesv_scratchpad_size<T>(exec_q, n, nrhs, lda, ldb);

if (scratchpad_size <= 0) {
throw std::runtime_error(
"Invalid scratchpad size: must be greater than zero."
"Calculated scratchpad size: " +
std::to_string(scratchpad_size));
}

T *scratchpad = nullptr;
// Allocate memory for the scratchpad
scratchpad = sycl::malloc_device<T>(scratchpad_size, exec_q);
vlad-perevezentsev marked this conversation as resolved.
Show resolved Hide resolved
if (!scratchpad)
throw std::runtime_error("Device allocation for scratchpad failed");

std::int64_t *ipiv = nullptr;
// Allocate memory for the ipiv
ipiv = sycl::malloc_device<std::int64_t>(n, exec_q);
if (!ipiv)
throw std::runtime_error("Device allocation for ipiv failed");

std::stringstream error_msg;
std::int64_t info = 0;
bool is_exception_caught = false;

sycl::event gesv_event;
try {
scratchpad = sycl::malloc_device<T>(scratchpad_size, exec_q);
ipiv = sycl::malloc_device<std::int64_t>(n, exec_q);

gesv_event = mkl_lapack::gesv(
exec_q,
n, // The order of the square matrix A
Expand Down Expand Up @@ -156,88 +258,50 @@ static sycl::event gesv_impl(sycl::queue exec_q,
throw std::runtime_error(error_msg.str());
}

sycl::event clean_up_event = exec_q.submit([&](sycl::handler &cgh) {
sycl::event ht_ev = exec_q.submit([&](sycl::handler &cgh) {
cgh.depends_on(gesv_event);
auto ctx = exec_q.get_context();
cgh.host_task([ctx, scratchpad, ipiv]() {
sycl::free(scratchpad, ctx);
sycl::free(ipiv, ctx);
});
});
host_task_events.push_back(clean_up_event);

return gesv_event;
return ht_ev;
}

std::pair<sycl::event, sycl::event>
gesv(sycl::queue exec_q,
gesv(sycl::queue &exec_q,
dpctl::tensor::usm_ndarray coeff_matrix,
dpctl::tensor::usm_ndarray dependent_vals,
const std::vector<sycl::event> &depends)
{
const int coeff_matrix_nd = coeff_matrix.get_ndim();
const int dependent_vals_nd = dependent_vals.get_ndim();

if (coeff_matrix_nd != 2) {
throw py::value_error("The coefficient matrix has ndim=" +
std::to_string(coeff_matrix_nd) +
", but a 2-dimensional array is expected.");
}

if (dependent_vals_nd > 2) {
throw py::value_error(
"The dependent values array has ndim=" +
std::to_string(dependent_vals_nd) +
", but a 1-dimensional or a 2-dimensional array is expected.");
}

const py::ssize_t *coeff_matrix_shape = coeff_matrix.get_shape_raw();
const py::ssize_t *dependent_vals_shape = dependent_vals.get_shape_raw();

if (coeff_matrix_shape[0] != coeff_matrix_shape[1]) {
throw py::value_error("The coefficient matrix must be square,"
" but got a shape of (" +
std::to_string(coeff_matrix_shape[0]) + ", " +
std::to_string(coeff_matrix_shape[1]) + ").");
}
constexpr int expected_coeff_matrix_ndim = 2;
constexpr int min_dependent_vals_ndim = 1;
constexpr int max_dependent_vals_ndim = 2;

// check compatibility of execution queue and allocation queue
if (!dpctl::utils::queues_are_compatible(exec_q,
{coeff_matrix, dependent_vals}))
{
throw py::value_error(
"Execution queue is not compatible with allocation queues");
}

auto const &overlap = dpctl::tensor::overlap::MemoryOverlap();
if (overlap(coeff_matrix, dependent_vals)) {
throw py::value_error(
"The arrays of coefficients and dependent variables "
"are overlapping segments of memory");
}

bool is_coeff_matrix_f_contig = coeff_matrix.is_f_contiguous();
if (!is_coeff_matrix_f_contig) {
throw py::value_error("The coefficient matrix "
"must be F-contiguous");
}
common_gesv_checks(exec_q, coeff_matrix, dependent_vals, coeff_matrix_shape,
dependent_vals_shape, expected_coeff_matrix_ndim,
min_dependent_vals_ndim, max_dependent_vals_ndim);

bool is_dependent_vals_f_contig = dependent_vals.is_f_contiguous();
if (!is_dependent_vals_f_contig) {
throw py::value_error("The array of dependent variables "
"must be F-contiguous");
// Ensure `batch_size`, `n` and 'nrhs' are non-zero, otherwise return empty
// events
if (helper::check_zeros_shape(coeff_matrix_nd, coeff_matrix_shape) ||
helper::check_zeros_shape(dependent_vals_nd, dependent_vals_shape))
{
// nothing to do
return std::make_pair(sycl::event(), sycl::event());
}

auto array_types = dpctl_td_ns::usm_ndarray_types();
int coeff_matrix_type_id =
const int coeff_matrix_type_id =
array_types.typenum_to_lookup_id(coeff_matrix.get_typenum());
int dependent_vals_type_id =
array_types.typenum_to_lookup_id(dependent_vals.get_typenum());

if (coeff_matrix_type_id != dependent_vals_type_id) {
throw py::value_error("The types of the coefficient matrix and "
"dependent variables are mismatched");
}

gesv_impl_fn_ptr_t gesv_fn = gesv_dispatch_vector[coeff_matrix_type_id];
if (gesv_fn == nullptr) {
Expand All @@ -253,18 +317,13 @@ std::pair<sycl::event, sycl::event>
const std::int64_t nrhs =
(dependent_vals_nd > 1) ? dependent_vals_shape[1] : 1;

const std::int64_t lda = std::max<size_t>(1UL, n);
const std::int64_t ldb = std::max<size_t>(1UL, n);

std::vector<sycl::event> host_task_events;
sycl::event gesv_ev =
gesv_fn(exec_q, n, nrhs, coeff_matrix_data, lda, dependent_vals_data,
ldb, host_task_events, depends);
sycl::event gesv_ev = gesv_fn(exec_q, n, nrhs, coeff_matrix_data,
dependent_vals_data, depends);

sycl::event args_ev = dpctl::utils::keep_args_alive(
exec_q, {coeff_matrix, dependent_vals}, host_task_events);
sycl::event ht_ev = dpctl::utils::keep_args_alive(
exec_q, {coeff_matrix, dependent_vals}, {gesv_ev});

return std::make_pair(args_ev, gesv_ev);
return std::make_pair(ht_ev, gesv_ev);
}

template <typename fnT, typename T>
Expand Down
Loading
Loading