Skip to content

Commit

Permalink
fixed permutation logic in BinaryEngine::init_indices_ + implemented …
Browse files Browse the repository at this point in the history
…(suboptimal) forms of Mult that were not implemented for the ToT case
  • Loading branch information
evaleev committed Jan 19, 2024
1 parent df0b808 commit 7410b0b
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 49 deletions.
73 changes: 33 additions & 40 deletions src/TiledArray/expressions/binary_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -150,43 +150,34 @@ class BinaryEngine : public ExprEngine<Derived> {
!left_tile_is_tot && !right_tile_is_tot;
constexpr bool args_are_mixed_tensors =
left_tile_is_tot ^ right_tile_is_tot;
if (args_are_plain_tensors &&
(left_outer_permtype_ == PermutationType::matrix_transpose ||
left_outer_permtype_ == PermutationType::identity)) {
left_.permute_tiles(false);
}
if (!args_are_plain_tensors &&
((left_outer_permtype_ == PermutationType::matrix_transpose ||
left_outer_permtype_ == PermutationType::identity) ||
(left_inner_permtype_ == PermutationType::matrix_transpose ||
left_inner_permtype_ == PermutationType::identity))) {
left_.permute_tiles(false);
}
if (args_are_plain_tensors &&
(right_outer_permtype_ == PermutationType::matrix_transpose ||
right_outer_permtype_ == PermutationType::identity)) {
right_.permute_tiles(false);
}
if (!args_are_plain_tensors &&
((left_outer_permtype_ == PermutationType::matrix_transpose ||
left_outer_permtype_ == PermutationType::identity) ||
(right_inner_permtype_ == PermutationType::matrix_transpose ||
right_inner_permtype_ == PermutationType::identity))) {
right_.permute_tiles(false);
}
if (args_are_mixed_tensors &&
((left_outer_permtype_ == PermutationType::matrix_transpose ||
left_outer_permtype_ == PermutationType::identity) ||
(left_inner_permtype_ == PermutationType::matrix_transpose ||
left_inner_permtype_ == PermutationType::identity))) {
left_.permute_tiles(false);
// permute_tiles() denotes what happens to outer OR inner modes
// if we have contraction happening to BOTH inner and outer modes, no need
// to involve permutation, can fuse it into GEMMs
if (left_tile_is_tot) {
if ((left_outer_permtype_ == PermutationType::matrix_transpose ||
left_outer_permtype_ == PermutationType::identity) &&
(left_inner_permtype_ == PermutationType::matrix_transpose ||
left_inner_permtype_ == PermutationType::identity)) {
left_.permute_tiles(false);
}
} else { // !left_tile_is_tot
if (left_outer_permtype_ == PermutationType::matrix_transpose ||
left_outer_permtype_ == PermutationType::identity) {
left_.permute_tiles(false);
}
}
if (args_are_mixed_tensors &&
((left_outer_permtype_ == PermutationType::matrix_transpose ||
left_outer_permtype_ == PermutationType::identity) ||
(right_inner_permtype_ == PermutationType::matrix_transpose ||
right_inner_permtype_ == PermutationType::identity))) {
right_.permute_tiles(false);
if (right_tile_is_tot) {
if ((right_outer_permtype_ == PermutationType::matrix_transpose ||
right_outer_permtype_ == PermutationType::identity) &&
(right_inner_permtype_ == PermutationType::matrix_transpose ||
right_inner_permtype_ == PermutationType::identity)) {
right_.permute_tiles(false);
}
} else { // !right_tile_is_tot
if (right_outer_permtype_ == PermutationType::matrix_transpose ||
right_outer_permtype_ == PermutationType::identity) {
right_.permute_tiles(false);
}
}
}

Expand All @@ -204,10 +195,12 @@ class BinaryEngine : public ExprEngine<Derived> {
/// \param target_indices The target index list for this expression
void perm_indices(const BipartiteIndexList& target_indices) {
if (permute_tiles_) {
TA_ASSERT(left_.indices().size() == target_indices.size() ||
(left_.indices().second().size() ^ target_indices.second().size()));
TA_ASSERT(right_.indices().size() == target_indices.size() ||
(right_.indices().second().size() ^ target_indices.second().size()));
TA_ASSERT(
left_.indices().size() == target_indices.size() ||
(left_.indices().second().size() ^ target_indices.second().size()));
TA_ASSERT(
right_.indices().size() == target_indices.size() ||
(right_.indices().second().size() ^ target_indices.second().size()));

init_indices_<TensorProduct::Hadamard>(target_indices);

Expand Down
1 change: 0 additions & 1 deletion src/TiledArray/expressions/mult_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,6 @@ class MultEngine : public ContEngine<MultEngine<Left, Right, Result>> {
// the tile op; the type of the tile op does not need to match the type of
// the operation on the outer indices
if (this->product_type() == TensorProduct::Hadamard) {
// assumes inner op is also Hadamard
BinaryEngine_::perm_indices(target_indices);
} else {
auto children_initialized = true;
Expand Down
25 changes: 19 additions & 6 deletions src/TiledArray/tile_op/mult.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,17 +128,30 @@ class Mult {

template <bool LC, bool RC, typename std::enable_if<LC>::type* = nullptr>
result_type eval(left_type& first, const right_type& second) const {
TA_ASSERT(!element_op_);
using TiledArray::mult_to;
return mult_to(first, second);
if (!element_op_) {
using TiledArray::mult_to;
return mult_to(first, second);
} else {
// TODO figure out why this does not compiles!!!
// using TiledArray::inplace_binary;
// return inplace_binary(first, second, element_op_);
using TiledArray::binary;
return binary(first, second, element_op_);
}
}

template <bool LC, bool RC,
typename std::enable_if<!LC && RC>::type* = nullptr>
result_type eval(const left_type& first, right_type& second) const {
TA_ASSERT(!element_op_);
using TiledArray::mult_to;
return mult_to(second, first);
if (!element_op_) {
using TiledArray::mult_to;
return mult_to(second, first);
} else { // WARNING: element_op_ might be noncommuting, so can't swap first
// and second! for GEMM could optimize, but can't introspect
// element_op_
using TiledArray::binary;
return binary(first, second, element_op_);
}
}

template <bool LC, bool RC, typename std::enable_if<!RC>::type* = nullptr>
Expand Down
4 changes: 2 additions & 2 deletions tests/einsum.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -797,8 +797,8 @@ BOOST_AUTO_TEST_CASE(ij_mn_eq_ij_mo_times_ji_on) {
using Array = TA::DistArray<TA::Tensor<TA::Tensor<int>>, TA::DensePolicy>;
using Perm = TA::Permutation;

TA::TiledRange lhs_trng{{0, 2, 3}, {0, 2, 4}};
TA::TiledRange rhs_trng{{0, 2, 4}, {0, 2, 3}};
TA::TiledRange lhs_trng{{0, 2, 3}, {0, 1}};
TA::TiledRange rhs_trng{{0, 1}, {0, 2, 3}};
TA::Range lhs_inner_rng{1, 1};
TA::Range rhs_inner_rng{1, 1};

Expand Down

0 comments on commit 7410b0b

Please sign in to comment.