Skip to content

Commit

Permalink
[ci skip][wip] TA::retile support for DistArray with tensor-of-te…
Browse files Browse the repository at this point in the history
…nsor tiles.
  • Loading branch information
bimalgaudel committed Sep 16, 2024
1 parent 8af44bd commit ede81f3
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 8 deletions.
7 changes: 5 additions & 2 deletions src/TiledArray/array_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,9 @@ class ArrayImpl : public TensorImpl<Policy>,
typedef typename TensorImpl_::pmap_interface
pmap_interface; ///< process map interface type
typedef Tile value_type; ///< Tile or data type
typedef typename Tile::value_type
element_type; ///< The value type of a tile. It is the numeric_type for
///< tensor-of-scalars tiles.
typedef
typename eval_trait<Tile>::type eval_type; ///< The tile evaluation type
typedef typename numeric_type<value_type>::type
Expand Down Expand Up @@ -854,8 +857,8 @@ template <typename Tile, typename Policy>
std::shared_ptr<ArrayImpl<Tile, Policy>> make_with_new_trange(
const std::shared_ptr<const ArrayImpl<Tile, Policy>>& source_array_sptr,
const TiledRange& target_trange,
typename ArrayImpl<Tile, Policy>::numeric_type new_value_fill =
typename ArrayImpl<Tile, Policy>::numeric_type{0}) {
typename ArrayImpl<Tile, Policy>::element_type new_value_fill =
typename ArrayImpl<Tile, Policy>::element_type{}) {
TA_ASSERT(source_array_sptr);
auto& source_array = *source_array_sptr;
auto& world = source_array.world();
Expand Down
2 changes: 1 addition & 1 deletion src/TiledArray/dist_array.h
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,7 @@ class DistArray : public madness::archive::ParallelSerializableObject {
/// This constructor remaps the data of \p other according to \p new_trange ,
/// with \p new_value_fill used to fill the new elements, if any
DistArray(const DistArray& other, const trange_type& new_trange,
numeric_type new_value_fill = numeric_type{0})
element_type new_value_fill = element_type{})
: pimpl_(
make_with_new_trange(other.pimpl(), new_trange, new_value_fill)) {
this->truncate();
Expand Down
22 changes: 17 additions & 5 deletions src/TiledArray/tensor/tensor_interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,9 @@ class TensorInterface {
template <typename X>
using numeric_t = typename TiledArray::detail::numeric_type<X>::type;

template <typename X>
using value_t = typename std::remove_reference_t<X>::value_type;

template <typename, typename, typename>
friend class TensorInterface;

Expand Down Expand Up @@ -188,16 +191,25 @@ class TensorInterface {
TA_ASSERT(data);
}

template <typename T1, typename std::enable_if<
detail::is_tensor<T1>::value>::type* = nullptr>
template <typename T1, typename std::enable_if<detail::is_nested_tensor<
T1>::value>::type* = nullptr>
TensorInterface_& operator=(const T1& other) {
if constexpr (std::is_same_v<numeric_type, numeric_t<T1>>) {
TA_ASSERT(data_ != other.data());
}

detail::inplace_tensor_op([](numeric_type& MADNESS_RESTRICT result,
const numeric_t<T1> arg) { result = arg; },
*this, other);
if constexpr (detail::is_tensor_v<value_type>) {
range_ = BlockRange(other.range(), other.range().lobound(),
other.range().upbound());
data_ = new value_type[other.total_size()];
auto cpy = other.clone();
for (auto i = 0; i < other.total_size(); ++i)
std::swap(data_[i], cpy.data()[i]);
} else {
detail::inplace_tensor_op([](numeric_type& MADNESS_RESTRICT result,
const numeric_t<T1> arg) { result = arg; },
*this, other);
}

return *this;
}
Expand Down
61 changes: 61 additions & 0 deletions tests/retile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,65 @@ BOOST_AUTO_TEST_CASE(retile_tensor) {
BOOST_CHECK_EQUAL(result_sparse.trange(), trange);
}

BOOST_AUTO_TEST_CASE(retile_more) {
using Numeric = int;
using T = TA::Tensor<Numeric>;
using ToT = TA::Tensor<T>;
using ArrayT = TA::DistArray<T, TA::SparsePolicy>;
using ArrayToT = TA::DistArray<ToT, TA::SparsePolicy>;

auto& world = TA::get_default_world();

auto const tr_source = TA::TiledRange({{0, 2, 4, 8}, {0, 3, 5}});
auto const tr_target = TA::TiledRange({{0, 4, 6, 8}, {0, 2, 4, 5}});

auto rand_num = [](auto&&) {
return TA::detail::MakeRandom<Numeric>::generate_value();
};

auto rand_tensor = [rand_num](auto const& rng) -> T {
return T(rng, rand_num);
};

auto rand_tensor_of_tensor = [rand_tensor](auto const& inner_rng) {
return [rand_tensor, inner_rng](auto const& rng) -> ToT {
return ToT(rng, rand_tensor(inner_rng));
};
};

auto set_random_tensor_of_tensor_tile = [rand_tensor_of_tensor](
auto const& inner_rng) {
return
[gen = rand_tensor_of_tensor(inner_rng)](auto& tile, auto const& rng) {
tile = gen(rng);
return tile.norm();
};
};

auto const inner_rng = TA::Range({3, 3});
auto arr_source = TA::make_array<ArrayToT>(
world, tr_source, set_random_tensor_of_tensor_tile(inner_rng));
arr_source.truncate();

auto arr_target = TA::retile(arr_source, tr_target);

arr_source.make_replicated();
world.gop.fence();
arr_target.make_replicated();
world.gop.fence();

auto const& elem_rng = tr_source.elements_range();
BOOST_REQUIRE(elem_rng.volume() == tr_target.elements_range().volume());

auto get_elem = [](ArrayToT const& arr, auto const& eix) {
auto tix = arr.trange().element_to_tile(eix);
auto&& tile = arr.find(tix).get(false);
return tile(eix);
};

for (auto&& eix : elem_rng) {
BOOST_REQUIRE(get_elem(arr_source, eix) == get_elem(arr_target, eix));
}
}

BOOST_AUTO_TEST_SUITE_END()

0 comments on commit ede81f3

Please sign in to comment.