Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update OneMKL gemm_batch call inside dpnp.matmul and column_major version of gemm #1793

Merged
merged 13 commits into from
May 8, 2024
24 changes: 11 additions & 13 deletions dpnp/backend/extensions/blas/blas_py.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,13 @@ PYBIND11_MODULE(_blas_impl, m)
blas_ext::DotContigFactory>(
dot_dispatch_vector);

auto dot_pypi = [&](sycl::queue exec_q, arrayT src1, arrayT src2,
arrayT dst, const event_vecT &depends = {}) {
auto dot_pyapi = [&](sycl::queue exec_q, arrayT src1, arrayT src2,
arrayT dst, const event_vecT &depends = {}) {
return dot_ext::dot_func(exec_q, src1, src2, dst, depends,
dot_dispatch_vector);
};

m.def("_dot", dot_pypi,
m.def("_dot", dot_pyapi,
"Call `dot` from OneMKL BLAS library to return "
"the dot product of two real-valued vectors.",
py::arg("sycl_queue"), py::arg("vectorA"), py::arg("vectorB"),
Expand All @@ -82,13 +82,13 @@ PYBIND11_MODULE(_blas_impl, m)
blas_ext::DotcContigFactory>(
dotc_dispatch_vector);

auto dotc_pypi = [&](sycl::queue exec_q, arrayT src1, arrayT src2,
arrayT dst, const event_vecT &depends = {}) {
auto dotc_pyapi = [&](sycl::queue exec_q, arrayT src1, arrayT src2,
arrayT dst, const event_vecT &depends = {}) {
return dot_ext::dot_func(exec_q, src1, src2, dst, depends,
dotc_dispatch_vector);
};

m.def("_dotc", dotc_pypi,
m.def("_dotc", dotc_pyapi,
"Call `dotc` from OneMKL BLAS library to return "
"the dot product of two complex vectors, "
"conjugating the first vector.",
Expand All @@ -101,13 +101,13 @@ PYBIND11_MODULE(_blas_impl, m)
blas_ext::DotuContigFactory>(
dotu_dispatch_vector);

auto dotu_pypi = [&](sycl::queue exec_q, arrayT src1, arrayT src2,
arrayT dst, const event_vecT &depends = {}) {
auto dotu_pyapi = [&](sycl::queue exec_q, arrayT src1, arrayT src2,
arrayT dst, const event_vecT &depends = {}) {
return dot_ext::dot_func(exec_q, src1, src2, dst, depends,
dotu_dispatch_vector);
};

m.def("_dotu", dotu_pypi,
m.def("_dotu", dotu_pyapi,
"Call `dotu` from OneMKL BLAS library to return "
"the dot product of two complex vectors.",
py::arg("sycl_queue"), py::arg("vectorA"), py::arg("vectorB"),
Expand All @@ -119,16 +119,14 @@ PYBIND11_MODULE(_blas_impl, m)
"Call `gemm` from OneMKL BLAS library to return "
"the matrix-matrix product with 2-D matrices.",
py::arg("sycl_queue"), py::arg("matrixA"), py::arg("matrixB"),
py::arg("result"), py::arg("depends") = py::list());
py::arg("resultC"), py::arg("depends") = py::list());
}

{
m.def("_gemm_batch", &blas_ext::gemm_batch,
"Call `gemm_batch` from OneMKL BLAS library to return "
"the matrix-matrix product for a batch of 2-D matrices.",
py::arg("sycl_queue"), py::arg("matrixA"), py::arg("matrixB"),
py::arg("result"), py::arg("batch_size"), py::arg("stridea"),
py::arg("strideb"), py::arg("stridec"),
py::arg("depends") = py::list());
py::arg("resultC"), py::arg("depends") = py::list());
}
}
4 changes: 0 additions & 4 deletions dpnp/backend/extensions/blas/gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,6 @@ extern std::pair<sycl::event, sycl::event>
dpctl::tensor::usm_ndarray matrixA,
dpctl::tensor::usm_ndarray matrixB,
dpctl::tensor::usm_ndarray resultC,
const std::int64_t batch_size,
size_t stridea,
size_t strideb,
size_t stridec,
const std::vector<sycl::event> &depends);

extern void init_gemm_dispatch_table(void);
Expand Down
75 changes: 41 additions & 34 deletions dpnp/backend/extensions/blas/gemm_batch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,10 +150,6 @@ std::pair<sycl::event, sycl::event>
dpctl::tensor::usm_ndarray matrixA,
dpctl::tensor::usm_ndarray matrixB,
dpctl::tensor::usm_ndarray resultC,
const std::int64_t batch_size,
size_t stridea,
size_t strideb,
size_t stridec,
const std::vector<sycl::event> &depends = {})
{
const int matrixA_nd = matrixA.get_ndim();
Expand Down Expand Up @@ -185,49 +181,60 @@ std::pair<sycl::event, sycl::event>
const py::ssize_t *a_shape = matrixA.get_shape_raw();
const py::ssize_t *b_shape = matrixB.get_shape_raw();
const py::ssize_t *c_shape = resultC.get_shape_raw();
const std::int64_t m = a_shape[matrixA_nd - 2];
const std::int64_t n = b_shape[matrixB_nd - 1];
const std::int64_t k = a_shape[matrixA_nd - 1];
if (a_shape[matrixA_nd - 1] != b_shape[matrixB_nd - 2]) {
const std::int64_t m = a_shape[1];
const std::int64_t n = b_shape[2];
const std::int64_t k = a_shape[2];
const std::int64_t batch_size = c_shape[0];
if (a_shape[2] != b_shape[1]) {
throw py::value_error("The number of columns in A must be equal to "
"the number of rows in B.");
}
if (a_shape[matrixA_nd - 2] != c_shape[resultC_nd - 2]) {
if (a_shape[1] != c_shape[1]) {
throw py::value_error("The number of rows in A must be equal to "
"the number of rows in result array.");
}
if (b_shape[matrixB_nd - 1] != c_shape[resultC_nd - 1]) {
if (b_shape[2] != c_shape[2]) {
throw py::value_error("The number of columns in B must be equal to "
"the number of columns in result array.");
}

bool shapes_equal = true;
size_t src_nelems = 1;
py::ssize_t lead_dim;
for (int i = 0; i < matrixA_nd - 2; ++i) {
if (a_shape[i] == b_shape[i]) {
lead_dim = a_shape[i];
}
else if (a_shape[i] == 1 || b_shape[i] == 1) {
lead_dim = std::max(a_shape[i], b_shape[i]);
}
else {
throw py::value_error("Array shapes do not match.");
}
src_nelems *= static_cast<size_t>(lead_dim);
shapes_equal = shapes_equal && (lead_dim == c_shape[i]);
std::int64_t first_dim;
if (a_shape[0] == b_shape[0]) {
first_dim = a_shape[0];
}
else if (a_shape[0] == 1 || b_shape[0] == 1) {
first_dim = std::max(a_shape[0], b_shape[0]);
}
src_nelems *= (m * n);
if (!shapes_equal) {
else {
throw py::value_error("Array shapes do not match.");
}
if (first_dim != c_shape[0]) {
throw py::value_error("Array shapes do not match.");
}
std::int64_t src_nelems = first_dim * m * n;
dpctl::tensor::validation::CheckWritable::throw_if_not_writable(resultC);
dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(resultC,
src_nelems);

// transA and transB are always False
oneapi::mkl::transpose transA = oneapi::mkl::transpose::N;
oneapi::mkl::transpose transB = oneapi::mkl::transpose::N;
std::vector<py::ssize_t> a_stride = matrixA.get_strides_vector();
std::vector<py::ssize_t> b_stride = matrixB.get_strides_vector();
std::vector<py::ssize_t> c_stride = resultC.get_strides_vector();
const std::int64_t stridea = a_stride[0];
const std::int64_t strideb = b_stride[0];
const std::int64_t stridec = c_stride[0];
bool A_base_is_f_contig = a_stride[1] == 1 && a_stride[2] == a_shape[1];
bool B_base_is_f_contig = b_stride[1] == 1 && b_stride[2] == b_shape[1];

oneapi::mkl::transpose transA = A_base_is_f_contig
? oneapi::mkl::transpose::T
: oneapi::mkl::transpose::N;
oneapi::mkl::transpose transB = B_base_is_f_contig
? oneapi::mkl::transpose::T
: oneapi::mkl::transpose::N;

const std::int64_t lda = (transA == oneapi::mkl::transpose::N) ? k : m;
const std::int64_t ldb = (transB == oneapi::mkl::transpose::N) ? n : k;
const std::int64_t ldc = n; // always n for row_major

int matrixA_typenum = matrixA.get_typenum();
int matrixB_typenum = matrixB.get_typenum();
Expand All @@ -252,10 +259,10 @@ std::pair<sycl::event, sycl::event>
char *b_typeless_ptr = matrixB.get_data();
char *r_typeless_ptr = resultC.get_data();

// Note that lda = k, ldb = n, and ld_result = n
sycl::event gemm_batch_ev = gemm_batch_fn(
exec_q, m, n, k, batch_size, k, n, n, stridea, strideb, stridec, transA,
transB, a_typeless_ptr, b_typeless_ptr, r_typeless_ptr, depends);
sycl::event gemm_batch_ev =
gemm_batch_fn(exec_q, m, n, k, batch_size, lda, ldb, ldc, stridea,
strideb, stridec, transA, transB, a_typeless_ptr,
b_typeless_ptr, r_typeless_ptr, depends);

sycl::event args_batch_ev = dpctl::utils::keep_args_alive(
exec_q, {matrixA, matrixB, resultC}, {gemm_batch_ev});
Expand Down
Loading
Loading