Skip to content

Commit

Permalink
Complete implementation of tensor_contract free function.
Browse files Browse the repository at this point in the history
Also adds function to compare contraction results with btas in the test fixture for ToT.
  • Loading branch information
bimalgaudel committed Jan 12, 2024
1 parent 8443d05 commit c7d43a6
Show file tree
Hide file tree
Showing 2 changed files with 165 additions and 40 deletions.
38 changes: 25 additions & 13 deletions src/TiledArray/tensor/kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -1182,8 +1182,8 @@ auto tensor_contract(TensorA const& A, Annot const& aA, TensorB const& B,

struct {
Indices A, B, C;
} const blas_layout{(indices.A - indices.B) + indices.i,
indices.i + (indices.B - indices.A), indices.e};
} const blas_layout{(indices.A - indices.B) | indices.i,
indices.i | (indices.B - indices.A), indices.e};

struct {
Permutation A, B, C;
Expand All @@ -1196,30 +1196,42 @@ auto tensor_contract(TensorA const& A, Annot const& aA, TensorB const& B,
} const do_perm{indices.A != blas_layout.A, indices.B != blas_layout.B,
indices.C != blas_layout.C};

auto permedA = [&]() -> TensorA {
return do_perm.A ? A.permute(perm.A) : std::cref(A);
};

auto permedB = [&]() -> TensorB {
return do_perm.B ? B.permute(perm.B) : std::cref(B);
};

math::GemmHelper gemm_helper{blas::Op::NoTrans, blas::Op::NoTrans,
static_cast<unsigned int>(indices.e.size()),
static_cast<unsigned int>(indices.A.size()),
static_cast<unsigned int>(indices.B.size())};

// initialize result with correct rank
// initialize result with the correct extents
Result result;
{
container::vector<size_t> rng(indices.e.size(), 0);
using Index = typename Indices::value_type;
using Extent = std::remove_cv_t<
typename decltype(std::declval<Range>().extent())::value_type>;
using ExtentMap = ::Einsum::index::IndexMap<Index, Extent>;

// Map tensor indices to their extents.
// Note that whether the contracting indices have matching extents is
// implicitly checked here by the pipe(|) operator on ExtentMap.

ExtentMap extent = (ExtentMap{indices.A, A.range().extent()} |
ExtentMap{indices.B, B.range().extent()});

container::vector<Extent> rng;
rng.reserve(indices.e.size());
for (auto&& ix : indices.e) {
// assuming ix _exists_ in extent
rng.emplace_back(extent[ix]);
}
result = Result{TA::Range(rng)};
}

using Numeric = typename Result::numeric_type;

// call gemm
gemm(Numeric{1}, permedA(), permedB(), Numeric{0}, result, gemm_helper);
gemm(Numeric{1}, //
do_perm.A ? A.permute(perm.A) : A, //
do_perm.B ? B.permute(perm.B) : B, //
Numeric{0}, result, gemm_helper);

return do_perm.C ? result.permute(perm.C.inv()) : result;
}
Expand Down
167 changes: 140 additions & 27 deletions tests/tot_array_fixture.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@
#include "tiledarray.h"
#include "unit_test_config.h"
#ifdef TILEDARRAY_HAS_BTAS
#include <TiledArray/conversions/btas.h>
#include <TiledArray/external/btas.h>
#include <btas/generic/contract.h>
#endif

/* Notes:
Expand Down Expand Up @@ -90,51 +92,162 @@ using output_archive_type = madness::archive::BinaryFstreamOutputArchive;

enum class ShapeComp { True, False };

template <typename NumericT>
template <typename TensorT,
std::enable_if_t<TA::detail::is_tensor_v<TensorT>, bool> = true>
auto random_tensor(TA::Range const& rng) {
TA::Tensor<NumericT> result{rng};
std::generate(result.begin(), result.end(),
TensorT result{rng};
using NumericT = typename TensorT::numeric_type;
std::generate(/*std::execution::par, */
result.begin(), result.end(),
TA::detail::MakeRandom<NumericT>::generate_value);
return result;
}

