diff --git a/src/TiledArray/tensor/kernels.h b/src/TiledArray/tensor/kernels.h index 699496d77e..5d40ce5c14 100644 --- a/src/TiledArray/tensor/kernels.h +++ b/src/TiledArray/tensor/kernels.h @@ -599,8 +599,17 @@ inline void inplace_tensor_op(Op&& op, TR& result, const Ts&... tensors) { [&op, stride]( typename TR::pointer MADNESS_RESTRICT const result_data, typename Ts::const_pointer MADNESS_RESTRICT const... tensors_data) { - for (decltype(result.range().volume()) i = 0ul; i < stride; ++i) - inplace_tensor_op(op, result_data[i], tensors_data[i]...); + for (decltype(result.range().volume()) i = 0ul; i < stride; ++i) { + if constexpr (std::is_invocable_v< + std::remove_reference_t, + typename std::remove_reference_t::value_type&, + typename std::remove_reference_t< + Ts>::value_type const&...>) { + std::forward(op)(result_data[i], tensors_data[i]...); + } else { + inplace_tensor_op(op, result_data[i], tensors_data[i]...); + } + } }; for (std::decay_t ord = 0ul; ord < volume; ord += stride) diff --git a/src/TiledArray/tensor/tensor_interface.h b/src/TiledArray/tensor/tensor_interface.h index a9e67318d0..6ba8f0430e 100644 --- a/src/TiledArray/tensor/tensor_interface.h +++ b/src/TiledArray/tensor/tensor_interface.h @@ -198,18 +198,9 @@ class TensorInterface { TA_ASSERT(data_ != other.data()); } - if constexpr (detail::is_tensor_v) { - range_ = BlockRange(other.range(), other.range().lobound(), - other.range().upbound()); - data_ = new value_type[other.total_size()]; - auto cpy = other.clone(); - for (auto i = 0; i < other.total_size(); ++i) - std::swap(data_[i], cpy.data()[i]); - } else { - detail::inplace_tensor_op([](numeric_type& MADNESS_RESTRICT result, - const numeric_t arg) { result = arg; }, - *this, other); - } + detail::inplace_tensor_op( + [](value_type& MADNESS_RESTRICT result, auto&& arg) { result = arg; }, + *this, other); return *this; } diff --git a/tests/retile.cpp b/tests/retile.cpp index 8d72dc6903..0f4100d4c8 100644 --- a/tests/retile.cpp +++ b/tests/retile.cpp @@ -39,51 +39,60 @@ BOOST_AUTO_TEST_CASE(retile_more) { auto const tr_source = TA::TiledRange({{0, 2, 4, 8}, {0, 3, 5}}); auto const tr_target = TA::TiledRange({{0, 4, 6, 8}, {0, 2, 4, 5}}); + auto const& elem_rng = tr_source.elements_range(); + + BOOST_REQUIRE(elem_rng.volume() == tr_target.elements_range().volume()); - auto rand_num = [](auto&&) { - return TA::detail::MakeRandom::generate_value(); + auto const inner_rng = TA::Range({3, 3}); + + auto rand_tensor = [](auto const& rng) -> T { + return T(rng, [](auto&&) { + return TA::detail::MakeRandom::generate_value(); + }); }; - auto rand_tensor = [rand_num](auto const& rng) -> T { - return T(rng, rand_num); + auto set_random_tensor_tile = [rand_tensor](auto& tile, auto const& rng) { + tile = rand_tensor(rng); + return tile.norm(); }; - auto rand_tensor_of_tensor = [rand_tensor](auto const& inner_rng) { - return [rand_tensor, inner_rng](auto const& rng) -> ToT { - return ToT(rng, rand_tensor(inner_rng)); - }; + auto rand_tensor_of_tensor = [rand_tensor, + inner_rng](auto const& rng) -> ToT { + return ToT(rng, [rand_tensor, inner_rng](auto&&) { + return rand_tensor(inner_rng); + }); }; auto set_random_tensor_of_tensor_tile = [rand_tensor_of_tensor]( - auto const& inner_rng) { - return - [gen = rand_tensor_of_tensor(inner_rng)](auto& tile, auto const& rng) { - tile = gen(rng); - return tile.norm(); - }; + auto& tile, auto const& rng) { + tile = rand_tensor_of_tensor(rng); + return tile.norm(); }; - auto const inner_rng = TA::Range({3, 3}); - auto arr_source = TA::make_array( - world, tr_source, set_random_tensor_of_tensor_tile(inner_rng)); - arr_source.truncate(); + auto get_elem = [](auto const& arr, auto const& eix) { + auto tix = arr.trange().element_to_tile(eix); + auto&& tile = arr.find(tix).get(false); + return tile(eix); + }; + + auto arr_source0 = + TA::make_array(world, tr_source, set_random_tensor_tile); + auto arr_target0 = TA::retile(arr_source0, tr_target); + for (auto&& eix : elem_rng) { + BOOST_REQUIRE(get_elem(arr_source0, eix) == get_elem(arr_target0, eix)); + } + + auto arr_source = TA::make_array(world, tr_source, + set_random_tensor_of_tensor_tile); auto arr_target = TA::retile(arr_source, tr_target); arr_source.make_replicated(); - world.gop.fence(); arr_target.make_replicated(); + arr_source.truncate(); + arr_target.truncate(); world.gop.fence(); - auto const& elem_rng = tr_source.elements_range(); - BOOST_REQUIRE(elem_rng.volume() == tr_target.elements_range().volume()); - - auto get_elem = [](ArrayToT const& arr, auto const& eix) { - auto tix = arr.trange().element_to_tile(eix); - auto&& tile = arr.find(tix).get(false); - return tile(eix); - }; - for (auto&& eix : elem_rng) { BOOST_REQUIRE(get_elem(arr_source, eix) == get_elem(arr_target, eix)); }