Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update OneMKL gemm_batch call inside dpnp.matmul and column_major version of gemm #1793

Merged
merged 13 commits into from
May 8, 2024
122 changes: 88 additions & 34 deletions dpnp/backend/extensions/blas/gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ typedef sycl::event (*gemm_impl_fn_ptr_t)(sycl::queue &,
const std::int64_t,
char *,
const std::int64_t,
bool,
const std::vector<sycl::event> &);

static gemm_impl_fn_ptr_t gemm_dispatch_table[dpctl_td_ns::num_types]
Expand All @@ -77,6 +78,7 @@ static sycl::event gemm_impl(sycl::queue &exec_q,
const std::int64_t ldb,
char *resultC,
const std::int64_t ldc,
bool is_row_major,
vtavana marked this conversation as resolved.
Show resolved Hide resolved
const std::vector<sycl::event> &depends)
{
type_utils::validate_type_for_device<Tab>(exec_q);
Expand All @@ -91,26 +93,50 @@ static sycl::event gemm_impl(sycl::queue &exec_q,

sycl::event gemm_event;
try {
gemm_event = mkl_blas::row_major::gemm(
exec_q,
transA, // Defines the transpose operation for matrix A:
// 'N' indicates no transpose, 'T' for transpose,
// or 'C' for a conjugate transpose.
transB, // Same as transA but for matrix B.
m, // Number of rows in matrices A and C.
n, // Number of columns in matrices B and C.
k, // Number of columns in matrix A and rows in matrix B.
Tab(1), // Scaling factor for the product of matrices A and B.
a, // Pointer to matrix A.
lda, // Leading dimension of matrix A, which is the
// stride between successive rows (for row major
// layout).
b, // Pointer to matrix B.
ldb, // Leading dimension of matrix B, similar to lda.
Tab(0), // Scaling factor for matrix C.
res, // Pointer to matrix C, where the result is stored.
ldc, // Leading dimension of matrix C.
depends);
if (is_row_major) {
gemm_event = mkl_blas::row_major::gemm(
exec_q,
transA, // Defines the transpose operation for matrix A:
// 'N' indicates no transpose, 'T' for transpose,
// or 'C' for a conjugate transpose.
transB, // Same as transA but for matrix B.
m, // Number of rows in matrices A and C.
n, // Number of columns in matrices B and C.
k, // Number of columns in matrix A and rows in matrix B.
Tab(1), // Scaling factor for the product of matrices A and B.
a, // Pointer to matrix A.
lda, // Leading dimension of matrix A, which is the
// stride between successive rows (for row major
// layout).
b, // Pointer to matrix B.
ldb, // Leading dimension of matrix B, similar to lda.
Tab(0), // Scaling factor for matrix C.
res, // Pointer to matrix C, where the result is stored.
ldc, // Leading dimension of matrix C.
depends);
}
else {
gemm_event = mkl_blas::column_major::gemm(
exec_q,
transA, // Defines the transpose operation for matrix A:
// 'N' indicates no transpose, 'T' for transpose,
// or 'C' for a conjugate transpose.
transB, // Same as transA but for matrix B.
m, // Number of rows in matrices A and C.
n, // Number of columns in matrices B and C.
k, // Number of columns in matrix A and rows in matrix B.
Tab(1), // Scaling factor for the product of matrices A and B.
a, // Pointer to matrix A.
lda, // Leading dimension of matrix A, which is the
// stride between successive rows (for row major
// layout).
b, // Pointer to matrix B.
ldb, // Leading dimension of matrix B, similar to lda.
Tab(0), // Scaling factor for matrix C.
res, // Pointer to matrix C, where the result is stored.
ldc, // Leading dimension of matrix C.
depends);
}
} catch (oneapi::mkl::exception const &e) {
error_msg
<< "Unexpected MKL exception caught during gemm() call:\nreason: "
Expand All @@ -130,7 +156,7 @@ static sycl::event gemm_impl(sycl::queue &exec_q,
return gemm_event;
}

std::pair<sycl::event, sycl::event>
std::tuple<sycl::event, sycl::event, bool>
gemm(sycl::queue &exec_q,
dpctl::tensor::usm_ndarray matrixA,
dpctl::tensor::usm_ndarray matrixB,
Expand Down Expand Up @@ -208,16 +234,44 @@ std::pair<sycl::event, sycl::event>
throw py::value_error(
"Result array is not c-contiguous nor f-contiguous.");
}
oneapi::mkl::transpose transA = is_matrixA_f_contig
? oneapi::mkl::transpose::T
: oneapi::mkl::transpose::N;
oneapi::mkl::transpose transB = is_matrixB_f_contig
? oneapi::mkl::transpose::T
: oneapi::mkl::transpose::N;
bool is_row_major = true;
if (is_matrixA_f_contig && is_matrixB_f_contig) {
is_row_major = false;
}
oneapi::mkl::transpose transA;
oneapi::mkl::transpose transB;
if (is_row_major) {
transA = is_matrixA_f_contig ? oneapi::mkl::transpose::T
: oneapi::mkl::transpose::N;
transB = is_matrixB_f_contig ? oneapi::mkl::transpose::T
: oneapi::mkl::transpose::N;
}
else {
transA = oneapi::mkl::transpose::N;
transB = oneapi::mkl::transpose::N;
}

const std::int64_t lda = (transA == oneapi::mkl::transpose::N) ? k : m;
const std::int64_t ldb = (transB == oneapi::mkl::transpose::N) ? n : k;
const std::int64_t ldc = n; // always n for row_major
std::int64_t lda;
std::int64_t ldb;
if (is_row_major) {
if (transA == oneapi::mkl::transpose::N) {
lda = k;
}
else {
lda = m;
}
if (transB == oneapi::mkl::transpose::N) {
ldb = n;
}
else {
ldb = k;
}
}
else {
lda = m;
ldb = k;
}
const std::int64_t ldc = is_row_major ? n : m;

int matrixA_typenum = matrixA.get_typenum();
int matrixB_typenum = matrixB.get_typenum();
Expand All @@ -242,14 +296,14 @@ std::pair<sycl::event, sycl::event>
char *b_typeless_ptr = matrixB.get_data();
char *r_typeless_ptr = resultC.get_data();

sycl::event gemm_ev =
gemm_fn(exec_q, transA, transB, m, n, k, a_typeless_ptr, lda,
b_typeless_ptr, ldb, r_typeless_ptr, ldc, depends);
sycl::event gemm_ev = gemm_fn(exec_q, transA, transB, m, n, k,
a_typeless_ptr, lda, b_typeless_ptr, ldb,
r_typeless_ptr, ldc, is_row_major, depends);

sycl::event args_ev = dpctl::utils::keep_args_alive(
exec_q, {matrixA, matrixB, resultC}, {gemm_ev});

return std::make_pair(args_ev, gemm_ev);
return std::make_tuple(args_ev, gemm_ev, is_row_major);
}

template <typename fnT, typename Tab, typename Tc>
Expand Down
4 changes: 2 additions & 2 deletions dpnp/backend/extensions/blas/gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,14 @@ namespace ext
{
namespace blas
{
extern std::pair<sycl::event, sycl::event>
extern std::tuple<sycl::event, sycl::event, bool>
gemm(sycl::queue &exec_q,
dpctl::tensor::usm_ndarray matrixA,
dpctl::tensor::usm_ndarray matrixB,
dpctl::tensor::usm_ndarray resultC,
const std::vector<sycl::event> &depends);

extern std::pair<sycl::event, sycl::event>
extern std::tuple<sycl::event, sycl::event, bool>
gemm_batch(sycl::queue &exec_q,
dpctl::tensor::usm_ndarray matrixA,
dpctl::tensor::usm_ndarray matrixB,
Expand Down
135 changes: 98 additions & 37 deletions dpnp/backend/extensions/blas/gemm_batch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ typedef sycl::event (*gemm_batch_impl_fn_ptr_t)(
char *,
char *,
char *,
bool,
const std::vector<sycl::event> &);

static gemm_batch_impl_fn_ptr_t
Expand All @@ -86,6 +87,7 @@ static sycl::event gemm_batch_impl(sycl::queue &exec_q,
char *matrixA,
char *matrixB,
char *resultC,
bool is_row_major,
const std::vector<sycl::event> &depends)
{
type_utils::validate_type_for_device<Tab>(exec_q);
Expand All @@ -100,31 +102,60 @@ static sycl::event gemm_batch_impl(sycl::queue &exec_q,

sycl::event gemm_batch_event;
try {
gemm_batch_event = mkl_blas::row_major::gemm_batch(
exec_q,
transA, // Defines the transpose operation for matrix A:
// 'N' indicates no transpose, 'T' for transpose,
// or 'C' for a conjugate transpose.
transB, // Same as transA but for matrix B.
m, // Number of rows in matrices A and C.
n, // Number of columns in matrices B and C.
k, // Number of columns in matrix A and rows in matrix B.
Tab(1), // Scaling factor for the product of matrices A and B.
a, // Pointer to matrix A.
lda, // Leading dimension of matrix A, which is the
// stride between successive rows (for row major
// layout).
stridea, // Stride between different A matrices.
b, // Pointer to matrix B.
ldb, // Leading dimension of matrix B, similar to lda.
strideb, // Stride between different B matrices.
Tab(0), // Scaling factor for matrix C.
res, // Pointer to matrix C, where the result is stored.
ld_result, // Leading dimension of matrix C.
stridec, // Stride between different C matrices.
batch_size, // Specifies the number of matrix multiply operations to
// perform.
depends);
if (is_row_major) {
gemm_batch_event = mkl_blas::row_major::gemm_batch(
exec_q,
transA, // Defines the transpose operation for matrix A:
// 'N' indicates no transpose, 'T' for transpose,
// or 'C' for a conjugate transpose.
transB, // Same as transA but for matrix B.
m, // Number of rows in matrices A and C.
n, // Number of columns in matrices B and C.
k, // Number of columns in matrix A and rows in matrix B.
Tab(1), // Scaling factor for the product of matrices A and B.
a, // Pointer to matrix A.
lda, // Leading dimension of matrix A, which is the
// stride between successive rows (for row major
// layout).
stridea, // Stride between different A matrices.
b, // Pointer to matrix B.
ldb, // Leading dimension of matrix B, similar to lda.
strideb, // Stride between different B matrices.
Tab(0), // Scaling factor for matrix C.
res, // Pointer to matrix C, where the result is stored.
ld_result, // Leading dimension of matrix C.
stridec, // Stride between different C matrices.
batch_size, // Specifies the number of matrix multiply
// operations to perform.
depends);
}
else {
gemm_batch_event = mkl_blas::column_major::gemm_batch(
exec_q,
transA, // Defines the transpose operation for matrix A:
// 'N' indicates no transpose, 'T' for transpose,
// or 'C' for a conjugate transpose.
transB, // Same as transA but for matrix B.
m, // Number of rows in matrices A and C.
n, // Number of columns in matrices B and C.
k, // Number of columns in matrix A and rows in matrix B.
Tab(1), // Scaling factor for the product of matrices A and B.
a, // Pointer to matrix A.
lda, // Leading dimension of matrix A, which is the
// stride between successive rows (for row major
// layout).
stridea, // Stride between different A matrices.
b, // Pointer to matrix B.
ldb, // Leading dimension of matrix B, similar to lda.
strideb, // Stride between different B matrices.
Tab(0), // Scaling factor for matrix C.
res, // Pointer to matrix C, where the result is stored.
ld_result, // Leading dimension of matrix C.
stridec, // Stride between different C matrices.
batch_size, // Specifies the number of matrix multiply
// operations to perform.
depends);
}
} catch (oneapi::mkl::exception const &e) {
error_msg << "Unexpected MKL exception caught during gemm_batch() "
"call:\nreason: "
Expand All @@ -145,7 +176,7 @@ static sycl::event gemm_batch_impl(sycl::queue &exec_q,
return gemm_batch_event;
}

std::pair<sycl::event, sycl::event>
std::tuple<sycl::event, sycl::event, bool>
gemm_batch(sycl::queue &exec_q,
dpctl::tensor::usm_ndarray matrixA,
dpctl::tensor::usm_ndarray matrixB,
Expand Down Expand Up @@ -222,19 +253,49 @@ std::pair<sycl::event, sycl::event>
const std::int64_t stridea = a_stride[0];
const std::int64_t strideb = b_stride[0];
const std::int64_t stridec = c_stride[0];

bool A_base_is_f_contig = a_stride[1] == 1 && a_stride[2] == a_shape[1];
bool B_base_is_f_contig = b_stride[1] == 1 && b_stride[2] == b_shape[1];

oneapi::mkl::transpose transA = A_base_is_f_contig
? oneapi::mkl::transpose::T
: oneapi::mkl::transpose::N;
oneapi::mkl::transpose transB = B_base_is_f_contig
? oneapi::mkl::transpose::T
: oneapi::mkl::transpose::N;
bool is_row_major = true;
if (A_base_is_f_contig && B_base_is_f_contig) {
is_row_major = false;
}

oneapi::mkl::transpose transA;
oneapi::mkl::transpose transB;
if (is_row_major) {
transA = A_base_is_f_contig ? oneapi::mkl::transpose::T
: oneapi::mkl::transpose::N;
transB = B_base_is_f_contig ? oneapi::mkl::transpose::T
: oneapi::mkl::transpose::N;
}
else {
transA = oneapi::mkl::transpose::N;
transB = oneapi::mkl::transpose::N;
}

const std::int64_t lda = (transA == oneapi::mkl::transpose::N) ? k : m;
const std::int64_t ldb = (transB == oneapi::mkl::transpose::N) ? n : k;
const std::int64_t ldc = n; // always n for row_major
std::int64_t lda;
std::int64_t ldb;
if (is_row_major) {
if (transA == oneapi::mkl::transpose::N) {
lda = k;
}
else {
lda = m;
}
if (transB == oneapi::mkl::transpose::N) {
ldb = n;
}
else {
ldb = k;
}
}
else {
lda = m;
ldb = k;
}
const std::int64_t ldc = is_row_major ? n : m;

int matrixA_typenum = matrixA.get_typenum();
int matrixB_typenum = matrixB.get_typenum();
Expand Down Expand Up @@ -262,12 +323,12 @@ std::pair<sycl::event, sycl::event>
sycl::event gemm_batch_ev =
gemm_batch_fn(exec_q, m, n, k, batch_size, lda, ldb, ldc, stridea,
strideb, stridec, transA, transB, a_typeless_ptr,
b_typeless_ptr, r_typeless_ptr, depends);
b_typeless_ptr, r_typeless_ptr, is_row_major, depends);

sycl::event args_batch_ev = dpctl::utils::keep_args_alive(
exec_q, {matrixA, matrixB, resultC}, {gemm_batch_ev});

return std::make_pair(args_batch_ev, gemm_batch_ev);
return std::make_tuple(args_batch_ev, gemm_batch_ev, is_row_major);
}

template <typename fnT, typename Tab, typename Tc>
Expand Down
Loading
Loading