diff --git a/src/TiledArray/einsum/tiledarray.h b/src/TiledArray/einsum/tiledarray.h index 0e66c4114b..db62720fc8 100644 --- a/src/TiledArray/einsum/tiledarray.h +++ b/src/TiledArray/einsum/tiledarray.h @@ -2,6 +2,7 @@ #define TILEDARRAY_EINSUM_H__INCLUDED #include "TiledArray/fwd.h" +#include "TiledArray/dist_array.h" #include "TiledArray/expressions/fwd.h" #include "TiledArray/einsum/index.h" #include "TiledArray/einsum/range.h" @@ -53,6 +54,8 @@ auto einsum( { using Array = std::remove_cv_t; + using Tensor = typename Array::value_type; + using Shape = typename Array::shape_type; auto a = std::get<0>(Einsum::idx(A)); auto b = std::get<0>(Einsum::idx(B)); @@ -103,6 +106,7 @@ auto einsum( TiledRange ei_tiled_range; Array ei; std::string expr; + std::vector< std::pair,Tensor> > local_tiles; bool own(Einsum::Index h) const { for (Einsum::Index ei : tiles) { auto idx = apply_inverse(permutation, h+ei); @@ -149,7 +153,6 @@ auto einsum( } using Index = Einsum::Index; - using Tensor = typename Array::value_type; if constexpr(std::tuple_size::value > 1) { TA_ASSERT(e); @@ -169,7 +172,7 @@ auto einsum( for (size_t i = 0; i < h.size(); ++i) { batch *= H.batch[i].at(h[i]); } - Tensor tile(TiledArray::Range{batch}); + Tensor tile(TiledArray::Range{batch}, typename Tensor::value_type()); for (Index i : tiles) { // skip this unless both input tiles exist const auto pahi_inv = apply_inverse(pa,h+i); @@ -208,6 +211,7 @@ auto einsum( } std::vector< std::shared_ptr > worlds; + std::vector< std::tuple > local_tiles; // iterates over tiles of hadamard indices for (Index h : H.tiles) { @@ -222,21 +226,29 @@ auto einsum( batch *= H.batch[i].at(h[i]); } for (auto &term : AB) { - term.ei = Array(*owners, term.ei_tiled_range); + term.local_tiles.clear(); const Permutation &P = term.permutation; for (Index ei : term.tiles) { auto idx = apply_inverse(P, h+ei); if (!term.array.is_local(idx)) continue; + if (term.array.is_zero(idx)) continue; auto tile = term.array.find(idx).get(); if (P) tile = tile.permute(P); - auto shape = term.ei.trange().tile(ei); + auto shape = term.ei_tiled_range.tile(ei); tile = tile.reshape(shape, batch); - term.ei.set(ei, tile); + term.local_tiles.push_back({ei, tile}); } + term.ei = TiledArray::make_array( + *owners, + term.ei_tiled_range, + term.local_tiles.begin(), + term.local_tiles.end() + ); } C.ei(C.expr) = (A.ei(A.expr) * B.ei(B.expr)).set_world(*owners); for (Index e : C.tiles) { if (!C.ei.is_local(e)) continue; + if (C.ei.is_zero(e)) continue; auto tile = C.ei.find(e).get(); assert(tile.batch_size() == batch); const Permutation &P = C.permutation; @@ -245,7 +257,7 @@ auto einsum( shape = apply_inverse(P, shape); tile = tile.reshape(shape); if (P) tile = tile.permute(P); - C.array.set(c, tile); + local_tiles.push_back({c, tile}); } // mark for lazy deletion A.ei = Array(); @@ -253,6 +265,21 @@ auto einsum( C.ei = Array(); } + if constexpr (!Shape::is_dense()) { + TiledRange tiled_range = TiledRange(range_map[c]); + std::vector< std::pair > tile_norms; + for (auto& [index,tile] : local_tiles) { + tile_norms.push_back({index,tile.norm()}); + } + Shape shape(world, tile_norms, tiled_range); + C.array = Array(world, TiledRange(range_map[c]), shape); + } + + for (auto& [index,tile] : local_tiles) { + if (C.array.is_zero(index)) continue; + C.array.set(index, tile); + } + for (auto &w : worlds) { w->gop.fence(); } diff --git a/tests/einsum.cpp b/tests/einsum.cpp index 4e2ac664b2..386bfa4096 100644 --- a/tests/einsum.cpp +++ b/tests/einsum.cpp @@ -494,15 +494,28 @@ BOOST_AUTO_TEST_SUITE_END() // TiledArray einsum expressions BOOST_AUTO_TEST_SUITE(einsum_tiledarray) -template, typename ... Args> +using TiledArray::SparsePolicy; +using TiledArray::DensePolicy; + +template, typename ... Args> auto random(Args ... args) { TiledArray::TiledRange tr{ {0, args}... }; auto& world = TiledArray::get_default_world(); - TiledArray::DistArray t(world,tr); + TiledArray::DistArray t(world,tr); t.fill_random(); return t; } +template, typename ... Args> +auto sparse_zero(Args ... args) { + TiledArray::TiledRange tr{ {0, args}... }; + auto& world = TiledArray::get_default_world(); + TiledArray::SparsePolicy::shape_type shape(0.0f, tr); + TiledArray::DistArray t(world,tr,shape); + t.fill(0); + return t; +} + template void einsum_tiledarray_check( TiledArray::DistArray &&A, @@ -523,85 +536,124 @@ void einsum_tiledarray_check( array_to_eigen_tensor>(B) ); auto result = array_to_eigen_tensor(C); + //std::cout << "e=" << result << std::endl; BOOST_CHECK(isApprox(result, reference)); } BOOST_AUTO_TEST_CASE(einsum_tiledarray_ak_bk_ab) { einsum_tiledarray_check<2,2,2>( - random(11,7), - random(13,7), + random(11,7), + random(13,7), "ak,bk->ab" ); } BOOST_AUTO_TEST_CASE(einsum_tiledarray_ka_bk_ba) { einsum_tiledarray_check<2,2,2>( - random(7,11), - random(13,7), + random(7,11), + random(13,7), "ka,bk->ba" ); } BOOST_AUTO_TEST_CASE(einsum_tiledarray_abi_cdi_cdab) { einsum_tiledarray_check<3,3,4>( - random(21,22,3), - random(24,25,3), + random(21,22,3), + random(24,25,3), "abi,cdi->cdab" ); } BOOST_AUTO_TEST_CASE(einsum_tiledarray_icd_ai_abcd) { einsum_tiledarray_check<3,3,4>( - random(3,12,13), - random(14,15,3), + random(3,12,13), + random(14,15,3), "icd,bai->abcd" ); } BOOST_AUTO_TEST_CASE(einsum_tiledarray_cdji_ibja_abcd) { einsum_tiledarray_check<4,4,4>( - random(14,15,3,5), - random(5,12,3,13), + random(14,15,3,5), + random(5,12,3,13), "cdji,ibja->abcd" ); } BOOST_AUTO_TEST_CASE(einsum_tiledarray_hai_hbi_hab) { einsum_tiledarray_check<3,3,3>( - random(7,14,3), - random(7,15,3), + random(7,14,3), + random(7,15,3), + "hai,hbi->hab" + ); + einsum_tiledarray_check<3,3,3>( + sparse_zero(7,14,3), + sparse_zero(7,15,3), "hai,hbi->hab" ); } BOOST_AUTO_TEST_CASE(einsum_tiledarray_iah_hib_bha) { einsum_tiledarray_check<3,3,3>( - random(7,14,3), - random(3,7,15), + random(7,14,3), + random(3,7,15), + "iah,hib->bha" + ); + einsum_tiledarray_check<3,3,3>( + sparse_zero(7,14,3), + sparse_zero(3,7,15), "iah,hib->bha" ); } BOOST_AUTO_TEST_CASE(einsum_tiledarray_iah_hib_abh) { einsum_tiledarray_check<3,3,3>( - random(7,14,3), - random(3,7,15), + random(7,14,3), + random(3,7,15), "iah,hib->abh" ); + einsum_tiledarray_check<3,3,3>( + sparse_zero(7,14,3), + sparse_zero(3,7,15), + "iah,hib->abh" + ); +} + +BOOST_AUTO_TEST_CASE(einsum_tiledarray_hai_hibc_habc) { + einsum_tiledarray_check<3,4,4>( + random(9,3,11), + random(9,11,5,7), + "hai,hibc->habc" + ); + einsum_tiledarray_check<3,4,4>( + sparse_zero(9,3,11), + sparse_zero(9,11,5,7), + "hai,hibc->habc" + ); } BOOST_AUTO_TEST_CASE(einsum_tiledarray_hi_hi_h) { einsum_tiledarray_check<2,2,1>( - random(7,14), - random(7,14), + random(7,14), + random(7,14), + "hi,hi->h" + ); + einsum_tiledarray_check<2,2,1>( + sparse_zero(7,14), + sparse_zero(7,14), "hi,hi->h" ); } BOOST_AUTO_TEST_CASE(einsum_tiledarray_hji_jih_hj) { einsum_tiledarray_check<3,3,2>( - random(14,7,5), - random(7,5,14), + random(14,7,5), + random(7,5,14), + "hji,jih->hj" + ); + einsum_tiledarray_check<3,3,2>( + sparse_zero(14,7,5), + sparse_zero(7,5,14), "hji,jih->hj" ); }