Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

function for element-wise operations with coordinate index #452

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 70 additions & 4 deletions src/TiledArray/conversions/foreach.h
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,8 @@ inline std::enable_if_t<is_dense_v<Policy>, DistArray<Tile, Policy>> foreach (
/// function will fence before AND after the data is modified
template <typename Tile, typename Policy, typename Op,
typename = typename std::enable_if<!TiledArray::detail::is_array<
typename std::decay<Op>::type>::value>::type>
typename std::decay<Op>::type>::value>::type,
typename = typename std::enable_if<detail::is_invocable<Op, Tile&>::value>::type>
inline std::enable_if_t<is_dense_v<Policy>, void> foreach_inplace(
DistArray<Tile, Policy>& arg, Op&& op, bool fence = true) {
// The tile data is being modified in place, which means we may need to
Expand All @@ -495,6 +496,68 @@ inline std::enable_if_t<is_dense_v<Policy>, void> foreach_inplace(
if (fence) arg.world().gop.fence();
}

/// Modify each element of an Array object

/// This function modifies the elements of a \c DistArray object with a const reference to the
/// index of the current element. This allows the user to modify specific elements of the array
/// based on their indices. Users must provide a function/functor that modifies each element. The provided function
/// should take a reference to a \c Tile object and a reference to a \c std::vector<std::size_t>
/// representing the indices of the current element within the tile. For example,
/// to copy the upper triangular elements of a nxnxn array to a c++ vector of size n^3:
/// \code
/// std::vector<double> vec(n*n*n);
/// forall(array, [&vec] (auto& tile, const auto& index) {
/// size_t i = index[0], j = index[1], k = index[2];
/// if (i <= j && j <= k) {
/// vec[i*n*n+j*n+k] = tile[index];
/// } else {
/// vec[i*n*n+j*n+k] = 0.0;
/// }
/// });
/// \endcode
/// Similarly, to set each upper triangular element of a nxnxn array to the square root of values in a c++ vector of size n^3:
/// \code
/// vector<double> vec(n*n*n);
/// std::generate(v.begin(), v.end(), std::rand);
/// forall(array, [&vec] (Tile& tile, index_type& index) {
/// size_t i = index[0], j = index[1], k = index[2];
/// if (i <= j && j <= k) {
/// tile[index] = std::sqrt(vec[i*n*n+j*n+k]);
/// } else {
/// tile[index] = 0.0;
/// }
/// });
/// \endcode
/// The expected signature of the element operation is:
/// \code
/// void op(Tile& tile, Range::index_type& index);
/// \endcode
/// \tparam Tile The tile type of \c arg
/// \tparam Policy The policy type of \c arg
/// \tparam Op Mutating element operation
/// \param arg The argument array to be modified
/// \param op The mutating element function
/// \param fence If \c true, this function will fence before and after the data is modified
template <typename Tile, typename Policy, typename Op,
typename = typename std::enable_if<!TiledArray::detail::is_array<
typename std::decay<Op>::type>::value>::type,
typename = typename std::enable_if<detail::is_invocable<Op, Tile&,
const Range::index_type&>::value>::type>
inline void foreach_inplace(
DistArray<Tile, Policy>& arg, Op&& op, bool fence = true) {

// wrap Op into a shallow-copy copyable handle
auto op_shared_handle = make_op_shared_handle(std::forward<Op>(op));

// Use foreach_inplace to iterate over tiles and modify elements
foreach_inplace(
arg,
[op = std::move(op_shared_handle)](Tile& tile) mutable {
for (const Range::index_type& index : tile.range())
op(tile, index);
}, fence); // Fence before and after the data is modified
}

/// Apply a function to each tile of a sparse Array

/// This function uses an \c Array object to generate a new \c Array where the
Expand Down Expand Up @@ -587,7 +650,8 @@ inline std::enable_if_t<!is_dense_v<Policy>, DistArray<Tile, Policy>> foreach (
/// function will fence before AND after the data is modified
template <typename Tile, typename Policy, typename Op,
typename = typename std::enable_if<!TiledArray::detail::is_array<
typename std::decay<Op>::type>::value>::type>
typename std::decay<Op>::type>::value>::type,
typename = typename std::enable_if<detail::is_invocable<Op, Tile&>::value>::type>
inline std::enable_if_t<!is_dense_v<Policy>, void> foreach_inplace(
DistArray<Tile, Policy>& arg, Op&& op, bool fence = true) {
// The tile data is being modified in place, which means we may need to
Expand Down Expand Up @@ -629,7 +693,8 @@ inline std::
}

/// This function takes two input tiles and put result into the left tile
template <typename LeftTile, typename RightTile, typename Policy, typename Op>
template <typename LeftTile, typename RightTile, typename Policy, typename Op,
typename = typename std::enable_if<detail::is_invocable<Op, LeftTile&, const RightTile>::value>::type>
inline std::enable_if_t<is_dense_v<Policy>, void> foreach_inplace(
DistArray<LeftTile, Policy>& left,
const DistArray<RightTile, Policy>& right, Op&& op, bool fence = true) {
Expand Down Expand Up @@ -675,7 +740,8 @@ inline std::
}

/// This function takes two input tiles and put result into the left tile
template <typename LeftTile, typename RightTile, typename Policy, typename Op>
template <typename LeftTile, typename RightTile, typename Policy, typename Op,
typename = typename std::enable_if<detail::is_invocable<Op, LeftTile&, const RightTile>::value>::type>
inline std::enable_if_t<!is_dense_v<Policy>, void> foreach_inplace(
DistArray<LeftTile, Policy>& left,
const DistArray<RightTile, Policy>& right, Op&& op,
Expand Down
20 changes: 20 additions & 0 deletions tests/foreach.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,26 @@ BOOST_AUTO_TEST_CASE(foreach_unary) {
}
}

BOOST_AUTO_TEST_CASE(foreach_w_idx) {

TArrayI result = a.clone();
foreach_inplace(result, [](TensorI& tile, const Range::index_type &coord_idx) {
long fac = (coord_idx[0] < coord_idx[1]) ? coord_idx[0] : coord_idx[1];
tile[coord_idx] = fac * tile[coord_idx];
}, true);

for (auto index : *result.pmap()) {
TensorI tile0 = a.find(index).get();
TensorI tile = result.find(index).get();
const Range &range = tile0.range();
for (std::size_t i = 0; i < tile.size(); ++i) {
const Range::index_type &coord_idx = range.idx(i);
long fac = coord_idx[0] < coord_idx[1] ? coord_idx[0] : coord_idx[1];
BOOST_CHECK_EQUAL(tile[i], fac * tile0[i]);
}
}
}

BOOST_AUTO_TEST_CASE(foreach_unary_sparse) {
TSpArrayI result =
foreach (c, [](TensorI& result, const TensorI& arg) -> float {
Expand Down