Skip to content

Commit

Permalink
Einsum: handle sparse DistArray
Browse files Browse the repository at this point in the history
  • Loading branch information
asadchev committed Aug 3, 2022
1 parent e967606 commit 5aa547c
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 28 deletions.
39 changes: 33 additions & 6 deletions src/TiledArray/einsum/tiledarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define TILEDARRAY_EINSUM_H__INCLUDED

#include "TiledArray/fwd.h"
#include "TiledArray/dist_array.h"
#include "TiledArray/expressions/fwd.h"
#include "TiledArray/einsum/index.h"
#include "TiledArray/einsum/range.h"
Expand Down Expand Up @@ -53,6 +54,8 @@ auto einsum(
{

using Array = std::remove_cv_t<Array_>;
using Tensor = typename Array::value_type;
using Shape = typename Array::shape_type;

auto a = std::get<0>(Einsum::idx(A));
auto b = std::get<0>(Einsum::idx(B));
Expand Down Expand Up @@ -103,6 +106,7 @@ auto einsum(
TiledRange ei_tiled_range;
Array ei;
std::string expr;
std::vector< std::pair<Einsum::Index<size_t>,Tensor> > local_tiles;
bool own(Einsum::Index<size_t> h) const {
for (Einsum::Index<size_t> ei : tiles) {
auto idx = apply_inverse(permutation, h+ei);
Expand Down Expand Up @@ -149,7 +153,6 @@ auto einsum(
}

using Index = Einsum::Index<size_t>;
using Tensor = typename Array::value_type;

if constexpr(std::tuple_size<decltype(cs)>::value > 1) {
TA_ASSERT(e);
Expand All @@ -169,7 +172,7 @@ auto einsum(
for (size_t i = 0; i < h.size(); ++i) {
batch *= H.batch[i].at(h[i]);
}
Tensor tile(TiledArray::Range{batch});
Tensor tile(TiledArray::Range{batch}, typename Tensor::value_type());
for (Index i : tiles) {
// skip this unless both input tiles exist
const auto pahi_inv = apply_inverse(pa,h+i);
Expand Down Expand Up @@ -208,6 +211,7 @@ auto einsum(
}

std::vector< std::shared_ptr<World> > worlds;
std::vector< std::tuple<Index,Tensor> > local_tiles;

// iterates over tiles of hadamard indices
for (Index h : H.tiles) {
Expand All @@ -222,21 +226,29 @@ auto einsum(
batch *= H.batch[i].at(h[i]);
}
for (auto &term : AB) {
term.ei = Array(*owners, term.ei_tiled_range);
term.local_tiles.clear();
const Permutation &P = term.permutation;
for (Index ei : term.tiles) {
auto idx = apply_inverse(P, h+ei);
if (!term.array.is_local(idx)) continue;
if (term.array.is_zero(idx)) continue;
auto tile = term.array.find(idx).get();
if (P) tile = tile.permute(P);
auto shape = term.ei.trange().tile(ei);
auto shape = term.ei_tiled_range.tile(ei);
tile = tile.reshape(shape, batch);
term.ei.set(ei, tile);
term.local_tiles.push_back({ei, tile});
}
term.ei = TiledArray::make_array<Array>(
*owners,
term.ei_tiled_range,
term.local_tiles.begin(),
term.local_tiles.end()
);
}
C.ei(C.expr) = (A.ei(A.expr) * B.ei(B.expr)).set_world(*owners);
for (Index e : C.tiles) {
if (!C.ei.is_local(e)) continue;
if (C.ei.is_zero(e)) continue;
auto tile = C.ei.find(e).get();
assert(tile.batch_size() == batch);
const Permutation &P = C.permutation;
Expand All @@ -245,14 +257,29 @@ auto einsum(
shape = apply_inverse(P, shape);
tile = tile.reshape(shape);
if (P) tile = tile.permute(P);
C.array.set(c, tile);
local_tiles.push_back({c, tile});
}
// mark for lazy deletion
A.ei = Array();
B.ei = Array();
C.ei = Array();
}

if constexpr (!Shape::is_dense()) {
TiledRange tiled_range = TiledRange(range_map[c]);
std::vector< std::pair<Index,float> > tile_norms;
for (auto& [index,tile] : local_tiles) {
tile_norms.push_back({index,tile.norm()});
}
Shape shape(world, tile_norms, tiled_range);
C.array = Array(world, TiledRange(range_map[c]), shape);
}

for (auto& [index,tile] : local_tiles) {
if (C.array.is_zero(index)) continue;
C.array.set(index, tile);
}

for (auto &w : worlds) {
w->gop.fence();
}
Expand Down
96 changes: 74 additions & 22 deletions tests/einsum.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -494,15 +494,28 @@ BOOST_AUTO_TEST_SUITE_END()
// TiledArray einsum expressions
BOOST_AUTO_TEST_SUITE(einsum_tiledarray)

template<typename T = Tensor<int>, typename ... Args>
using TiledArray::SparsePolicy;
using TiledArray::DensePolicy;

template<typename Policy, typename T = Tensor<int>, typename ... Args>
auto random(Args ... args) {
TiledArray::TiledRange tr{ {0, args}... };
auto& world = TiledArray::get_default_world();
TiledArray::DistArray<T,TiledArray::SparsePolicy> t(world,tr);
TiledArray::DistArray<T,Policy> t(world,tr);
t.fill_random();
return t;
}

template<typename T = Tensor<int>, typename ... Args>
auto sparse_zero(Args ... args) {
TiledArray::TiledRange tr{ {0, args}... };
auto& world = TiledArray::get_default_world();
TiledArray::SparsePolicy::shape_type shape(0.0f, tr);
TiledArray::DistArray<T,TiledArray::SparsePolicy> t(world,tr,shape);
t.fill(0);
return t;
}

template<int NA, int NB, int NC, typename T, typename Policy>
void einsum_tiledarray_check(
TiledArray::DistArray<T,Policy> &&A,
Expand All @@ -523,85 +536,124 @@ void einsum_tiledarray_check(
array_to_eigen_tensor<Tensor<U,NB>>(B)
);
auto result = array_to_eigen_tensor<TC>(C);
//std::cout << "e=" << result << std::endl;
BOOST_CHECK(isApprox(result, reference));
}

BOOST_AUTO_TEST_CASE(einsum_tiledarray_ak_bk_ab) {
einsum_tiledarray_check<2,2,2>(
random(11,7),
random(13,7),
random<SparsePolicy>(11,7),
random<SparsePolicy>(13,7),
"ak,bk->ab"
);
}

BOOST_AUTO_TEST_CASE(einsum_tiledarray_ka_bk_ba) {
einsum_tiledarray_check<2,2,2>(
random(7,11),
random(13,7),
random<SparsePolicy>(7,11),
random<SparsePolicy>(13,7),
"ka,bk->ba"
);
}

BOOST_AUTO_TEST_CASE(einsum_tiledarray_abi_cdi_cdab) {
einsum_tiledarray_check<3,3,4>(
random(21,22,3),
random(24,25,3),
random<SparsePolicy>(21,22,3),
random<SparsePolicy>(24,25,3),
"abi,cdi->cdab"
);
}

BOOST_AUTO_TEST_CASE(einsum_tiledarray_icd_ai_abcd) {
einsum_tiledarray_check<3,3,4>(
random(3,12,13),
random(14,15,3),
random<SparsePolicy>(3,12,13),
random<SparsePolicy>(14,15,3),
"icd,bai->abcd"
);
}

BOOST_AUTO_TEST_CASE(einsum_tiledarray_cdji_ibja_abcd) {
einsum_tiledarray_check<4,4,4>(
random(14,15,3,5),
random(5,12,3,13),
random<SparsePolicy>(14,15,3,5),
random<SparsePolicy>(5,12,3,13),
"cdji,ibja->abcd"
);
}

BOOST_AUTO_TEST_CASE(einsum_tiledarray_hai_hbi_hab) {
einsum_tiledarray_check<3,3,3>(
random(7,14,3),
random(7,15,3),
random<SparsePolicy>(7,14,3),
random<SparsePolicy>(7,15,3),
"hai,hbi->hab"
);
einsum_tiledarray_check<3,3,3>(
sparse_zero(7,14,3),
sparse_zero(7,15,3),
"hai,hbi->hab"
);
}

BOOST_AUTO_TEST_CASE(einsum_tiledarray_iah_hib_bha) {
einsum_tiledarray_check<3,3,3>(
random(7,14,3),
random(3,7,15),
random<SparsePolicy>(7,14,3),
random<SparsePolicy>(3,7,15),
"iah,hib->bha"
);
einsum_tiledarray_check<3,3,3>(
sparse_zero(7,14,3),
sparse_zero(3,7,15),
"iah,hib->bha"
);
}

BOOST_AUTO_TEST_CASE(einsum_tiledarray_iah_hib_abh) {
einsum_tiledarray_check<3,3,3>(
random(7,14,3),
random(3,7,15),
random<SparsePolicy>(7,14,3),
random<SparsePolicy>(3,7,15),
"iah,hib->abh"
);
einsum_tiledarray_check<3,3,3>(
sparse_zero(7,14,3),
sparse_zero(3,7,15),
"iah,hib->abh"
);
}

BOOST_AUTO_TEST_CASE(einsum_tiledarray_hai_hibc_habc) {
einsum_tiledarray_check<3,4,4>(
random<SparsePolicy>(9,3,11),
random<SparsePolicy>(9,11,5,7),
"hai,hibc->habc"
);
einsum_tiledarray_check<3,4,4>(
sparse_zero(9,3,11),
sparse_zero(9,11,5,7),
"hai,hibc->habc"
);
}

BOOST_AUTO_TEST_CASE(einsum_tiledarray_hi_hi_h) {
einsum_tiledarray_check<2,2,1>(
random(7,14),
random(7,14),
random<SparsePolicy>(7,14),
random<SparsePolicy>(7,14),
"hi,hi->h"
);
einsum_tiledarray_check<2,2,1>(
sparse_zero(7,14),
sparse_zero(7,14),
"hi,hi->h"
);
}

BOOST_AUTO_TEST_CASE(einsum_tiledarray_hji_jih_hj) {
einsum_tiledarray_check<3,3,2>(
random(14,7,5),
random(7,5,14),
random<SparsePolicy>(14,7,5),
random<SparsePolicy>(7,5,14),
"hji,jih->hj"
);
einsum_tiledarray_check<3,3,2>(
sparse_zero(14,7,5),
sparse_zero(7,5,14),
"hji,jih->hj"
);
}
Expand Down

0 comments on commit 5aa547c

Please sign in to comment.