Skip to content

Commit

Permalink
Implemented TA::detail::retile_v2, the nearly-optimal version of elem…
Browse files Browse the repository at this point in the history
…ent-level reranging (retiling, etc.)
  • Loading branch information
evaleev committed Aug 26, 2024
1 parent a9633fc commit 569b0bf
Show file tree
Hide file tree
Showing 2 changed files with 176 additions and 20 deletions.
156 changes: 153 additions & 3 deletions src/TiledArray/conversions/retile.h
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,156 @@ auto retile_v1(const DistArray<Tile, Policy>& tensor,
return output;
}

template <typename Tile, typename Policy>
void write_tile_block(madness::uniqueidT target_array_id,
std::size_t target_tile_ord,
const Tile& target_tile_contribution) {
auto* world_ptr = World::world_from_id(target_array_id.get_world_id());
auto target_array_ptr_opt = world_ptr->ptr_from_id<
typename DistArray<Tile, Policy>::impl_type::storage_type>(
target_array_id);
TA_ASSERT(target_array_ptr_opt);
TA_ASSERT((*target_array_ptr_opt)->is_local(target_tile_ord));
(*target_array_ptr_opt)
->get_local(target_tile_ord)
.get()
.block(target_tile_contribution.range()) = target_tile_contribution;
}

template <typename Tile, typename Policy>
auto retile_v2(const DistArray<Tile, Policy>& source_array,
const TiledRange& target_trange) {
auto& world = source_array.world();
const auto rank = source_array.trange().rank();
TA_ASSERT(rank == target_trange.rank());

// compute metadata
// - list of target tile indices and the corresponding Range1 for each 1-d
// source tile
using target_tiles_t = std::vector<std::pair<TA_1INDEX_TYPE, Range1>>;
using mode_target_tiles_t = std::vector<target_tiles_t>;
using all_target_tiles_t = std::vector<mode_target_tiles_t>;

all_target_tiles_t all_target_tiles(target_trange.rank());
// for each mode ...
for (auto d = 0; d != target_trange.rank(); ++d) {
mode_target_tiles_t& mode_target_tiles = all_target_tiles[d];
auto& target_tr1 = target_trange.dim(d);
auto& target_element_range = target_tr1.elements_range();
// ... and each tile in that mode ...
for (auto&& source_tile : source_array.trange().dim(d)) {
mode_target_tiles.emplace_back();
auto& target_tiles = mode_target_tiles.back();
auto source_tile_lo = source_tile.lobound();
auto source_tile_up = source_tile.upbound();
auto source_element_idx = source_tile_lo;
// ... find all target tiles what overlap with it
if (target_element_range.overlaps_with(source_tile)) {
while (source_element_idx < source_tile_up) {
if (target_element_range.includes(source_element_idx)) {
auto target_tile_idx =
target_tr1.element_to_tile(source_element_idx);
auto target_tile = target_tr1.tile(target_tile_idx);
auto target_lo =
std::max(source_element_idx, target_tile.lobound());
auto target_up = std::min(source_tile_up, target_tile.upbound());
target_tiles.emplace_back(target_tile_idx,
Range1(target_lo, target_up));
source_element_idx = target_up;
} else if (source_element_idx < target_element_range.lobound()) {
source_element_idx = target_element_range.lobound();
} else if (source_element_idx >= target_element_range.upbound())
break;
}
}
}
}

// estimate the shape, if sparse
// use max value for each nonzero tile, then will recompute after tiles are
// assigned
using shape_type = typename Policy::shape_type;
shape_type target_shape;
const auto& target_tiles_range = target_trange.tiles_range();
if constexpr (!is_dense_v<Policy>) {
// each rank computes contributions to the shape norms from its local tiles
Tensor<float> target_shape_norms(target_tiles_range, 0);
auto& source_trange = source_array.trange();
const auto e = source_array.end();
for (auto it = source_array.begin(); it != e; ++it) {
auto source_tile_idx = it.index();

// make range for iterating over all possible target tile idx combinations
TA::Index target_tile_ord_extent_range(rank);
for (auto d = 0; d != rank; ++d) {
target_tile_ord_extent_range[d] =
all_target_tiles[d][source_tile_idx[d]].size();
}

// loop over every target tile combination
TA::Range target_tile_ord_extent(target_tile_ord_extent_range);
for (auto& target_tile_ord : target_tile_ord_extent) {
TA::Index target_tile_idx(rank);
for (auto d = 0; d != rank; ++d) {
target_tile_idx[d] =
all_target_tiles[d][source_tile_idx[d]][target_tile_ord[d]].first;
}
target_shape_norms(target_tile_idx) = std::numeric_limits<float>::max();
}
}
world.gop.max(target_shape_norms.data(), target_shape_norms.size());
target_shape = SparseShape(target_shape_norms, target_trange);
}

using Array = DistArray<Tile, Policy>;
Array target_array(source_array.world(), target_trange, target_shape);
target_array.fill_local(0.0);
target_array.world().gop.fence();

// loop over local tile and sends its contributions to the targets
{
auto& source_trange = source_array.trange();
const auto e = source_array.end();
auto& target_tiles_range = target_trange.tiles_range();
for (auto it = source_array.begin(); it != e; ++it) {
const auto& source_tile = *it;
auto source_tile_idx = it.index();

// make range for iterating over all possible target tile idx combinations
TA::Index target_tile_ord_extent_range(rank);
for (auto d = 0; d != rank; ++d) {
target_tile_ord_extent_range[d] =
all_target_tiles[d][source_tile_idx[d]].size();
}

// loop over every target tile combination
TA::Range target_tile_ord_extent(target_tile_ord_extent_range);
for (auto& target_tile_ord : target_tile_ord_extent) {
TA::Index target_tile_idx(rank);
container::svector<TA::Range1> target_tile_rngs1(rank);
for (auto d = 0; d != rank; ++d) {
std::tie(target_tile_idx[d], target_tile_rngs1[d]) =
all_target_tiles[d][source_tile_idx[d]][target_tile_ord[d]];
}
TA_ASSERT(source_tile.future().probe());
Tile target_tile_contribution(
source_tile.get().block(target_tile_rngs1));
auto target_tile_idx_ord = target_tiles_range.ordinal(target_tile_idx);
auto target_proc = target_array.pmap()->owner(target_tile_idx_ord);
world.taskq.add(target_proc, &write_tile_block<Tile, Policy>,
target_array.id(), target_tile_idx_ord,
target_tile_contribution);
}
}
}
// data is mutated in place, so must wait for all tasks to complete
target_array.world().gop.fence();
// recompute norms/trim away zeros
target_array.truncate();

return target_array;
}

} // namespace detail

