Skip to content

Commit

Permalink
Merge pull request #495 from ValeevGroup/gaudel/fix/outer_mode_contra…
Browse files Browse the repository at this point in the history
…ction_in_tot

Outer mode contraction in tot made more efficient
  • Loading branch information
evaleev authored Nov 21, 2024
2 parents 87664ae + 42578f7 commit fb5d5b8
Show file tree
Hide file tree
Showing 5 changed files with 149 additions and 101 deletions.
183 changes: 95 additions & 88 deletions src/TiledArray/einsum/tiledarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,9 @@ auto einsum(expressions::TsrExpr<ArrayA_> A, expressions::TsrExpr<ArrayB_> B,
using ResultTensor = typename ArrayC::value_type;
using ResultShape = typename ArrayC::shape_type;

auto const& tnsrExprA = A;
auto const& tnsrExprB = B;

auto a = std::get<0>(Einsum::idx(A));
auto b = std::get<0>(Einsum::idx(B));
Einsum::Index<std::string> c = std::get<0>(cs);
Expand Down Expand Up @@ -536,16 +539,10 @@ auto einsum(expressions::TsrExpr<ArrayA_> A, expressions::TsrExpr<ArrayB_> B,
// the evaluation can be delegated to the expression layer
// for distarrays of both nested and non-nested tensor tiles.
// *) If no Hadamard indices are present (!h) the evaluation
// can be delegated to the expression _only_ for distarrays with
// non-nested tensor tiles.
// This is because even if Hadamard indices are not present, a contracted
// index might be present pertinent to the outer tensor in case of a
// nested-tile distarray, which is especially handled within this
// function because expression layer cannot handle that yet.
// can be delegated to the expression layer.
//
if ((h && !(i || e)) // pure Hadamard
|| (IsArrayToT<ArrayC> && !(i || h)) // ToT result from outer-product
|| (IsArrayT<ArrayC> && !h)) // T from general product without Hadamard
if ((h && !(i || e)) // pure Hadamard
|| !h) // no Hadamard
{
ArrayC C;
C(std::string(c) + inner.c) = A * B;
Expand Down Expand Up @@ -577,21 +574,6 @@ auto einsum(expressions::TsrExpr<ArrayA_> A, expressions::TsrExpr<ArrayB_> B,
return C;
}

//
// when contraction happens in the outer tensor
// need to evaluate specially..
//
if (IsArrayToT<ArrayC> && i.size() > 0) {
auto annot_c = std::string(h + e + i) + inner.c;
auto temp1 = einsum(A, B, idx<ArrayC>(annot_c), world);
auto temp2 = reduce_modes(temp1, i.size());

auto annot_c_ = std::string(h + e) + inner.c;
decltype(temp2) result;
result(std::string(c) + inner.c) = temp2(annot_c_);
return result;
}

using ::Einsum::index::permutation;
using TiledArray::Permutation;

Expand Down Expand Up @@ -640,79 +622,104 @@ auto einsum(expressions::TsrExpr<ArrayA_> A, expressions::TsrExpr<ArrayB_> B,

using Index = Einsum::Index<size_t>;

if constexpr (AreArraySame<ArrayA, ArrayB> &&
AreArraySame<ArrayB, ArrayC>) {
if (!e) { // hadamard reduction
auto &[A, B] = AB;
TiledRange trange(range_map[i]);
RangeProduct tiles;
for (auto idx : i) {
tiles *= Range(range_map[idx].tiles_range());
if (!e) { // hadamard reduction
auto &[A, B] = AB;
TiledRange trange(range_map[i]);
RangeProduct tiles;
for (auto idx : i) {
tiles *= Range(range_map[idx].tiles_range());
}
auto pa = A.permutation;
auto pb = B.permutation;
for (Index h : H.tiles) {
if (!C.array.is_local(h)) continue;
size_t batch = 1;
for (size_t i = 0; i < h.size(); ++i) {
batch *= H.batch[i].at(h[i]);
}
auto pa = A.permutation;
auto pb = B.permutation;
for (Index h : H.tiles) {
if (!C.array.is_local(h)) continue;
size_t batch = 1;
for (size_t i = 0; i < h.size(); ++i) {
batch *= H.batch[i].at(h[i]);
}
ResultTensor tile(TiledArray::Range{batch},
typename ResultTensor::value_type{});
for (Index i : tiles) {
// skip this unless both input tiles exist
const auto pahi_inv = apply_inverse(pa, h + i);
const auto pbhi_inv = apply_inverse(pb, h + i);
if (A.array.is_zero(pahi_inv) || B.array.is_zero(pbhi_inv))
continue;

auto ai = A.array.find(pahi_inv).get();
auto bi = B.array.find(pbhi_inv).get();
if (pa) ai = ai.permute(pa);
if (pb) bi = bi.permute(pb);
auto shape = trange.tile(i);
ai = ai.reshape(shape, batch);
bi = bi.reshape(shape, batch);
for (size_t k = 0; k < batch; ++k) {
using Ix = ::Einsum::Index<std::string>;
if constexpr (AreArrayToT<ArrayA, ArrayB>) {
auto aik = ai.batch(k);
auto bik = bi.batch(k);
auto vol = aik.total_size();
TA_ASSERT(vol == bik.total_size());

auto &el = tile({k});
using TensorT = std::remove_reference_t<decltype(el)>;

auto mult_op = [&inner](auto const &l,
auto const &r) -> TensorT {
return inner.h ? TA::detail::tensor_hadamard(l, inner.A, r,
inner.B, inner.C)
: TA::detail::tensor_contract(
l, inner.A, r, inner.B, inner.C);
};

for (auto i = 0; i < vol; ++i)
el.add_to(mult_op(aik.data()[i], bik.data()[i]));

} else {
auto hk = ai.batch(k).dot(bi.batch(k));
tile({k}) += hk;
}
ResultTensor tile(TiledArray::Range{batch},
typename ResultTensor::value_type{});
for (Index i : tiles) {
// skip this unless both input tiles exist
const auto pahi_inv = apply_inverse(pa, h + i);
const auto pbhi_inv = apply_inverse(pb, h + i);
if (A.array.is_zero(pahi_inv) || B.array.is_zero(pbhi_inv)) continue;

auto ai = A.array.find(pahi_inv).get();
auto bi = B.array.find(pbhi_inv).get();
if (pa) ai = ai.permute(pa);
if (pb) bi = bi.permute(pb);
auto shape = trange.tile(i);
ai = ai.reshape(shape, batch);
bi = bi.reshape(shape, batch);
for (size_t k = 0; k < batch; ++k) {
using Ix = ::Einsum::Index<std::string>;
if constexpr (AreArrayToT<ArrayA, ArrayB>) {
auto aik = ai.batch(k);
auto bik = bi.batch(k);
auto vol = aik.total_size();
TA_ASSERT(vol == bik.total_size());

auto &el = tile({k});
using TensorT = std::remove_reference_t<decltype(el)>;

auto mult_op = [&inner](auto const &l, auto const &r) -> TensorT {
if (l.empty() || r.empty()) return TensorT{};
return inner.h ? TA::detail::tensor_hadamard(l, inner.A, r,
inner.B, inner.C)
: TA::detail::tensor_contract(l, inner.A, r,
inner.B, inner.C);
};

for (auto i = 0; i < vol; ++i)
el.add_to(mult_op(aik.data()[i], bik.data()[i]));

} else if constexpr (!AreArraySame<ArrayA, ArrayB>) {
auto aik = ai.batch(k);
auto bik = bi.batch(k);
auto vol = aik.total_size();
TA_ASSERT(vol == bik.total_size());

auto &el = tile({k});

for (auto i = 0; i < vol; ++i)
if constexpr (IsArrayToT<ArrayA>) {
el.add_to(aik.data()[i].scale(bik.data()[i]));
} else {
el.add_to(bik.data()[i].scale(aik.data()[i]));
}

} else {
auto hk = ai.batch(k).dot(bi.batch(k));
tile({k}) += hk;
}
}
auto pc = C.permutation;
auto shape = apply_inverse(pc, C.array.trange().tile(h));
tile = tile.reshape(shape);
if (pc) tile = tile.permute(pc);
C.array.set(h, tile);
}
return C.array;
auto pc = C.permutation;
auto shape = apply_inverse(pc, C.array.trange().tile(h));
tile = tile.reshape(shape);
if (pc) tile = tile.permute(pc);
C.array.set(h, tile);
}
return C.array;
}

// generalized contraction

if constexpr (IsArrayToT<ArrayC>) {
if (inner.C != inner.h + inner.e) {
// when inner tensor permutation is non-trivial (could be potentially
// elided by extending this function (@c einsum) to take into account
// of inner tensor's permutations)
auto temp_annot = std::string(c) + ";" + std::string(inner.h + inner.e);
ArrayC temp = einsum(tnsrExprA, tnsrExprB,
Einsum::idx<ArrayC>(temp_annot), world);
ArrayC result;
result(std::string(c) + inner.c) = temp(temp_annot);
return result;
}
}

auto update_tr = [&e = std::as_const(e), &i = std::as_const(i),
&range_map = std::as_const(range_map)](auto &term) {
auto ei = (e + i & term.idx);
Expand Down
54 changes: 47 additions & 7 deletions src/TiledArray/expressions/cont_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -279,25 +279,62 @@ class ContEngine : public BinaryEngine<Derived> {
outer_size(left_indices_), outer_size(right_indices_),
(!implicit_permute_outer_ ? std::move(outer_perm) : Permutation{}));
} else {

auto make_total_perm = [this]() -> BipartitePermutation {
if (this->product_type() != TensorProduct::Contraction
|| this->implicit_permute_inner_)
return this->implicit_permute_outer_
? BipartitePermutation()
: BipartitePermutation(outer(this->perm_));

// Here,
// this->product_type() is Tensor::Contraction, and,
// this->implicit_permute_inner_ is false

return this->inner_product_type() == TensorProduct::Scale
? BipartitePermutation(outer(this->perm_))
: this->perm_;
};

auto total_perm = make_total_perm();

// factor_ is absorbed into inner_tile_nonreturn_op_
op_ = op_type(
left_op, right_op, scalar_type(1), outer_size(indices_),
outer_size(left_indices_), outer_size(right_indices_),
(!implicit_permute_outer_ ? std::move(outer_perm) : Permutation{}),
total_perm,
this->element_nonreturn_op_);
}
trange_ = ContEngine_::make_trange(outer_perm);
shape_ = ContEngine_::make_shape(outer_perm);
} else {
// Initialize non-permuted structure

if constexpr (!TiledArray::detail::is_tensor_of_tensor_v<value_type>) {
op_ = op_type(left_op, right_op, factor_, outer_size(indices_),
outer_size(left_indices_), outer_size(right_indices_));
} else {

auto make_total_perm = [this]() -> BipartitePermutation {
if (this->product_type() != TensorProduct::Contraction
|| this->implicit_permute_inner_)
return {};

// Here,
// this->product_type() is Tensor::Contraction, and,
// this->implicit_permute_inner_ is false

return this->inner_product_type() == TensorProduct::Scale
? BipartitePermutation(outer(this->perm_))
: this->perm_;
};

auto total_perm = make_total_perm();

// factor_ is absorbed into inner_tile_nonreturn_op_
op_ = op_type(left_op, right_op, scalar_type(1), outer_size(indices_),
outer_size(left_indices_), outer_size(right_indices_),
BipartitePermutation{}, this->element_nonreturn_op_);
total_perm, this->element_nonreturn_op_);
}
trange_ = ContEngine_::make_trange();
shape_ = ContEngine_::make_shape();
Expand Down Expand Up @@ -509,12 +546,15 @@ class ContEngine : public BinaryEngine<Derived> {
inner_size(this->left_indices_),
inner_size(this->right_indices_));
this->element_nonreturn_op_ =
[contrreduce_op](result_tile_element_type& result,
const left_tile_element_type& left,
const right_tile_element_type& right) {
[contrreduce_op, permute_inner = this->product_type() !=
TensorProduct::Contraction](
result_tile_element_type& result,
const left_tile_element_type& left,
const right_tile_element_type& right) {
contrreduce_op(result, left, right);
if (!TA::empty(result))
result = contrreduce_op(result); // permutations of result are applied as "postprocessing"
// permutations of result are applied as "postprocessing"
if (permute_inner && !TA::empty(result))
result = contrreduce_op(result);
};
} // ToT x ToT
} else if (inner_prod == TensorProduct::Hadamard) {
Expand Down
11 changes: 6 additions & 5 deletions src/TiledArray/tensor/kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -417,14 +417,15 @@ inline void inplace_tensor_op(Op&& op, TR& result, const Ts&... tensors) {
TA_ASSERT(!empty(result, tensors...));
TA_ASSERT(is_range_set_congruent(result, tensors...));

const auto volume = result.range().volume();

for (decltype(result.range().volume()) ord = 0ul; ord < volume; ++ord) {
auto volume = result.total_size();
for (decltype(volume) ord = 0; ord < volume; ++ord) {
if constexpr (is_tensor_of_tensor_v<TR, Ts...>)
if (((tensors.data()[ord].range().volume() == 0) || ...)) continue;
if constexpr (std::is_invocable_r_v<void, Op, typename TR::value_type&,
typename Ts::value_type...>)
op(result.at_ordinal(ord), tensors.at_ordinal(ord)...);
op(result.data()[ord], tensors.data()[ord]...);
else
inplace_tensor_op(op, result.at_ordinal(ord), tensors.at_ordinal(ord)...);
inplace_tensor_op(op, result.data()[ord], tensors.data()[ord]...);
}
}

Expand Down
1 change: 1 addition & 0 deletions src/TiledArray/tensor/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -1630,6 +1630,7 @@ class Tensor {
template <typename Right,
typename std::enable_if<is_tensor<Right>::value>::type* = nullptr>
Tensor add(const Right& right) const& {
if (right.empty()) return *this;
return binary(
right,
[](const value_type& l, const value_t<Right>& r) -> decltype(auto) {
Expand Down
1 change: 0 additions & 1 deletion src/TiledArray/tile_op/contract_reduce.h
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,6 @@ class ContractReduce : public ContractReduceBase<Result, Left, Right, Scalar> {

if constexpr (!ContractReduceBase_::plain_tensors) {
TA_ASSERT(this->elem_muladd_op());
// not yet implemented
gemm(result, left, right, ContractReduceBase_::gemm_helper(),
this->elem_muladd_op());
} else { // plain tensors
Expand Down

0 comments on commit fb5d5b8

Please sign in to comment.