Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update OneMKL gemm_batch call inside dpnp.matmul and column_major version of gemm #1793

Merged
merged 13 commits into from
May 8, 2024
24 changes: 11 additions & 13 deletions dpnp/backend/extensions/blas/blas_py.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,13 @@ PYBIND11_MODULE(_blas_impl, m)
blas_ext::DotContigFactory>(
dot_dispatch_vector);

auto dot_pypi = [&](sycl::queue exec_q, arrayT src1, arrayT src2,
arrayT dst, const event_vecT &depends = {}) {
auto dot_pyapi = [&](sycl::queue exec_q, arrayT src1, arrayT src2,
arrayT dst, const event_vecT &depends = {}) {
return dot_ext::dot_func(exec_q, src1, src2, dst, depends,
dot_dispatch_vector);
};

m.def("_dot", dot_pypi,
m.def("_dot", dot_pyapi,
"Call `dot` from OneMKL BLAS library to return "
"the dot product of two real-valued vectors.",
py::arg("sycl_queue"), py::arg("vectorA"), py::arg("vectorB"),
Expand All @@ -82,13 +82,13 @@ PYBIND11_MODULE(_blas_impl, m)
blas_ext::DotcContigFactory>(
dotc_dispatch_vector);

auto dotc_pypi = [&](sycl::queue exec_q, arrayT src1, arrayT src2,
arrayT dst, const event_vecT &depends = {}) {
auto dotc_pyapi = [&](sycl::queue exec_q, arrayT src1, arrayT src2,
arrayT dst, const event_vecT &depends = {}) {
return dot_ext::dot_func(exec_q, src1, src2, dst, depends,
dotc_dispatch_vector);
};

m.def("_dotc", dotc_pypi,
m.def("_dotc", dotc_pyapi,
"Call `dotc` from OneMKL BLAS library to return "
"the dot product of two complex vectors, "
"conjugating the first vector.",
Expand All @@ -101,13 +101,13 @@ PYBIND11_MODULE(_blas_impl, m)
blas_ext::DotuContigFactory>(
dotu_dispatch_vector);

auto dotu_pypi = [&](sycl::queue exec_q, arrayT src1, arrayT src2,
arrayT dst, const event_vecT &depends = {}) {
auto dotu_pyapi = [&](sycl::queue exec_q, arrayT src1, arrayT src2,
arrayT dst, const event_vecT &depends = {}) {
return dot_ext::dot_func(exec_q, src1, src2, dst, depends,
dotu_dispatch_vector);
};

m.def("_dotu", dotu_pypi,
m.def("_dotu", dotu_pyapi,
"Call `dotu` from OneMKL BLAS library to return "
"the dot product of two complex vectors.",
py::arg("sycl_queue"), py::arg("vectorA"), py::arg("vectorB"),
Expand All @@ -119,16 +119,14 @@ PYBIND11_MODULE(_blas_impl, m)
"Call `gemm` from OneMKL BLAS library to return "
"the matrix-matrix product with 2-D matrices.",
py::arg("sycl_queue"), py::arg("matrixA"), py::arg("matrixB"),
py::arg("result"), py::arg("depends") = py::list());
py::arg("resultC"), py::arg("depends") = py::list());
}

{
m.def("_gemm_batch", &blas_ext::gemm_batch,
"Call `gemm_batch` from OneMKL BLAS library to return "
"the matrix-matrix product for a batch of 2-D matrices.",
py::arg("sycl_queue"), py::arg("matrixA"), py::arg("matrixB"),
py::arg("result"), py::arg("batch_size"), py::arg("stridea"),
py::arg("strideb"), py::arg("stridec"),
py::arg("depends") = py::list());
py::arg("resultC"), py::arg("depends") = py::list());
}
}
78 changes: 63 additions & 15 deletions dpnp/backend/extensions/blas/gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ typedef sycl::event (*gemm_impl_fn_ptr_t)(sycl::queue &,
const std::int64_t,
char *,
const std::int64_t,
bool,
const std::vector<sycl::event> &);

static gemm_impl_fn_ptr_t gemm_dispatch_table[dpctl_td_ns::num_types]
Expand All @@ -77,6 +78,7 @@ static sycl::event gemm_impl(sycl::queue &exec_q,
const std::int64_t ldb,
char *resultC,
const std::int64_t ldc,
bool is_row_major,
vtavana marked this conversation as resolved.
Show resolved Hide resolved
const std::vector<sycl::event> &depends)
{
type_utils::validate_type_for_device<Tab>(exec_q);
Expand All @@ -91,7 +93,25 @@ static sycl::event gemm_impl(sycl::queue &exec_q,

sycl::event gemm_event;
try {
gemm_event = mkl_blas::row_major::gemm(
auto gemm_func =
[&](sycl::queue &q, oneapi::mkl::transpose transA,
oneapi::mkl::transpose transB, std::int64_t m, std::int64_t n,
std::int64_t k, Tab alpha, const Tab *a, std::int64_t lda,
const Tab *b, std::int64_t ldb, Tab beta, Tc *c,
std::int64_t ldc,
const std::vector<sycl::event> &deps) -> sycl::event {
if (is_row_major) {
return mkl_blas::row_major::gemm(q, transA, transB, m, n, k,
alpha, a, lda, b, ldb, beta, c,
ldc, deps);
}
else {
return mkl_blas::column_major::gemm(q, transA, transB, m, n, k,
alpha, a, lda, b, ldb, beta,
c, ldc, deps);
}
};
gemm_event = gemm_func(
exec_q,
transA, // Defines the transpose operation for matrix A:
// 'N' indicates no transpose, 'T' for transpose,
Expand Down Expand Up @@ -130,7 +150,7 @@ static sycl::event gemm_impl(sycl::queue &exec_q,
return gemm_event;
}

std::pair<sycl::event, sycl::event>
std::tuple<sycl::event, sycl::event, bool>
gemm(sycl::queue &exec_q,
dpctl::tensor::usm_ndarray matrixA,
dpctl::tensor::usm_ndarray matrixB,
Expand Down Expand Up @@ -208,16 +228,44 @@ std::pair<sycl::event, sycl::event>
throw py::value_error(
"Result array is not c-contiguous nor f-contiguous.");
}
oneapi::mkl::transpose transA = is_matrixA_f_contig
? oneapi::mkl::transpose::T
: oneapi::mkl::transpose::N;
oneapi::mkl::transpose transB = is_matrixB_f_contig
? oneapi::mkl::transpose::T
: oneapi::mkl::transpose::N;
bool is_row_major = true;
if (is_matrixA_f_contig && is_matrixB_f_contig) {
is_row_major = false;
}
oneapi::mkl::transpose transA;
oneapi::mkl::transpose transB;
if (is_row_major) {
transA = is_matrixA_f_contig ? oneapi::mkl::transpose::T
: oneapi::mkl::transpose::N;
transB = is_matrixB_f_contig ? oneapi::mkl::transpose::T
: oneapi::mkl::transpose::N;
}
else {
transA = oneapi::mkl::transpose::N;
transB = oneapi::mkl::transpose::N;
}

const std::int64_t lda = (transA == oneapi::mkl::transpose::N) ? k : m;
const std::int64_t ldb = (transB == oneapi::mkl::transpose::N) ? n : k;
const std::int64_t ldc = n; // always n for row_major
std::int64_t lda;
std::int64_t ldb;
if (is_row_major) {
if (transA == oneapi::mkl::transpose::N) {
lda = k;
}
else {
lda = m;
}
if (transB == oneapi::mkl::transpose::N) {
ldb = n;
}
else {
ldb = k;
}
}
else {
lda = m;
ldb = k;
}
const std::int64_t ldc = is_row_major ? n : m;

int matrixA_typenum = matrixA.get_typenum();
int matrixB_typenum = matrixB.get_typenum();
Expand All @@ -242,14 +290,14 @@ std::pair<sycl::event, sycl::event>
char *b_typeless_ptr = matrixB.get_data();
char *r_typeless_ptr = resultC.get_data();

sycl::event gemm_ev =
gemm_fn(exec_q, transA, transB, m, n, k, a_typeless_ptr, lda,
b_typeless_ptr, ldb, r_typeless_ptr, ldc, depends);
sycl::event gemm_ev = gemm_fn(exec_q, transA, transB, m, n, k,
a_typeless_ptr, lda, b_typeless_ptr, ldb,
r_typeless_ptr, ldc, is_row_major, depends);

sycl::event args_ev = dpctl::utils::keep_args_alive(
exec_q, {matrixA, matrixB, resultC}, {gemm_ev});

return std::make_pair(args_ev, gemm_ev);
return std::make_tuple(args_ev, gemm_ev, is_row_major);
}

template <typename fnT, typename Tab, typename Tc>
Expand Down
8 changes: 2 additions & 6 deletions dpnp/backend/extensions/blas/gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,22 +38,18 @@ namespace ext
{
namespace blas
{
extern std::pair<sycl::event, sycl::event>
extern std::tuple<sycl::event, sycl::event, bool>
gemm(sycl::queue &exec_q,
dpctl::tensor::usm_ndarray matrixA,
dpctl::tensor::usm_ndarray matrixB,
dpctl::tensor::usm_ndarray resultC,
const std::vector<sycl::event> &depends);

extern std::pair<sycl::event, sycl::event>
extern std::tuple<sycl::event, sycl::event, bool>
gemm_batch(sycl::queue &exec_q,
dpctl::tensor::usm_ndarray matrixA,
dpctl::tensor::usm_ndarray matrixB,
dpctl::tensor::usm_ndarray resultC,
const std::int64_t batch_size,
size_t stridea,
size_t strideb,
size_t stridec,
const std::vector<sycl::event> &depends);

extern void init_gemm_dispatch_table(void);
Expand Down
Loading
Loading