Skip to content

Commit

Permalink
[WIP] T x ToT overload of einsum: first attempt.
Browse files Browse the repository at this point in the history
  • Loading branch information
bimalgaudel committed Nov 13, 2023
1 parent 4ceb416 commit 65f4374
Showing 1 changed file with 225 additions and 0 deletions.
225 changes: 225 additions & 0 deletions src/TiledArray/einsum/tiledarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,231 @@ auto einsum(expressions::TsrExpr<Array_> A, expressions::TsrExpr<Array_> B,
return C.array;
}

namespace {
template <typename DArrayT>
constexpr bool IsArrayT = detail::is_tensor_v<typename DArrayT::value_type>;

template <typename DArrayToT>
constexpr bool IsArrayToT =
detail::is_tensor_of_tensor_v<typename DArrayToT::value_type>;
} // namespace

template <
typename ArrayT_, typename ArrayToT_, typename... Indices,
typename = std::enable_if_t<IsArrayT<ArrayT_> && IsArrayToT<ArrayToT_>>>
auto einsum(expressions::TsrExpr<ArrayT_> A, expressions::TsrExpr<ArrayToT_> B,
std::tuple<Einsum::Index<std::string>, Indices...> cs,
World &world) {
using ArrayT = std::remove_cv_t<ArrayT_>;
using ArrayToT = std::remove_cv_t<ArrayToT_>;
using Shape = typename ArrayToT::shape_type;
using T = typename ArrayT::value_type;
using ToT = typename ArrayToT::value_type;

auto a = std::get<0>(Einsum::idx(A));
auto b = std::get<0>(Einsum::idx(B));
Einsum::Index<std::string> c = std::get<0>(cs);

struct {
std::string a, b, c;
} inner;
if constexpr (std::tuple_size<decltype(cs)>::value == 2) {
inner.b = ";" + (std::string)std::get<1>(Einsum::idx(B));
inner.c = ";" + (std::string)std::get<1>(cs);
}

// these are "Hadamard" (fused) indices
auto h = a & b & c;

auto e = (a ^ b);
// contracted indices
auto i = (a & b) - h;

// cannot be hadamard reduction type operation for this overload
TA_ASSERT(e);

// no Hadamard indices => standard contraction (or even outer product)
// same a, b, and c => pure Hadamard
TA_ASSERT(!h || (!(a ^ b) && !(b ^ c)));

// maps Index to TiledRange1
// (asserts same index maps to the same TR1 in A, and B)
auto range_map =
(RangeMap(a, A.array().trange()) | RangeMap(b, B.array().trange()));

using ::Einsum::index::permutation;
using TiledArray::Permutation;

auto arrayTermA = ArrayTerm<ArrayT>{A.array(), a};
auto arrayTermB = ArrayTerm<ArrayToT>{B.array(), b};

{
auto ei = (e + i & arrayTermA.idx);
if (arrayTermA.idx != h + ei)
arrayTermA.permutation = permutation(arrayTermA.idx, h + ei);
arrayTermA.expr = ei;
}

{
auto ei = (e + i & arrayTermB.idx);
if (arrayTermB.idx != h + ei)
arrayTermB.permutation = permutation(arrayTermB.idx, h + ei);
arrayTermB.expr = ei;
}

ArrayTerm<ArrayToT> C = {ArrayToT(world, TiledRange(range_map[c])), c};
for (auto idx : e) {
C.tiles *= Range(range_map[idx].tiles_range());
}
if (C.idx != h + e) {
C.permutation = permutation(h + e, C.idx);
}
C.expr = e;

struct {
RangeProduct tiles;
std::vector<std::vector<size_t>> batch;
} H;

for (auto idx : h) {
H.tiles *= Range(range_map[idx].tiles_range());
H.batch.push_back({});
for (auto r : range_map[idx]) {
H.batch.back().push_back(Range{r}.size());
}
}

using Index = Einsum::Index<size_t>;

// generalized contraction
{
auto ei = (e + i & arrayTermA.idx);
arrayTermA.ei_tiled_range = TiledRange(range_map[ei]);
for (auto idx : ei) arrayTermA.tiles *= Range(range_map[idx].tiles_range());
}

{
auto ei = (e + i & arrayTermB.idx);
arrayTermB.ei_tiled_range = TiledRange(range_map[ei]);
for (auto idx : ei) arrayTermB.tiles *= Range(range_map[idx].tiles_range());
}

std::vector<std::shared_ptr<World>> worlds;
std::vector<std::tuple<Index, ToT>> local_tiles;

// iterates over tiles of hadamard indices
for (Index h : H.tiles) {
auto &A = arrayTermA;
auto &B = arrayTermB;

auto own = A.own(h) || B.own(h);
auto comm = world.mpi.comm().Split(own, world.rank());
worlds.push_back(std::make_unique<World>(comm));
auto &owners = worlds.back();
if (!own) continue;
size_t batch = 1;
for (size_t i = 0; i < h.size(); ++i) {
batch *= H.batch[i].at(h[i]);
}

{
arrayTermA.local_tiles.clear();
const Permutation &P = arrayTermA.permutation;

for (Index ei : arrayTermA.tiles) {
auto idx = apply_inverse(P, h + ei);
if (!arrayTermA.array.is_local(idx)) continue;
if (arrayTermA.array.is_zero(idx)) continue;
// TODO no need for immediate evaluation
auto tile = arrayTermA.array.find_local(idx).get();
if (P) tile = tile.permute(P);
auto shape = arrayTermA.ei_tiled_range.tile(ei);
tile = tile.reshape(shape, batch);
arrayTermA.local_tiles.push_back({ei, tile});
}
bool replicated = arrayTermA.array.pmap()->is_replicated();
arrayTermA.ei = TiledArray::make_array<ArrayT>(
*owners, arrayTermA.ei_tiled_range, arrayTermA.local_tiles.begin(),
arrayTermA.local_tiles.end(), replicated);
}

{
arrayTermB.local_tiles.clear();
const Permutation &P = arrayTermB.permutation;

for (Index ei : arrayTermB.tiles) {
auto idx = apply_inverse(P, h + ei);
if (!arrayTermB.array.is_local(idx)) continue;
if (arrayTermB.array.is_zero(idx)) continue;
// TODO no need for immediate evaluation
auto tile = arrayTermB.array.find_local(idx).get();
if (P) tile = tile.permute(P);
auto shape = arrayTermB.ei_tiled_range.tile(ei);
tile = tile.reshape(shape, batch);
arrayTermB.local_tiles.push_back({ei, tile});
}
bool replicated = arrayTermB.array.pmap()->is_replicated();
arrayTermB.ei = TiledArray::make_array<ArrayToT>(
*owners, arrayTermB.ei_tiled_range, arrayTermB.local_tiles.begin(),
arrayTermB.local_tiles.end(), replicated);
}

// todo
// C.ei(C.expr) = (A.ei(A.expr) * B.ei(B.expr)).set_world(*owners);
A.ei.defer_deleter_to_next_fence();
B.ei.defer_deleter_to_next_fence();
A.ei = ArrayT();
B.ei = ArrayToT();
// why omitting this fence leads to deadlock?
owners->gop.fence();
for (Index e : C.tiles) {
if (!C.ei.is_local(e)) continue;
if (C.ei.is_zero(e)) continue;
// TODO no need for immediate evaluation
auto tile = C.ei.find_local(e).get();
assert(tile.batch_size() == batch);
const Permutation &P = C.permutation;
auto c = apply(P, h + e);
auto shape = C.array.trange().tile(c);
shape = apply_inverse(P, shape);
tile = tile.reshape(shape);
if (P) tile = tile.permute(P);
local_tiles.push_back({c, tile});
}
// mark for lazy deletion
C.ei = ArrayToT();
}

if constexpr (!Shape::is_dense()) {
TiledRange tiled_range = TiledRange(range_map[c]);
std::vector<std::pair<Index, float>> tile_norms;
for (auto &[index, tile] : local_tiles) {
tile_norms.push_back({index, tile.norm()});
}
Shape shape(world, tile_norms, tiled_range);
C.array = ArrayToT(world, TiledRange(range_map[c]), shape);
}

for (auto &[index, tile] : local_tiles) {
if (C.array.is_zero(index)) continue;
C.array.set(index, tile);
}

for (auto &w : worlds) {
w->gop.fence();
}

return C.array;
}

template <typename ArrayT, typename ArrayToT, typename... Indices,
typename = std::enable_if_t<IsArrayT<ArrayT> && IsArrayToT<ArrayToT>>>
auto einsum(expressions::TsrExpr<ArrayToT> B, expressions::TsrExpr<ArrayT> A,
std::tuple<Einsum::Index<std::string>, Indices...> cs,
World &world) {
return einsum(A, B, cs, world);
}

/// Computes ternary tensor product whose result
/// is a scalar (a ternary dot product). Optimized for the case where
/// the arguments have common (Hadamard) indices.
Expand Down

0 comments on commit 65f4374

Please sign in to comment.