diff --git a/src/TiledArray/einsum/tiledarray.h b/src/TiledArray/einsum/tiledarray.h index 41797efafa..40076ed0ce 100644 --- a/src/TiledArray/einsum/tiledarray.h +++ b/src/TiledArray/einsum/tiledarray.h @@ -485,8 +485,13 @@ auto einsum(expressions::TsrExpr A, expressions::TsrExpr B, // Step IV: C2(ijpq) -> C(ipjq) auto sum_tot_2_tos = [](auto const &tot) { - typename std::remove_reference_t::value_type result( - tot.range(), [tot](auto &&ix) { return tot(ix).sum(); }); + using tot_t = std::remove_reference_t; + typename tot_t::value_type result( + tot.range(), [tot](auto &&ix) { + if (!tot(ix).empty()) + return tot(ix).sum(); + else return typename tot_t::numeric_type{}; + }); return result; }; diff --git a/src/TiledArray/expressions/cont_engine.h b/src/TiledArray/expressions/cont_engine.h index 58d7b9ad57..f0a94c7e05 100644 --- a/src/TiledArray/expressions/cont_engine.h +++ b/src/TiledArray/expressions/cont_engine.h @@ -513,7 +513,8 @@ class ContEngine : public BinaryEngine { const left_tile_element_type& left, const right_tile_element_type& right) { contrreduce_op(result, left, right); - result = contrreduce_op(result); // permutations of result are applied as "postprocessing" + if (!TA::empty(result)) + result = contrreduce_op(result); // permutations of result are applied as "postprocessing" }; } // ToT x ToT } else if (inner_prod == TensorProduct::Hadamard) { diff --git a/src/TiledArray/tensor/kernels.h b/src/TiledArray/tensor/kernels.h index 5d40ce5c14..a2530f2f5d 100644 --- a/src/TiledArray/tensor/kernels.h +++ b/src/TiledArray/tensor/kernels.h @@ -996,6 +996,8 @@ auto tensor_reduce(ReduceOp&& reduce_op, JoinOp&& join_op, auto result = identity; for (std::remove_cv_t ord = 0ul; ord < volume; ++ord) { + if (tensor1.data()[ord].range().volume() == 0 + || ((tensors.data()[ord].range().volume() == 0) || ...)) continue; auto temp = tensor_reduce(reduce_op, join_op, identity, tensor1.data()[ord], tensors.data()[ord]...); join_op(result, temp); diff --git a/src/TiledArray/tensor/tensor.h b/src/TiledArray/tensor/tensor.h index bd72af487c..bd6fb8f3e5 100644 --- a/src/TiledArray/tensor/tensor.h +++ b/src/TiledArray/tensor/tensor.h @@ -431,7 +431,8 @@ class Tensor { auto volume = total_size(); for (decltype(volume) i = 0; i < volume; ++i) { auto& el = *(data() + i); - el = p(el, inner_perm); + if (!el.empty()) + el = p(el, inner_perm); } } } @@ -588,9 +589,13 @@ class Tensor { Tensor clone() const { Tensor result; if (data_) { - result = detail::tensor_op( - [](const numeric_type value) -> numeric_type { return value; }, - *this); + if constexpr (detail::is_tensor_of_tensor_v) { + result = Tensor(*this, [](value_type const& el) { return el.clone(); }); + } else { + result = detail::tensor_op( + [](const numeric_type value) -> numeric_type { return value; }, + *this); + } } else if (range_) { // corner case: data_ = null implies range_.volume() // == 0; TA_ASSERT(range_.volume() == 0); @@ -1538,6 +1543,7 @@ class Tensor { detail::is_bipartite_permutation_v; // tile ops pass bipartite permutations here even if this is a plain tensor if constexpr (!is_tot) { + if (empty()) return *this; if constexpr (is_bperm) { TA_ASSERT(inner_size(perm) == 0); // ensure this is a plain permutation return Tensor(*this, op, outer(std::forward(perm))); @@ -1574,6 +1580,7 @@ class Tensor { template >::type* = nullptr> Tensor scale(const Scalar factor) const { + if (range().volume() == 0) return *this; return unary([factor](const value_type& a) -> decltype(auto) { using namespace TiledArray::detail; return a * factor; @@ -1626,6 +1633,10 @@ class Tensor { return binary( right, [](const value_type& l, const value_t& r) -> decltype(auto) { + if constexpr (detail::is_tensor_v) { + if (l.empty() && r.empty()) + return value_type{}; + } return l + r; }); } @@ -1740,6 +1751,7 @@ class Tensor { template ::value>::type* = nullptr> Tensor& add_to(const Right& right) { + if (right.empty()) return *this; if (empty()) { *this = Tensor{right.range(), value_type{}}; } @@ -1923,11 +1935,17 @@ class Tensor { typename std::enable_if>::type* = nullptr> decltype(auto) mult(const Right& right) const { - return binary( - right, - [](const value_type& l, const value_t& r) -> decltype(auto) { - return l * r; - }); + + auto mult_op =[](const value_type& l, const value_t& r) -> decltype(auto) { + return l * r; + }; + + if (empty() || right.empty()) { + using res_t = decltype(std::declval().binary(std::declval(), mult_op)); + return res_t{}; + } + + return binary(right, mult_op); } /// Multiply this by \c right to create a new, permuted tensor diff --git a/src/TiledArray/tile_op/contract_reduce.h b/src/TiledArray/tile_op/contract_reduce.h index 94c7107343..f0654f1431 100644 --- a/src/TiledArray/tile_op/contract_reduce.h +++ b/src/TiledArray/tile_op/contract_reduce.h @@ -326,17 +326,17 @@ class ContractReduce : public ContractReduceBase { /// \param[in] right The right-hand tile to be contracted void operator()(result_type& result, const first_argument_type& left, const second_argument_type& right) const { + using TiledArray::empty; + using TiledArray::gemm; + if (empty(left) || empty(right)) return; + if constexpr (!ContractReduceBase_::plain_tensors) { TA_ASSERT(this->elem_muladd_op()); // not yet implemented - using TiledArray::empty; - using TiledArray::gemm; gemm(result, left, right, ContractReduceBase_::gemm_helper(), this->elem_muladd_op()); } else { // plain tensors TA_ASSERT(!this->elem_muladd_op()); - using TiledArray::empty; - using TiledArray::gemm; if (empty(result)) result = gemm(left, right, ContractReduceBase_::factor(), ContractReduceBase_::gemm_helper()); diff --git a/tests/retile.cpp b/tests/retile.cpp index 0f4100d4c8..6ac15a48c4 100644 --- a/tests/retile.cpp +++ b/tests/retile.cpp @@ -6,26 +6,24 @@ BOOST_AUTO_TEST_SUITE(retile_suite) BOOST_AUTO_TEST_CASE(retile_tensor) { - TA::detail::matrix_il some_values = { - {0.1, 0.2, 0.3, 0.4, 0.5}, - {0.6, 0.7, 0.8, 0.9, 1.0}, - {1.1, 1.2, 1.3, 1.4, 1.5}, - {1.6, 1.7, 1.8, 1.9, 2.0}, - {2.1, 2.2, 2.3, 2.4, 2.5} - }; - - auto range0 = TA::TiledRange1(0, 3, 5); - auto range1 = TA::TiledRange1(0, 4, 5); - auto trange = TA::TiledRange({range0, range1}); - - TA::TArrayD default_dense(*GlobalFixture::world, some_values); - TA::TSpArrayD default_sparse(*GlobalFixture::world, some_values); - - auto result_dense = retile(default_dense, trange); - auto result_sparse = retile(default_sparse, trange); - - BOOST_CHECK_EQUAL(result_dense.trange(), trange); - BOOST_CHECK_EQUAL(result_sparse.trange(), trange); + TA::detail::matrix_il some_values = {{0.1, 0.2, 0.3, 0.4, 0.5}, + {0.6, 0.7, 0.8, 0.9, 1.0}, + {1.1, 1.2, 1.3, 1.4, 1.5}, + {1.6, 1.7, 1.8, 1.9, 2.0}, + {2.1, 2.2, 2.3, 2.4, 2.5}}; + + auto range0 = TA::TiledRange1(0, 3, 5); + auto range1 = TA::TiledRange1(0, 4, 5); + auto trange = TA::TiledRange({range0, range1}); + + TA::TArrayD default_dense(*GlobalFixture::world, some_values); + TA::TSpArrayD default_sparse(*GlobalFixture::world, some_values); + + auto result_dense = retile(default_dense, trange); + auto result_sparse = retile(default_sparse, trange); + + BOOST_CHECK_EQUAL(result_dense.trange(), trange); + BOOST_CHECK_EQUAL(result_sparse.trange(), trange); } BOOST_AUTO_TEST_CASE(retile_more) { @@ -69,17 +67,20 @@ BOOST_AUTO_TEST_CASE(retile_more) { return tile.norm(); }; + auto arr_source0 = + TA::make_array(world, tr_source, set_random_tensor_tile); + auto arr_target0 = TA::retile(arr_source0, tr_target); + auto get_elem = [](auto const& arr, auto const& eix) { auto tix = arr.trange().element_to_tile(eix); auto&& tile = arr.find(tix).get(false); return tile(eix); }; - auto arr_source0 = - TA::make_array(world, tr_source, set_random_tensor_tile); - auto arr_target0 = TA::retile(arr_source0, tr_target); - for (auto&& eix : elem_rng) { + auto tix = arr_source0.trange().element_to_tile(eix); + BOOST_REQUIRE(arr_source0.is_zero(tix) == arr_target0.is_zero(tix)); + if (arr_source0.is_zero(tix)) continue; BOOST_REQUIRE(get_elem(arr_source0, eix) == get_elem(arr_target0, eix)); } @@ -94,8 +95,11 @@ BOOST_AUTO_TEST_CASE(retile_more) { world.gop.fence(); for (auto&& eix : elem_rng) { + auto tix = arr_source.trange().element_to_tile(eix); + BOOST_REQUIRE(arr_source.is_zero(tix) == arr_target.is_zero(tix)); + if (arr_source.is_zero(tix)) continue; BOOST_REQUIRE(get_elem(arr_source, eix) == get_elem(arr_target, eix)); } } -BOOST_AUTO_TEST_SUITE_END() \ No newline at end of file +BOOST_AUTO_TEST_SUITE_END()