Skip to content

Commit

Permalink
compound assignment operator expressions (+=, -=, *=) work with null …
Browse files Browse the repository at this point in the history
…DistArrays
  • Loading branch information
evaleev committed Sep 30, 2024
1 parent ec51edb commit d04fb08
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 3 deletions.
18 changes: 15 additions & 3 deletions src/TiledArray/expressions/tsr_expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,11 @@ class TsrExpr : public Expr<TsrExpr<Array, Alias>> {
TiledArray::expressions::is_aliased<D>::value,
"no_alias() expressions are not allowed on the right-hand side of "
"the assignment operator.");
return operator=(AddExpr<TsrExpr_, D>(*this, other.derived()));
if (this->array().is_initialized()) {
return operator=(AddExpr<TsrExpr_, D>(*this, other.derived()));
} else {
return operator=(other);
}
}

/// Expression minus-assignment operator
Expand All @@ -160,7 +164,11 @@ class TsrExpr : public Expr<TsrExpr<Array, Alias>> {
TiledArray::expressions::is_aliased<D>::value,
"no_alias() expressions are not allowed on the right-hand side of "
"the assignment operator.");
return operator=(SubtExpr<TsrExpr_, D>(*this, other.derived()));
if (this->array().is_initialized()) {
return operator=(SubtExpr<TsrExpr_, D>(*this, other.derived()));
} else {
return operator=(-1 * other.derived());
}
}

/// Expression multiply-assignment operator
Expand All @@ -173,7 +181,11 @@ class TsrExpr : public Expr<TsrExpr<Array, Alias>> {
TiledArray::expressions::is_aliased<D>::value,
"no_alias() expressions are not allowed on the right-hand side of "
"the assignment operator.");
return operator=(MultExpr<TsrExpr_, D>(*this, other.derived()));
if (this->array().is_initialized()) {
return operator=(MultExpr<TsrExpr_, D>(*this, other.derived()));
} else {
return operator=(0 * other.derived());
}
}

/// Array accessor
Expand Down
42 changes: 42 additions & 0 deletions tests/expressions_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1049,6 +1049,23 @@ BOOST_FIXTURE_TEST_CASE_TEMPLATE(add_to, F, Fixtures, F) {
BOOST_CHECK(a.is_zero(i) && b.is_zero(i));
}
}

// try += with null this
std::decay_t<decltype(c)> d;
BOOST_REQUIRE_NO_THROW(d("a,b,c") += b("a,b,c"));

for (std::size_t i = 0ul; i < d.size(); ++i) {
if (!d.is_zero(i)) {
BOOST_CHECK(!b.is_zero(i));
auto b_tile = b.find(i).get();
auto d_tile = d.find(i).get();

for (std::size_t j = 0ul; j < d_tile.size(); ++j)
BOOST_CHECK_EQUAL(d_tile[j], b_tile[j]);
} else {
BOOST_CHECK(b.is_zero(i));
}
}
}

BOOST_FIXTURE_TEST_CASE_TEMPLATE(add_permute, F, Fixtures, F) {
Expand Down Expand Up @@ -1350,6 +1367,23 @@ BOOST_FIXTURE_TEST_CASE_TEMPLATE(subt_to, F, Fixtures, F) {
BOOST_CHECK(a.is_zero(i) && b.is_zero(i));
}
}

// try -= with null this
std::decay_t<decltype(c)> d;
BOOST_REQUIRE_NO_THROW(d("a,b,c") -= b("a,b,c"));

for (std::size_t i = 0ul; i < d.size(); ++i) {
if (!d.is_zero(i)) {
BOOST_CHECK(!b.is_zero(i));
auto b_tile = b.find(i).get();
auto d_tile = d.find(i).get();

for (std::size_t j = 0ul; j < d_tile.size(); ++j)
BOOST_CHECK_EQUAL(d_tile[j], -b_tile[j]);
} else {
BOOST_CHECK(b.is_zero(i));
}
}
}

BOOST_FIXTURE_TEST_CASE_TEMPLATE(sub_permute, F, Fixtures, F) {
Expand Down Expand Up @@ -1679,6 +1713,14 @@ BOOST_FIXTURE_TEST_CASE_TEMPLATE(mult_to, F, Fixtures, F) {
BOOST_CHECK(a.is_zero(i) || b.is_zero(i));
}
}

// try *= with null this
std::decay_t<decltype(c)> d;
BOOST_REQUIRE_NO_THROW(d("a,b,c") *= b("a,b,c"));

for (std::size_t i = 0ul; i < d.size(); ++i) {
BOOST_CHECK(d.is_zero(i));
}
}

BOOST_FIXTURE_TEST_CASE_TEMPLATE(scale_mult, F, Fixtures, F) {
Expand Down

0 comments on commit d04fb08

Please sign in to comment.