Skip to content

Commit

Permalink
test outer product with sparse KroneckerDeltaTile
Browse files Browse the repository at this point in the history
  • Loading branch information
evaleev committed Aug 24, 2024
1 parent af1d41c commit 828699e
Showing 1 changed file with 18 additions and 3 deletions.
21 changes: 18 additions & 3 deletions tests/expressions_mixed.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
*
*/

#include "TiledArray/special/diagonal_array.h"
#include "TiledArray/special/kronecker_delta.h"
#include "range_fixture.h"
#include "sparse_tile.h"
Expand All @@ -38,6 +39,7 @@ struct MixedExpressionsFixture : public TiledRangeFixture {
typedef DistArray<EigenSparseTile<double, tag<0>>, DensePolicy> TArrayDS1;
typedef DistArray<EigenSparseTile<double, tag<1>>, DensePolicy> TArrayDS2;
typedef DistArray<KroneckerDeltaTile, DensePolicy> ArrayKronDelta1;
typedef DistArray<KroneckerDeltaTile, SparsePolicy> SpArrayKronDelta1;

MixedExpressionsFixture()
: u(*GlobalFixture::world, trange2),
Expand All @@ -51,19 +53,26 @@ struct MixedExpressionsFixture : public TiledRangeFixture {
*GlobalFixture::world, trange2.tiles_range().volume())),
delta1e(*GlobalFixture::world, trange2e, DenseShape(),
std::make_shared<detail::ReplicatedPmap>(
*GlobalFixture::world, trange2e.tiles_range().volume())) {
*GlobalFixture::world, trange2e.tiles_range().volume())),
spe2(*GlobalFixture::world, trange2e),
spdelta1(*GlobalFixture::world, trange2,
SparseShape(detail::diagonal_shape(trange2, 1), trange2),
std::make_shared<detail::ReplicatedPmap>(
*GlobalFixture::world, trange2.tiles_range().volume())) {
random_fill(u);
random_fill(v);
u2.fill(0);
random_fill(e2);
e4.fill(0);
init_kronecker_delta(delta1);
init_kronecker_delta(delta1e);
random_fill(spe2);
init_kronecker_delta(spdelta1);
GlobalFixture::world->gop.fence();
}

template <typename Tile>
static void random_fill(DistArray<Tile>& array) {
template <typename Tile, typename Policy>
static void random_fill(DistArray<Tile, Policy>& array) {
array.fill_random();
}

Expand Down Expand Up @@ -133,6 +142,8 @@ struct MixedExpressionsFixture : public TiledRangeFixture {
TArrayDS2 w;
ArrayKronDelta1 delta1;
ArrayKronDelta1 delta1e;
TSpArrayD spe2;
SpArrayKronDelta1 spdelta1;
}; // MixedExpressionsFixture

// Instantiate static variables for fixture
Expand Down Expand Up @@ -194,6 +205,10 @@ BOOST_AUTO_TEST_CASE(outer_product_factories) {
// ok
BOOST_CHECK_NO_THROW(u2("a,b,c,d") += delta1("a,b") * u("c,d"));

// ok
TSpArrayD tmp;
BOOST_CHECK_NO_THROW(tmp("a,b,c,d") = spdelta1("a,b") * spe2("c,d"));

// ok
BOOST_CHECK_NO_THROW(e4("a,c,b,d") += delta1e("a,b") * e2("c,d"));
}
Expand Down

0 comments on commit 828699e

Please sign in to comment.