From 2e67af6243476f41cf137cc9bede3ee18aa0905f Mon Sep 17 00:00:00 2001 From: Bimal Gaudel Date: Tue, 7 May 2024 17:00:04 -0400 Subject: [PATCH] Simplify `(H+E,H)->H+E` logic. --- src/TiledArray/einsum/tiledarray.h | 31 ++++++++++-------------------- 1 file changed, 10 insertions(+), 21 deletions(-) diff --git a/src/TiledArray/einsum/tiledarray.h b/src/TiledArray/einsum/tiledarray.h index cff0e2cd7b..3da230ca19 100644 --- a/src/TiledArray/einsum/tiledarray.h +++ b/src/TiledArray/einsum/tiledarray.h @@ -504,34 +504,23 @@ auto einsum(expressions::TsrExpr A, expressions::TsrExpr B, auto range_map = (RangeMap(a, A.array().trange()) | RangeMap(b, B.array().trange())); - auto perm_and_rank_replicate = [delta_trng = make_trange(range_map, e)]( - auto pre, // - std::string const &pre_annot, // - std::string const &permed_annot) { - decltype(pre) permed; - permed(permed_annot) = pre(pre_annot); - return replicate_array(permed, delta_trng); - }; - // special Hadamard if (h.size() == a.size() || h.size() == b.size()) { TA_ASSERT(!i && e); - bool small_a = h.size() == a.size(); - std::string const eh_annot = (e | h); - std::string const permed_annot = - std::string(h) + (small_a ? inner.a : inner.b); - std::string const C_annot = std::string(c) + inner.c; - std::string const temp_annot = std::string(e) + "," + permed_annot; + bool const small_a = h.size() == a.size(); + auto const delta_trng = make_trange(range_map, e); + std::string target_layout = std::string(c) + inner.c; ArrayC C; if (small_a) { - auto temp = - perm_and_rank_replicate(A.array(), A.annotation(), permed_annot); - C(C_annot) = temp(temp_annot) * B; + auto temp = replicate_array(A.array(), delta_trng); + std::string temp_layout = std::string(e) + "," + A.annotation(); + C(target_layout) = temp(temp_layout) * B; } else { - auto temp = - perm_and_rank_replicate(B.array(), B.annotation(), permed_annot); - C(C_annot) = A * temp(temp_annot); + auto temp = replicate_array(B.array(), delta_trng); + std::string temp_layout = std::string(e) + "," + B.annotation(); + C(target_layout) = A * temp(temp_layout); } + return C; }