Skip to content

Commit

Permalink
Merge pull request #474 from ValeevGroup/gaudel/feature/retile_tot_ar…
Browse files Browse the repository at this point in the history
…rays

`TA::retile` support for `DistArray` with tensor-of-tensors tiles
  • Loading branch information
evaleev authored Oct 29, 2024
2 parents 5944bdb + 4bf23ad commit 486ae16
Show file tree
Hide file tree
Showing 6 changed files with 98 additions and 13 deletions.
7 changes: 5 additions & 2 deletions src/TiledArray/array_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,9 @@ class ArrayImpl : public TensorImpl<Policy>,
typedef typename TensorImpl_::pmap_interface
pmap_interface; ///< process map interface type
typedef Tile value_type; ///< Tile or data type
typedef typename Tile::value_type
element_type; ///< The value type of a tile. It is the numeric_type for
///< tensor-of-scalars tiles.
typedef
typename eval_trait<Tile>::type eval_type; ///< The tile evaluation type
typedef typename numeric_type<value_type>::type
Expand Down Expand Up @@ -889,8 +892,8 @@ template <typename Tile, typename Policy>
std::shared_ptr<ArrayImpl<Tile, Policy>> make_with_new_trange(
const std::shared_ptr<const ArrayImpl<Tile, Policy>>& source_array_sptr,
const TiledRange& target_trange,
typename ArrayImpl<Tile, Policy>::numeric_type new_value_fill =
typename ArrayImpl<Tile, Policy>::numeric_type{0}) {
typename ArrayImpl<Tile, Policy>::element_type new_value_fill =
typename ArrayImpl<Tile, Policy>::element_type{}) {
TA_ASSERT(source_array_sptr);
auto& source_array = *source_array_sptr;
auto& world = source_array.world();
Expand Down
2 changes: 1 addition & 1 deletion src/TiledArray/dist_array.h
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,7 @@ class DistArray : public madness::archive::ParallelSerializableObject {
/// This constructor remaps the data of \p other according to \p new_trange ,
/// with \p new_value_fill used to fill the new elements, if any
DistArray(const DistArray& other, const trange_type& new_trange,
numeric_type new_value_fill = numeric_type{0})
element_type new_value_fill = element_type{})
: pimpl_(
make_with_new_trange(other.pimpl(), new_trange, new_value_fill)) {
this->truncate();
Expand Down
6 changes: 3 additions & 3 deletions src/TiledArray/pmap/user_pmap.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class UserPmap : public Pmap {
UserPmap(World& world, size_type size, Index2Rank&& i2r)
: Pmap(world, size), index2rank_(std::forward<Index2Rank>(i2r)) {}

/// Constructs map that does not know the number of local elements
/// Constructs map that knows the number of local elements

/// \tparam Index2Rank a callable type with `size_type(size_t)` signature
/// \param world A reference to the world
Expand Down Expand Up @@ -88,10 +88,10 @@ class UserPmap : public Pmap {
virtual bool known_local_size() const { return known_local_size_; }

virtual const_iterator begin() const {
return Iterator(*this, 0, this->size_, 0, false);
return Iterator(*this, 0, this->size_, 0, /* checking = */ true);
}
virtual const_iterator end() const {
return Iterator(*this, 0, this->size_, this->size_, false);
return Iterator(*this, 0, this->size_, this->size_, /* checking = */ true);
}

private:
Expand Down
13 changes: 11 additions & 2 deletions src/TiledArray/tensor/kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -599,8 +599,17 @@ inline void inplace_tensor_op(Op&& op, TR& result, const Ts&... tensors) {
[&op, stride](
typename TR::pointer MADNESS_RESTRICT const result_data,
typename Ts::const_pointer MADNESS_RESTRICT const... tensors_data) {
for (decltype(result.range().volume()) i = 0ul; i < stride; ++i)
inplace_tensor_op(op, result_data[i], tensors_data[i]...);
for (decltype(result.range().volume()) i = 0ul; i < stride; ++i) {
if constexpr (std::is_invocable_v<
std::remove_reference_t<Op>,
typename std::remove_reference_t<TR>::value_type&,
typename std::remove_reference_t<
Ts>::value_type const&...>) {
std::forward<Op>(op)(result_data[i], tensors_data[i]...);
} else {
inplace_tensor_op(op, result_data[i], tensors_data[i]...);
}
}
};

for (std::decay_t<decltype(volume)> ord = 0ul; ord < volume; ord += stride)
Expand Down
13 changes: 8 additions & 5 deletions src/TiledArray/tensor/tensor_interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,9 @@ class TensorInterface {
template <typename X>
using numeric_t = typename TiledArray::detail::numeric_type<X>::type;

template <typename X>
using value_t = typename std::remove_reference_t<X>::value_type;

template <typename, typename, typename>
friend class TensorInterface;

Expand Down Expand Up @@ -188,16 +191,16 @@ class TensorInterface {
TA_ASSERT(data);
}

template <typename T1, typename std::enable_if<
detail::is_tensor<T1>::value>::type* = nullptr>
template <typename T1, typename std::enable_if<detail::is_nested_tensor<
T1>::value>::type* = nullptr>
TensorInterface_& operator=(const T1& other) {
if constexpr (std::is_same_v<numeric_type, numeric_t<T1>>) {
TA_ASSERT(data_ != other.data());
}

detail::inplace_tensor_op([](numeric_type& MADNESS_RESTRICT result,
const numeric_t<T1> arg) { result = arg; },
*this, other);
detail::inplace_tensor_op(
[](value_type& MADNESS_RESTRICT result, auto&& arg) { result = arg; },
*this, other);

return *this;
}
Expand Down
70 changes: 70 additions & 0 deletions tests/retile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,74 @@ BOOST_AUTO_TEST_CASE(retile_tensor) {
BOOST_CHECK_EQUAL(result_sparse.trange(), trange);
}

BOOST_AUTO_TEST_CASE(retile_more) {
using Numeric = int;
using T = TA::Tensor<Numeric>;
using ToT = TA::Tensor<T>;
using ArrayT = TA::DistArray<T, TA::SparsePolicy>;
using ArrayToT = TA::DistArray<ToT, TA::SparsePolicy>;

auto& world = TA::get_default_world();

auto const tr_source = TA::TiledRange({{0, 2, 4, 8}, {0, 3, 5}});
auto const tr_target = TA::TiledRange({{0, 4, 6, 8}, {0, 2, 4, 5}});
auto const& elem_rng = tr_source.elements_range();

BOOST_REQUIRE(elem_rng.volume() == tr_target.elements_range().volume());

auto const inner_rng = TA::Range({3, 3});

auto rand_tensor = [](auto const& rng) -> T {
return T(rng, [](auto&&) {
return TA::detail::MakeRandom<Numeric>::generate_value();
});
};

auto set_random_tensor_tile = [rand_tensor](auto& tile, auto const& rng) {
tile = rand_tensor(rng);
return tile.norm();
};

auto rand_tensor_of_tensor = [rand_tensor,
inner_rng](auto const& rng) -> ToT {
return ToT(rng, [rand_tensor, inner_rng](auto&&) {
return rand_tensor(inner_rng);
});
};

auto set_random_tensor_of_tensor_tile = [rand_tensor_of_tensor](
auto& tile, auto const& rng) {
tile = rand_tensor_of_tensor(rng);
return tile.norm();
};

auto get_elem = [](auto const& arr, auto const& eix) {
auto tix = arr.trange().element_to_tile(eix);
auto&& tile = arr.find(tix).get(false);
return tile(eix);
};

auto arr_source0 =
TA::make_array<ArrayT>(world, tr_source, set_random_tensor_tile);
auto arr_target0 = TA::retile(arr_source0, tr_target);

for (auto&& eix : elem_rng) {
BOOST_REQUIRE(get_elem(arr_source0, eix) == get_elem(arr_target0, eix));
}

auto arr_source = TA::make_array<ArrayToT>(world, tr_source,
set_random_tensor_of_tensor_tile);
auto arr_target = TA::retile(arr_source, tr_target);

arr_source.make_replicated();
arr_target.make_replicated();
arr_source.truncate();
arr_target.truncate();
world.gop.fence();

for (auto&& eix : elem_rng) {
BOOST_REQUIRE(get_elem(arr_source, eix) == get_elem(arr_target, eix));
}
}

BOOST_AUTO_TEST_SUITE_END()

0 comments on commit 486ae16

Please sign in to comment.