Skip to content

Commit

Permalink
Merge pull request #484 from ValeevGroup/evaleev/feature/distarray-in…
Browse files Browse the repository at this point in the history
…it-fence

fix synchronization in collective `DistArray` initializations/transformations
  • Loading branch information
evaleev authored Oct 11, 2024
2 parents 5c2681c + c955339 commit a73a17b
Show file tree
Hide file tree
Showing 6 changed files with 118 additions and 42 deletions.
45 changes: 40 additions & 5 deletions src/TiledArray/array_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,17 @@ std::ostream& operator<<(std::ostream& os, const TileConstReference<Impl>& a) {
return os;
}

/// Callaback used to update counter (typically, task counter)
template <typename AtomicInt>
struct IncrementCounter : public madness::CallbackInterface {
AtomicInt& counter;
IncrementCounter(AtomicInt& counter) : counter(counter) {}
void notify() override {
++counter;
delete this;
}
};

} // namespace detail
} // namespace TiledArray

Expand Down Expand Up @@ -770,20 +781,24 @@ class ArrayImpl : public TensorImpl<Policy>,
/// \tparam Op The type of the functor/function
/// \param[in] op The operation used to generate tiles
/// \param[in] skip_set If false, will throw if any tiles are already set
/// \return the total number of tiles that have been (or will be) initialized
/// \throw TiledArray::Exception if the PIMPL is not set. Strong throw
/// guarantee.
/// \throw TiledArray::Exception if a tile is already set and skip_set is
/// false. Weak throw guarantee.
template <HostExecutor Exec = HostExecutor::Default, typename Op>
void init_tiles(Op&& op, bool skip_set = false) {
template <HostExecutor Exec = HostExecutor::Default, Fence fence = Fence::No,
typename Op>
std::int64_t init_tiles(Op&& op, bool skip_set = false) {
// lifetime management of op depends on whether it is a lvalue ref (i.e. has
// an external owner) or an rvalue ref
// - if op is an lvalue ref: pass op to tasks
// - if op is an rvalue ref pass make_shared_function(op) to tasks
auto op_shared_handle = make_op_shared_handle(std::forward<Op>(op));

std::int64_t ntiles_initialized{0};
auto it = this->pmap()->begin();
const auto end = this->pmap()->end();
std::atomic<std::int64_t> ntask_completed{0};
for (; it != end; ++it) {
const auto& index = *it;
if (!this->is_zero(index)) {
Expand All @@ -792,19 +807,39 @@ class ArrayImpl : public TensorImpl<Policy>,
if (fut.probe()) continue;
}
if constexpr (Exec == HostExecutor::MADWorld) {
Future<value_type> tile = this->world().taskq.add(
[this_sptr = this->shared_from_this(),
index = ordinal_type(index), op_shared_handle]() -> value_type {
Future<value_type> tile =
this->world().taskq.add([this_sptr = this->shared_from_this(),
index = ordinal_type(index),
op_shared_handle, this]() -> value_type {
return op_shared_handle(
this_sptr->trange().make_tile_range(index));
});
++ntiles_initialized;
if constexpr (fence == Fence::Local) {
tile.register_callback(
new IncrementCounter<decltype(ntask_completed)>(
ntask_completed));
}
set(index, std::move(tile));
} else {
static_assert(Exec == HostExecutor::Thread);
set(index, op_shared_handle(this->trange().make_tile_range(index)));
++ntiles_initialized;
}
}
}

if constexpr (fence == Fence::Local) {
if constexpr (Exec == HostExecutor::MADWorld) {
if (ntiles_initialized > 0)
this->world().await([&ntask_completed, ntiles_initialized]() {
return ntask_completed == ntiles_initialized;
});
}
} else if constexpr (fence == Fence::Global) {
this->world().gop.fence();
}
return ntiles_initialized;
}

}; // class ArrayImpl
Expand Down
23 changes: 13 additions & 10 deletions src/TiledArray/conversions/foreach.h
Original file line number Diff line number Diff line change
Expand Up @@ -283,19 +283,17 @@ inline std::
arg.trange().tiles_range(), 0);

