Skip to content

Commit

Permalink
Make a new CP ALS which takes the THC format
Browse files Browse the repository at this point in the history
  • Loading branch information
kmp5VT committed Dec 24, 2024
1 parent 65f75f1 commit a52b369
Show file tree
Hide file tree
Showing 4 changed files with 327 additions and 8 deletions.
2 changes: 2 additions & 0 deletions src/TiledArray/math/solvers/cp.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,12 @@

#include <TiledArray/math/solvers/cp/cp.h>
#include <TiledArray/math/solvers/cp/cp_als.h>
#include <TiledArray/math/solvers/cp/cp_thc_als.h>
#include <TiledArray/math/solvers/cp/cp_reconstruct.h>

namespace TiledArray {
using TiledArray::math::cp::CP_ALS;
using TiledArray::math::cp::CP_THC_ALS;
using TiledArray::math::cp::cp_reconstruct;
} // namespace TiledArray

Expand Down
17 changes: 9 additions & 8 deletions src/TiledArray/math/solvers/cp/cp.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ class CP {
/// \returns the fit: \f$ 1.0 - |T_{\text{exact}} - T_{\text{approx}} | \f$
double compute_rank(size_t rank, size_t rank_block_size = 0,
bool build_rank = false, double epsilonALS = 1e-3,
bool verbose = false) {
bool verbose = false, int niters = 100) {
rank_block_size = (rank_block_size == 0 ? rank : rank_block_size);
double epsilon = 1.0;
fit_tol = epsilonALS;
Expand All @@ -101,13 +101,13 @@ class CP {
do {
rank_trange = TiledRange1::make_uniform(cur_rank, rank_block_size);
build_guess(cur_rank, rank_trange);
ALS(cur_rank, 100, verbose);
ALS(cur_rank, niters, verbose);
++cur_rank;
} while (cur_rank < rank);
} else {
rank_trange = TiledRange1::make_uniform(rank, rank_block_size);
build_guess(rank, rank_trange);
ALS(rank, 100, verbose);
ALS(rank, niters, verbose);
}
return epsilon;
}
Expand Down Expand Up @@ -185,7 +185,8 @@ class CP {
final_fit, // The final fit of the ALS
// optimization at fixed rank.
fit_tol, // Tolerance for the ALS solver
norm_reference; // used in determining the CP fit.
norm_reference, // used in determining the CP fit.
norm_ref_sq;
std::size_t converged_num =
0; // How many times the ALS solver
// has changed less than the tolerance in a row
Expand Down Expand Up @@ -370,16 +371,16 @@ class CP {
for (size_t i = 1; i < ndim - 1; ++i, ++gram_ptr) {
W("r,rp") *= (*gram_ptr)("r,rp");
}
auto result = sqrt(W("r,rp").dot(
(unNormalized_Factor("r,n") * unNormalized_Factor("rp,n"))));
auto result = W("r,rp").dot(
(unNormalized_Factor("r,n") * unNormalized_Factor("rp,n")));
// not sure why need to fence here, but hang periodically without it
W.world().gop.fence();
return result;
};
// compute the error in the loss function and find the fit
const auto norm_cp = factor_norm(); // ||T_CP||_2
const auto squared_norm_error = norm_reference * norm_reference +
norm_cp * norm_cp -
const auto squared_norm_error = norm_ref_sq +
norm_cp -
2.0 * ref_dot_cp; // ||T - T_CP||_2^2
// N.B. squared_norm_error is very noisy
// TA_ASSERT(squared_norm_error >= - 1e-8);
Expand Down
1 change: 1 addition & 0 deletions src/TiledArray/math/solvers/cp/cp_als.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ class CP_ALS : public CP<Tile, Policy> {
first_gemm_dim_last.pop_back();

this->norm_reference = norm2(tref);
this->norm_ref_sq = this->norm_reference * this->norm_reference;
}

protected:
Expand Down
315 changes: 315 additions & 0 deletions src/TiledArray/math/solvers/cp/cp_thc_als.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,315 @@
/*
* This file is a part of TiledArray.
* Copyright (C) 2023 Virginia Tech
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*
* Karl Pierce
* Department of Chemistry, Virginia Tech
*
* cp.h
* April 17, 2022
*
*/

#ifndef TILEDARRAY_MATH_SOLVERS_CP_CP_THC_ALS__H
#define TILEDARRAY_MATH_SOLVERS_CP_CP_THC_ALS__H

#include <TiledArray/math/solvers/cp/cp.h>
#include <TiledArray/expressions/einsum.h>
#include <TiledArray/math/solvers/cp/cp_reconstruct.h>

namespace TiledArray::math::cp {

/**
* This is a canonical polyadic (CP) optimization class which
* takes a reference order-N tensor that is expressed in the THC format
* and decomposes it into a set of order-2 tensors all coupled by
* a hyperdimension called the rank. These factors are optimized
* using an alternating least squares algorithm.
*
* @tparam Tile typing for the DistArray tiles
* @tparam Policy policy of the DistArray
**/
template <typename Tile, typename Policy>
class CP_THC_ALS : public CP<Tile, Policy> {
public:
using CP<Tile, Policy>::ndim;
using CP<Tile, Policy>::cp_factors;

/// Default CP_ALS constructor
CP_THC_ALS() = default;

/// CP_ALS constructor function
/// takes, as a constant reference, the tensor to be decomposed
/// \param[in] tref A constant reference to the tensor to be decomposed.
// for now I am going to assume an order-4 THC but later this will be used for
// arbitrary order.
CP_THC_ALS(const DistArray<Tile, Policy>& tref1, const DistArray<Tile, Policy>& tref2, const DistArray<Tile, Policy>& tref3)
: CP<Tile, Policy>(2 * rank(tref3)), ref_orb_a(tref1), ref_orb_b(tref2), ref_core(tref3), world(tref1.world()) {

DistArray<Tile, Policy> pr, pq;
pr("r,rp") = (ref_orb_a("a,r") * ref_orb_a("a,rp")) * (ref_orb_b("i,r") * ref_orb_b("i,rp"));
pq("p,q") = ref_core("p,r") * pr("r,rp") * ref_core("q,rp");
this->norm_ref_sq = pq("r,rp").dot(pr("r,rp")).get();
this->norm_reference = sqrt(this->norm_ref_sq);
}

protected:
const DistArray<Tile, Policy>& ref_orb_a, ref_orb_b, ref_core;
madness::World& world;
std::vector<typename Tile::value_type> lambda;
std::vector<DistArray<Tile, Policy>> THC_times_CPD;
TiledRange1 rank_trange1;
size_t size_of_core;

/// This function constructs the initial CP factor matrices
/// stores them in CP::cp_factors vector.
/// In general the initial guess is constructed using quasi-random numbers
/// generated between [-1, 1]
/// \param[in] rank rank of the CP approximation
/// \param[in] rank_trange TiledRange1 of the rank dimension.
void build_guess(const size_t rank, const TiledRange1 rank_trange) override {
rank_trange1 = rank_trange;
if (cp_factors.size() == 0) {
cp_factors.emplace_back(this->construct_random_factor(
world, rank, ref_orb_a.trange().elements_range().extent(0),
rank_trange, ref_orb_a.trange().data()[0]));
cp_factors.emplace_back(this->construct_random_factor(
world, rank, ref_orb_b.trange().elements_range().extent(0),
rank_trange, ref_orb_b.trange().data()[0]));
cp_factors.emplace_back(this->construct_random_factor(
world, rank, ref_orb_a.trange().elements_range().extent(0),
rank_trange, ref_orb_a.trange().data()[0]));
cp_factors.emplace_back(this->construct_random_factor(
world, rank, ref_orb_b.trange().elements_range().extent(0),
rank_trange, ref_orb_b.trange().data()[0]));
} else {
TA_EXCEPTION("Currently no implementation to increase or change rank");
}

return;
}

/// This function is specified by the CP solver
/// optimizes the rank @c rank CP approximation
/// stored in cp_factors.
/// \param[in] rank rank of the CP approximation
/// \param[in] max_iter max number of ALS iterations
/// \param[in] verbose Should ALS print fit information while running?
void ALS(size_t rank, size_t max_iter, bool verbose = false) override {
size_t iter = 0;
bool converged = false;
auto nthc = TA::rank(ref_core);
// initialize partial grammians
{
auto ptr = this->partial_grammian.begin();
for (auto& i : cp_factors) {
(*ptr)("r,rp") = i("r,n") * i("rp, n");
++ptr;
}
DistArray<Tile, Policy> pq;
pq("p,q") = ref_orb_a("a,p") * cp_factors[2]("q,a");
THC_times_CPD.emplace_back(pq);
pq("p,q") *= ref_orb_b("b,p") * cp_factors[3]("q,b");
pq.truncate();
THC_times_CPD.emplace_back(pq);

}
// auto factor_begin = cp_factors.data(),
// gram_begin = this->partial_grammian.data();
// DistArray<Tile, Policy> abr, tref;
// abr = einsum(ref_orb_a("a,r"), ref_orb_b("b,r"), "a,b,r");
// tref("a,b,c,d") = abr("a,b,r") * ref_core("r,rp") * abr("c,d,rp");
//
// std::cout << "Norm2 : " << norm2(tref) << std::endl;
// std::cout << "set: " << this->norm_reference << std::endl;
do {
update_factors_left();
update_factors_right();
converged = this->check_fit(verbose);
// for (auto i = 0; i < nthc; ++i) {
// update_factor(i, rank);
// }
++iter;
} while (iter < max_iter && !converged);
}

void update_factor(size_t mode, size_t rank){

size_t pos = 2 * mode, pos_plus_one = pos + 1;
// going through the core 0 is associated with factors 0 and 1
// core 1 associated with factors 2 and 3 ...
// First we need to take other side of the problem and contract it with the core
// if core is greater than 2 I need to contract all the centers but
// here there's only one so just do one contract
size_t other_side = (mode + 1) % 2;
DistArray<Tile, Policy> env, pq, W_env, W;
{
DistArray<Tile, Policy> An;
env("p,q") = ref_core("p,r") * THC_times_CPD[other_side]("r,q");

pq("p,q") = ref_orb_b("b,p") * cp_factors[pos_plus_one]("q,b");
An("q,a") = (pq("p,q") * env("p,q")) * ref_orb_a("a,p");

// TODO check to see if the Cholesky will fail. If it does
// use SVD
DistArray<Tile, Policy> W;
other_side *= 2;
W_env("p,q") = this->partial_grammian[other_side]("p,q") *
this->partial_grammian[other_side + 1]("p,q");
W("p,q") = this->partial_grammian[pos_plus_one]("p,q") * W_env("p,q");

this->cholesky_inverse(An, W);
world.gop.fence(); // N.B. seems to deadlock without this

this->normalize_factor(An);
cp_factors[pos] = An;
auto& gram = this->partial_grammian[pos];
gram("r,rp") = An("r,n") * An("rp,n");
pq("p,q") = ref_orb_a("a,p") * cp_factors[pos]("q,a");
THC_times_CPD[mode] = pq;
}

// Finished with the first factor in THC.
// Starting the second factor
{
DistArray<Tile, Policy> Bn;
Bn = DistArray<Tile, Policy>();
Bn("q,b") = einsum(pq("p,q"), env("p,q"), "p,q")("p,q") * ref_orb_b("b,p");

this->MTtKRP = Bn;

// TODO check to see if the Cholesky will fail. If it does
// use SVD
W("p,q") = this->partial_grammian[pos]("p,q") * W_env("p,q");
this->cholesky_inverse(Bn, W);
world.gop.fence(); // N.B. seems to deadlock without this

this->unNormalized_Factor = Bn.clone();
this->normalize_factor(Bn);
cp_factors[pos_plus_one] = Bn;
auto& gram = this->partial_grammian[pos_plus_one];
gram("r,rp") = Bn("r,n") * Bn("rp,n");
THC_times_CPD[mode]("p,q") *= ref_orb_b("b,p") * cp_factors[pos_plus_one]("q,b");
}
}
void update_factors_left(){

// going through the core 0 is associated with factors 0 and 1
// core 1 associated with factors 2 and 3 ...
// First we need to take other side of the problem and contract it with the core
// if core is greater than 2 I need to contract all the centers but
// here there's only one so just do one contract
DistArray<Tile, Policy> env, pq, W_env, W;
{
DistArray<Tile, Policy> An;
env("p,q") = ref_core("p,r") * THC_times_CPD[1]("r,q");

pq("p,q") = ref_orb_b("b,p") * cp_factors[1]("q,b");
An("q,a") = (pq("p,q") * env("p,q")) * ref_orb_a("a,p");

// TODO check to see if the Cholesky will fail. If it does
// use SVD
W_env("p,q") = this->partial_grammian[2]("p,q") *
this->partial_grammian[3]("p,q");
W("p,q") = this->partial_grammian[1]("p,q") * W_env("p,q");

this->cholesky_inverse(An, W);
world.gop.fence(); // N.B. seems to deadlock without this

this->normalize_factor(An);
cp_factors[0] = An;
this->partial_grammian[0]("r,rp") = An("r,n") * An("rp,n");
pq("p,q") = ref_orb_a("a,p") * An("q,a");
THC_times_CPD[0] = pq;
}

// Finished with the first factor in THC.
// Starting the second factor
{
DistArray<Tile, Policy> Bn;
Bn = DistArray<Tile, Policy>();
Bn("q,b") = (pq("p,q") * env("p,q")) * ref_orb_b("b,p");

// TODO check to see if the Cholesky will fail. If it does
// use SVD
W("p,q") = this->partial_grammian[0]("p,q") * W_env("p,q");
this->cholesky_inverse(Bn, W);
world.gop.fence(); // N.B. seems to deadlock without this

this->normalize_factor(Bn);
cp_factors[1] = Bn;
this->partial_grammian[1]("r,rp") = Bn("r,n") * Bn("rp,n");
THC_times_CPD[0]("p,q") *= ref_orb_b("b,p") * Bn("q,b");
}
}
void update_factors_right(){

// going through the core 0 is associated with factors 0 and 1
// core 1 associated with factors 2 and 3 ...
// First we need to take other side of the problem and contract it with the core
// if core is greater than 2 I need to contract all the centers but
// here there's only one so just do one contract
DistArray<Tile, Policy> env, pq, W_env, W;
{
DistArray<Tile, Policy> An;
env("p,q") = ref_core("r,p") * THC_times_CPD[0]("r,q");

pq("p,q") = ref_orb_b("b,p") * cp_factors[3]("q,b");
An("q,a") = (pq("p,q") * env("p,q")) * ref_orb_a("a,p");

// TODO check to see if the Cholesky will fail. If it does
// use SVD
W_env("p,q") = this->partial_grammian[0]("p,q") *
this->partial_grammian[1]("p,q");
W("p,q") = this->partial_grammian[3]("p,q") * W_env("p,q");

this->cholesky_inverse(An, W);
world.gop.fence(); // N.B. seems to deadlock without this

this->normalize_factor(An);
cp_factors[2] = An;
this->partial_grammian[2]("r,rp") = An("r,n") * An("rp,n");
pq("p,q") = ref_orb_a("a,p") * An("q,a");
THC_times_CPD[1] = pq;
}

// Finished with the first factor in THC.
// Starting the second factor
{
DistArray<Tile, Policy> Bn;
Bn = DistArray<Tile, Policy>();
Bn("q,b") = (pq("p,q") * env("p,q")) * ref_orb_b("b,p");

this->MTtKRP = Bn;

// TODO check to see if the Cholesky will fail. If it does
// use SVD
W("p,q") = this->partial_grammian[2]("p,q") * W_env("p,q");
this->cholesky_inverse(Bn, W);
world.gop.fence(); // N.B. seems to deadlock without this

this->unNormalized_Factor = Bn.clone();
this->normalize_factor(Bn);
cp_factors[3] = Bn;
this->partial_grammian[3]("r,rp") = Bn("r,n") * Bn("rp,n");
THC_times_CPD[1]("p,q") *= ref_orb_b("b,p") * Bn("q,b");
}
}
};

} // namespace TiledArray::math::cp

#endif // TILEDARRAY_MATH_SOLVERS_CP_CP_THC_ALS__H

0 comments on commit a52b369

Please sign in to comment.