/// Creates a new DistArray with the same data as the input tensor, but with a
Expand All @@ -154,11 +304,11 @@ auto retile_v1(const DistArray<Tile, Policy>& tensor,
/// well as increasing the element range (with the new elements initialized to
/// zero)
/// \param array The DistArray whose data is to be retiled
/// \param new_trange The desired TiledRange of the output tensor
/// \param target_trange The desired TiledRange of the output tensor
template <typename Tile, typename Policy>
auto retile(const DistArray<Tile, Policy>& array,
const TiledRange& new_trange) {
return detail::retile_v0(array, new_trange);
const TiledRange& target_trange) {
return detail::retile_v0(array, target_trange);
}

} // namespace TiledArray
Expand Down
40 changes: 23 additions & 17 deletions tests/expressions_mixed.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -186,27 +186,33 @@ BOOST_AUTO_TEST_CASE(kronecker) {
TSpArrayD x(*GlobalFixture::world, trange2);
random_fill(x);

TA::TiledRange yrange{{5, 18}, {7, 20}};
TA::TiledRange retiler_range{yrange.dim(0), yrange.dim(1), trange2.dim(0),
trange2.dim(1)};
SpArrayKronDelta retiler(
*GlobalFixture::world, retiler_range,
SparseShape(detail::kronecker_shape(retiler_range), retiler_range),
std::make_shared<detail::ReplicatedPmap>(
*GlobalFixture::world, retiler_range.tiles_range().volume()));
init_kronecker_delta(retiler);

TA::TSpArrayD y;
// includes target tiles that receive contributions from multiple source
// tiles, tiny target tiles with single contribution, and tiles partially and
// completely outside the source range N.B. retile_v0 seems to struggle with
// completely empty tiles (e.g. add 47 to each 1-d range)
TA::TiledRange yrange{{5, 18, 20, 45}, {7, 20, 22, 45}};
TA::TSpArrayD y1;
// TA::TiledRange retiler_range{yrange.dim(0), yrange.dim(1), trange2.dim(0),
// trange2.dim(1)};
// SpArrayKronDelta retiler(
// *GlobalFixture::world, retiler_range,
// SparseShape(detail::kronecker_shape(retiler_range), retiler_range),
// std::make_shared<detail::ReplicatedPmap>(
// *GlobalFixture::world, retiler_range.tiles_range().volume()));
// init_kronecker_delta(retiler);
// y("d1,d2") = retiler("d1,d2,s1,s2") * x("s1,s2");
y = TA::detail::retile_v1(x, yrange);
// std::cout << "y = " << y << std::endl;
// why deadlock without this?
y.world().gop.fence();
y1 = TA::detail::retile_v1(x, yrange);
// std::cout << "y1 = " << y1 << std::endl;
// why deadlock without this?
y1.world().gop.fence();

TA::TSpArrayD y_ref = TA::retile(x, yrange);
TA::TSpArrayD y_ref = TA::detail::retile_v0(x, yrange);
// std::cout << "y_ref = " << y_ref << std::endl;
BOOST_CHECK((y1("d1,d2") - y_ref("d1,d2")).norm().get() == 0.);

BOOST_CHECK((y("d1,d2") - y_ref("d1,d2")).norm().get() == 0.);
auto y2 = TA::detail::retile_v2(x, yrange);
// std::cout << "y2 = " << y2 << std::endl;
BOOST_CHECK((y2("d1,d2") - y_ref("d1,d2")).norm().get() == 0.);
}

BOOST_AUTO_TEST_SUITE_END()

0 comments on commit 569b0bf

Please sign in to comment.