Skip to content

Commit

Permalink
Fix ToT times ToT into T kind of evaluations so that varying inner …
Browse files Browse the repository at this point in the history
…tensor extents are supported.
  • Loading branch information
bimalgaudel committed May 1, 2024
1 parent 2bfd5aa commit 92e416e
Showing 1 changed file with 133 additions and 15 deletions.
148 changes: 133 additions & 15 deletions src/TiledArray/einsum/tiledarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,102 @@ using ::Einsum::index::IndexMap;
using ::Einsum::index::Permutation;
using ::Einsum::index::permutation;

///
/// \tparam T A type that parameterizes ::Einsum::Index<T>.
///
/// This class makes it easier to work with indices involved in a binary
/// tensor multiplication. Also defines a canonical order of the indices.
///
/// Consider an arbitrary binary tensor multiplication annotated as:
/// A(a_1,...,a_m) * B(b_1,...,b_n) -> C(c_1,...,c_l)
/// Note that {c_1,...,c_l} is subset of ({a_1,...,a_m} union {b_1,...,b_n}).
///
/// We define following index types.
/// * Hadamard index: An index that annotates A, B, and C.
/// * Contracted index: An index that annotates A and B but not C.
/// * External index of A: An index that annotates A and C but not B.
/// * External index of B: An index that annotates B and C but not A.
///
/// Defining canonical index ordering.
/// * Hadamard indices are canonically ordered if they appear in the same
/// order in A's annotation.
/// * Contracted indices are canonically ordered if they appear in the same
/// order in A's annotation.
/// * External indices of A are canonically ordered if they appear in the
/// same order in A's annotation.
/// * External indices of B are canonically ordered if they appear in the
/// same order in B's annotation.
/// * Tensor A's indices are canonically ordered if Hadamard, external
/// indices of A, and contracted indices appear in that order and all
/// three index groups are themselves canonically ordered.
/// * Tensor B's indices are canonically ordered if Hadamard, external
/// indices of B, and contracted indices appear in that order and all
/// three index groups are themselves canonically ordered.
/// * Tensor C's indices are canonically ordered if Hadamard, external
/// indices of A and external indices of B appear in that order and all
/// three index groups are themselves canonically ordered.
///
/// Example: Consider the evaluation: A(i,j,p,a,b) * B(j,i,q,b,a) -> C(i,p,j,q).
/// - Hadamard indices: {i,j}
/// - External indices of A: {p}
/// - External indices of B: {q}
/// - Contracted indices: {a, b}
/// All index groups above are canonically ordered.
/// Writing C's indices in canonical order would give: {i,j,p,q}.
///
template <typename T>
class TensorOpIndices {
public:
using index_t = ::Einsum::Index<T>;

TensorOpIndices(index_t const &ixA, index_t const &ixB, index_t const &ixC)
: orig_indices_({ixA, ixB, ixC}) {
hadamard_ = ixA & ixB & ixC;
contracted_ = (ixA & ixB) - ixC;
external_A_ = (ixA - ixB) & ixC;
external_B_ = (ixB - ixA) & ixC;
}

[[nodiscard]] index_t const &ix_A() const { return orig_indices_[A]; }
[[nodiscard]] index_t const &ix_B() const { return orig_indices_[B]; }
[[nodiscard]] index_t const &ix_C() const { return orig_indices_[C]; }

[[nodiscard]] index_t ix_A_canon() const {
return hadamard() + external_A() + contracted();
}

[[nodiscard]] index_t ix_B_canon() const {
return hadamard() + external_B() + contracted();
}

[[nodiscard]] index_t ix_C_canon() const {
return hadamard() + external_A() + external_B();
}

[[nodiscard]] index_t const &hadamard() const { return hadamard_; }
[[nodiscard]] index_t const &contracted() const { return contracted_; }
[[nodiscard]] index_t const &external_A() const { return external_A_; }
[[nodiscard]] index_t const &external_B() const { return external_B_; }

[[nodiscard]] Permutation to_canon_A() const {
return ::Einsum::index::permutation(ix_A(), ix_A_canon());
}

[[nodiscard]] Permutation to_canon_B() const {
return ::Einsum::index::permutation(ix_B(), ix_B_canon());
}

[[nodiscard]] Permutation to_canon_C() const {
return ::Einsum::index::permutation(ix_C(), ix_C_canon());
}

private:
enum { A, B, C, ABC };
std::array<index_t, ABC> orig_indices_;

index_t hadamard_, contracted_, external_A_, external_B_;
};

