From c7d43a606476d7af74feaa5a8534f34f9233565c Mon Sep 17 00:00:00 2001 From: Bimal Gaudel Date: Fri, 12 Jan 2024 11:07:20 -0500 Subject: [PATCH] Complete implementation of tensor_contract free function. Also adds function to compare contraction results with btas in the test fixture for ToT. --- src/TiledArray/tensor/kernels.h | 38 +++++--- tests/tot_array_fixture.h | 167 ++++++++++++++++++++++++++------ 2 files changed, 165 insertions(+), 40 deletions(-) diff --git a/src/TiledArray/tensor/kernels.h b/src/TiledArray/tensor/kernels.h index af951755a3..5dc32db65d 100644 --- a/src/TiledArray/tensor/kernels.h +++ b/src/TiledArray/tensor/kernels.h @@ -1182,8 +1182,8 @@ auto tensor_contract(TensorA const& A, Annot const& aA, TensorB const& B, struct { Indices A, B, C; - } const blas_layout{(indices.A - indices.B) + indices.i, - indices.i + (indices.B - indices.A), indices.e}; + } const blas_layout{(indices.A - indices.B) | indices.i, + indices.i | (indices.B - indices.A), indices.e}; struct { Permutation A, B, C; @@ -1196,30 +1196,42 @@ auto tensor_contract(TensorA const& A, Annot const& aA, TensorB const& B, } const do_perm{indices.A != blas_layout.A, indices.B != blas_layout.B, indices.C != blas_layout.C}; - auto permedA = [&]() -> TensorA { - return do_perm.A ? A.permute(perm.A) : std::cref(A); - }; - - auto permedB = [&]() -> TensorB { - return do_perm.B ? B.permute(perm.B) : std::cref(B); - }; - math::GemmHelper gemm_helper{blas::Op::NoTrans, blas::Op::NoTrans, static_cast(indices.e.size()), static_cast(indices.A.size()), static_cast(indices.B.size())}; - // initialize result with correct rank + // initialize result with the correct extents Result result; { - container::vector rng(indices.e.size(), 0); + using Index = typename Indices::value_type; + using Extent = std::remove_cv_t< + typename decltype(std::declval().extent())::value_type>; + using ExtentMap = ::Einsum::index::IndexMap; + + // Map tensor indices to their extents. + // Note that whether the contracting indices have matching extents is + // implicitly checked here by the pipe(|) operator on ExtentMap. + + ExtentMap extent = (ExtentMap{indices.A, A.range().extent()} | + ExtentMap{indices.B, B.range().extent()}); + + container::vector rng; + rng.reserve(indices.e.size()); + for (auto&& ix : indices.e) { + // assuming ix _exists_ in extent + rng.emplace_back(extent[ix]); + } result = Result{TA::Range(rng)}; } using Numeric = typename Result::numeric_type; // call gemm - gemm(Numeric{1}, permedA(), permedB(), Numeric{0}, result, gemm_helper); + gemm(Numeric{1}, // + do_perm.A ? A.permute(perm.A) : A, // + do_perm.B ? B.permute(perm.B) : B, // + Numeric{0}, result, gemm_helper); return do_perm.C ? result.permute(perm.C.inv()) : result; } diff --git a/tests/tot_array_fixture.h b/tests/tot_array_fixture.h index 7345401518..87749afec0 100644 --- a/tests/tot_array_fixture.h +++ b/tests/tot_array_fixture.h @@ -22,7 +22,9 @@ #include "tiledarray.h" #include "unit_test_config.h" #ifdef TILEDARRAY_HAS_BTAS +#include #include +#include #endif /* Notes: @@ -90,51 +92,162 @@ using output_archive_type = madness::archive::BinaryFstreamOutputArchive; enum class ShapeComp { True, False }; -template +template , bool> = true> auto random_tensor(TA::Range const& rng) { - TA::Tensor result{rng}; - std::generate(result.begin(), result.end(), + TensorT result{rng}; + using NumericT = typename TensorT::numeric_type; + std::generate(/*std::execution::par, */ + result.begin(), result.end(), TA::detail::MakeRandom::generate_value); return result; } +// // note: all the inner tensors (elements of the outer tensor) // have the same @c inner_rng -template -auto random_tensor_of_tensor(TA::Range const& outer_rng, - TA::Range const& inner_rng) { - TA::Tensor> result{outer_rng}; - - std::generate(result.begin(), result.end(), - [inner_rng]() { return random_tensor(inner_rng); }); +// +template < + typename TensorT, + std::enable_if_t, bool> = true> +auto random_tensor(TA::Range const& outer_rng, TA::Range const& inner_rng) { + using InnerTensorT = typename TensorT::value_type; + TensorT result{outer_rng}; + + std::generate(/*std::execution::par,*/ + result.begin(), result.end(), [inner_rng]() { + return random_tensor(inner_rng); + }); return result; } -template -auto make_random_array(TA::TiledRange const& trange) { - using ArrayT = TA::DistArray, Policy>; - - auto make_tile = [](TA::Tensor& tile, TA::Range const& rng) { - tile = random_tensor(rng); - if constexpr (std::is_same_v) return tile.norm(); +/// +/// \tparam Array The type of DistArray to be generated. Cannot be cv-qualified +/// or reference type. +/// \tparam Args TA::Range type for inner tensor if the tile type of the result +/// is a tensor-of-tensor. +/// \param trange The TiledRange of the result DistArray. +/// \param args Either exactly one TA::Range type when the tile type of Array is +/// tensor-of-tensor or nothing. +/// \return Returns a DistArray of type Array whose elements are randomly +/// generated. +/// @note: +/// - Although DistArrays with Sparse policy can be generated all of their +/// tiles are initialized with random values -- technically the returned value +/// is dense. +/// - In case of arrays with tensor-of-tensor tiles, all the inner tensors have +/// the same rank and the same extent of corresponding modes. +/// +template < + typename Array, typename... Args, + typename = + std::void_t, + std::enable_if_t, + bool> = true> +auto random_array(TA::TiledRange const& trange, Args const&... args) { + static_assert( + (sizeof...(Args) == 0 && + TA::detail::is_tensor_v) || + (sizeof...(Args) == 1) && + (TA::detail::is_tensor_of_tensor_v)); + + if constexpr (sizeof...(Args) == 1) + static_assert(std::is_convertible_v); + + using TensorT = typename Array::value_type; + using PolicyT = typename Array::policy_type; + + auto make_tile_meta = [](auto&&... args) { + return [=](TensorT& tile, TA::Range const& rng) { + tile = random_tensor(rng, args...); + if constexpr (std::is_same_v) + return tile.norm(); + }; }; - return TA::make_array(TA::get_default_world(), trange, make_tile); + return TA::make_array(TA::get_default_world(), trange, + make_tile_meta(args...)); +} + +/// +/// Succinctly call TA::detail::tensor_contract +/// +/// \tparam T TA::Tensor type. +/// \param einsum_annot Example annot: 'ik,kj->ij', when @c A is annotated by +/// 'i' and 'k' for its two modes, and @c B is annotated by 'k' and 'j' for the +/// same. The result tensor is rank-2 as well and its modes are annotated by 'i' +/// and 'j'. +/// \return Tensor contraction result. +/// +template , bool> = true> +auto tensor_contract(std::string const& einsum_annot, T const& A, T const& B) { + using ::Einsum::string::split2; + auto [ab, aC] = split2(einsum_annot, "->"); + auto [aA, aB] = split2(ab, ","); + + return TA::detail::tensor_contract(A, aA, B, aB, aC); } -template -auto make_random_array(TA::TiledRange const& trange, TA::Range const& inner) { - using ArrayT = TA::DistArray>, Policy>; +#ifdef TILEDARRAY_HAS_BTAS - auto make_tile = [inner](TA::Tensor>& tile, - TA::Range const& rng) { - tile = random_tensor_of_tensor(rng, inner); - if constexpr (std::is_same_v) return tile.norm(); - }; - return TA::make_array(TA::get_default_world(), trange, make_tile); +template >> +auto tensor_to_btas_tensor(T const& ta_tensor) { + using value_type = typename T::value_type; + using range_type = typename T::range_type; + + btas::Tensor result{ta_tensor.range()}; + TA::tensor_to_btas_subtensor(ta_tensor, result); + return result; +} + +template >> +auto btas_tensor_to_tensor( + btas::Tensor const& btas_tensor) { + TA::Tensor result{TA::Range(btas_tensor.range())}; + TA::btas_subtensor_to_tensor(btas_tensor, result); + return result; } +/// +/// @c einsum_annot pattern example: 'ik,kj->ij'. See tensor_contract function. +/// +template , bool> = true> +auto tensor_contract_btas(std::string const& einsum_annot, T const& A, + T const& B) { + using ::Einsum::string::split2; + auto [ab, aC] = split2(einsum_annot, "->"); + auto [aA, aB] = split2(ab, ","); + + using NumericT = typename T::numeric_type; + + struct { + btas::Tensor A, B, C; + } btas_tensor{tensor_to_btas_tensor(A), tensor_to_btas_tensor(B), {}}; + + btas::contract(NumericT{1}, btas_tensor.A, aA, btas_tensor.B, aB, NumericT{0}, + btas_tensor.C, aC); + + return btas_tensor_to_tensor(btas_tensor.C); +} + +/// +/// \tparam T TA::Tensor type +/// \param einsum_annot see tensor_contract_mult +/// \return True when TA::detail::tensor_contract and btas::contract result the +/// result. Performs bitwise comparison. +/// +template >> +auto tensor_contract_equal(std::string const& einsum_annot, T const& A, + T const& B) { + T result_ta = tensor_contract(einsum_annot, A, B); + T result_btas = tensor_contract_btas(einsum_annot, A, B); + return result_ta == result_btas; +} + +#endif + /* * * When generating arrays containing tensors of tensors (ToT) we adopt simple