diff --git a/src/TiledArray/math/linalg/scalapack/block_cyclic.h b/src/TiledArray/math/linalg/scalapack/block_cyclic.h index bea4fd5dfa..7c36ec0869 100644 --- a/src/TiledArray/math/linalg/scalapack/block_cyclic.h +++ b/src/TiledArray/math/linalg/scalapack/block_cyclic.h @@ -30,11 +30,11 @@ #include #if TILEDARRAY_HAS_SCALAPACK +#include #include #include #include #include -#include #include #include @@ -58,7 +58,10 @@ class BlockCyclicMatrix : public madness::WorldObject> { col_major_mat_t local_mat_; ///< Local block cyclic buffer std::pair dims_; ///< Dims of the matrix - void put_tile(const Tensor& tile) { + template >> + void put_tile(const Tile& tile) { // Extract Tile information const auto* lo = tile.range().lobound_data(); const auto* up = tile.range().upbound_data(); @@ -104,9 +107,22 @@ class BlockCyclicMatrix : public madness::WorldObject> { } else { // Send the subblock to a remote rank for processing - Tensor subblock(tile.block({i, j}, {i_last, j_last})); - world_base_t::send(owner(i, j), &BlockCyclicMatrix::put_tile, - subblock); + Tensor subblock; + if constexpr (TiledArray::detail::is_ta_tensor_v) + subblock = tile.block({i, j}, {i_last, j_last}); + else { + auto tile_blk_range = TiledArray::BlockRange( + TiledArray::detail::make_ta_range(tile.range()), {i, j}, + {i_last, j_last}); + using std::data; + auto tile_blk_view = + TiledArray::make_const_map(data(tile), tile_blk_range); + subblock = tile_blk_view; + } + world_base_t::send( + owner(i, j), + &BlockCyclicMatrix::template put_tile, + subblock); } } // for (j) @@ -114,7 +130,10 @@ class BlockCyclicMatrix : public madness::WorldObject> { } // put_tile - Tensor extract_submatrix(std::vector lo, std::vector up) { + template >> + Tile extract_submatrix(std::vector lo, std::vector up) { assert(bc_dist_.i_own(lo[0], lo[1])); auto [i_st, j_st] = bc_dist_.local_indx(lo[0], lo[1]); @@ -123,7 +142,7 @@ class BlockCyclicMatrix : public madness::WorldObject> { auto j_extent = up[1] - lo[1]; Range range(lo, up); - Tensor tile(range); + Tile tile(range); auto tile_map = eigen_map(tile); @@ -172,7 +191,7 @@ class BlockCyclicMatrix : public madness::WorldObject> { array.trange().dim(1).extent(), MB, NB) { TA_ASSERT(array.trange().rank() == 2); - for (auto it = array.begin(); it != array.end(); ++it) put_tile(*it); + for (auto it = array.begin(); it != array.end(); ++it) put_tile(it->get()); world_base_t::process_pending(); } @@ -197,8 +216,9 @@ class BlockCyclicMatrix : public madness::WorldObject> { template Array tensor_from_matrix(const TiledRange& trange) const { - auto construct_tile = [&](Tensor& tile, const Range& range) { - tile = Tensor(range); + using Tile = typename Array::value_type; + auto construct_tile = [&](Tile& tile, const Range& range) { + tile = Tile(range); // Extract Tile information const auto* lo = tile.range().lobound_data(); @@ -246,13 +266,24 @@ class BlockCyclicMatrix : public madness::WorldObject> { std::vector lo{i, j}; std::vector up{i_last, j_last}; madness::Future> remtile_fut = world_base_t::send( - owner(i, j), &BlockCyclicMatrix::extract_submatrix, lo, up); - - tile.block(lo, up) = remtile_fut.get(); + owner(i, j), + &BlockCyclicMatrix::template extract_submatrix>, + lo, up); + + if constexpr (TiledArray::detail::is_ta_tensor_v) + tile.block(lo, up) = remtile_fut.get(); + else { + auto tile_blk_range = TiledArray::BlockRange( + TiledArray::detail::make_ta_range(tile.range()), lo, up); + using std::data; + auto tile_blk_view = + TiledArray::make_map(data(tile), tile_blk_range); + tile_blk_view = remtile_fut.get(); + } } } - return tile.norm(); + return norm(tile); }; return make_array(world_base_t::get_world(), trange, construct_tile); @@ -298,7 +329,7 @@ std::remove_cv_t block_cyclic_to_array( return matrix.template tensor_from_matrix>(trange); } -} // namespace TiledArray +} // namespace TiledArray::math::linalg::scalapack #endif // TILEDARRAY_HAS_SCALAPACK #endif // TILEDARRAY_MATH_LINALG_SCALAPACK_TO_BLOCKCYCLIC_H__INCLUDED diff --git a/src/TiledArray/tensor/tensor.h b/src/TiledArray/tensor/tensor.h index 06b1c8cc4b..fe76b07bd0 100644 --- a/src/TiledArray/tensor/tensor.h +++ b/src/TiledArray/tensor/tensor.h @@ -369,6 +369,7 @@ class Tensor { template ::value>::type* = nullptr> Tensor_& operator=(const T1& other) { + pimpl_ = std::make_shared(detail::clone_range(other)); detail::inplace_tensor_op( [](reference MADNESS_RESTRICT tr, typename T1::const_reference MADNESS_RESTRICT t1) { tr = t1; }, diff --git a/tests/linalg.cpp b/tests/linalg.cpp index b5c67e14fc..de9139e081 100644 --- a/tests/linalg.cpp +++ b/tests/linalg.cpp @@ -44,9 +44,9 @@ struct ReferenceFixture { #endif } - inline double make_ta_reference(TA::Tensor& t, - TA::Range const& range) { - t = TA::Tensor(range, 0.0); + template + inline double make_ta_reference(Tile& t, TA::Range const& range) { + t = Tile(range, 0.0); auto lo = range.lobound_data(); auto up = range.upbound_data(); for (auto m = lo[0]; m < up[0]; ++m) { @@ -55,7 +55,7 @@ struct ReferenceFixture { } } - return t.norm(); + return norm(t); }; ReferenceFixture(int64_t N = 1000) @@ -308,23 +308,35 @@ BOOST_AUTO_TEST_CASE(bc_to_sparse_tiled_array_test) { auto trange = gen_trange(N, {static_cast(NB)}); - auto ref_ta = TA::make_array>( - *GlobalFixture::world, trange, - [this](TA::Tensor& t, TA::Range const& range) -> double { - return this->make_ta_reference(t, range); - }); + // test with TA and btas tile + using typelist_t = + std::tuple, btas::Tensor>; + typelist_t typevals; - GlobalFixture::world->gop.fence(); - auto test_ta = scalapack::block_cyclic_to_array>( - ref_matrix, trange); - GlobalFixture::world->gop.fence(); + auto test = [&](const auto& typeval_ref) { + using Tile = std::decay_t; + using Array = TA::DistArray; - auto norm_diff = - (ref_ta("i,j") - test_ta("i,j")).norm(*GlobalFixture::world).get(); + auto ref_ta = TA::make_array( + *GlobalFixture::world, trange, + [this](Tile& t, TA::Range const& range) -> double { + return this->make_ta_reference(t, range); + }); - BOOST_CHECK_SMALL(norm_diff, std::numeric_limits::epsilon()); + GlobalFixture::world->gop.fence(); + auto test_ta = scalapack::block_cyclic_to_array(ref_matrix, trange); + GlobalFixture::world->gop.fence(); - GlobalFixture::world->gop.fence(); + auto norm_diff = + (ref_ta("i,j") - test_ta("i,j")).norm(*GlobalFixture::world).get(); + + BOOST_CHECK_SMALL(norm_diff, std::numeric_limits::epsilon()); + + GlobalFixture::world->gop.fence(); + }; + + test(std::get<0>(typevals)); + test(std::get<1>(typevals)); }; BOOST_AUTO_TEST_CASE(sparse_tiled_array_to_bc_test) { @@ -337,29 +349,42 @@ BOOST_AUTO_TEST_CASE(sparse_tiled_array_to_bc_test) { auto trange = gen_trange(N, {static_cast(NB)}); - auto ref_ta = TA::make_array>( - *GlobalFixture::world, trange, - [this](TA::Tensor& t, TA::Range const& range) -> double { - return this->make_ta_reference(t, range); - }); + // test with TA and btas tile + using typelist_t = + std::tuple, btas::Tensor>; + typelist_t typevals; - GlobalFixture::world->gop.fence(); - auto test_matrix = scalapack::array_to_block_cyclic(ref_ta, grid, NB, NB); - GlobalFixture::world->gop.fence(); + auto test = [&](const auto& typeval_ref) { + using Tile = std::decay_t; + using Array = TA::DistArray; - double local_norm_diff = - (test_matrix.local_mat() - ref_matrix.local_mat()).norm(); - local_norm_diff *= local_norm_diff; + auto ref_ta = TA::make_array( + *GlobalFixture::world, trange, + [this](Tile& t, TA::Range const& range) -> double { + return this->make_ta_reference(t, range); + }); - double norm_diff; - MPI_Allreduce(&local_norm_diff, &norm_diff, 1, MPI_DOUBLE, MPI_SUM, - MPI_COMM_WORLD); + GlobalFixture::world->gop.fence(); + auto test_matrix = scalapack::array_to_block_cyclic(ref_ta, grid, NB, NB); + GlobalFixture::world->gop.fence(); - norm_diff = std::sqrt(norm_diff); + double local_norm_diff = + (test_matrix.local_mat() - ref_matrix.local_mat()).norm(); + local_norm_diff *= local_norm_diff; - BOOST_CHECK_SMALL(norm_diff, std::numeric_limits::epsilon()); + double norm_diff; + MPI_Allreduce(&local_norm_diff, &norm_diff, 1, MPI_DOUBLE, MPI_SUM, + MPI_COMM_WORLD); - GlobalFixture::world->gop.fence(); + norm_diff = std::sqrt(norm_diff); + + BOOST_CHECK_SMALL(norm_diff, std::numeric_limits::epsilon()); + + GlobalFixture::world->gop.fence(); + }; + + test(std::get<0>(typevals)); + test(std::get<1>(typevals)); }; BOOST_AUTO_TEST_CASE(const_tiled_array_to_bc_test) {