Skip to content

Commit

Permalink
fixed assignment to block expression from an expression with nonzero …
Browse files Browse the repository at this point in the history
…base
  • Loading branch information
evaleev committed Sep 4, 2024
1 parent 7f687b3 commit 65b8520
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 3 deletions.
13 changes: 11 additions & 2 deletions src/TiledArray/expressions/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@

#include <TiledArray/tensor/type_traits.h>

#include <range/v3/to_container.hpp>
#include <range/v3/view/zip_with.hpp>

namespace TiledArray::expressions {

template <typename Engine>
Expand Down Expand Up @@ -509,8 +512,14 @@ class Expr {
if (tsr.array().trange().tiles_range().volume() != 0) {
// N.B. must deep copy
TA_ASSERT(tsr.array().trange().tiles_range().includes(tsr.lower_bound()));
const container::svector<long> shift =
tsr.array().trange().make_tile_range(tsr.lower_bound()).lobound();
// N.B. this expression's range,
// dist_eval.trange().elements_range().lobound(), may not be zero!
const auto shift =
ranges::views::zip_with(
[](auto a, auto b) { return a - b; },
tsr.array().trange().make_tile_range(tsr.lower_bound()).lobound(),
dist_eval.trange().elements_range().lobound()) |
ranges::to<container::svector<long>>();

std::shared_ptr<op_type> shift_op =
std::make_shared<op_type>(shift_op_type(shift));
Expand Down
19 changes: 19 additions & 0 deletions tests/expressions_fixture.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ struct ExpressionsFixture : public TiledRangeFixture {
ExpressionsFixture()
: s_tr_1(make_random_sparseshape(tr)),
s_tr_2(make_random_sparseshape(tr)),
s_tr_base1_1(make_random_sparseshape(tr_base1)),
s_tr_base1_2(make_random_sparseshape(tr_base1)),
s_tr1_1(make_random_sparseshape(trange1)),
s_tr1_2(make_random_sparseshape(trange1)),
s_tr2(make_random_sparseshape(trange2)),
Expand All @@ -65,19 +67,26 @@ struct ExpressionsFixture : public TiledRangeFixture {
a(*GlobalFixture::world, tr, s_tr_1),
b(*GlobalFixture::world, tr, s_tr_2),
c(*GlobalFixture::world, tr, s_tr_2),
a_base1(*GlobalFixture::world, tr_base1, s_tr_base1_1),
b_base1(*GlobalFixture::world, tr_base1, s_tr_base1_2),
c_base1(*GlobalFixture::world, tr_base1, s_tr_base1_2),
aC(*GlobalFixture::world, trangeC, s_trC),
aC_f(*GlobalFixture::world, trangeC_f, s_trC_f),
u(*GlobalFixture::world, trange1, s_tr1_1),
v(*GlobalFixture::world, trange1, s_tr1_2),
w(*GlobalFixture::world, trange2, s_tr2) {
random_fill(a);
random_fill(b);
random_fill(a_base1);
random_fill(b_base1);
random_fill(u);
random_fill(v);
random_fill(aC);
GlobalFixture::world->gop.fence();
a.truncate();
b.truncate();
a_base1.truncate();
b_base1.truncate();
u.truncate();
v.truncate();
}
Expand All @@ -89,13 +98,18 @@ struct ExpressionsFixture : public TiledRangeFixture {
: a(*GlobalFixture::world, tr),
b(*GlobalFixture::world, tr),
c(*GlobalFixture::world, tr),
a_base1(*GlobalFixture::world, tr_base1),
b_base1(*GlobalFixture::world, tr_base1),
c_base1(*GlobalFixture::world, tr_base1),
u(*GlobalFixture::world, trange1),
v(*GlobalFixture::world, trange1),
w(*GlobalFixture::world, trange2),
aC(*GlobalFixture::world, trangeC),
aC_f(*GlobalFixture::world, trangeC_f) {
random_fill(a);
random_fill(b);
random_fill(a_base1);
random_fill(b_base1);
random_fill(u);
random_fill(v);
random_fill(aC);
Expand Down Expand Up @@ -229,6 +243,8 @@ struct ExpressionsFixture : public TiledRangeFixture {

SparseShape<float> s_tr_1;
SparseShape<float> s_tr_2;
SparseShape<float> s_tr_base1_1;
SparseShape<float> s_tr_base1_2;
SparseShape<float> s_tr1_1;
SparseShape<float> s_tr1_2;
SparseShape<float> s_tr2;
Expand All @@ -237,6 +253,9 @@ struct ExpressionsFixture : public TiledRangeFixture {
TArray a;
TArray b;
TArray c;
TArray a_base1;
TArray b_base1;
TArray c_base1;
TArray u;
TArray v;
TArray w;
Expand Down
29 changes: 29 additions & 0 deletions tests/expressions_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ BOOST_FIXTURE_TEST_CASE_TEMPLATE(tensor_factories, F, Fixtures, F) {
auto& a = F::a;
auto& c = F::c;
auto& aC = F::aC;
auto& a_base1 = F::a_base1;

const auto& ca = a;
const std::array<int, 3> lobound{{3, 3, 3}};
Expand Down Expand Up @@ -66,6 +67,8 @@ BOOST_FIXTURE_TEST_CASE_TEMPLATE(tensor_factories, F, Fixtures, F) {
BOOST_CHECK_NO_THROW(c("a,b,c") =
ca("a,b,c").block(iv(3, 3, 3), iv(5, 5, 5)));

BOOST_CHECK_NO_THROW(c("a,b,c") = a_base1("a,b,c").block(lobound, upbound));

// make sure that c("abc") = a("abc") does a deep copy
{
BOOST_CHECK_NO_THROW(c("a,b,c") = a("a, b, c"));
Expand Down Expand Up @@ -291,6 +294,7 @@ BOOST_FIXTURE_TEST_CASE_TEMPLATE(block, F, Fixtures, F) {
auto& a = F::a;
auto& b = F::b;
auto& c = F::c;
auto& a_base1 = F::a_base1;

BlockRange block_range(a.trange().tiles_range(), {3, 3, 3}, {5, 5, 5});

Expand Down Expand Up @@ -683,6 +687,31 @@ BOOST_FIXTURE_TEST_CASE_TEMPLATE(assign_subblock_block, F, Fixtures, F) {
}
}

BOOST_FIXTURE_TEST_CASE_TEMPLATE(assign_subblock_block_base1, F, Fixtures, F) {
auto& a = F::a;
auto& b = F::b;
auto& c = F::c;
auto& a_base1 = F::a_base1;
auto& c_base1 = F::c_base1;
auto& ntiles = F::ntiles;

c.fill_local(0.0);
c_base1.fill_local(0.0);

BOOST_REQUIRE_NO_THROW(c("a,b,c").block({3, 3, 3}, {5, 5, 5}) =
a_base1("a,b,c").block({3, 3, 3}, {5, 5, 5}));
BOOST_REQUIRE(tile_ranges_match_trange(c));
BOOST_REQUIRE_NO_THROW(c_base1("a,b,c").block({3, 3, 3}, {5, 5, 5}) =
a("a,b,c").block({3, 3, 3}, {5, 5, 5}));
BOOST_REQUIRE(tile_ranges_match_trange(c_base1));
BOOST_REQUIRE_NO_THROW(c("a,b,c").block({0, 0, 0}, {ntiles, ntiles, ntiles}) =
a_base1("a,b,c"));
BOOST_REQUIRE(tile_ranges_match_trange(c));
BOOST_REQUIRE_NO_THROW(
c_base1("a,b,c").block({0, 0, 0}, {ntiles, ntiles, ntiles}) = a("a,b,c"));
BOOST_REQUIRE(tile_ranges_match_trange(c_base1));
}

BOOST_FIXTURE_TEST_CASE_TEMPLATE(assign_subblock_permute_block, F, Fixtures,
F) {
auto& a = F::a;
Expand Down
2 changes: 1 addition & 1 deletion tests/range_fixture.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ struct RangeFixture {

struct Range1Fixture {
using index1_type = Range1::index1_type;
static const size_t ntiles = 5;
static const inline size_t ntiles = 5;

Range1Fixture()
: tr1_hashmarks(make_hashmarks<ntiles + 1>()),
Expand Down

0 comments on commit 65b8520

Please sign in to comment.