Skip to content

Commit

Permalink
[unit] retile_suite/retile_more skip zero tiles
Browse files Browse the repository at this point in the history
  • Loading branch information
evaleev committed Nov 10, 2024
1 parent c6e9490 commit bc69ec5
Showing 1 changed file with 29 additions and 25 deletions.
54 changes: 29 additions & 25 deletions tests/retile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,26 +6,24 @@
BOOST_AUTO_TEST_SUITE(retile_suite)

BOOST_AUTO_TEST_CASE(retile_tensor) {
TA::detail::matrix_il<double> some_values = {
{0.1, 0.2, 0.3, 0.4, 0.5},
{0.6, 0.7, 0.8, 0.9, 1.0},
{1.1, 1.2, 1.3, 1.4, 1.5},
{1.6, 1.7, 1.8, 1.9, 2.0},
{2.1, 2.2, 2.3, 2.4, 2.5}
};

auto range0 = TA::TiledRange1(0, 3, 5);
auto range1 = TA::TiledRange1(0, 4, 5);
auto trange = TA::TiledRange({range0, range1});

TA::TArrayD default_dense(*GlobalFixture::world, some_values);
TA::TSpArrayD default_sparse(*GlobalFixture::world, some_values);

auto result_dense = retile(default_dense, trange);
auto result_sparse = retile(default_sparse, trange);

BOOST_CHECK_EQUAL(result_dense.trange(), trange);
BOOST_CHECK_EQUAL(result_sparse.trange(), trange);
TA::detail::matrix_il<double> some_values = {{0.1, 0.2, 0.3, 0.4, 0.5},
{0.6, 0.7, 0.8, 0.9, 1.0},
{1.1, 1.2, 1.3, 1.4, 1.5},
{1.6, 1.7, 1.8, 1.9, 2.0},
{2.1, 2.2, 2.3, 2.4, 2.5}};

auto range0 = TA::TiledRange1(0, 3, 5);
auto range1 = TA::TiledRange1(0, 4, 5);
auto trange = TA::TiledRange({range0, range1});

TA::TArrayD default_dense(*GlobalFixture::world, some_values);
TA::TSpArrayD default_sparse(*GlobalFixture::world, some_values);

auto result_dense = retile(default_dense, trange);
auto result_sparse = retile(default_sparse, trange);

BOOST_CHECK_EQUAL(result_dense.trange(), trange);
BOOST_CHECK_EQUAL(result_sparse.trange(), trange);
}

BOOST_AUTO_TEST_CASE(retile_more) {
Expand Down Expand Up @@ -69,17 +67,20 @@ BOOST_AUTO_TEST_CASE(retile_more) {
return tile.norm();
};

auto arr_source0 =
TA::make_array<ArrayT>(world, tr_source, set_random_tensor_tile);
auto arr_target0 = TA::retile(arr_source0, tr_target);

auto get_elem = [](auto const& arr, auto const& eix) {
auto tix = arr.trange().element_to_tile(eix);
auto&& tile = arr.find(tix).get(false);
return tile(eix);
};

auto arr_source0 =
TA::make_array<ArrayT>(world, tr_source, set_random_tensor_tile);
auto arr_target0 = TA::retile(arr_source0, tr_target);

for (auto&& eix : elem_rng) {
auto tix = arr_source0.trange().element_to_tile(eix);
BOOST_REQUIRE(arr_source0.is_zero(tix) == arr_target0.is_zero(tix));
if (arr_source0.is_zero(tix)) continue;
BOOST_REQUIRE(get_elem(arr_source0, eix) == get_elem(arr_target0, eix));
}

Expand All @@ -94,8 +95,11 @@ BOOST_AUTO_TEST_CASE(retile_more) {
world.gop.fence();

for (auto&& eix : elem_rng) {
auto tix = arr_source.trange().element_to_tile(eix);
BOOST_REQUIRE(arr_source.is_zero(tix) == arr_target.is_zero(tix));
if (arr_source.is_zero(tix)) continue;
BOOST_REQUIRE(get_elem(arr_source, eix) == get_elem(arr_target, eix));
}
}

BOOST_AUTO_TEST_SUITE_END()
BOOST_AUTO_TEST_SUITE_END()

0 comments on commit bc69ec5

Please sign in to comment.