From 731948246a97d511d235fce3423bc9db0ec3cee0 Mon Sep 17 00:00:00 2001 From: Bimal Gaudel Date: Mon, 4 Nov 2024 10:00:09 -0500 Subject: [PATCH] Tensor-of-tensor tiles can have some elements (i.e. tensors) that are zero. --- src/TiledArray/einsum/tiledarray.h | 9 ++++-- src/TiledArray/expressions/cont_engine.h | 3 +- src/TiledArray/tensor/kernels.h | 2 ++ src/TiledArray/tensor/tensor.h | 36 ++++++++++++++++++------ src/TiledArray/tile_op/contract_reduce.h | 8 +++--- 5 files changed, 42 insertions(+), 16 deletions(-) diff --git a/src/TiledArray/einsum/tiledarray.h b/src/TiledArray/einsum/tiledarray.h index 41797efafa..40076ed0ce 100644 --- a/src/TiledArray/einsum/tiledarray.h +++ b/src/TiledArray/einsum/tiledarray.h @@ -485,8 +485,13 @@ auto einsum(expressions::TsrExpr A, expressions::TsrExpr B, // Step IV: C2(ijpq) -> C(ipjq) auto sum_tot_2_tos = [](auto const &tot) { - typename std::remove_reference_t::value_type result( - tot.range(), [tot](auto &&ix) { return tot(ix).sum(); }); + using tot_t = std::remove_reference_t; + typename tot_t::value_type result( + tot.range(), [tot](auto &&ix) { + if (!tot(ix).empty()) + return tot(ix).sum(); + else return typename tot_t::numeric_type{}; + }); return result; }; diff --git a/src/TiledArray/expressions/cont_engine.h b/src/TiledArray/expressions/cont_engine.h index 58d7b9ad57..f0a94c7e05 100644 --- a/src/TiledArray/expressions/cont_engine.h +++ b/src/TiledArray/expressions/cont_engine.h @@ -513,7 +513,8 @@ class ContEngine : public BinaryEngine { const left_tile_element_type& left, const right_tile_element_type& right) { contrreduce_op(result, left, right); - result = contrreduce_op(result); // permutations of result are applied as "postprocessing" + if (!TA::empty(result)) + result = contrreduce_op(result); // permutations of result are applied as "postprocessing" }; } // ToT x ToT } else if (inner_prod == TensorProduct::Hadamard) { diff --git a/src/TiledArray/tensor/kernels.h b/src/TiledArray/tensor/kernels.h index 5d40ce5c14..a2530f2f5d 100644 --- a/src/TiledArray/tensor/kernels.h +++ b/src/TiledArray/tensor/kernels.h @@ -996,6 +996,8 @@ auto tensor_reduce(ReduceOp&& reduce_op, JoinOp&& join_op, auto result = identity; for (std::remove_cv_t ord = 0ul; ord < volume; ++ord) { + if (tensor1.data()[ord].range().volume() == 0 + || ((tensors.data()[ord].range().volume() == 0) || ...)) continue; auto temp = tensor_reduce(reduce_op, join_op, identity, tensor1.data()[ord], tensors.data()[ord]...); join_op(result, temp); diff --git a/src/TiledArray/tensor/tensor.h b/src/TiledArray/tensor/tensor.h index bd72af487c..bd6fb8f3e5 100644 --- a/src/TiledArray/tensor/tensor.h +++ b/src/TiledArray/tensor/tensor.h @@ -431,7 +431,8 @@ class Tensor { auto volume = total_size(); for (decltype(volume) i = 0; i < volume; ++i) { auto& el = *(data() + i); - el = p(el, inner_perm); + if (!el.empty()) + el = p(el, inner_perm); } } } @@ -588,9 +589,13 @@ class Tensor { Tensor clone() const { Tensor result; if (data_) { - result = detail::tensor_op( - [](const numeric_type value) -> numeric_type { return value; }, - *this); + if constexpr (detail::is_tensor_of_tensor_v) { + result = Tensor(*this, [](value_type const& el) { return el.clone(); }); + } else { + result = detail::tensor_op( + [](const numeric_type value) -> numeric_type { return value; }, + *this); + } } else if (range_) { // corner case: data_ = null implies range_.volume() // == 0; TA_ASSERT(range_.volume() == 0); @@ -1538,6 +1543,7 @@ class Tensor { detail::is_bipartite_permutation_v; // tile ops pass bipartite permutations here even if this is a plain tensor if constexpr (!is_tot) { + if (empty()) return *this; if constexpr (is_bperm) { TA_ASSERT(inner_size(perm) == 0); // ensure this is a plain permutation return Tensor(*this, op, outer(std::forward(perm))); @@ -1574,6 +1580,7 @@ class Tensor { template >::type* = nullptr> Tensor scale(const Scalar factor) const { + if (range().volume() == 0) return *this; return unary([factor](const value_type& a) -> decltype(auto) { using namespace TiledArray::detail; return a * factor; @@ -1626,6 +1633,10 @@ class Tensor { return binary( right, [](const value_type& l, const value_t& r) -> decltype(auto) { + if constexpr (detail::is_tensor_v) { + if (l.empty() && r.empty()) + return value_type{}; + } return l + r; }); } @@ -1740,6 +1751,7 @@ class Tensor { template ::value>::type* = nullptr> Tensor& add_to(const Right& right) { + if (right.empty()) return *this; if (empty()) { *this = Tensor{right.range(), value_type{}}; } @@ -1923,11 +1935,17 @@ class Tensor { typename std::enable_if>::type* = nullptr> decltype(auto) mult(const Right& right) const { - return binary( - right, - [](const value_type& l, const value_t& r) -> decltype(auto) { - return l * r; - }); + + auto mult_op =[](const value_type& l, const value_t& r) -> decltype(auto) { + return l * r; + }; + + if (empty() || right.empty()) { + using res_t = decltype(std::declval().binary(std::declval(), mult_op)); + return res_t{}; + } + + return binary(right, mult_op); } /// Multiply this by \c right to create a new, permuted tensor diff --git a/src/TiledArray/tile_op/contract_reduce.h b/src/TiledArray/tile_op/contract_reduce.h index 94c7107343..f0654f1431 100644 --- a/src/TiledArray/tile_op/contract_reduce.h +++ b/src/TiledArray/tile_op/contract_reduce.h @@ -326,17 +326,17 @@ class ContractReduce : public ContractReduceBase { /// \param[in] right The right-hand tile to be contracted void operator()(result_type& result, const first_argument_type& left, const second_argument_type& right) const { + using TiledArray::empty; + using TiledArray::gemm; + if (empty(left) || empty(right)) return; + if constexpr (!ContractReduceBase_::plain_tensors) { TA_ASSERT(this->elem_muladd_op()); // not yet implemented - using TiledArray::empty; - using TiledArray::gemm; gemm(result, left, right, ContractReduceBase_::gemm_helper(), this->elem_muladd_op()); } else { // plain tensors TA_ASSERT(!this->elem_muladd_op()); - using TiledArray::empty; - using TiledArray::gemm; if (empty(result)) result = gemm(left, right, ContractReduceBase_::factor(), ContractReduceBase_::gemm_helper());