// Construct the task function used to construct the result tiles.
madness::AtomicInt counter;
counter = 0;
int task_count = 0;
std::atomic<std::int64_t> ntask_completed{0};
std::int64_t ntask_created{0};
auto op_shared_handle = make_op_shared_handle(std::forward<Op>(op));
const auto task = [op_shared_handle, &counter, &tile_norms](
const auto task = [op_shared_handle, &tile_norms](
const ordinal_type ord,
const_if_t<not inplace, arg_value_type>& arg_tile,
const ArgTiles&... arg_tiles) -> result_value_type {
op_helper<inplace, result_value_type> op_caller;
auto result_tile =
op_caller(std::move(op_shared_handle), tile_norms.at_ordinal(ord),
arg_tile, arg_tiles...);
++counter;
return result_tile;
};

Expand All @@ -310,7 +308,9 @@ inline std::
continue;
auto result_tile =
world.taskq.add(task, ord, arg.find_local(ord), args.find(ord)...);
++task_count;
++ntask_created;
result_tile.register_callback(
new IncrementCounter<decltype(ntask_completed)>(ntask_completed));
tiles.emplace_back(ord, std::move(result_tile));
if (op_returns_void) // if Op does not evaluate norms, use the (scaled)
// norms of the first arg
Expand All @@ -324,7 +324,9 @@ inline std::
auto result_tile =
world.taskq.add(task, ord, detail::get_sparse_tile(ord, arg),
detail::get_sparse_tile(ord, args)...);
++task_count;
++ntask_created;
result_tile.register_callback(
new IncrementCounter<decltype(ntask_completed)>(ntask_completed));
tiles.emplace_back(ord, std::move(result_tile));
if (op_returns_void) // if Op does not evaluate norms, find max
// (scaled) norms of all args
Expand All @@ -339,9 +341,10 @@ inline std::
}

// Wait for tile norm data to be collected.
if (task_count > 0)
world.await(
[&counter, task_count]() -> bool { return counter == task_count; });
if (ntask_created > 0)
world.await([&ntask_completed, ntask_created]() -> bool {
return ntask_created == ntask_completed;
});

// Construct the new array
result_array_type result(
Expand Down
36 changes: 26 additions & 10 deletions src/TiledArray/conversions/make_array.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#ifndef TILEDARRAY_CONVERSIONS_MAKE_ARRAY_H__INCLUDED
#define TILEDARRAY_CONVERSIONS_MAKE_ARRAY_H__INCLUDED

#include "TiledArray/array_impl.h"
#include "TiledArray/external/madness.h"
#include "TiledArray/shape.h"
#include "TiledArray/type_traits.h"
Expand Down Expand Up @@ -79,6 +80,10 @@ inline Array make_array(
// Make an empty result array
Array result(world, trange);

// Construct the task function used to construct the result tiles.
std::atomic<std::int64_t> ntask_completed{0};
std::int64_t ntask_created{0};

// Iterate over local tiles of arg
for (const auto index : *result.pmap()) {
// Spawn a task to evaluate the tile
Expand All @@ -89,11 +94,20 @@ inline Array make_array(
return tile;
},
trange.make_tile_range(index));

++ntask_created;
tile.register_callback(
new detail::IncrementCounter<decltype(ntask_completed)>(
ntask_completed));
// Store result tile
result.set(index, tile);
result.set(index, std::move(tile));
}

// Wait for tile tasks to complete
if (ntask_created > 0)
world.await([&ntask_completed, ntask_created]() -> bool {
return ntask_completed == ntask_created;
});

return result;
}

Expand Down Expand Up @@ -150,26 +164,28 @@ inline Array make_array(
trange.tiles_range(), 0);