//
// note: all the inner tensors (elements of the outer tensor)
// have the same @c inner_rng
template <typename NumericT>
auto random_tensor_of_tensor(TA::Range const& outer_rng,
TA::Range const& inner_rng) {
TA::Tensor<TA::Tensor<NumericT>> result{outer_rng};

std::generate(result.begin(), result.end(),
[inner_rng]() { return random_tensor<NumericT>(inner_rng); });
//
template <
typename TensorT,
std::enable_if_t<TA::detail::is_tensor_of_tensor_v<TensorT>, bool> = true>
auto random_tensor(TA::Range const& outer_rng, TA::Range const& inner_rng) {
using InnerTensorT = typename TensorT::value_type;
TensorT result{outer_rng};

std::generate(/*std::execution::par,*/
result.begin(), result.end(), [inner_rng]() {
return random_tensor<InnerTensorT>(inner_rng);
});

return result;
}

template <typename NumericT, typename Policy = TA::DensePolicy>
auto make_random_array(TA::TiledRange const& trange) {
using ArrayT = TA::DistArray<TA::Tensor<NumericT>, Policy>;

auto make_tile = [](TA::Tensor<NumericT>& tile, TA::Range const& rng) {
tile = random_tensor<NumericT>(rng);
if constexpr (std::is_same_v<Policy, TA::SparsePolicy>) return tile.norm();
///
/// \tparam Array The type of DistArray to be generated. Cannot be cv-qualified
/// or reference type.
/// \tparam Args TA::Range type for inner tensor if the tile type of the result
/// is a tensor-of-tensor.
/// \param trange The TiledRange of the result DistArray.
/// \param args Either exactly one TA::Range type when the tile type of Array is
/// tensor-of-tensor or nothing.
/// \return Returns a DistArray of type Array whose elements are randomly
/// generated.
/// @note:
/// - Although DistArrays with Sparse policy can be generated all of their
/// tiles are initialized with random values -- technically the returned value
/// is dense.
/// - In case of arrays with tensor-of-tensor tiles, all the inner tensors have
/// the same rank and the same extent of corresponding modes.
///
template <
typename Array, typename... Args,
typename =
std::void_t<typename Array::value_type, typename Array::policy_type>,
std::enable_if_t<TA::detail::is_nested_tensor_v<typename Array::value_type>,
bool> = true>
auto random_array(TA::TiledRange const& trange, Args const&... args) {
static_assert(
(sizeof...(Args) == 0 &&
TA::detail::is_tensor_v<typename Array::value_type>) ||
(sizeof...(Args) == 1) &&
(TA::detail::is_tensor_of_tensor_v<typename Array::value_type>));

if constexpr (sizeof...(Args) == 1)
static_assert(std::is_convertible_v<Args..., TA::Range>);

using TensorT = typename Array::value_type;
using PolicyT = typename Array::policy_type;

auto make_tile_meta = [](auto&&... args) {
return [=](TensorT& tile, TA::Range const& rng) {
tile = random_tensor<TensorT>(rng, args...);
if constexpr (std::is_same_v<TA::SparsePolicy, PolicyT>)
return tile.norm();
};
};

return TA::make_array<ArrayT>(TA::get_default_world(), trange, make_tile);
return TA::make_array<Array>(TA::get_default_world(), trange,
make_tile_meta(args...));
}

///
/// Succinctly call TA::detail::tensor_contract
///
/// \tparam T TA::Tensor type.
/// \param einsum_annot Example annot: 'ik,kj->ij', when @c A is annotated by
/// 'i' and 'k' for its two modes, and @c B is annotated by 'k' and 'j' for the
/// same. The result tensor is rank-2 as well and its modes are annotated by 'i'
/// and 'j'.
/// \return Tensor contraction result.
///
template <typename T, std::enable_if_t<TA::detail::is_tensor_v<T>, bool> = true>
auto tensor_contract(std::string const& einsum_annot, T const& A, T const& B) {
using ::Einsum::string::split2;
auto [ab, aC] = split2(einsum_annot, "->");
auto [aA, aB] = split2(ab, ",");

return TA::detail::tensor_contract(A, aA, B, aB, aC);
}

template <typename NumericT, typename Policy = TA::DensePolicy>
auto make_random_array(TA::TiledRange const& trange, TA::Range const& inner) {
using ArrayT = TA::DistArray<TA::Tensor<TA::Tensor<NumericT>>, Policy>;
#ifdef TILEDARRAY_HAS_BTAS

auto make_tile = [inner](TA::Tensor<TA::Tensor<NumericT>>& tile,
TA::Range const& rng) {
tile = random_tensor_of_tensor<NumericT>(rng, inner);
if constexpr (std::is_same_v<Policy, TA::SparsePolicy>) return tile.norm();
};
return TA::make_array<ArrayT>(TA::get_default_world(), trange, make_tile);
template <typename T, typename = std::enable_if_t<TA::detail::is_tensor_v<T>>>
auto tensor_to_btas_tensor(T const& ta_tensor) {
using value_type = typename T::value_type;
using range_type = typename T::range_type;

btas::Tensor<value_type, range_type> result{ta_tensor.range()};
TA::tensor_to_btas_subtensor(ta_tensor, result);
return result;
}

template <typename NumericT, typename RangeT, typename... Ts,
typename = std::enable_if_t<std::is_convertible_v<RangeT, TA::Range>>>
auto btas_tensor_to_tensor(
btas::Tensor<NumericT, RangeT, Ts...> const& btas_tensor) {
TA::Tensor<NumericT> result{TA::Range(btas_tensor.range())};
TA::btas_subtensor_to_tensor(btas_tensor, result);
return result;
}

///
/// @c einsum_annot pattern example: 'ik,kj->ij'. See tensor_contract function.
///
template <typename T, std::enable_if_t<TA::detail::is_tensor_v<T>, bool> = true>
auto tensor_contract_btas(std::string const& einsum_annot, T const& A,
T const& B) {
using ::Einsum::string::split2;
auto [ab, aC] = split2(einsum_annot, "->");
auto [aA, aB] = split2(ab, ",");

using NumericT = typename T::numeric_type;

struct {
btas::Tensor<NumericT, TA::Range> A, B, C;
} btas_tensor{tensor_to_btas_tensor(A), tensor_to_btas_tensor(B), {}};

btas::contract(NumericT{1}, btas_tensor.A, aA, btas_tensor.B, aB, NumericT{0},
btas_tensor.C, aC);

return btas_tensor_to_tensor(btas_tensor.C);
}

///
/// \tparam T TA::Tensor type
/// \param einsum_annot see tensor_contract_mult
/// \return True when TA::detail::tensor_contract and btas::contract result the
/// result. Performs bitwise comparison.
///
template <typename T, typename = std::enable_if_t<TA::detail::is_tensor_v<T>>>
auto tensor_contract_equal(std::string const& einsum_annot, T const& A,
T const& B) {
T result_ta = tensor_contract(einsum_annot, A, B);
T result_btas = tensor_contract_btas(einsum_annot, A, B);
return result_ta == result_btas;
}

#endif

/*
*
* When generating arrays containing tensors of tensors (ToT) we adopt simple
Expand Down

0 comments on commit c7d43a6

Please sign in to comment.