Skip to content

Commit

Permalink
foreach and make_array use callbacks instead of atomic counters for l…
Browse files Browse the repository at this point in the history
…ocal completion checks
  • Loading branch information
evaleev committed Oct 11, 2024
1 parent 6d661ab commit c955339
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 20 deletions.
23 changes: 13 additions & 10 deletions src/TiledArray/conversions/foreach.h
Original file line number Diff line number Diff line change
Expand Up @@ -283,19 +283,17 @@ inline std::
arg.trange().tiles_range(), 0);

// Construct the task function used to construct the result tiles.
madness::AtomicInt counter;
counter = 0;
int task_count = 0;
std::atomic<std::int64_t> ntask_completed{0};
std::int64_t ntask_created{0};
auto op_shared_handle = make_op_shared_handle(std::forward<Op>(op));
const auto task = [op_shared_handle, &counter, &tile_norms](
const auto task = [op_shared_handle, &tile_norms](
const ordinal_type ord,
const_if_t<not inplace, arg_value_type>& arg_tile,
const ArgTiles&... arg_tiles) -> result_value_type {
op_helper<inplace, result_value_type> op_caller;
auto result_tile =
op_caller(std::move(op_shared_handle), tile_norms.at_ordinal(ord),
arg_tile, arg_tiles...);
++counter;
return result_tile;
};

Expand All @@ -310,7 +308,9 @@ inline std::
continue;
auto result_tile =
world.taskq.add(task, ord, arg.find_local(ord), args.find(ord)...);
++task_count;
++ntask_created;
result_tile.register_callback(
new IncrementCounter<decltype(ntask_completed)>(ntask_completed));
tiles.emplace_back(ord, std::move(result_tile));
if (op_returns_void) // if Op does not evaluate norms, use the (scaled)
// norms of the first arg
Expand All @@ -324,7 +324,9 @@ inline std::
auto result_tile =
world.taskq.add(task, ord, detail::get_sparse_tile(ord, arg),
detail::get_sparse_tile(ord, args)...);
++task_count;
++ntask_created;
result_tile.register_callback(
new IncrementCounter<decltype(ntask_completed)>(ntask_completed));
tiles.emplace_back(ord, std::move(result_tile));
if (op_returns_void) // if Op does not evaluate norms, find max
// (scaled) norms of all args
Expand All @@ -339,9 +341,10 @@ inline std::
}

// Wait for tile norm data to be collected.
if (task_count > 0)
world.await(
[&counter, task_count]() -> bool { return counter == task_count; });
if (ntask_created > 0)
world.await([&ntask_completed, ntask_created]() -> bool {
return ntask_created == ntask_completed;
});

// Construct the new array
result_array_type result(
Expand Down
36 changes: 26 additions & 10 deletions src/TiledArray/conversions/make_array.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#ifndef TILEDARRAY_CONVERSIONS_MAKE_ARRAY_H__INCLUDED
#define TILEDARRAY_CONVERSIONS_MAKE_ARRAY_H__INCLUDED

#include "TiledArray/array_impl.h"
#include "TiledArray/external/madness.h"
#include "TiledArray/shape.h"
#include "TiledArray/type_traits.h"
Expand Down Expand Up @@ -79,6 +80,10 @@ inline Array make_array(
// Make an empty result array
Array result(world, trange);

// Construct the task function used to construct the result tiles.
std::atomic<std::int64_t> ntask_completed{0};
std::int64_t ntask_created{0};

// Iterate over local tiles of arg
for (const auto index : *result.pmap()) {
// Spawn a task to evaluate the tile
Expand All @@ -89,11 +94,20 @@ inline Array make_array(
return tile;
},
trange.make_tile_range(index));

++ntask_created;
tile.register_callback(
new detail::IncrementCounter<decltype(ntask_completed)>(
ntask_completed));
// Store result tile
result.set(index, tile);
result.set(index, std::move(tile));
}

// Wait for tile tasks to complete
if (ntask_created > 0)
world.await([&ntask_completed, ntask_created]() -> bool {
return ntask_completed == ntask_created;
});

return result;
}

Expand Down Expand Up @@ -150,26 +164,28 @@ inline Array make_array(
trange.tiles_range(), 0);

// Construct the task function used to construct the result tiles.
madness::AtomicInt counter;
counter = 0;
int task_count = 0;
std::atomic<std::int64_t> ntask_completed{0};
std::int64_t ntask_created{0};
auto task = [&](const ordinal_type index) -> value_type {
value_type tile;
tile_norms.at_ordinal(index) = op(tile, trange.make_tile_range(index));
++counter;
return tile;
};

for (const auto index : *pmap) {
auto result_tile = world.taskq.add(task, index);
++task_count;
++ntask_created;
result_tile.register_callback(
new detail::IncrementCounter<decltype(ntask_completed)>(
ntask_completed));
tiles.emplace_back(index, std::move(result_tile));
}

// Wait for tile norm data to be collected.
if (task_count > 0)
world.await(
[&counter, task_count]() -> bool { return counter == task_count; });
if (ntask_created > 0)
world.await([&ntask_completed, ntask_created]() -> bool {
return ntask_completed == ntask_created;
});

// Construct the new array
Array result(world, trange,
Expand Down

0 comments on commit c955339

Please sign in to comment.