// Construct the task function used to construct the result tiles.
madness::AtomicInt counter;
counter = 0;
int task_count = 0;
std::atomic<std::int64_t> ntask_completed{0};
std::int64_t ntask_created{0};
auto task = [&](const ordinal_type index) -> value_type {
value_type tile;
tile_norms.at_ordinal(index) = op(tile, trange.make_tile_range(index));
++counter;
return tile;
};

for (const auto index : *pmap) {
auto result_tile = world.taskq.add(task, index);
++task_count;
++ntask_created;
result_tile.register_callback(
new detail::IncrementCounter<decltype(ntask_completed)>(
ntask_completed));
tiles.emplace_back(index, std::move(result_tile));
}

// Wait for tile norm data to be collected.
if (task_count > 0)
world.await(
[&counter, task_count]() -> bool { return counter == task_count; });
if (ntask_created > 0)
world.await([&ntask_completed, ntask_created]() -> bool {
return ntask_completed == ntask_created;
});

// Construct the new array
Array result(world, trange,
Expand Down
43 changes: 29 additions & 14 deletions src/TiledArray/dist_array.h
Original file line number Diff line number Diff line change
Expand Up @@ -906,23 +906,29 @@ class DistArray : public madness::archive::ParallelSerializableObject {
/// guarantee.
/// \throw TiledArray::Exception if skip_set is false and a local tile is
/// already set. Weak throw guarantee.
void fill_local(const element_type& value = element_type(),
bool skip_set = false) {
init_tiles(
template <Fence fence = Fence::No>
std::int64_t fill_local(const element_type& value = element_type(),
bool skip_set = false) {
return init_tiles<HostExecutor::Default, fence>(
[value](const range_type& range) { return value_type(range, value); },
skip_set);
}

/// Fill all local tiles with the specified value

/// \tparam fence If Fence::No, the operation will return early,
/// before the tasks have completed
/// \param[in] value What each local tile should be filled with.
/// \param[in] skip_set If false, will throw if any tiles are already set
/// \return the total number of tiles that have been (or will be) initialized
/// \throw TiledArray::Exception if the PIMPL is uninitialized. Strong throw
/// guarantee.
/// \throw TiledArray::Exception if skip_set is false and a local tile is
/// already set. Weak throw guarantee.
void fill(const element_type& value = numeric_type(), bool skip_set = false) {
fill_local(value, skip_set);
template <Fence fence = Fence::No>
std::int64_t fill(const element_type& value = numeric_type(),
bool skip_set = false) {
return fill_local<fence>(value, skip_set);
}

/// Fill all local tiles with random values
Expand All @@ -934,18 +940,21 @@ class DistArray : public madness::archive::ParallelSerializableObject {
/// generate random values of type T this function will be disabled via SFINAE
/// and attempting to use it will lead to a compile-time error.
///
/// \tparam fence If Fence::No, the operation will return early,
/// before the tasks have completed
/// \tparam T The type of random value to generate. Defaults to
/// element_type.
/// \param[in] skip_set If false, will throw if any tiles are already set
/// \return the total number of tiles that have been (or will be) initialized
/// \throw TiledArray::Exception if the PIMPL is not initialized. Strong
/// throw guarantee.
/// \throw TiledArray::Exception if skip_set is false and a local tile is
/// already initialized. Weak throw guarantee.
template <HostExecutor Exec = HostExecutor::Default,
typename T = element_type,
typename T = element_type, Fence fence = Fence::No,
typename = detail::enable_if_can_make_random_t<T>>
void fill_random(bool skip_set = false) {
init_elements<Exec>(
std::int64_t fill_random(bool skip_set = false) {
return init_elements<Exec, fence>(
[](const auto&) { return detail::MakeRandom<T>::generate_value(); });
}

Expand Down Expand Up @@ -978,16 +987,20 @@ class DistArray : public madness::archive::ParallelSerializableObject {
/// return tile;
/// });
/// \endcode
/// \tparam fence If Fence::No, the operation will return early,
/// before the tasks have completed
/// \tparam Op The type of the functor/function
/// \param[in] op The operation used to generate tiles
/// \param[in] skip_set If false, will throw if any tiles are already set
/// \throw TiledArray::Exception if the PIMPL is not set. Strong throw
/// guarantee.
/// \throw TiledArray::Exception if a tile is already set and skip_set is
/// false. Weak throw guarantee.
template <HostExecutor Exec = HostExecutor::Default, typename Op>
void init_tiles(Op&& op, bool skip_set = false) {
impl_ref().template init_tiles<Exec>(std::forward<Op>(op), skip_set);
template <HostExecutor Exec = HostExecutor::Default, Fence fence = Fence::No,
typename Op>
std::int64_t init_tiles(Op&& op, bool skip_set = false) {
return impl_ref().template init_tiles<Exec, fence>(std::forward<Op>(op),
skip_set);
}

/// Initialize elements of local, non-zero tiles with a user provided functor
Expand All @@ -1009,15 +1022,17 @@ class DistArray : public madness::archive::ParallelSerializableObject {
/// \tparam Op Type of the function/functor which will generate the elements.
/// \param[in] op The operation used to generate elements
/// \param[in] skip_set If false, will throw if any tiles are already set
/// \return the total number of tiles that have been (or will be) initialized
/// \throw TiledArray::Exception if the PIMPL is not initialized. Strong
/// throw guarnatee.
/// \throw TiledArray::Exception if skip_set is false and a local, non-zero
/// tile is already initialized. Weak throw
/// guarantee.
template <HostExecutor Exec = HostExecutor::Default, typename Op>
void init_elements(Op&& op, bool skip_set = false) {
template <HostExecutor Exec = HostExecutor::Default, Fence fence = Fence::No,
typename Op>
std::int64_t init_elements(Op&& op, bool skip_set = false) {
auto op_shared_handle = make_op_shared_handle(std::forward<Op>(op));
init_tiles<Exec>(
return init_tiles<Exec, fence>(
[op = std::move(op_shared_handle)](
const TiledArray::Range& range) -> value_type {
// Initialize the tile with the given range object
Expand Down
8 changes: 8 additions & 0 deletions src/TiledArray/fwd.h
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,14 @@ using Array

enum class HostExecutor { Thread, MADWorld, Default = MADWorld };

/// fence types
enum class Fence {
Global, //!< global fence (`world.gop.fence()`)
Local, //!< local fence (all local work done, equivalent to
//!< `world.taskq.fence() in absence of active messages)
No //!< no fence
};

namespace conversions {

/// user defined conversions
Expand Down
5 changes: 2 additions & 3 deletions src/TiledArray/special/diagonal_array.h
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,8 @@ std::enable_if_t<is_iterator<RandomAccessIterator>::value, void>
write_diag_tiles_to_array_rng(Array &A, RandomAccessIterator diagonals_begin) {
using Tile = typename Array::value_type;

A.init_tiles(
// N.B. Fence::Local ensures lifetime of the diagonals range
A.template init_tiles<HostExecutor::Default, Fence::Local>(
// Task to create each tile
[diagonals_begin](const Range &rng) {
// Compute range of diagonal elements in the tile
Expand Down Expand Up @@ -221,7 +222,6 @@ diagonal_array(World &world, TiledRange const &trange,
if constexpr (is_dense_v<Policy>) {
Array A(world, trange);
detail::write_diag_tiles_to_array_rng(A, diagonals_begin);
A.world().taskq.fence(); // ensure tasks outlive the diagonals_begin view
return A;
} else {
// Compute shape and init the Array
Expand All @@ -231,7 +231,6 @@ diagonal_array(World &world, TiledRange const &trange,
ShapeType shape(shape_norm, trange);
Array A(world, trange, shape);
detail::write_diag_tiles_to_array_rng(A, diagonals_begin);
A.world().taskq.fence(); // ensure tasks outlive the diagonals_begin view
return A;
}
abort(); // unreachable
Expand Down

0 comments on commit a73a17b

Please sign in to comment.