From c521e41aef17aac403ba457a4d4b77d996750644 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Mon, 5 Feb 2024 05:46:57 -0600 Subject: [PATCH 1/2] Reduce code duplication Wrote functors using 'if constexpr' to avoid use of m_groups = 1 specializations. --- .../include/kernels/linalg_functions/gemm.hpp | 1754 ++++------------- 1 file changed, 374 insertions(+), 1380 deletions(-) diff --git a/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp b/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp index 0d90917885..fbd6402924 100644 --- a/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp +++ b/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp @@ -21,6 +21,8 @@ namespace tensor namespace kernels { +using dpctl::tensor::ssize_t; + namespace gemm_detail { @@ -361,6 +363,7 @@ sycl::event tree_reduction_for_gemm(sycl::queue &exec_q, using KernelName = class gemm_tree_reduction_krn< T, T, ReductionOpT, InputOutputIterIndexerT, ReductionIndexerT>; + cgh.parallel_for( sycl::nd_range<1>(globalRange, localRange), ReductionOverGroupNoAtomicFunctor it) const { - size_t gr_id = it.get_group_linear_id(); + const size_t gr_id = it.get_group_linear_id(); // lift group_id to (block_i, block_j, block_s), // 0 <= block_i < n_blocks, 0 <= block_j < m_blocks, 0 <= block_s < // k_blocks - size_t block_i = gr_id / (m_blocks * k_blocks); - size_t block_r = gr_id - block_i * (m_blocks * k_blocks); - size_t block_j = block_r / k_blocks; - size_t block_s = block_r - block_j * k_blocks; + const size_t block_i = gr_id / (m_blocks * k_blocks); + const size_t block_r = gr_id - block_i * (m_blocks * k_blocks); + const size_t block_j = block_r / k_blocks; + const size_t block_s = block_r - block_j * k_blocks; - size_t lid = it.get_local_linear_id(); - size_t local_i = lid / wg_delta_m; // 0<= local_i < wg_delta_n - size_t local_j = lid - local_i * wg_delta_m; // 0<= local_j < wg_delta_m + const size_t lid = it.get_local_linear_id(); + const size_t local_i = lid / wg_delta_m; // 0<= local_i < wg_delta_n + const size_t local_j = + lid - local_i * wg_delta_m; // 0<= local_j < wg_delta_m // load A block and B blocks into SLM size_t i = block_i * wi_delta_n * wg_delta_n; size_t j = block_j * wi_delta_m * wg_delta_m; - size_t s = block_s * wi_delta_k; + const size_t s = block_s * wi_delta_k; const std::int64_t a_st0 = k; const std::int64_t a_st1 = 1; @@ -670,11 +674,12 @@ class GemmFunctorThreadNM size_t lws = it.get_local_range(0); for (size_t vid = lid; vid < local_A_block.size(); vid += lws) { - size_t v_i = vid / wi_delta_k; // 0<= v_i < wg_delta_n * wi_delta_n - size_t v_s = vid - v_i * wi_delta_k; // 0<= v_s < wi_delta_k + const size_t v_i = + vid / wi_delta_k; // 0<= v_i < wg_delta_n * wi_delta_n + const size_t v_s = vid - v_i * wi_delta_k; // 0<= v_s < wi_delta_k - size_t g_i = i + v_i; - size_t g_s = s + v_s; + const size_t g_i = i + v_i; + const size_t g_s = s + v_s; local_A_block[vid] = (g_i < n && g_s < k) @@ -683,25 +688,37 @@ class GemmFunctorThreadNM : resT(0); } + using slmB_t = typename LocAccT2::value_type; + for (size_t vid = lid; vid < local_B_block.size(); vid += lws) { - size_t v_j = vid / wi_delta_k; // 0 <= v_i < wg_delta_m - size_t v_s = vid - v_j * wi_delta_k; // 0 <= v_s < wi_delta_k + const size_t v_j = vid / wi_delta_k; // 0 <= v_i < wg_delta_m + const size_t v_s = vid - v_j * wi_delta_k; // 0 <= v_s < wi_delta_k - size_t g_j0 = j + v_j * wi_delta_m; - size_t g_s = s + v_s; + const size_t g_j = j + v_j * wi_delta_m; + const size_t g_s = s + v_s; - sycl::vec vec{}; -#pragma unroll - for (std::uint8_t lane_id = 0; lane_id < wi_delta_m; ++lane_id) { - size_t g_j = g_j0 + lane_id; - vec[lane_id] = + if constexpr (wi_delta_m == 1 && std::is_same_v) { + local_B_block[vid] = (g_j < m && g_s < k) ? static_cast( rhs[rhs_indexer(g_s * b_st0 + g_j * b_st1)]) : resT(0); } + else { + slmB_t vec{}; +#pragma unroll + for (std::uint8_t lane_id = 0; lane_id < wi_delta_m; ++lane_id) + { + const size_t g_j1 = g_j + lane_id; + vec[lane_id] = + (g_j1 < m && g_s < k) + ? static_cast( + rhs[rhs_indexer(g_s * b_st0 + g_j1 * b_st1)]) + : resT(0); + } - local_B_block[vid] = vec; + local_B_block[vid] = vec; + } } it.barrier(sycl::access::fence_space::local_space); @@ -709,189 +726,49 @@ class GemmFunctorThreadNM i += local_i * wi_delta_n; j += local_j * wi_delta_m; - size_t a_offset = local_i * wi_delta_k * wi_delta_n; - size_t b_offset = local_j * wi_delta_k; + const size_t a_offset = local_i * wi_delta_k * wi_delta_n; + const size_t b_offset = local_j * wi_delta_k; constexpr resT identity_(0); for (std::uint8_t private_i = 0; private_i < wi_delta_n; ++private_i) { size_t a_pr_offset = private_i * wi_delta_k; - sycl::vec local_sum(identity_); + slmB_t local_sum(identity_); for (size_t private_s = 0; private_s < wi_delta_k; ++private_s) { local_sum = local_sum + (local_A_block[a_offset + a_pr_offset + private_s] * local_B_block[b_offset + private_s]); } - size_t gl_i = i + private_i; - -#pragma unroll - for (std::uint8_t lane_id = 0; lane_id < wi_delta_m; ++lane_id) { - size_t gl_j = j + lane_id; + const size_t gl_i = i + private_i; + if constexpr (wi_delta_m == 1 && std::is_same_v) { + const size_t gl_j = j; if (gl_i < n && gl_j < m) { sycl::atomic_ref aout(res[res_indexer(gl_i * c_st0 + gl_j * c_st1)]); - aout += local_sum[lane_id]; + aout += local_sum; } } - } - } -}; - -// specialization for wi_delta_m == 1 -template -class GemmFunctorThreadNM -{ -private: - const lhsT *lhs = nullptr; - const rhsT *rhs = nullptr; - resT *res = nullptr; - LocAccT1 local_A_block; - LocAccT2 local_B_block; - size_t n = 0; - size_t wg_delta_n = 0; - size_t k = 0; - size_t k_blocks = 0; - size_t wi_delta_k = 0; - size_t m = 0; - size_t m_blocks = 0; - size_t wg_delta_m = 0; - OuterInnerDimsIndexerT lhs_indexer; - OuterInnerDimsIndexerT rhs_indexer; - OuterInnerDimsIndexerT res_indexer; - -public: - GemmFunctorThreadNM(const lhsT *lhs_, - const rhsT *rhs_, - resT *res_, - LocAccT1 local_A_block_, - LocAccT2 local_B_block_, - size_t n_, - size_t wg_delta_n_, - size_t k_, - size_t k_blocks_, - size_t wi_delta_k_, - size_t m_, - size_t m_blocks_, - size_t wg_delta_m_, - OuterInnerDimsIndexerT lhs_indexer_, - OuterInnerDimsIndexerT rhs_indexer_, - OuterInnerDimsIndexerT res_indexer_) - : lhs(lhs_), rhs(rhs_), res(res_), local_A_block(local_A_block_), - local_B_block(local_B_block_), n(n_), wg_delta_n(wg_delta_n_), k(k_), - k_blocks(k_blocks_), wi_delta_k(wi_delta_k_), m(m_), - m_blocks(m_blocks_), wg_delta_m(wg_delta_m_), - lhs_indexer(lhs_indexer_), rhs_indexer(rhs_indexer_), - res_indexer(res_indexer_) - { - } - - void operator()(sycl::nd_item<1> it) const - { - size_t gr_id = it.get_group_linear_id(); - // lift group_id to (block_i, block_j, block_s), - // 0 <= block_i < n_blocks, 0 <= block_j < m_blocks, 0 <= block_s < - // k_blocks - size_t block_i = gr_id / (m_blocks * k_blocks); - size_t block_r = gr_id - block_i * (m_blocks * k_blocks); - size_t block_j = block_r / k_blocks; - size_t block_s = block_r - block_j * k_blocks; - - size_t lid = it.get_local_linear_id(); - size_t local_i = lid / wg_delta_m; // 0<= local_i < wg_delta_n - size_t local_j = lid - local_i * wg_delta_m; // 0<= local_j < wg_delta_m - - // load A block and B blocks into SLM - - size_t i = block_i * wi_delta_n * wg_delta_n; - size_t j = block_j * wg_delta_m; - size_t s = block_s * wi_delta_k; - - const std::int64_t a_st0 = k; - const std::int64_t a_st1 = 1; - - const std::int64_t b_st0 = m; - const std::int64_t b_st1 = 1; - - const std::int64_t c_st0 = m; - const std::int64_t c_st1 = 1; - - size_t lws = it.get_local_range(0); - - for (size_t vid = lid; vid < local_A_block.size(); vid += lws) { - size_t v_i = vid / wi_delta_k; // 0<= v_i < wg_delta_n * wi_delta_n - size_t v_s = vid - v_i * wi_delta_k; // 0<= v_s < wi_delta_k - - size_t g_i = i + v_i; - size_t g_s = s + v_s; - - local_A_block[vid] = - (g_i < n && g_s < k) - ? static_cast( - lhs[lhs_indexer(g_i * a_st0 + g_s * a_st1)]) - : resT(0); - } - - for (size_t vid = lid; vid < local_B_block.size(); vid += lws) { - size_t v_j = vid / wi_delta_k; // 0<= v_i < wg_delta_m - size_t v_s = vid - v_j * wi_delta_k; // 0<= v_s < wi_delta_k - - size_t g_j0 = j + v_j; - size_t g_s = s + v_s; - - resT val = (g_j0 < m && g_s < k) - ? static_cast( - rhs[rhs_indexer(g_s * b_st0 + g_j0 * b_st1)]) - : resT(0); - - local_B_block[vid] = val; - } - - it.barrier(sycl::access::fence_space::local_space); - - i += local_i * wi_delta_n; - j += local_j; - - size_t a_offset = local_i * wi_delta_k * wi_delta_n; - size_t b_offset = local_j * wi_delta_k; - - constexpr resT identity_(0); - for (std::uint8_t private_i = 0; private_i < wi_delta_n; ++private_i) { - size_t a_pr_offset = private_i * wi_delta_k; - - resT local_sum(identity_); - for (size_t private_s = 0; private_s < wi_delta_k; ++private_s) { - local_sum = local_sum + - (local_A_block[a_offset + a_pr_offset + private_s] * - local_B_block[b_offset + private_s]); - } - - size_t gl_i = i + private_i; - - if (gl_i < n && j < m) { - sycl::atomic_ref - aout(res[res_indexer(gl_i * c_st0 + j * c_st1)]); - - aout += local_sum; + else { +#pragma unroll + for (std::uint8_t lane_id = 0; lane_id < wi_delta_m; ++lane_id) + { + const size_t gl_j = j + lane_id; + + if (gl_i < n && gl_j < m) { + sycl::atomic_ref< + resT, sycl::memory_order::relaxed, + sycl::memory_scope::device, + sycl::access::address_space::global_space> + aout(res[res_indexer(gl_i * c_st0 + gl_j * c_st1)]); + + aout += local_sum[lane_id]; + } + } } } } @@ -972,21 +849,32 @@ class GemmFunctorThreadK size_t j = m_groups * block_j; size_t s = block_s * delta_k * n_wi + local_s; + using accV_t = typename LocAccT::value_type; + constexpr resT identity_ = resT(0); if (local_i == 0) { for (size_t q = 0; q < n_wi * delta_k; q += delta_k) { size_t sq = s + q; size_t sqmj = sq * m + j; - sycl::vec local_B_vec; -#pragma unroll - for (size_t vec_idx = 0; vec_idx < m_groups; ++vec_idx) { - local_B_vec[vec_idx] = - (sq < k && j + vec_idx < m) - ? static_cast( - rhs[rhs_indexer(sqmj + vec_idx)]) + + if constexpr (m_groups == 1 && std::is_same_v) { + local_B_block[local_s + q] = + (sq < k && j < m) + ? static_cast(rhs[rhs_indexer(sqmj)]) : identity_; } - local_B_block[local_s + q] = local_B_vec; + else { + accV_t local_B_vec; +#pragma unroll + for (size_t vec_idx = 0; vec_idx < m_groups; ++vec_idx) { + local_B_vec[vec_idx] = + (sq < k && j + vec_idx < m) + ? static_cast( + rhs[rhs_indexer(sqmj + vec_idx)]) + : identity_; + } + local_B_block[local_s + q] = local_B_vec; + } } } @@ -995,8 +883,8 @@ class GemmFunctorThreadK size_t t_shift = block_s * delta_k * n_wi; size_t global_s_offset = i * k + t_shift; - sycl::vec private_sum(identity_); - constexpr sycl::vec vec_identity_(identity_); + accV_t private_sum(identity_); + constexpr accV_t vec_identity_(identity_); for (size_t t = local_s; t < local_B_block.size(); t += delta_k) { private_sum += ((i < n) && (t + t_shift < k)) ? (static_cast( @@ -1011,7 +899,7 @@ class GemmFunctorThreadK it.barrier(sycl::access::fence_space::local_space); if (local_s == 0 && i < n) { - sycl::vec local_sum(workspace[workspace_i_shift]); + accV_t local_sum(workspace[workspace_i_shift]); for (size_t t = 1; t < delta_k; ++t) { local_sum += workspace[workspace_i_shift + t]; } @@ -1021,169 +909,53 @@ class GemmFunctorThreadK sycl::access::address_space::global_space> aout0(res[res_indexer(i * m + j)]); - aout0 += local_sum[0]; + if constexpr (m_groups == 1 && std::is_same_v) { + aout0 += local_sum; + } + else { + aout0 += local_sum[0]; #pragma unroll - for (size_t vec_id = 1; vec_id < m_groups; ++vec_id) { - if (j + vec_id < m) { - sycl::atomic_ref - aout1(res[res_indexer(i * m + j + vec_id)]); - - aout1 += local_sum[vec_id]; + for (size_t vec_id = 1; vec_id < m_groups; ++vec_id) { + if (j + vec_id < m) { + sycl::atomic_ref< + resT, sycl::memory_order::relaxed, + sycl::memory_scope::device, + sycl::access::address_space::global_space> + aout1(res[res_indexer(i * m + j + vec_id)]); + + aout1 += local_sum[vec_id]; + } } } } } }; -// specialization for m_groups == 1 -template -class GemmFunctorThreadK -{ -private: - const lhsT *lhs = nullptr; - const rhsT *rhs = nullptr; - resT *res = nullptr; - LocAccT workspace; - LocAccT local_B_block; - size_t n = 0; - size_t n_blocks = 0; - size_t delta_n = 0; - size_t k = 0; - size_t k_blocks = 0; - size_t delta_k = 0; - size_t n_wi = 0; - size_t m = 0; - OuterInnerDimsIndexerT lhs_indexer; - OuterInnerDimsIndexerT rhs_indexer; - OuterInnerDimsIndexerT res_indexer; - -public: - GemmFunctorThreadK(const lhsT *lhs_, - const rhsT *rhs_, - resT *res_, - LocAccT workspace_, - LocAccT local_B_block_, - size_t n_, - size_t n_blocks_, - size_t delta_n_, - size_t k_, - size_t k_blocks_, - size_t delta_k_, - size_t n_wi_, - size_t m_, - OuterInnerDimsIndexerT lhs_indexer_, - OuterInnerDimsIndexerT rhs_indexer_, - OuterInnerDimsIndexerT res_indexer_) - : lhs(lhs_), rhs(rhs_), res(res_), workspace(workspace_), - local_B_block(local_B_block_), n(n_), n_blocks(n_blocks_), - delta_n(delta_n_), k(k_), k_blocks(k_blocks_), delta_k(delta_k_), - n_wi(n_wi_), m(m_), lhs_indexer(lhs_indexer_), - rhs_indexer(rhs_indexer_), res_indexer(res_indexer_) - { - } +template class gemm_init_krn; - void operator()(sycl::nd_item<1> it) const - { - size_t gr_id = it.get_group_linear_id(); - size_t lid = it.get_local_linear_id(); +template +class gemm_k_krn; - // lift gr_id -> (block_i, block_j, block_s) - // block_i moves fastest, then block_s, then block_j +template +class gemm_nm_krn; - size_t block_j = - gr_id / (n_blocks * k_blocks); // 0 <= block_j < m_blocks - size_t block_r = - gr_id - block_j * (n_blocks * - k_blocks); // 0 <= block_r < n_blocks * k_blocks - size_t block_s = block_r / n_blocks; // 0 <= block_s < k_blocks - size_t block_i = - block_r - block_s * n_blocks; // 0 <= block_i < n_blocks - - size_t local_i = lid / (delta_k); // 0 <= local_i < delta_n - size_t local_s = lid - local_i * (delta_k); // 0 <= local_s < delta_k - - size_t i = block_i * delta_n + local_i; - size_t j = block_j; - size_t s = block_s * delta_k * n_wi + local_s; - - constexpr resT identity_ = resT(0); - if (local_i == 0) { - for (size_t q = 0; q < n_wi * delta_k; q += delta_k) { - size_t sq = s + q; - size_t sqmj = sq * m + j; - local_B_block[local_s + q] = - (sq < k && j < m) - ? static_cast(rhs[rhs_indexer(sqmj)]) - : identity_; - } - } - - it.barrier(sycl::access::fence_space::local_space); - - size_t t_shift = block_s * delta_k * n_wi; - size_t global_s_offset = i * k + t_shift; - - resT private_sum(identity_); - for (size_t t = local_s; t < local_B_block.size(); t += delta_k) { - private_sum += ((i < n) && (t + t_shift < k)) - ? (static_cast( - lhs[lhs_indexer(global_s_offset + t)]) * - local_B_block[t]) - : identity_; - } - - size_t workspace_i_shift = local_i * delta_k; - workspace[workspace_i_shift + local_s] = private_sum; - - it.barrier(sycl::access::fence_space::local_space); - - if (local_s == 0 && i < n) { - resT local_sum(workspace[workspace_i_shift]); - for (size_t t = 1; t < delta_k; ++t) { - local_sum += workspace[workspace_i_shift + t]; - } - - sycl::atomic_ref - aout(res[res_indexer(i * m + j)]); - - aout += local_sum; - } - } -}; - -template class gemm_init_krn; - -template -class gemm_k_krn; - -template -class gemm_nm_krn; - -typedef sycl::event (*gemm_impl_fn_ptr_t)( - sycl::queue &, - const char *, // lhs - const char *, // rhs - char *, // res - size_t, // lhs_outer_nelems (n) - size_t, // inner_nelems (k) - size_t, // rhs_outer_nelems (m) - int, // inner nd - int, // lhs outer nd - const ssize_t *, // lhs shape and strides - int, // rhs outer nd - const ssize_t *, // rhs shape and strides - int, // res outer nd - const ssize_t *, // res shape and strides - std::vector const &); +typedef sycl::event (*gemm_impl_fn_ptr_t)( + sycl::queue &, + const char *, // lhs + const char *, // rhs + char *, // res + size_t, // lhs_outer_nelems (n) + size_t, // inner_nelems (k) + size_t, // rhs_outer_nelems (m) + int, // inner nd + int, // lhs outer nd + const ssize_t *, // lhs shape and strides + int, // rhs outer nd + const ssize_t *, // rhs shape and strides + int, // res outer nd + const ssize_t *, // res shape and strides + std::vector const &); template sycl::event gemm_impl(sycl::queue &exec_q, @@ -1633,25 +1405,37 @@ class GemmNoAtomicFunctorThreadNM : resT(0); } + using slmB_t = typename LocAccT2::value_type; + for (size_t vid = lid; vid < local_B_block.size(); vid += lws) { size_t v_j = vid / wi_delta_k; // 0<= v_i < wg_delta_m size_t v_s = vid - v_j * wi_delta_k; // 0<= v_s < wi_delta_k - size_t g_j0 = j + v_j * wi_delta_m; + size_t g_j = j + v_j * wi_delta_m; size_t g_s = s + v_s; - sycl::vec vec{}; -#pragma unroll - for (std::uint8_t lane_id = 0; lane_id < wi_delta_m; ++lane_id) { - size_t g_j = g_j0 + lane_id; - vec[lane_id] = + if constexpr (wi_delta_m == 1 && std::is_same_v) { + local_B_block[vid] = (g_j < m && g_s < k) ? static_cast( rhs[rhs_indexer(g_s * b_st0 + g_j * b_st1)]) : resT(0); } + else { + slmB_t vec{}; +#pragma unroll + for (std::uint8_t lane_id = 0; lane_id < wi_delta_m; ++lane_id) + { + size_t g_j1 = g_j + lane_id; + vec[lane_id] = + (g_j1 < m && g_s < k) + ? static_cast( + rhs[rhs_indexer(g_s * b_st0 + g_j1 * b_st1)]) + : resT(0); + } - local_B_block[vid] = vec; + local_B_block[vid] = vec; + } } it.barrier(sycl::access::fence_space::local_space); @@ -1659,15 +1443,15 @@ class GemmNoAtomicFunctorThreadNM i += local_i * wi_delta_n; j += local_j * wi_delta_m; - size_t a_offset = local_i * wi_delta_k * wi_delta_n; - size_t b_offset = local_j * wi_delta_k; + const size_t a_offset = local_i * wi_delta_k * wi_delta_n; + const size_t b_offset = local_j * wi_delta_k; constexpr resT identity_(0); for (std::uint8_t private_i = 0; private_i < wi_delta_n; ++private_i) { - size_t a_pr_offset = private_i * wi_delta_k; + const size_t a_pr_offset = private_i * wi_delta_k; - sycl::vec local_sum(identity_); + slmB_t local_sum(identity_); for (size_t private_s = 0; private_s < wi_delta_k; ++private_s) { local_sum = local_sum + (local_A_block[a_offset + a_pr_offset + private_s] * @@ -1676,167 +1460,24 @@ class GemmNoAtomicFunctorThreadNM size_t gl_i = i + private_i; -#pragma unroll - for (std::uint8_t lane_id = 0; lane_id < wi_delta_m; ++lane_id) { - size_t gl_j = j + lane_id; - + if constexpr (wi_delta_m == 1 && std::is_same_v) { + const size_t gl_j = j; if (gl_i < n && gl_j < m) { res[res_indexer(gl_i * c_st0 + gl_j * c_st1 + - block_s * n * m)] = local_sum[lane_id]; + block_s * n * m)] = local_sum; } } - } - } -}; - -template -class GemmNoAtomicFunctorThreadNM -{ -private: - const lhsT *lhs = nullptr; - const rhsT *rhs = nullptr; - resT *res = nullptr; - LocAccT1 local_A_block; - LocAccT2 local_B_block; - size_t n = 0; - size_t wg_delta_n = 0; - size_t k = 0; - size_t k_blocks = 0; - size_t wi_delta_k = 0; - size_t m = 0; - size_t m_blocks = 0; - size_t wg_delta_m = 0; - OuterInnerDimsIndexerT lhs_indexer; - OuterInnerDimsIndexerT rhs_indexer; - ResIndexerT res_indexer; - -public: - GemmNoAtomicFunctorThreadNM(const lhsT *lhs_, - const rhsT *rhs_, - resT *res_, - LocAccT1 local_A_block_, - LocAccT2 local_B_block_, - size_t n_, - size_t wg_delta_n_, - size_t k_, - size_t k_blocks_, - size_t wi_delta_k_, - size_t m_, - size_t m_blocks_, - size_t wg_delta_m_, - OuterInnerDimsIndexerT lhs_indexer_, - OuterInnerDimsIndexerT rhs_indexer_, - ResIndexerT res_indexer_) - : lhs(lhs_), rhs(rhs_), res(res_), local_A_block(local_A_block_), - local_B_block(local_B_block_), n(n_), wg_delta_n(wg_delta_n_), k(k_), - k_blocks(k_blocks_), wi_delta_k(wi_delta_k_), m(m_), - m_blocks(m_blocks_), wg_delta_m(wg_delta_m_), - lhs_indexer(lhs_indexer_), rhs_indexer(rhs_indexer_), - res_indexer(res_indexer_) - { - } - - void operator()(sycl::nd_item<1> it) const - { - size_t gr_id = it.get_group_linear_id(); - // lift group_id to (block_i, block_j, block_s), - // 0 <= block_i < n_blocks, 0 <= block_j < m_blocks, 0 <= block_s < - // k_blocks - size_t block_i = gr_id / (m_blocks * k_blocks); - size_t block_r = gr_id - block_i * (m_blocks * k_blocks); - size_t block_j = block_r / k_blocks; - size_t block_s = block_r - block_j * k_blocks; - - size_t lid = it.get_local_linear_id(); - size_t local_i = lid / wg_delta_m; // 0<= local_i < wg_delta_n - size_t local_j = lid - local_i * wg_delta_m; // 0<= local_j < wg_delta_m - - // load A block and B blocks into SLM - - size_t i = block_i * wi_delta_n * wg_delta_n; - size_t j = block_j * wg_delta_m; - size_t s = block_s * wi_delta_k; - - const std::int64_t a_st0 = k; - const std::int64_t a_st1 = 1; - - const std::int64_t b_st0 = m; - const std::int64_t b_st1 = 1; - - const std::int64_t c_st0 = m; - const std::int64_t c_st1 = 1; - - size_t lws = it.get_local_range(0); - - for (size_t vid = lid; vid < local_A_block.size(); vid += lws) { - size_t v_i = vid / wi_delta_k; // 0<= v_i < wg_delta_n * wi_delta_n - size_t v_s = vid - v_i * wi_delta_k; // 0<= v_s < wi_delta_k - - size_t g_i = i + v_i; - size_t g_s = s + v_s; - - local_A_block[vid] = - (g_i < n && g_s < k) - ? static_cast( - lhs[lhs_indexer(g_i * a_st0 + g_s * a_st1)]) - : resT(0); - } - - for (size_t vid = lid; vid < local_B_block.size(); vid += lws) { - size_t v_j = vid / wi_delta_k; // 0<= v_i < wg_delta_m - size_t v_s = vid - v_j * wi_delta_k; // 0<= v_s < wi_delta_k - - size_t g_j0 = j + v_j; - size_t g_s = s + v_s; - - resT val = (g_j0 < m && g_s < k) - ? static_cast( - rhs[rhs_indexer(g_s * b_st0 + g_j0 * b_st1)]) - : resT(0); - - local_B_block[vid] = val; - } - - it.barrier(sycl::access::fence_space::local_space); - - i += local_i * wi_delta_n; - j += local_j; - - size_t a_offset = local_i * wi_delta_k * wi_delta_n; - size_t b_offset = local_j * wi_delta_k; - - constexpr resT identity_(0); - - for (std::uint8_t private_i = 0; private_i < wi_delta_n; ++private_i) { - size_t a_pr_offset = private_i * wi_delta_k; - - resT local_sum(identity_); - for (size_t private_s = 0; private_s < wi_delta_k; ++private_s) { - local_sum = local_sum + - (local_A_block[a_offset + a_pr_offset + private_s] * - local_B_block[b_offset + private_s]); - } - - size_t gl_i = i + private_i; - - if (gl_i < n && j < m) { - res[res_indexer(gl_i * c_st0 + j * c_st1 + block_s * n * m)] = - local_sum; + else { +#pragma unroll + for (std::uint8_t lane_id = 0; lane_id < wi_delta_m; ++lane_id) + { + const size_t gl_j = j + lane_id; + + if (gl_i < n && gl_j < m) { + res[res_indexer(gl_i * c_st0 + gl_j * c_st1 + + block_s * n * m)] = local_sum[lane_id]; + } + } } } } @@ -1918,21 +1559,33 @@ class GemmNoAtomicFunctorThreadK size_t j = m_groups * block_j; size_t s = block_s * delta_k * n_wi + local_s; + using accV_t = typename LocAccT::value_type; + constexpr resT identity_ = resT(0); if (local_i == 0) { for (size_t q = 0; q < n_wi * delta_k; q += delta_k) { size_t sq = s + q; size_t sqmj = sq * m + j; - sycl::vec local_B_vec; -#pragma unroll - for (size_t vec_idx = 0; vec_idx < m_groups; ++vec_idx) { - local_B_vec[vec_idx] = - (sq < k && j + vec_idx < m) - ? static_cast( - rhs[rhs_indexer(sqmj + vec_idx)]) + + if constexpr (m_groups == 1 && std::is_same_v) { + local_B_block[local_s + q] = + (sq < k && j < m) + ? static_cast(rhs[rhs_indexer(sqmj)]) : identity_; + ; + } + else { + accV_t local_B_vec; +#pragma unroll + for (size_t vec_idx = 0; vec_idx < m_groups; ++vec_idx) { + local_B_vec[vec_idx] = + (sq < k && j + vec_idx < m) + ? static_cast( + rhs[rhs_indexer(sqmj + vec_idx)]) + : identity_; + } + local_B_block[local_s + q] = local_B_vec; } - local_B_block[local_s + q] = local_B_vec; } } @@ -1941,8 +1594,8 @@ class GemmNoAtomicFunctorThreadK size_t t_shift = block_s * delta_k * n_wi; size_t global_s_offset = i * k + t_shift; - sycl::vec private_sum(identity_); - constexpr sycl::vec vec_identity_(identity_); + accV_t private_sum(identity_); + constexpr accV_t vec_identity_(identity_); for (size_t t = local_s; t < local_B_block.size(); t += delta_k) { private_sum += ((i < n) && (t + t_shift < k)) ? (static_cast( @@ -1957,155 +1610,39 @@ class GemmNoAtomicFunctorThreadK it.barrier(sycl::access::fence_space::local_space); if (local_s == 0 && i < n) { - sycl::vec local_sum(workspace[workspace_i_shift]); + accV_t local_sum(workspace[workspace_i_shift]); for (size_t t = 1; t < delta_k; ++t) { local_sum += workspace[workspace_i_shift + t]; } const size_t res_offset = (block_s * n * m); - res[res_indexer(i * m + j) + res_offset] = local_sum[0]; + + if constexpr (m_groups == 1 && std::is_same_v) { + res[res_indexer(i * m + j) + res_offset] = local_sum; + } + else { + static_assert(m_groups >= 1); + res[res_indexer(i * m + j) + res_offset] = local_sum[0]; #pragma unroll - for (size_t vec_id = 1; vec_id < m_groups; ++vec_id) { - if (j + vec_id < m) { - res[res_indexer(i * m + j + vec_id) + res_offset] = - local_sum[vec_id]; + for (size_t vec_id = 1; vec_id < m_groups; ++vec_id) { + if (j + vec_id < m) { + res[res_indexer(i * m + j + vec_id) + res_offset] = + local_sum[vec_id]; + } } } } } }; -template -class GemmNoAtomicFunctorThreadK -{ -private: - const lhsT *lhs = nullptr; - const rhsT *rhs = nullptr; - resT *res = nullptr; - LocAccT workspace; - LocAccT local_B_block; - size_t n = 0; - size_t n_blocks = 0; - size_t delta_n = 0; - size_t k = 0; - size_t k_blocks = 0; - size_t delta_k = 0; - size_t n_wi = 0; - size_t m = 0; - OuterInnerDimsIndexerT lhs_indexer; - OuterInnerDimsIndexerT rhs_indexer; - ResIndexerT res_indexer; - -public: - GemmNoAtomicFunctorThreadK(const lhsT *lhs_, - const rhsT *rhs_, - resT *res_, - LocAccT workspace_, - LocAccT local_B_block_, - size_t n_, - size_t n_blocks_, - size_t delta_n_, - size_t k_, - size_t k_blocks_, - size_t delta_k_, - size_t n_wi_, - size_t m_, - OuterInnerDimsIndexerT lhs_indexer_, - OuterInnerDimsIndexerT rhs_indexer_, - ResIndexerT res_indexer_) - : lhs(lhs_), rhs(rhs_), res(res_), workspace(workspace_), - local_B_block(local_B_block_), n(n_), n_blocks(n_blocks_), - delta_n(delta_n_), k(k_), k_blocks(k_blocks_), delta_k(delta_k_), - n_wi(n_wi_), m(m_), lhs_indexer(lhs_indexer_), - rhs_indexer(rhs_indexer_), res_indexer(res_indexer_) - { - } - - void operator()(sycl::nd_item<1> it) const - { - size_t gr_id = it.get_group_linear_id(); - size_t lid = it.get_local_linear_id(); - - // lift gr_id -> (block_i, block_j, block_s) - // block_i moves fastest, then block_s, then block_j - - size_t block_j = - gr_id / (n_blocks * k_blocks); // 0 <= block_j < m_blocks - size_t block_r = - gr_id - block_j * (n_blocks * - k_blocks); // 0 <= block_r < n_blocks * k_blocks - size_t block_s = block_r / n_blocks; // 0 <= block_s < k_blocks - size_t block_i = - block_r - block_s * n_blocks; // 0 <= block_i < n_blocks - - size_t local_i = lid / (delta_k); // 0 <= local_i < delta_n - size_t local_s = lid - local_i * (delta_k); // 0 <= local_s < delta_k - - size_t i = block_i * delta_n + local_i; - size_t j = block_j; - size_t s = block_s * delta_k * n_wi + local_s; - - constexpr resT identity_ = resT(0); - if (local_i == 0) { - for (size_t q = 0; q < n_wi * delta_k; q += delta_k) { - size_t sq = s + q; - size_t sqmj = sq * m + j; - local_B_block[local_s + q] = - (sq < k && j < m) - ? static_cast(rhs[rhs_indexer(sqmj)]) - : identity_; - } - } - - it.barrier(sycl::access::fence_space::local_space); - - size_t t_shift = block_s * delta_k * n_wi; - size_t global_s_offset = i * k + t_shift; - - resT private_sum(identity_); - for (size_t t = local_s; t < local_B_block.size(); t += delta_k) { - private_sum += ((i < n) && (t + t_shift < k)) - ? (static_cast( - lhs[lhs_indexer(global_s_offset + t)]) * - local_B_block[t]) - : identity_; - } - - size_t workspace_i_shift = local_i * delta_k; - workspace[workspace_i_shift + local_s] = private_sum; - - it.barrier(sycl::access::fence_space::local_space); - - if (local_s == 0 && i < n) { - resT local_sum(workspace[workspace_i_shift]); - for (size_t t = 1; t < delta_k; ++t) { - local_sum += workspace[workspace_i_shift + t]; - } - - res[res_indexer(i * m + j) + (block_s * n * m)] = local_sum; - } - } -}; - -template -class gemm_tree_nm_krn; +template +class gemm_tree_nm_krn; template vec{}; -#pragma unroll - for (std::uint8_t lane_id = 0; lane_id < wi_delta_m; ++lane_id) { - size_t g_j = g_j0 + lane_id; - vec[lane_id] = + if constexpr (wi_delta_m == 1 && std::is_same_v) { + local_B_block[vid] = (g_j < m && g_s < k) ? static_cast( rhs[rhs_offset + rhs_indexer(g_s * b_st0 + g_j * b_st1)]) : resT(0); } + else { + slmB_t vec{}; +#pragma unroll + for (std::uint8_t lane_id = 0; lane_id < wi_delta_m; ++lane_id) + { + const size_t g_j1 = g_j + lane_id; + vec[lane_id] = + (g_j1 < m && g_s < k) + ? static_cast( + rhs[rhs_offset + + rhs_indexer(g_s * b_st0 + g_j1 * b_st1)]) + : resT(0); + } - local_B_block[vid] = vec; + local_B_block[vid] = vec; + } } it.barrier(sycl::access::fence_space::local_space); @@ -3685,7 +3235,7 @@ class GemmBatchFunctorThreadNM for (std::uint8_t private_i = 0; private_i < wi_delta_n; ++private_i) { size_t a_pr_offset = private_i * wi_delta_k; - sycl::vec local_sum(identity_); + slmB_t local_sum(identity_); for (size_t private_s = 0; private_s < wi_delta_k; ++private_s) { local_sum = local_sum + (local_A_block[a_offset + a_pr_offset + private_s] * @@ -3694,10 +3244,8 @@ class GemmBatchFunctorThreadNM size_t gl_i = i + private_i; -#pragma unroll - for (std::uint8_t lane_id = 0; lane_id < wi_delta_m; ++lane_id) { - size_t gl_j = j + lane_id; - + if constexpr (wi_delta_m == 1 && std::is_same_v) { + const size_t gl_j = j; if (gl_i < n && gl_j < m) { sycl::atomic_ref -class GemmBatchFunctorThreadNM -{ -private: - const lhsT *lhs = nullptr; - const rhsT *rhs = nullptr; - resT *res = nullptr; - LocAccT1 local_A_block; - LocAccT2 local_B_block; - size_t n = 0; - size_t wg_delta_n = 0; - size_t k = 0; - size_t k_blocks = 0; - size_t wi_delta_k = 0; - size_t m = 0; - size_t m_blocks = 0; - size_t wg_delta_m = 0; - size_t batch_nelems; - BatchDimsIndexerT batch_indexer; - OuterInnerDimsIndexerT lhs_indexer; - OuterInnerDimsIndexerT rhs_indexer; - OuterInnerDimsIndexerT res_indexer; - -public: - GemmBatchFunctorThreadNM(const lhsT *lhs_, - const rhsT *rhs_, - resT *res_, - LocAccT1 local_A_block_, - LocAccT2 local_B_block_, - size_t n_, - size_t wg_delta_n_, - size_t k_, - size_t k_blocks_, - size_t wi_delta_k_, - size_t m_, - size_t m_blocks_, - size_t wg_delta_m_, - size_t batch_nelems_, - BatchDimsIndexerT batch_indexer_, - OuterInnerDimsIndexerT lhs_indexer_, - OuterInnerDimsIndexerT rhs_indexer_, - OuterInnerDimsIndexerT res_indexer_) - : lhs(lhs_), rhs(rhs_), res(res_), local_A_block(local_A_block_), - local_B_block(local_B_block_), n(n_), wg_delta_n(wg_delta_n_), k(k_), - k_blocks(k_blocks_), wi_delta_k(wi_delta_k_), m(m_), - m_blocks(m_blocks_), wg_delta_m(wg_delta_m_), - batch_nelems(batch_nelems_), batch_indexer(batch_indexer_), - lhs_indexer(lhs_indexer_), rhs_indexer(rhs_indexer_), - res_indexer(res_indexer_) - { - } - - void operator()(sycl::nd_item<1> it) const - { - const size_t n_groups_per_batch = it.get_group_range(0) / batch_nelems; - const size_t m_id = it.get_group_linear_id() / n_groups_per_batch; - const size_t gr_id = - it.get_group_linear_id() - m_id * n_groups_per_batch; - - const auto &three_offsets_ = batch_indexer(static_cast(m_id)); - - // lift group_id to (block_i, block_j, block_s), - // 0 <= block_i < n_blocks, 0 <= block_j < m_blocks, 0 <= block_s - // < k_blocks - - const auto &lhs_offset = three_offsets_.get_first_offset(); - const auto &rhs_offset = three_offsets_.get_second_offset(); - const auto &res_offset = three_offsets_.get_third_offset(); - - size_t block_i = gr_id / (m_blocks * k_blocks); - size_t block_r = gr_id - block_i * (m_blocks * k_blocks); - size_t block_j = block_r / k_blocks; - size_t block_s = block_r - block_j * k_blocks; - - size_t lid = it.get_local_linear_id(); - size_t local_i = lid / wg_delta_m; // 0<= local_i < wg_delta_n - size_t local_j = lid - local_i * wg_delta_m; // 0<= local_j < wg_delta_m - - // load A block and B blocks into SLM - - size_t i = block_i * wi_delta_n * wg_delta_n; - size_t j = block_j * wg_delta_m; - size_t s = block_s * wi_delta_k; - - const std::int64_t a_st0 = k; - const std::int64_t a_st1 = 1; - - const std::int64_t b_st0 = m; - const std::int64_t b_st1 = 1; - - const std::int64_t c_st0 = m; - const std::int64_t c_st1 = 1; - - size_t lws = it.get_local_range(0); - - for (size_t vid = lid; vid < local_A_block.size(); vid += lws) { - size_t v_i = vid / wi_delta_k; // 0<= v_i < wg_delta_n * wi_delta_n - size_t v_s = vid - v_i * wi_delta_k; // 0<= v_s < wi_delta_k - - size_t g_i = i + v_i; - size_t g_s = s + v_s; - - local_A_block[vid] = - (g_i < n && g_s < k) - ? static_cast( - lhs[lhs_offset + - lhs_indexer(g_i * a_st0 + g_s * a_st1)]) - : resT(0); - } - - for (size_t vid = lid; vid < local_B_block.size(); vid += lws) { - size_t v_j = vid / wi_delta_k; // 0<= v_i < wg_delta_m - size_t v_s = vid - v_j * wi_delta_k; // 0<= v_s < wi_delta_k - - size_t g_j0 = j + v_j; - size_t g_s = s + v_s; - - resT val = (g_j0 < m && g_s < k) - ? static_cast( - rhs[rhs_offset + - rhs_indexer(g_s * b_st0 + g_j0 * b_st1)]) - : resT(0); - - local_B_block[vid] = val; - } - - it.barrier(sycl::access::fence_space::local_space); - - i += local_i * wi_delta_n; - j += local_j; - - size_t a_offset = local_i * wi_delta_k * wi_delta_n; - size_t b_offset = local_j * wi_delta_k; - - constexpr resT identity_(0); - for (std::uint8_t private_i = 0; private_i < wi_delta_n; ++private_i) { - size_t a_pr_offset = private_i * wi_delta_k; - - resT local_sum(identity_); - for (size_t private_s = 0; private_s < wi_delta_k; ++private_s) { - local_sum = local_sum + - (local_A_block[a_offset + a_pr_offset + private_s] * - local_B_block[b_offset + private_s]); - } - - size_t gl_i = i + private_i; - - if (gl_i < n && j < m) { - sycl::atomic_ref - aout(res[res_offset + - res_indexer(gl_i * c_st0 + j * c_st1)]); - - aout += local_sum; + else { +#pragma unroll + for (std::uint8_t lane_id = 0; lane_id < wi_delta_m; ++lane_id) + { + const size_t gl_j = j + lane_id; + + if (gl_i < n && gl_j < m) { + sycl::atomic_ref< + resT, sycl::memory_order::relaxed, + sycl::memory_scope::device, + sycl::access::address_space::global_space> + aout(res[res_offset + + res_indexer(gl_i * c_st0 + gl_j * c_st1)]); + + aout += local_sum[lane_id]; + } + } } } } @@ -3955,7 +3346,7 @@ class GemmBatchFunctorThreadK const size_t m_id = it.get_group_linear_id() / n_groups_per_batch; const size_t gr_id = it.get_group_linear_id() - m_id * n_groups_per_batch; - size_t lid = it.get_local_linear_id(); + const size_t lid = it.get_local_linear_id(); const auto &three_offsets_ = batch_indexer(static_cast(m_id)); @@ -3966,37 +3357,51 @@ class GemmBatchFunctorThreadK // lift gr_id -> (block_i, block_j, block_s) // block_i moves fastest, then block_s, then block_j - size_t block_j = + const size_t block_j = gr_id / (n_blocks * k_blocks); // 0 <= block_j < m_blocks - size_t block_r = + const size_t block_r = gr_id - block_j * (n_blocks * k_blocks); // 0 <= block_r < n_blocks * k_blocks - size_t block_s = block_r / n_blocks; // 0 <= block_s < k_blocks - size_t block_i = + const size_t block_s = block_r / n_blocks; // 0 <= block_s < k_blocks + const size_t block_i = block_r - block_s * n_blocks; // 0 <= block_i < n_blocks - size_t local_i = lid / (delta_k); // 0 <= local_i < delta_n - size_t local_s = lid - local_i * (delta_k); // 0 <= local_s < delta_k + const size_t local_i = lid / (delta_k); // 0 <= local_i < delta_n + const size_t local_s = + lid - local_i * (delta_k); // 0 <= local_s < delta_k size_t i = block_i * delta_n + local_i; size_t j = m_groups * block_j; size_t s = block_s * delta_k * n_wi + local_s; + using accV_t = typename LocAccT::value_type; + constexpr resT identity_ = resT(0); if (local_i == 0) { for (size_t q = 0; q < n_wi * delta_k; q += delta_k) { - size_t sq = s + q; - size_t sqmj = sq * m + j; - sycl::vec local_B_vec; -#pragma unroll - for (size_t vec_idx = 0; vec_idx < m_groups; ++vec_idx) { - local_B_vec[vec_idx] = - (sq < k && j + vec_idx < m) + const size_t sq = s + q; + const size_t sqmj = sq * m + j; + + if constexpr (m_groups == 1 && std::is_same_v) { + local_B_block[local_s + q] = + (sq < k && j < m) ? static_cast( - rhs[rhs_offset + rhs_indexer(sqmj + vec_idx)]) + rhs[rhs_offset + rhs_indexer(sqmj)]) : identity_; } - local_B_block[local_s + q] = local_B_vec; + else { + accV_t local_B_vec; +#pragma unroll + for (size_t vec_idx = 0; vec_idx < m_groups; ++vec_idx) { + local_B_vec[vec_idx] = + (sq < k && j + vec_idx < m) + ? static_cast( + rhs[rhs_offset + + rhs_indexer(sqmj + vec_idx)]) + : identity_; + } + local_B_block[local_s + q] = local_B_vec; + } } } @@ -4005,8 +3410,8 @@ class GemmBatchFunctorThreadK size_t t_shift = block_s * delta_k * n_wi; size_t global_s_offset = i * k + t_shift; - sycl::vec private_sum(identity_); - constexpr sycl::vec vec_identity_(identity_); + accV_t private_sum(identity_); + constexpr accV_t vec_identity_(identity_); for (size_t t = local_s; t < local_B_block.size(); t += delta_k) { private_sum += ((i < n) && (t + t_shift < k)) @@ -4022,7 +3427,7 @@ class GemmBatchFunctorThreadK it.barrier(sycl::access::fence_space::local_space); if (local_s == 0 && i < n) { - sycl::vec local_sum(workspace[workspace_i_shift]); + accV_t local_sum(workspace[workspace_i_shift]); for (size_t t = 1; t < delta_k; ++t) { local_sum += workspace[workspace_i_shift + t]; } @@ -4032,173 +3437,30 @@ class GemmBatchFunctorThreadK sycl::access::address_space::global_space> aout0(res[res_offset + res_indexer(i * m + j)]); - aout0 += local_sum[0]; + if constexpr (m_groups == 1 && std::is_same_v) { + aout0 += local_sum; + } + else { + aout0 += local_sum[0]; #pragma unroll - for (size_t vec_id = 1; vec_id < m_groups; ++vec_id) { - if (j + vec_id < m) { - sycl::atomic_ref - aout1( - res[res_offset + res_indexer(i * m + j + vec_id)]); - - aout1 += local_sum[vec_id]; + for (size_t vec_id = 1; vec_id < m_groups; ++vec_id) { + if (j + vec_id < m) { + sycl::atomic_ref< + resT, sycl::memory_order::relaxed, + sycl::memory_scope::device, + sycl::access::address_space::global_space> + aout1(res[res_offset + + res_indexer(i * m + j + vec_id)]); + + aout1 += local_sum[vec_id]; + } } } } } }; -template -class GemmBatchFunctorThreadK -{ -private: - const lhsT *lhs = nullptr; - const rhsT *rhs = nullptr; - resT *res = nullptr; - LocAccT workspace; - LocAccT local_B_block; - size_t n = 0; - size_t n_blocks = 0; - size_t delta_n = 0; - size_t k = 0; - size_t k_blocks = 0; - size_t delta_k = 0; - size_t n_wi = 0; - size_t m = 0; - size_t batch_nelems = 0; - BatchDimsIndexerT batch_indexer; - OuterInnerDimsIndexerT lhs_indexer; - OuterInnerDimsIndexerT rhs_indexer; - OuterInnerDimsIndexerT res_indexer; - -public: - GemmBatchFunctorThreadK(const lhsT *lhs_, - const rhsT *rhs_, - resT *res_, - LocAccT workspace_, - LocAccT local_B_block_, - size_t n_, - size_t n_blocks_, - size_t delta_n_, - size_t k_, - size_t k_blocks_, - size_t delta_k_, - size_t n_wi_, - size_t m_, - size_t batch_nelems_, - BatchDimsIndexerT batch_indexer_, - OuterInnerDimsIndexerT lhs_indexer_, - OuterInnerDimsIndexerT rhs_indexer_, - OuterInnerDimsIndexerT res_indexer_) - : lhs(lhs_), rhs(rhs_), res(res_), workspace(workspace_), - local_B_block(local_B_block_), n(n_), n_blocks(n_blocks_), - delta_n(delta_n_), k(k_), k_blocks(k_blocks_), delta_k(delta_k_), - n_wi(n_wi_), m(m_), batch_nelems(batch_nelems_), - batch_indexer(batch_indexer_), lhs_indexer(lhs_indexer_), - rhs_indexer(rhs_indexer_), res_indexer(res_indexer_) - { - } - - void operator()(sycl::nd_item<1> it) const - { - // for batching: - // (current matrix in batch) m_id = global_id / (global_range / - // batch_nelems) for lhs, offset = m_id * (n * k) for rhs, offset = - // m_id - // * (k * m) for res, offset = m_id * (n * m) - const size_t n_groups_per_batch = it.get_group_range(0) / batch_nelems; - const size_t m_id = it.get_group_linear_id() / n_groups_per_batch; - const size_t gr_id = - it.get_group_linear_id() - m_id * n_groups_per_batch; - size_t lid = it.get_local_linear_id(); - - const auto &three_offsets_ = batch_indexer(static_cast(m_id)); - - const auto &lhs_offset = three_offsets_.get_first_offset(); - const auto &rhs_offset = three_offsets_.get_second_offset(); - const auto &res_offset = three_offsets_.get_third_offset(); - - // lift gr_id -> (block_i, block_j, block_s) - // block_i moves fastest, then block_s, then block_j - - size_t block_j = - gr_id / (n_blocks * k_blocks); // 0 <= block_j < m_blocks - size_t block_r = - gr_id - block_j * (n_blocks * - k_blocks); // 0 <= block_r < n_blocks * k_blocks - size_t block_s = block_r / n_blocks; // 0 <= block_s < k_blocks - size_t block_i = - block_r - block_s * n_blocks; // 0 <= block_i < n_blocks - - size_t local_i = lid / (delta_k); // 0 <= local_i < delta_n - size_t local_s = lid - local_i * (delta_k); // 0 <= local_s < delta_k - - size_t i = block_i * delta_n + local_i; - size_t j = block_j; - size_t s = block_s * delta_k * n_wi + local_s; - - constexpr resT identity_ = resT(0); - if (local_i == 0) { - for (size_t q = 0; q < n_wi * delta_k; q += delta_k) { - size_t sq = s + q; - size_t sqmj = sq * m + j; - local_B_block[local_s + q] = - (sq < k && j < m) - ? static_cast(rhs[rhs_offset + rhs_indexer(sqmj)]) - : identity_; - ; - } - } - - it.barrier(sycl::access::fence_space::local_space); - - size_t t_shift = block_s * delta_k * n_wi; - size_t global_s_offset = i * k + t_shift; - - resT private_sum(identity_); - for (size_t t = local_s; t < local_B_block.size(); t += delta_k) { - private_sum += - ((i < n) && (t + t_shift < k)) - ? (static_cast( - lhs[lhs_offset + lhs_indexer(global_s_offset + t)]) * - local_B_block[t]) - : identity_; - } - - size_t workspace_i_shift = local_i * delta_k; - workspace[workspace_i_shift + local_s] = private_sum; - - it.barrier(sycl::access::fence_space::local_space); - - if (local_s == 0 && i < n) { - resT local_sum(workspace[workspace_i_shift]); - for (size_t t = 1; t < delta_k; ++t) { - local_sum += workspace[workspace_i_shift + t]; - } - - sycl::atomic_ref - aout(res[res_offset + res_indexer(i * m + j)]); - - aout += local_sum; - } - } -}; - template class gemm_batch_init_krn; template vec{}; -#pragma unroll - for (std::uint8_t lane_id = 0; lane_id < wi_delta_m; ++lane_id) { - size_t g_j = g_j0 + lane_id; - vec[lane_id] = + if constexpr (wi_delta_m == 1 && std::is_same_v) { + local_B_block[vid] = (g_j < m && g_s < k) ? static_cast( rhs[rhs_offset + rhs_indexer(g_s * b_st0 + g_j * b_st1)]) : resT(0); } + else { + slmB_t vec{}; +#pragma unroll + for (std::uint8_t lane_id = 0; lane_id < wi_delta_m; ++lane_id) + { + size_t g_j1 = g_j + lane_id; + vec[lane_id] = + (g_j1 < m && g_s < k) + ? static_cast( + rhs[rhs_offset + + rhs_indexer(g_s * b_st0 + g_j1 * b_st1)]) + : resT(0); + } - local_B_block[vid] = vec; + local_B_block[vid] = vec; + } } it.barrier(sycl::access::fence_space::local_space); @@ -4798,203 +4073,43 @@ class GemmBatchNoAtomicFunctorThreadNM i += local_i * wi_delta_n; j += local_j * wi_delta_m; - size_t a_offset = local_i * wi_delta_k * wi_delta_n; - size_t b_offset = local_j * wi_delta_k; + const size_t a_offset = local_i * wi_delta_k * wi_delta_n; + const size_t b_offset = local_j * wi_delta_k; constexpr resT identity_(0); for (std::uint8_t private_i = 0; private_i < wi_delta_n; ++private_i) { - size_t a_pr_offset = private_i * wi_delta_k; + const size_t a_pr_offset = private_i * wi_delta_k; - sycl::vec local_sum(identity_); + slmB_t local_sum(identity_); for (size_t private_s = 0; private_s < wi_delta_k; ++private_s) { local_sum = local_sum + (local_A_block[a_offset + a_pr_offset + private_s] * local_B_block[b_offset + private_s]); } - size_t gl_i = i + private_i; - -#pragma unroll - for (std::uint8_t lane_id = 0; lane_id < wi_delta_m; ++lane_id) { - size_t gl_j = j + lane_id; + const size_t gl_i = i + private_i; + if constexpr (wi_delta_m == 1 && std::is_same_v) { + const size_t gl_j = j; if (gl_i < n && gl_j < m) { res[res_offset + res_indexer(gl_i * c_st0 + gl_j * c_st1) + - (block_s * n * m * batch_nelems)] = local_sum[lane_id]; + (block_s * n * m * batch_nelems)] = local_sum; } } - } - } -}; - -template -class GemmBatchNoAtomicFunctorThreadNM -{ -private: - const lhsT *lhs = nullptr; - const rhsT *rhs = nullptr; - resT *res = nullptr; - LocAccT1 local_A_block; - LocAccT2 local_B_block; - size_t n = 0; - size_t wg_delta_n = 0; - size_t k = 0; - size_t k_blocks = 0; - size_t wi_delta_k = 0; - size_t m = 0; - size_t m_blocks = 0; - size_t wg_delta_m = 0; - size_t batch_nelems; - BatchDimsIndexerT batch_indexer; - OuterInnerDimsIndexerT lhs_indexer; - OuterInnerDimsIndexerT rhs_indexer; - ResIndexerT res_indexer; - -public: - GemmBatchNoAtomicFunctorThreadNM(const lhsT *lhs_, - const rhsT *rhs_, - resT *res_, - LocAccT1 local_A_block_, - LocAccT2 local_B_block_, - size_t n_, - size_t wg_delta_n_, - size_t k_, - size_t k_blocks_, - size_t wi_delta_k_, - size_t m_, - size_t m_blocks_, - size_t wg_delta_m_, - size_t batch_nelems_, - BatchDimsIndexerT batch_indexer_, - OuterInnerDimsIndexerT lhs_indexer_, - OuterInnerDimsIndexerT rhs_indexer_, - ResIndexerT res_indexer_) - : lhs(lhs_), rhs(rhs_), res(res_), local_A_block(local_A_block_), - local_B_block(local_B_block_), n(n_), wg_delta_n(wg_delta_n_), k(k_), - k_blocks(k_blocks_), wi_delta_k(wi_delta_k_), m(m_), - m_blocks(m_blocks_), wg_delta_m(wg_delta_m_), - batch_nelems(batch_nelems_), batch_indexer(batch_indexer_), - lhs_indexer(lhs_indexer_), rhs_indexer(rhs_indexer_), - res_indexer(res_indexer_) - { - } - - void operator()(sycl::nd_item<1> it) const - { - const size_t n_groups_per_batch = it.get_group_range(0) / batch_nelems; - const size_t m_id = it.get_group_linear_id() / n_groups_per_batch; - const size_t gr_id = - it.get_group_linear_id() - m_id * n_groups_per_batch; - - const auto &three_offsets_ = batch_indexer(static_cast(m_id)); - - // lift group_id to (block_i, block_j, block_s), - // 0 <= block_i < n_blocks, 0 <= block_j < m_blocks, 0 <= block_s - // < k_blocks - - const auto &lhs_offset = three_offsets_.get_first_offset(); - const auto &rhs_offset = three_offsets_.get_second_offset(); - const auto &res_offset = three_offsets_.get_third_offset(); - - size_t block_i = gr_id / (m_blocks * k_blocks); - size_t block_r = gr_id - block_i * (m_blocks * k_blocks); - size_t block_j = block_r / k_blocks; - size_t block_s = block_r - block_j * k_blocks; - - size_t lid = it.get_local_linear_id(); - size_t local_i = lid / wg_delta_m; // 0<= local_i < wg_delta_n - size_t local_j = lid - local_i * wg_delta_m; // 0<= local_j < wg_delta_m - - // load A block and B blocks into SLM - - size_t i = block_i * wi_delta_n * wg_delta_n; - size_t j = block_j * wg_delta_m; - size_t s = block_s * wi_delta_k; - - const std::int64_t a_st0 = k; - const std::int64_t a_st1 = 1; - - const std::int64_t b_st0 = m; - const std::int64_t b_st1 = 1; - - const std::int64_t c_st0 = m; - const std::int64_t c_st1 = 1; - - size_t lws = it.get_local_range(0); - - for (size_t vid = lid; vid < local_A_block.size(); vid += lws) { - size_t v_i = vid / wi_delta_k; // 0<= v_i < wg_delta_n * wi_delta_n - size_t v_s = vid - v_i * wi_delta_k; // 0<= v_s < wi_delta_k - - size_t g_i = i + v_i; - size_t g_s = s + v_s; - - local_A_block[vid] = - (g_i < n && g_s < k) - ? static_cast( - lhs[lhs_offset + - lhs_indexer(g_i * a_st0 + g_s * a_st1)]) - : resT(0); - } - - for (size_t vid = lid; vid < local_B_block.size(); vid += lws) { - size_t v_j = vid / wi_delta_k; // 0<= v_i < wg_delta_m - size_t v_s = vid - v_j * wi_delta_k; // 0<= v_s < wi_delta_k - - size_t g_j0 = j + v_j; - size_t g_s = s + v_s; - - resT val = (g_j0 < m && g_s < k) - ? static_cast( - rhs[rhs_offset + - rhs_indexer(g_s * b_st0 + g_j0 * b_st1)]) - : resT(0); - - local_B_block[vid] = val; - } - - it.barrier(sycl::access::fence_space::local_space); - - i += local_i * wi_delta_n; - j += local_j; - - size_t a_offset = local_i * wi_delta_k * wi_delta_n; - size_t b_offset = local_j * wi_delta_k; - - constexpr resT identity_(0); - for (std::uint8_t private_i = 0; private_i < wi_delta_n; ++private_i) { - size_t a_pr_offset = private_i * wi_delta_k; - - resT local_sum(identity_); - for (size_t private_s = 0; private_s < wi_delta_k; ++private_s) { - local_sum = local_sum + - (local_A_block[a_offset + a_pr_offset + private_s] * - local_B_block[b_offset + private_s]); - } - - size_t gl_i = i + private_i; - - if (gl_i < n && j < m) { - res[res_offset + res_indexer(gl_i * c_st0 + j * c_st1) + - (block_s * n * m * batch_nelems)] = local_sum; + else { +#pragma unroll + for (std::uint8_t lane_id = 0; lane_id < wi_delta_m; ++lane_id) + { + const size_t gl_j = j + lane_id; + + if (gl_i < n && gl_j < m) { + res[res_offset + + res_indexer(gl_i * c_st0 + gl_j * c_st1) + + (block_s * n * m * batch_nelems)] = + local_sum[lane_id]; + } + } } } } @@ -5090,21 +4205,34 @@ class GemmBatchNoAtomicFunctorThreadK size_t j = m_groups * block_j; size_t s = block_s * delta_k * n_wi + local_s; + using accV_t = typename LocAccT::value_type; + constexpr resT identity_ = resT(0); if (local_i == 0) { for (size_t q = 0; q < n_wi * delta_k; q += delta_k) { size_t sq = s + q; size_t sqmj = sq * m + j; - sycl::vec local_B_vec; -#pragma unroll - for (size_t vec_idx = 0; vec_idx < m_groups; ++vec_idx) { - local_B_vec[vec_idx] = - (sq < k && j + vec_idx < m) + + if constexpr (m_groups == 1 && std::is_same_v) { + local_B_block[local_s + q] = + (sq < k && j < m) ? static_cast( - rhs[rhs_offset + rhs_indexer(sqmj + vec_idx)]) + rhs[rhs_offset + rhs_indexer(sqmj)]) : identity_; } - local_B_block[local_s + q] = local_B_vec; + else { + accV_t local_B_vec; +#pragma unroll + for (size_t vec_idx = 0; vec_idx < m_groups; ++vec_idx) { + local_B_vec[vec_idx] = + (sq < k && j + vec_idx < m) + ? static_cast( + rhs[rhs_offset + + rhs_indexer(sqmj + vec_idx)]) + : identity_; + } + local_B_block[local_s + q] = local_B_vec; + } } } @@ -5113,8 +4241,8 @@ class GemmBatchNoAtomicFunctorThreadK size_t t_shift = block_s * delta_k * n_wi; size_t global_s_offset = i * k + t_shift; - sycl::vec private_sum(identity_); - constexpr sycl::vec vec_identity_(identity_); + accV_t private_sum(identity_); + constexpr accV_t vec_identity_(identity_); for (size_t t = local_s; t < local_B_block.size(); t += delta_k) { private_sum += ((i < n) && (t + t_shift < k)) @@ -5130,162 +4258,28 @@ class GemmBatchNoAtomicFunctorThreadK it.barrier(sycl::access::fence_space::local_space); if (local_s == 0 && i < n) { - sycl::vec local_sum(workspace[workspace_i_shift]); + accV_t local_sum(workspace[workspace_i_shift]); for (size_t t = 1; t < delta_k; ++t) { local_sum += workspace[workspace_i_shift + t]; } const size_t total_offset = res_offset + (block_s * n * m * batch_nelems); - res[total_offset + res_indexer(i * m + j)] = local_sum[0]; - -#pragma unroll - for (size_t vec_id = 1; vec_id < m_groups; ++vec_id) { - if (j + vec_id < m) { - res[total_offset + res_indexer(i * m + j + vec_id)] = - local_sum[1]; - } - } - } - } -}; - -template -class GemmBatchNoAtomicFunctorThreadK -{ -private: - const lhsT *lhs = nullptr; - const rhsT *rhs = nullptr; - resT *res = nullptr; - LocAccT workspace; - LocAccT local_B_block; - size_t n = 0; - size_t n_blocks = 0; - size_t delta_n = 0; - size_t k = 0; - size_t k_blocks = 0; - size_t delta_k = 0; - size_t n_wi = 0; - size_t m = 0; - size_t batch_nelems = 0; - BatchDimsIndexerT batch_indexer; - OuterInnerDimsIndexerT lhs_indexer; - OuterInnerDimsIndexerT rhs_indexer; - ResIndexerT res_indexer; - -public: - GemmBatchNoAtomicFunctorThreadK(const lhsT *lhs_, - const rhsT *rhs_, - resT *res_, - LocAccT workspace_, - LocAccT local_B_block_, - size_t n_, - size_t n_blocks_, - size_t delta_n_, - size_t k_, - size_t k_blocks_, - size_t delta_k_, - size_t n_wi_, - size_t m_, - size_t batch_nelems_, - BatchDimsIndexerT batch_indexer_, - OuterInnerDimsIndexerT lhs_indexer_, - OuterInnerDimsIndexerT rhs_indexer_, - ResIndexerT res_indexer_) - : lhs(lhs_), rhs(rhs_), res(res_), workspace(workspace_), - local_B_block(local_B_block_), n(n_), n_blocks(n_blocks_), - delta_n(delta_n_), k(k_), k_blocks(k_blocks_), delta_k(delta_k_), - n_wi(n_wi_), m(m_), batch_nelems(batch_nelems_), - batch_indexer(batch_indexer_), lhs_indexer(lhs_indexer_), - rhs_indexer(rhs_indexer_), res_indexer(res_indexer_) - { - } - - void operator()(sycl::nd_item<1> it) const - { - const size_t n_groups_per_batch = it.get_group_range(0) / batch_nelems; - const size_t m_id = it.get_group_linear_id() / n_groups_per_batch; - const size_t gr_id = - it.get_group_linear_id() - m_id * n_groups_per_batch; - size_t lid = it.get_local_linear_id(); - - const auto &three_offsets_ = batch_indexer(static_cast(m_id)); - const auto &lhs_offset = three_offsets_.get_first_offset(); - const auto &rhs_offset = three_offsets_.get_second_offset(); - const auto &res_offset = three_offsets_.get_third_offset(); - - // lift gr_id -> (block_i, block_j, block_s) - // block_i moves fastest, then block_s, then block_j - size_t block_j = - gr_id / (n_blocks * k_blocks); // 0 <= block_j < m_blocks - size_t block_r = - gr_id - block_j * (n_blocks * - k_blocks); // 0 <= block_r < n_blocks * k_blocks - size_t block_s = block_r / n_blocks; // 0 <= block_s < k_blocks - size_t block_i = - block_r - block_s * n_blocks; // 0 <= block_i < n_blocks - - size_t local_i = lid / (delta_k); // 0 <= local_i < delta_n - size_t local_s = lid - local_i * (delta_k); // 0 <= local_s < delta_k - - size_t i = block_i * delta_n + local_i; - size_t j = block_j; - size_t s = block_s * delta_k * n_wi + local_s; - - constexpr resT identity_ = resT(0); - if (local_i == 0) { - for (size_t q = 0; q < n_wi * delta_k; q += delta_k) { - size_t sq = s + q; - size_t sqmj = sq * m + j; - local_B_block[local_s + q] = - (sq < k && j < m) - ? static_cast(rhs[rhs_offset + rhs_indexer(sqmj)]) - : identity_; + if constexpr (m_groups == 1 && std::is_same_v) { + res[total_offset + res_indexer(i * m + j)] = local_sum; } - } - - it.barrier(sycl::access::fence_space::local_space); - - size_t t_shift = block_s * delta_k * n_wi; - size_t global_s_offset = i * k + t_shift; - - resT private_sum(identity_); - for (size_t t = local_s; t < local_B_block.size(); t += delta_k) { - private_sum += - ((i < n) && ((t + t_shift < k))) - ? (static_cast( - lhs[lhs_offset + lhs_indexer(global_s_offset + t)]) * - local_B_block[t]) - : identity_; - } - - size_t workspace_i_shift = local_i * delta_k; - workspace[workspace_i_shift + local_s] = private_sum; - - it.barrier(sycl::access::fence_space::local_space); + else { + res[total_offset + res_indexer(i * m + j)] = local_sum[0]; - if (local_s == 0 && i < n) { - resT local_sum(workspace[workspace_i_shift]); - for (size_t t = 1; t < delta_k; ++t) { - local_sum += workspace[workspace_i_shift + t]; +#pragma unroll + for (size_t vec_id = 1; vec_id < m_groups; ++vec_id) { + if (j + vec_id < m) { + res[total_offset + res_indexer(i * m + j + vec_id)] = + local_sum[1]; + } + } } - - res[res_offset + res_indexer(i * m + j) + - (block_s * n * m * batch_nelems)] = local_sum; } } }; @@ -5387,15 +4381,15 @@ gemm_batch_tree_k_impl(sycl::queue &exec_q, lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, BatchDimsIndexerT, m_groups>; - cgh.parallel_for( - ndRange, - GemmBatchNoAtomicFunctorThreadK< - lhsTy, rhsTy, resTy, LocAccT, OuterInnerDimsIndexerT, - OuterInnerDimsIndexerT, BatchDimsIndexerT, m_groups>( - lhs_tp, rhs_tp, res_tp, workspace, local_B_block, n, - n_blocks, delta_n, k, k_blocks, delta_k, n_wi, m, - batch_nelems, batch_indexer, lhs_indexer, rhs_indexer, - res_indexer)); + const auto &krn_body = GemmBatchNoAtomicFunctorThreadK< + lhsTy, rhsTy, resTy, LocAccT, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, BatchDimsIndexerT, m_groups>( + lhs_tp, rhs_tp, res_tp, workspace, local_B_block, n, + n_blocks, delta_n, k, k_blocks, delta_k, n_wi, m, + batch_nelems, batch_indexer, lhs_indexer, rhs_indexer, + res_indexer); + + cgh.parallel_for(ndRange, krn_body); } else { using LocAccT = From e3f41608056e1ba90067de96c8d7cc5c80ad38c1 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Mon, 5 Feb 2024 07:06:37 -0600 Subject: [PATCH 2/2] Use const references for read-use-only vectors --- .../tensor/libtensor/source/linalg_functions/dot.cpp | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/dpctl/tensor/libtensor/source/linalg_functions/dot.cpp b/dpctl/tensor/libtensor/source/linalg_functions/dot.cpp index 9a2b51497e..91e07e3793 100644 --- a/dpctl/tensor/libtensor/source/linalg_functions/dot.cpp +++ b/dpctl/tensor/libtensor/source/linalg_functions/dot.cpp @@ -296,14 +296,14 @@ py_dot(const dpctl::tensor::usm_ndarray &x1, const char *x2_data = x2.get_data(); char *dst_data = dst.get_data(); - auto x1_shape_vec = x1.get_shape_vector(); - auto x1_strides_vec = x1.get_strides_vector(); + const auto &x1_shape_vec = x1.get_shape_vector(); + const auto &x1_strides_vec = x1.get_strides_vector(); - auto x2_shape_vec = x2.get_shape_vector(); - auto x2_strides_vec = x2.get_strides_vector(); + const auto &x2_shape_vec = x2.get_shape_vector(); + const auto &x2_strides_vec = x2.get_strides_vector(); - auto dst_shape_vec = dst.get_shape_vector(); - auto dst_strides_vec = dst.get_strides_vector(); + const auto &dst_shape_vec = dst.get_shape_vector(); + const auto &dst_strides_vec = dst.get_strides_vector(); bool is_x1_c_contig = x1.is_c_contiguous(); bool is_x1_f_contig = x1.is_f_contiguous();