Skip to content

Commit

Permalink
ToT support for math/linalg functions and concat function.
Browse files Browse the repository at this point in the history
  • Loading branch information
bimalgaudel committed Jul 7, 2024
1 parent 644f0e9 commit 52b0960
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
5 changes: 3 additions & 2 deletions src/TiledArray/conversions/concat.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,9 @@ DistArray<Tile, Policy> concat(
DistArray<Tile, Policy> result(*target_world, tr);
const auto annot = detail::dummy_annotation(r);
for (auto i = 0ul; i != arrays.size(); ++i) {
result(annot).block(tile_begin_end[i].first, tile_begin_end[i].second) =
arrays[i](annot);
result.make_tsrexpr(annot).block(tile_begin_end[i].first,
tile_begin_end[i].second) =
arrays[i].make_tsrexpr(annot);
}
result.world().gop.fence();

Expand Down
7 changes: 4 additions & 3 deletions src/TiledArray/math/linalg/basic.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,14 +79,14 @@ template <typename Tile, typename Policy>
inline void vec_multiply(DistArray<Tile, Policy>& a1,
const DistArray<Tile, Policy>& a2) {
auto vars = TiledArray::detail::dummy_annotation(rank(a1));
a1(vars) = a1(vars) * a2(vars);
a1.make_tsrexpr(vars) = a1.make_tsrexpr(vars) * a2.make_tsrexpr(vars);
}

template <typename Tile, typename Policy, typename S>
inline void scale(DistArray<Tile, Policy>& a, S scaling_factor) {
using numeric_type = typename DistArray<Tile, Policy>::numeric_type;
auto vars = TiledArray::detail::dummy_annotation(rank(a));
a(vars) = numeric_type(scaling_factor) * a(vars);
a.make_tsrexpr(vars) = numeric_type(scaling_factor) * a.make_tsrexpr(vars);
}

template <typename Tile, typename Policy>
Expand All @@ -99,7 +99,8 @@ inline void axpy(DistArray<Tile, Policy>& y, S alpha,
const DistArray<Tile, Policy>& x) {
using numeric_type = typename DistArray<Tile, Policy>::numeric_type;
auto vars = TiledArray::detail::dummy_annotation(rank(y));
y(vars) = y(vars) + numeric_type(alpha) * x(vars);
y.make_tsrexpr(vars) =
y.make_tsrexpr(vars) + numeric_type(alpha) * x.make_tsrexpr(vars);
}

/// selector for concat
Expand Down

0 comments on commit 52b0960

Please sign in to comment.