/// converts the annotation of an expression to an Index
template <typename Array>
auto idx(const std::string &s) {
Expand Down Expand Up @@ -334,33 +430,55 @@ auto einsum(expressions::TsrExpr<ArrayA_> A, expressions::TsrExpr<ArrayB_> B,
"Nested-rank-reduction only supported when the inner tensor "
"ranks match on the arguments");

// Step I: A * B -> C'
// Step II: C' -> C
//
// At "Step I", a general product (without reduction) in outer indices,
// and pure Hadamard product in inner indices is carried out.
// Then at "Step II", the inner tensors are reduced with a unary function.
// The reducing function is determined by looking at the contracting and
// non-contracting outer indices.
// Illustration of steps by an example.
//
// eg. A(i,j,k;a,b) * B(k,j;a,b) -> C(i,j) involves following two steps:
// Step I: A(i,j,k;a,b) * B(k,j;a,b) -> C'(i,j;a,b)
// Step II: C'(i,j;a,b) -> C(i,j)

auto Cp = einsum(A, B, std::string(c) + ";" + std::string(inner.i));
// Consider the evaluation: A(ijpab;xy) * B(jiqba;yx) -> C(ipjq).
//
// Note for the outer indices:
// - Hadamard: 'ij'
// - External A: 'p'
// - External B: 'q'
// - Contracted: 'ab'
//
// Now C is evaluated in the following steps.
// Step I: A(ijpab;xy) * B(jiqba;yx) -> C0(ijpqab;xy)
// Step II: C0(ijpqab;xy) -> C1(ijpqab)
// Step III: C1(ijpqab) -> C2(ijpq)
// Step IV: C2(ijpq) -> C(ipjq)

auto sum_tot_2_tos = [](auto const &tot) {
typename std::remove_reference_t<decltype(tot)>::value_type result(
tot.range(), [tot](auto &&ix) { return tot(ix).sum(); });
return result;
};

auto result = TA::foreach<typename ArrayC::value_type>(
Cp, [sum_tot_2_tos](auto &out_tile, auto const &in_tile) {
auto const oixs = TensorOpIndices(a, b, c);

struct {
std::string C0, C1, C2;
} const Cn_annot{
std::string(oixs.ix_C_canon() + oixs.contracted()) + inner.a,
{oixs.ix_C_canon() + oixs.contracted()},
{oixs.ix_C_canon()}};

// Step I: A(ijpab;xy) * B(jiqba;yx) -> C0(ijpqab;xy)
auto C0 = einsum(A, B, Cn_annot.C0);

// Step II: C0(ijpqab;xy) -> C1(ijpqab)
auto C1 = TA::foreach<typename ArrayC::value_type>(
C0, [sum_tot_2_tos](auto &out_tile, auto const &in_tile) {
out_tile = sum_tot_2_tos(in_tile);
});

return result;
// Step III: C1(ijpqab) -> C2(ijpq)
auto C2 = reduce_modes(C1, oixs.contracted().size());

// Step IV: C2(ijpq) -> C(ipjq)
ArrayC C;
C(c) = C2(Cn_annot.C2);
return C;

} else {
// these are "Hadamard" (fused) indices
auto h = a & b & c;
Expand Down

0 comments on commit 92e416e

Please sign in to comment.