Skip to content

Commit

Permalink
Completes TA::retile support for DistArray with tensor-of-tensor …
Browse files Browse the repository at this point in the history
…tiles.
  • Loading branch information
bimalgaudel committed Sep 18, 2024
1 parent ede81f3 commit 342dd25
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 42 deletions.
13 changes: 11 additions & 2 deletions src/TiledArray/tensor/kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -599,8 +599,17 @@ inline void inplace_tensor_op(Op&& op, TR& result, const Ts&... tensors) {
[&op, stride](
typename TR::pointer MADNESS_RESTRICT const result_data,
typename Ts::const_pointer MADNESS_RESTRICT const... tensors_data) {
for (decltype(result.range().volume()) i = 0ul; i < stride; ++i)
inplace_tensor_op(op, result_data[i], tensors_data[i]...);
for (decltype(result.range().volume()) i = 0ul; i < stride; ++i) {
if constexpr (std::is_invocable_v<
std::remove_reference_t<Op>,
typename std::remove_reference_t<TR>::value_type&,
typename std::remove_reference_t<
Ts>::value_type const&...>) {
std::forward<Op>(op)(result_data[i], tensors_data[i]...);
} else {
inplace_tensor_op(op, result_data[i], tensors_data[i]...);
}
}
};

for (std::decay_t<decltype(volume)> ord = 0ul; ord < volume; ord += stride)
Expand Down
15 changes: 3 additions & 12 deletions src/TiledArray/tensor/tensor_interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -198,18 +198,9 @@ class TensorInterface {
TA_ASSERT(data_ != other.data());
}

if constexpr (detail::is_tensor_v<value_type>) {
range_ = BlockRange(other.range(), other.range().lobound(),
other.range().upbound());
data_ = new value_type[other.total_size()];
auto cpy = other.clone();
for (auto i = 0; i < other.total_size(); ++i)
std::swap(data_[i], cpy.data()[i]);
} else {
detail::inplace_tensor_op([](numeric_type& MADNESS_RESTRICT result,
const numeric_t<T1> arg) { result = arg; },
*this, other);
}
detail::inplace_tensor_op(
[](value_type& MADNESS_RESTRICT result, auto&& arg) { result = arg; },
*this, other);

return *this;
}
Expand Down
65 changes: 37 additions & 28 deletions tests/retile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,51 +39,60 @@ BOOST_AUTO_TEST_CASE(retile_more) {

auto const tr_source = TA::TiledRange({{0, 2, 4, 8}, {0, 3, 5}});
auto const tr_target = TA::TiledRange({{0, 4, 6, 8}, {0, 2, 4, 5}});
auto const& elem_rng = tr_source.elements_range();

BOOST_REQUIRE(elem_rng.volume() == tr_target.elements_range().volume());

auto rand_num = [](auto&&) {
return TA::detail::MakeRandom<Numeric>::generate_value();
auto const inner_rng = TA::Range({3, 3});

auto rand_tensor = [](auto const& rng) -> T {
return T(rng, [](auto&&) {
return TA::detail::MakeRandom<Numeric>::generate_value();
});
};

auto rand_tensor = [rand_num](auto const& rng) -> T {
return T(rng, rand_num);
auto set_random_tensor_tile = [rand_tensor](auto& tile, auto const& rng) {
tile = rand_tensor(rng);
return tile.norm();
};

auto rand_tensor_of_tensor = [rand_tensor](auto const& inner_rng) {
return [rand_tensor, inner_rng](auto const& rng) -> ToT {
return ToT(rng, rand_tensor(inner_rng));
};
auto rand_tensor_of_tensor = [rand_tensor,
inner_rng](auto const& rng) -> ToT {
return ToT(rng, [rand_tensor, inner_rng](auto&&) {
return rand_tensor(inner_rng);
});
};

auto set_random_tensor_of_tensor_tile = [rand_tensor_of_tensor](
auto const& inner_rng) {
return
[gen = rand_tensor_of_tensor(inner_rng)](auto& tile, auto const& rng) {
tile = gen(rng);
return tile.norm();
};
auto& tile, auto const& rng) {
tile = rand_tensor_of_tensor(rng);
return tile.norm();
};

auto const inner_rng = TA::Range({3, 3});
auto arr_source = TA::make_array<ArrayToT>(
world, tr_source, set_random_tensor_of_tensor_tile(inner_rng));
arr_source.truncate();
auto get_elem = [](auto const& arr, auto const& eix) {
auto tix = arr.trange().element_to_tile(eix);
auto&& tile = arr.find(tix).get(false);
return tile(eix);
};

auto arr_source0 =
TA::make_array<ArrayT>(world, tr_source, set_random_tensor_tile);
auto arr_target0 = TA::retile(arr_source0, tr_target);

for (auto&& eix : elem_rng) {
BOOST_REQUIRE(get_elem(arr_source0, eix) == get_elem(arr_target0, eix));
}

auto arr_source = TA::make_array<ArrayToT>(world, tr_source,
set_random_tensor_of_tensor_tile);
auto arr_target = TA::retile(arr_source, tr_target);

arr_source.make_replicated();
world.gop.fence();
arr_target.make_replicated();
arr_source.truncate();
arr_target.truncate();
world.gop.fence();

auto const& elem_rng = tr_source.elements_range();
BOOST_REQUIRE(elem_rng.volume() == tr_target.elements_range().volume());

auto get_elem = [](ArrayToT const& arr, auto const& eix) {
auto tix = arr.trange().element_to_tile(eix);
auto&& tile = arr.find(tix).get(false);
return tile(eix);
};

for (auto&& eix : elem_rng) {
BOOST_REQUIRE(get_elem(arr_source, eix) == get_elem(arr_target, eix));
}
Expand Down

0 comments on commit 342dd25

Please sign in to comment.