Skip to content

Commit

Permalink
Tensor-of-tensor tiles can have some elements (i.e. tensors) that are…
Browse files Browse the repository at this point in the history
… zero.
  • Loading branch information
bimalgaudel authored and evaleev committed Nov 9, 2024
1 parent b81da44 commit c6e9490
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 16 deletions.
9 changes: 7 additions & 2 deletions src/TiledArray/einsum/tiledarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -485,8 +485,13 @@ auto einsum(expressions::TsrExpr<ArrayA_> A, expressions::TsrExpr<ArrayB_> B,
// Step IV: C2(ijpq) -> C(ipjq)

auto sum_tot_2_tos = [](auto const &tot) {
typename std::remove_reference_t<decltype(tot)>::value_type result(
tot.range(), [tot](auto &&ix) { return tot(ix).sum(); });
using tot_t = std::remove_reference_t<decltype(tot)>;
typename tot_t::value_type result(
tot.range(), [tot](auto &&ix) {
if (!tot(ix).empty())
return tot(ix).sum();
else return typename tot_t::numeric_type{};
});
return result;
};

Expand Down
3 changes: 2 additions & 1 deletion src/TiledArray/expressions/cont_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -513,7 +513,8 @@ class ContEngine : public BinaryEngine<Derived> {
const left_tile_element_type& left,
const right_tile_element_type& right) {
contrreduce_op(result, left, right);
result = contrreduce_op(result); // permutations of result are applied as "postprocessing"
if (!TA::empty(result))
result = contrreduce_op(result); // permutations of result are applied as "postprocessing"
};
} // ToT x ToT
} else if (inner_prod == TensorProduct::Hadamard) {
Expand Down
2 changes: 2 additions & 0 deletions src/TiledArray/tensor/kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -996,6 +996,8 @@ auto tensor_reduce(ReduceOp&& reduce_op, JoinOp&& join_op,

auto result = identity;
for (std::remove_cv_t<decltype(volume)> ord = 0ul; ord < volume; ++ord) {
if (tensor1.data()[ord].range().volume() == 0
|| ((tensors.data()[ord].range().volume() == 0) || ...)) continue;
auto temp = tensor_reduce(reduce_op, join_op, identity, tensor1.data()[ord],
tensors.data()[ord]...);
join_op(result, temp);
Expand Down
36 changes: 27 additions & 9 deletions src/TiledArray/tensor/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,8 @@ class Tensor {
auto volume = total_size();
for (decltype(volume) i = 0; i < volume; ++i) {
auto& el = *(data() + i);
el = p(el, inner_perm);
if (!el.empty())
el = p(el, inner_perm);
}
}
}
Expand Down Expand Up @@ -588,9 +589,13 @@ class Tensor {
Tensor clone() const {
Tensor result;
if (data_) {
result = detail::tensor_op<Tensor>(
[](const numeric_type value) -> numeric_type { return value; },
*this);
if constexpr (detail::is_tensor_of_tensor_v<Tensor>) {
result = Tensor(*this, [](value_type const& el) { return el.clone(); });
} else {
result = detail::tensor_op<Tensor>(
[](const numeric_type value) -> numeric_type { return value; },
*this);
}
} else if (range_) { // corner case: data_ = null implies range_.volume()
// == 0;
TA_ASSERT(range_.volume() == 0);
Expand Down Expand Up @@ -1538,6 +1543,7 @@ class Tensor {
detail::is_bipartite_permutation_v<Perm>;
// tile ops pass bipartite permutations here even if this is a plain tensor
if constexpr (!is_tot) {
if (empty()) return *this;
if constexpr (is_bperm) {
TA_ASSERT(inner_size(perm) == 0); // ensure this is a plain permutation
return Tensor(*this, op, outer(std::forward<Perm>(perm)));
Expand Down Expand Up @@ -1574,6 +1580,7 @@ class Tensor {
template <typename Scalar, typename std::enable_if<
detail::is_numeric_v<Scalar>>::type* = nullptr>
Tensor scale(const Scalar factor) const {
if (range().volume() == 0) return *this;
return unary([factor](const value_type& a) -> decltype(auto) {
using namespace TiledArray::detail;
return a * factor;
Expand Down Expand Up @@ -1626,6 +1633,10 @@ class Tensor {
return binary(
right,
[](const value_type& l, const value_t<Right>& r) -> decltype(auto) {
if constexpr (detail::is_tensor_v<value_type>) {
if (l.empty() && r.empty())
return value_type{};
}
return l + r;
});
}
Expand Down Expand Up @@ -1740,6 +1751,7 @@ class Tensor {
template <typename Right,
typename std::enable_if<is_tensor<Right>::value>::type* = nullptr>
Tensor& add_to(const Right& right) {
if (right.empty()) return *this;
if (empty()) {
*this = Tensor{right.range(), value_type{}};
}
Expand Down Expand Up @@ -1923,11 +1935,17 @@ class Tensor {
typename std::enable_if<detail::is_nested_tensor_v<Right>>::type* =
nullptr>
decltype(auto) mult(const Right& right) const {
return binary(
right,
[](const value_type& l, const value_t<Right>& r) -> decltype(auto) {
return l * r;
});

auto mult_op =[](const value_type& l, const value_t<Right>& r) -> decltype(auto) {
return l * r;
};

if (empty() || right.empty()) {
using res_t = decltype(std::declval<Tensor>().binary(std::declval<Right>(), mult_op));
return res_t{};
}

return binary(right, mult_op);
}

/// Multiply this by \c right to create a new, permuted tensor
Expand Down
8 changes: 4 additions & 4 deletions src/TiledArray/tile_op/contract_reduce.h
Original file line number Diff line number Diff line change
Expand Up @@ -326,17 +326,17 @@ class ContractReduce : public ContractReduceBase<Result, Left, Right, Scalar> {
/// \param[in] right The right-hand tile to be contracted
void operator()(result_type& result, const first_argument_type& left,
const second_argument_type& right) const {
using TiledArray::empty;
using TiledArray::gemm;
if (empty(left) || empty(right)) return;

if constexpr (!ContractReduceBase_::plain_tensors) {
TA_ASSERT(this->elem_muladd_op());
// not yet implemented
using TiledArray::empty;
using TiledArray::gemm;
gemm(result, left, right, ContractReduceBase_::gemm_helper(),
this->elem_muladd_op());
} else { // plain tensors
TA_ASSERT(!this->elem_muladd_op());
using TiledArray::empty;
using TiledArray::gemm;
if (empty(result))
result = gemm(left, right, ContractReduceBase_::factor(),
ContractReduceBase_::gemm_helper());
Expand Down

0 comments on commit c6e9490

Please sign in to comment.