Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support non-zero ToTs with some zero inner Ts #492

Merged
merged 2 commits into from
Nov 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions src/TiledArray/einsum/tiledarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -485,8 +485,13 @@ auto einsum(expressions::TsrExpr<ArrayA_> A, expressions::TsrExpr<ArrayB_> B,
// Step IV: C2(ijpq) -> C(ipjq)

auto sum_tot_2_tos = [](auto const &tot) {
typename std::remove_reference_t<decltype(tot)>::value_type result(
tot.range(), [tot](auto &&ix) { return tot(ix).sum(); });
using tot_t = std::remove_reference_t<decltype(tot)>;
typename tot_t::value_type result(
tot.range(), [tot](auto &&ix) {
if (!tot(ix).empty())
return tot(ix).sum();
else return typename tot_t::numeric_type{};
});
return result;
};

Expand Down
3 changes: 2 additions & 1 deletion src/TiledArray/expressions/cont_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -513,7 +513,8 @@ class ContEngine : public BinaryEngine<Derived> {
const left_tile_element_type& left,
const right_tile_element_type& right) {
contrreduce_op(result, left, right);
result = contrreduce_op(result); // permutations of result are applied as "postprocessing"
if (!TA::empty(result))
result = contrreduce_op(result); // permutations of result are applied as "postprocessing"
};
} // ToT x ToT
} else if (inner_prod == TensorProduct::Hadamard) {
Expand Down
2 changes: 2 additions & 0 deletions src/TiledArray/tensor/kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -996,6 +996,8 @@ auto tensor_reduce(ReduceOp&& reduce_op, JoinOp&& join_op,

auto result = identity;
for (std::remove_cv_t<decltype(volume)> ord = 0ul; ord < volume; ++ord) {
if (tensor1.data()[ord].range().volume() == 0
|| ((tensors.data()[ord].range().volume() == 0) || ...)) continue;
auto temp = tensor_reduce(reduce_op, join_op, identity, tensor1.data()[ord],
tensors.data()[ord]...);
join_op(result, temp);
Expand Down
36 changes: 27 additions & 9 deletions src/TiledArray/tensor/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,8 @@ class Tensor {
auto volume = total_size();
for (decltype(volume) i = 0; i < volume; ++i) {
auto& el = *(data() + i);
el = p(el, inner_perm);
if (!el.empty())
el = p(el, inner_perm);
}
}
}
Expand Down Expand Up @@ -588,9 +589,13 @@ class Tensor {
Tensor clone() const {
Tensor result;
if (data_) {
result = detail::tensor_op<Tensor>(
[](const numeric_type value) -> numeric_type { return value; },
*this);
if constexpr (detail::is_tensor_of_tensor_v<Tensor>) {
result = Tensor(*this, [](value_type const& el) { return el.clone(); });
} else {
result = detail::tensor_op<Tensor>(
[](const numeric_type value) -> numeric_type { return value; },
*this);
}
} else if (range_) { // corner case: data_ = null implies range_.volume()
// == 0;
TA_ASSERT(range_.volume() == 0);
Expand Down Expand Up @@ -1538,6 +1543,7 @@ class Tensor {
detail::is_bipartite_permutation_v<Perm>;
// tile ops pass bipartite permutations here even if this is a plain tensor
if constexpr (!is_tot) {
if (empty()) return *this;
if constexpr (is_bperm) {
TA_ASSERT(inner_size(perm) == 0); // ensure this is a plain permutation
return Tensor(*this, op, outer(std::forward<Perm>(perm)));
Expand Down Expand Up @@ -1574,6 +1580,7 @@ class Tensor {
template <typename Scalar, typename std::enable_if<
detail::is_numeric_v<Scalar>>::type* = nullptr>
Tensor scale(const Scalar factor) const {
if (range().volume() == 0) return *this;
return unary([factor](const value_type& a) -> decltype(auto) {
using namespace TiledArray::detail;
return a * factor;
Expand Down Expand Up @@ -1626,6 +1633,10 @@ class Tensor {
return binary(
right,
[](const value_type& l, const value_t<Right>& r) -> decltype(auto) {
if constexpr (detail::is_tensor_v<value_type>) {
if (l.empty() && r.empty())
return value_type{};
}
return l + r;
});
}
Expand Down Expand Up @@ -1740,6 +1751,7 @@ class Tensor {
template <typename Right,
typename std::enable_if<is_tensor<Right>::value>::type* = nullptr>
Tensor& add_to(const Right& right) {
if (right.empty()) return *this;
if (empty()) {
*this = Tensor{right.range(), value_type{}};
}
Expand Down Expand Up @@ -1923,11 +1935,17 @@ class Tensor {
typename std::enable_if<detail::is_nested_tensor_v<Right>>::type* =
nullptr>
decltype(auto) mult(const Right& right) const {
return binary(
right,
[](const value_type& l, const value_t<Right>& r) -> decltype(auto) {
return l * r;
});

auto mult_op =[](const value_type& l, const value_t<Right>& r) -> decltype(auto) {
return l * r;
};

if (empty() || right.empty()) {
using res_t = decltype(std::declval<Tensor>().binary(std::declval<Right>(), mult_op));
return res_t{};
}

return binary(right, mult_op);
}

/// Multiply this by \c right to create a new, permuted tensor
Expand Down
8 changes: 4 additions & 4 deletions src/TiledArray/tile_op/contract_reduce.h
Original file line number Diff line number Diff line change
Expand Up @@ -326,17 +326,17 @@ class ContractReduce : public ContractReduceBase<Result, Left, Right, Scalar> {
/// \param[in] right The right-hand tile to be contracted
void operator()(result_type& result, const first_argument_type& left,
const second_argument_type& right) const {
using TiledArray::empty;
using TiledArray::gemm;
if (empty(left) || empty(right)) return;

if constexpr (!ContractReduceBase_::plain_tensors) {
TA_ASSERT(this->elem_muladd_op());
// not yet implemented
using TiledArray::empty;
using TiledArray::gemm;
gemm(result, left, right, ContractReduceBase_::gemm_helper(),
this->elem_muladd_op());
} else { // plain tensors
TA_ASSERT(!this->elem_muladd_op());
using TiledArray::empty;
using TiledArray::gemm;
if (empty(result))
result = gemm(left, right, ContractReduceBase_::factor(),
ContractReduceBase_::gemm_helper());
Expand Down
54 changes: 29 additions & 25 deletions tests/retile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,26 +6,24 @@
BOOST_AUTO_TEST_SUITE(retile_suite)

BOOST_AUTO_TEST_CASE(retile_tensor) {
TA::detail::matrix_il<double> some_values = {
{0.1, 0.2, 0.3, 0.4, 0.5},
{0.6, 0.7, 0.8, 0.9, 1.0},
{1.1, 1.2, 1.3, 1.4, 1.5},
{1.6, 1.7, 1.8, 1.9, 2.0},
{2.1, 2.2, 2.3, 2.4, 2.5}
};

auto range0 = TA::TiledRange1(0, 3, 5);
auto range1 = TA::TiledRange1(0, 4, 5);
auto trange = TA::TiledRange({range0, range1});

TA::TArrayD default_dense(*GlobalFixture::world, some_values);
TA::TSpArrayD default_sparse(*GlobalFixture::world, some_values);

auto result_dense = retile(default_dense, trange);
auto result_sparse = retile(default_sparse, trange);

BOOST_CHECK_EQUAL(result_dense.trange(), trange);
BOOST_CHECK_EQUAL(result_sparse.trange(), trange);
TA::detail::matrix_il<double> some_values = {{0.1, 0.2, 0.3, 0.4, 0.5},
{0.6, 0.7, 0.8, 0.9, 1.0},
{1.1, 1.2, 1.3, 1.4, 1.5},
{1.6, 1.7, 1.8, 1.9, 2.0},
{2.1, 2.2, 2.3, 2.4, 2.5}};

auto range0 = TA::TiledRange1(0, 3, 5);
auto range1 = TA::TiledRange1(0, 4, 5);
auto trange = TA::TiledRange({range0, range1});

TA::TArrayD default_dense(*GlobalFixture::world, some_values);
TA::TSpArrayD default_sparse(*GlobalFixture::world, some_values);

auto result_dense = retile(default_dense, trange);
auto result_sparse = retile(default_sparse, trange);

BOOST_CHECK_EQUAL(result_dense.trange(), trange);
BOOST_CHECK_EQUAL(result_sparse.trange(), trange);
}

BOOST_AUTO_TEST_CASE(retile_more) {
Expand Down Expand Up @@ -69,17 +67,20 @@ BOOST_AUTO_TEST_CASE(retile_more) {
return tile.norm();
};

auto arr_source0 =
TA::make_array<ArrayT>(world, tr_source, set_random_tensor_tile);
auto arr_target0 = TA::retile(arr_source0, tr_target);

auto get_elem = [](auto const& arr, auto const& eix) {
auto tix = arr.trange().element_to_tile(eix);
auto&& tile = arr.find(tix).get(false);
return tile(eix);
};

auto arr_source0 =
TA::make_array<ArrayT>(world, tr_source, set_random_tensor_tile);
auto arr_target0 = TA::retile(arr_source0, tr_target);

for (auto&& eix : elem_rng) {
auto tix = arr_source0.trange().element_to_tile(eix);
BOOST_REQUIRE(arr_source0.is_zero(tix) == arr_target0.is_zero(tix));
if (arr_source0.is_zero(tix)) continue;
BOOST_REQUIRE(get_elem(arr_source0, eix) == get_elem(arr_target0, eix));
}

Expand All @@ -94,8 +95,11 @@ BOOST_AUTO_TEST_CASE(retile_more) {
world.gop.fence();

for (auto&& eix : elem_rng) {
auto tix = arr_source.trange().element_to_tile(eix);
BOOST_REQUIRE(arr_source.is_zero(tix) == arr_target.is_zero(tix));
if (arr_source.is_zero(tix)) continue;
BOOST_REQUIRE(get_elem(arr_source, eix) == get_elem(arr_target, eix));
}
}

BOOST_AUTO_TEST_SUITE_END()
BOOST_AUTO_TEST_SUITE_END()
Loading