Skip to content

Commit

Permalink
Merge pull request #466 from ValeevGroup/evaleev/feature/element-block
Browse files Browse the repository at this point in the history
can change DistArray's trange (retile + more)
  • Loading branch information
evaleev authored Aug 27, 2024
2 parents 7e45348 + 09819de commit d174eda
Show file tree
Hide file tree
Showing 24 changed files with 1,083 additions and 575 deletions.
3 changes: 2 additions & 1 deletion src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ TiledArray/util/backtrace.h
TiledArray/util/bug.h
TiledArray/util/function.h
TiledArray/util/initializer_list.h
TiledArray/util/invoke.h
TiledArray/util/logger.h
TiledArray/util/ptr_registry.h
TiledArray/util/random.h
Expand Down Expand Up @@ -258,7 +259,7 @@ set_source_files_properties(

# the list of libraries on which TiledArray depends on, will be cached later
# when FetchContent umpire: set(_TILEDARRAY_DEPENDENCIES MADworld TiledArray_Eigen BTAS::BTAS blaspp_headers umpire)
set(_TILEDARRAY_DEPENDENCIES MADworld TiledArray_Eigen BTAS::BTAS blaspp_headers TiledArray_UMPIRE)
set(_TILEDARRAY_DEPENDENCIES MADworld TiledArray_Eigen BTAS::BTAS blaspp_headers TiledArray_UMPIRE range-v3::range-v3)

if(CUDA_FOUND OR HIP_FOUND)

Expand Down
326 changes: 324 additions & 2 deletions src/TiledArray/array_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include <TiledArray/tensor_impl.h>
#include <TiledArray/transform_iterator.h>
#include <TiledArray/type_traits.h>
#include <TiledArray/util/function.h>

namespace TiledArray {
namespace detail {
Expand Down Expand Up @@ -407,7 +408,8 @@ class ArrayIterator {
/// \note It is the users responsibility to ensure the process maps on all
/// nodes are identical.
template <typename Tile, typename Policy>
class ArrayImpl : public TensorImpl<Policy> {
class ArrayImpl : public TensorImpl<Policy>,
public std::enable_shared_from_this<ArrayImpl<Tile, Policy>> {
public:
typedef ArrayImpl<Tile, Policy> ArrayImpl_; ///< This object type
typedef TensorImpl<Policy> TensorImpl_; ///< The base class of this object
Expand Down Expand Up @@ -440,6 +442,68 @@ class ArrayImpl : public TensorImpl<Policy> {
private:
storage_type data_; ///< Tile container

public:
static madness::AtomicInt cleanup_counter_;

/// Array deleter function

/// This function schedules a task for lazy cleanup. Array objects are
/// deleted only after the object has been deleted in all processes.
/// \param pimpl The implementation pointer to be deleted.
static void lazy_deleter(const ArrayImpl_* const pimpl) {
if (pimpl) {
if (madness::initialized()) {
World& world = pimpl->world();
const madness::uniqueidT id = pimpl->id();
cleanup_counter_++;

// wait for all DelayedSet's to vanish
world.await([&]() { return (pimpl->num_live_ds() == 0); }, true);

try {
world.gop.lazy_sync(id, [pimpl]() {
delete pimpl;
ArrayImpl_::cleanup_counter_--;
});
} catch (madness::MadnessException& e) {
fprintf(stderr,
"!! ERROR TiledArray: madness::MadnessException thrown in "
"DistArray::lazy_deleter().\n"
"%s\n"
"!! ERROR TiledArray: The exception has been absorbed.\n"
"!! ERROR TiledArray: rank=%i\n",
e.what(), world.rank());

cleanup_counter_--;
delete pimpl;
} catch (std::exception& e) {
fprintf(stderr,
"!! ERROR TiledArray: std::exception thrown in "
"DistArray::lazy_deleter().\n"
"%s\n"
"!! ERROR TiledArray: The exception has been absorbed.\n"
"!! ERROR TiledArray: rank=%i\n",
e.what(), world.rank());

cleanup_counter_--;
delete pimpl;
} catch (...) {
fprintf(stderr,
"!! ERROR TiledArray: An unknown exception was thrown in "
"DistArray::lazy_deleter().\n"
"!! ERROR TiledArray: The exception has been absorbed.\n"
"!! ERROR TiledArray: rank=%i\n",
world.rank());

cleanup_counter_--;
delete pimpl;
}
} else {
delete pimpl;
}
}
}

public:
/// Constructor

Expand All @@ -453,7 +517,32 @@ class ArrayImpl : public TensorImpl<Policy> {
ArrayImpl(World& world, const trange_type& trange, const shape_type& shape,
const std::shared_ptr<const pmap_interface>& pmap)
: TensorImpl_(world, trange, shape, pmap),
data_(world, trange.tiles_range().volume(), pmap) {}
data_(world, trange.tiles_range().volume(), pmap) {
// Validate the process map
TA_ASSERT(pmap->size() == trange.tiles_range().volume() &&
"TiledArray::DistArray::DistArray() -- The size of the process "
"map is not "
"equal to the number of tiles in the TiledRange object.");
TA_ASSERT(pmap->rank() ==
typename pmap_interface::size_type(world.rank()) &&
"TiledArray::DistArray::DistArray() -- The rank of the process "
"map is not equal to that "
"of the world object.");
TA_ASSERT(pmap->procs() ==
typename pmap_interface::size_type(world.size()) &&
"TiledArray::DistArray::DistArray() -- The number of processes "
"in the process map is not "
"equal to that of the world object.");

// Validate the shape
TA_ASSERT(
!shape.empty() &&
"TiledArray::DistArray::DistArray() -- The shape is not initialized.");
TA_ASSERT(shape.validate(trange.tiles_range()) &&
"TiledArray::DistArray::DistArray() -- The range of the shape is "
"not equal to "
"the tiles range.");
}

/// Virtual destructor
virtual ~ArrayImpl() {}
Expand Down Expand Up @@ -649,8 +738,80 @@ class ArrayImpl : public TensorImpl<Policy> {
return data_.num_live_df();
}

/// Initialize (local) tiles with a user provided functor

/// This function is used to initialize the local, non-zero tiles of the array
/// via a function (or functor). The work is done in parallel, therefore \c op
/// must be a thread safe function/functor. The signature of the functor
/// should be:
/// \code
/// value_type op(const range_type&)
/// \endcode
/// For example, in the following code, the array tiles are initialized with
/// random numbers from 0 to 1:
/// \code
/// array.init_tiles([] (const TiledArray::Range& range) ->
/// TiledArray::Tensor<double>
/// {
/// // Initialize the tile with the given range object
/// TiledArray::Tensor<double> tile(range);
///
/// // Initialize the random number generator
/// std::default_random_engine generator;
/// std::uniform_real_distribution<double> distribution(0.0,1.0);
///
/// // Fill the tile with random numbers
/// for(auto& value : tile)
/// value = distribution(generator);
///
/// return tile;
/// });
/// \endcode
/// \tparam Op The type of the functor/function
/// \param[in] op The operation used to generate tiles
/// \param[in] skip_set If false, will throw if any tiles are already set
/// \throw TiledArray::Exception if the PIMPL is not set. Strong throw
/// guarantee.
/// \throw TiledArray::Exception if a tile is already set and skip_set is
/// false. Weak throw guarantee.
template <HostExecutor Exec = HostExecutor::Default, typename Op>
void init_tiles(Op&& op, bool skip_set = false) {
// lifetime management of op depends on whether it is a lvalue ref (i.e. has
// an external owner) or an rvalue ref
// - if op is an lvalue ref: pass op to tasks
// - if op is an rvalue ref pass make_shared_function(op) to tasks
auto op_shared_handle = make_op_shared_handle(std::forward<Op>(op));

auto it = this->pmap()->begin();
const auto end = this->pmap()->end();
for (; it != end; ++it) {
const auto& index = *it;
if (!this->is_zero(index)) {
if (skip_set) {
auto& fut = this->get_local(index);
if (fut.probe()) continue;
}
if constexpr (Exec == HostExecutor::MADWorld) {
Future<value_type> tile = this->world().taskq.add(
[this_sptr = this->shared_from_this(),
index = ordinal_type(index), op_shared_handle]() -> value_type {
return op_shared_handle(
this_sptr->trange().make_tile_range(index));
});
set(index, std::move(tile));
} else {
static_assert(Exec == HostExecutor::Thread);
set(index, op_shared_handle(this->trange().make_tile_range(index)));
}
}
}
}

}; // class ArrayImpl

template <typename Tile, typename Policy>
madness::AtomicInt ArrayImpl<Tile, Policy>::cleanup_counter_;

#ifndef TILEDARRAY_HEADER_ONLY

extern template class ArrayImpl<Tensor<double>, DensePolicy>;
Expand All @@ -673,6 +834,167 @@ extern template class ArrayImpl<Tensor<std::complex<float>>, SparsePolicy>;

#endif // TILEDARRAY_HEADER_ONLY

template <typename Tile, typename Policy>
void write_tile_block(madness::uniqueidT target_array_id,
std::size_t target_tile_ord,
const Tile& target_tile_contribution) {
auto* world_ptr = World::world_from_id(target_array_id.get_world_id());
auto target_array_ptr_opt =
world_ptr->ptr_from_id<typename ArrayImpl<Tile, Policy>::storage_type>(
target_array_id);
TA_ASSERT(target_array_ptr_opt);
TA_ASSERT((*target_array_ptr_opt)->is_local(target_tile_ord));
(*target_array_ptr_opt)
->get_local(target_tile_ord)
.get()
.block(target_tile_contribution.range()) = target_tile_contribution;
}

template <typename Tile, typename Policy>
std::shared_ptr<ArrayImpl<Tile, Policy>> make_with_new_trange(
const std::shared_ptr<const ArrayImpl<Tile, Policy>>& source_array_sptr,
const TiledRange& target_trange,
typename ArrayImpl<Tile, Policy>::numeric_type new_value_fill =
typename ArrayImpl<Tile, Policy>::numeric_type{0}) {
TA_ASSERT(source_array_sptr);
auto& source_array = *source_array_sptr;
auto& world = source_array.world();
const auto rank = source_array.trange().rank();
TA_ASSERT(rank == target_trange.rank());

// compute metadata
// - list of target tile indices and the corresponding Range1 for each 1-d
// source tile
using target_tiles_t = std::vector<std::pair<TA_1INDEX_TYPE, Range1>>;
using mode_target_tiles_t = std::vector<target_tiles_t>;
using all_target_tiles_t = std::vector<mode_target_tiles_t>;

all_target_tiles_t all_target_tiles(target_trange.rank());
// for each mode ...
for (auto d = 0; d != target_trange.rank(); ++d) {
mode_target_tiles_t& mode_target_tiles = all_target_tiles[d];
auto& target_tr1 = target_trange.dim(d);
auto& target_element_range = target_tr1.elements_range();
// ... and each tile in that mode ...
for (auto&& source_tile : source_array.trange().dim(d)) {
mode_target_tiles.emplace_back();
auto& target_tiles = mode_target_tiles.back();
auto source_tile_lo = source_tile.lobound();
auto source_tile_up = source_tile.upbound();
auto source_element_idx = source_tile_lo;
// ... find all target tiles what overlap with it
if (target_element_range.overlaps_with(source_tile)) {
while (source_element_idx < source_tile_up) {
if (target_element_range.includes(source_element_idx)) {
auto target_tile_idx =
target_tr1.element_to_tile(source_element_idx);
auto target_tile = target_tr1.tile(target_tile_idx);
auto target_lo =
std::max(source_element_idx, target_tile.lobound());
auto target_up = std::min(source_tile_up, target_tile.upbound());
target_tiles.emplace_back(target_tile_idx,
Range1(target_lo, target_up));
source_element_idx = target_up;
} else if (source_element_idx < target_element_range.lobound()) {
source_element_idx = target_element_range.lobound();
} else if (source_element_idx >= target_element_range.upbound())
break;
}
}
}
}

// estimate the shape, if sparse
// use max value for each nonzero tile, then will recompute after tiles are
// assigned
using shape_type = typename Policy::shape_type;
shape_type target_shape;
const auto& target_tiles_range = target_trange.tiles_range();
if constexpr (!is_dense_v<Policy>) {
// each rank computes contributions to the shape norms from its local tiles
Tensor<float> target_shape_norms(target_tiles_range, 0);
auto& source_trange = source_array.trange();
const auto e = source_array.cend();
for (auto it = source_array.cbegin(); it != e; ++it) {
auto source_tile_idx = it.index();

// make range for iterating over all possible target tile idx combinations
TA::Index target_tile_ord_extent_range(rank);
for (auto d = 0; d != rank; ++d) {
target_tile_ord_extent_range[d] =
all_target_tiles[d][source_tile_idx[d]].size();
}

// loop over every target tile combination
TA::Range target_tile_ord_extent(target_tile_ord_extent_range);
for (auto& target_tile_ord : target_tile_ord_extent) {
TA::Index target_tile_idx(rank);
for (auto d = 0; d != rank; ++d) {
target_tile_idx[d] =
all_target_tiles[d][source_tile_idx[d]][target_tile_ord[d]].first;
}
target_shape_norms(target_tile_idx) = std::numeric_limits<float>::max();
}
}
world.gop.max(target_shape_norms.data(), target_shape_norms.size());
target_shape = SparseShape(target_shape_norms, target_trange);
}

using Array = ArrayImpl<Tile, Policy>;
auto target_array_sptr = std::shared_ptr<Array>(
new Array(
source_array.world(), target_trange, target_shape,
Policy::default_pmap(world, target_trange.tiles_range().volume())),
Array::lazy_deleter);
auto& target_array = *target_array_sptr;
target_array.init_tiles([value = new_value_fill](const Range& range) {
return typename Array::value_type(range, value);
});
target_array.world().gop.fence();

// loop over local tile and sends its contributions to the targets
{
auto& source_trange = source_array.trange();
const auto e = source_array.cend();
auto& target_tiles_range = target_trange.tiles_range();
for (auto it = source_array.cbegin(); it != e; ++it) {
const auto& source_tile = *it;
auto source_tile_idx = it.index();

// make range for iterating over all possible target tile idx combinations
TA::Index target_tile_ord_extent_range(rank);
for (auto d = 0; d != rank; ++d) {
target_tile_ord_extent_range[d] =
all_target_tiles[d][source_tile_idx[d]].size();
}

// loop over every target tile combination
TA::Range target_tile_ord_extent(target_tile_ord_extent_range);
for (auto& target_tile_ord : target_tile_ord_extent) {
TA::Index target_tile_idx(rank);
container::svector<TA::Range1> target_tile_rngs1(rank);
for (auto d = 0; d != rank; ++d) {
std::tie(target_tile_idx[d], target_tile_rngs1[d]) =
all_target_tiles[d][source_tile_idx[d]][target_tile_ord[d]];
}
TA_ASSERT(source_tile.future().probe());
Tile target_tile_contribution(
source_tile.get().block(target_tile_rngs1));
auto target_tile_idx_ord = target_tiles_range.ordinal(target_tile_idx);
auto target_proc = target_array.pmap()->owner(target_tile_idx_ord);
world.taskq.add(target_proc, &write_tile_block<Tile, Policy>,
target_array.id(), target_tile_idx_ord,
target_tile_contribution);
}
}
}
// data is mutated in place, so must wait for all tasks to complete
target_array.world().gop.fence();
// WARNING!! need to truncate in DistArray ctor

return target_array_sptr;
}

} // namespace detail
} // namespace TiledArray

Expand Down
Loading

0 comments on commit d174eda

Please sign in to comment.