diff --git a/src/TiledArray/math/solvers/cp.h b/src/TiledArray/math/solvers/cp.h index f94ea259ff..b957d33539 100644 --- a/src/TiledArray/math/solvers/cp.h +++ b/src/TiledArray/math/solvers/cp.h @@ -28,10 +28,12 @@ #include #include +#include #include namespace TiledArray { using TiledArray::math::cp::CP_ALS; +using TiledArray::math::cp::CP_THC_ALS; using TiledArray::math::cp::cp_reconstruct; } // namespace TiledArray diff --git a/src/TiledArray/math/solvers/cp/cp.h b/src/TiledArray/math/solvers/cp/cp.h index 9065776211..dcde333055 100644 --- a/src/TiledArray/math/solvers/cp/cp.h +++ b/src/TiledArray/math/solvers/cp/cp.h @@ -91,7 +91,7 @@ class CP { /// \returns the fit: \f$ 1.0 - |T_{\text{exact}} - T_{\text{approx}} | \f$ double compute_rank(size_t rank, size_t rank_block_size = 0, bool build_rank = false, double epsilonALS = 1e-3, - bool verbose = false) { + bool verbose = false, int niters = 100) { rank_block_size = (rank_block_size == 0 ? rank : rank_block_size); double epsilon = 1.0; fit_tol = epsilonALS; @@ -101,13 +101,13 @@ class CP { do { rank_trange = TiledRange1::make_uniform(cur_rank, rank_block_size); build_guess(cur_rank, rank_trange); - ALS(cur_rank, 100, verbose); + ALS(cur_rank, niters, verbose); ++cur_rank; } while (cur_rank < rank); } else { rank_trange = TiledRange1::make_uniform(rank, rank_block_size); build_guess(rank, rank_trange); - ALS(rank, 100, verbose); + ALS(rank, niters, verbose); } return epsilon; } @@ -185,7 +185,8 @@ class CP { final_fit, // The final fit of the ALS // optimization at fixed rank. fit_tol, // Tolerance for the ALS solver - norm_reference; // used in determining the CP fit. + norm_reference, // used in determining the CP fit. + norm_ref_sq; std::size_t converged_num = 0; // How many times the ALS solver // has changed less than the tolerance in a row @@ -370,16 +371,16 @@ class CP { for (size_t i = 1; i < ndim - 1; ++i, ++gram_ptr) { W("r,rp") *= (*gram_ptr)("r,rp"); } - auto result = sqrt(W("r,rp").dot( - (unNormalized_Factor("r,n") * unNormalized_Factor("rp,n")))); + auto result = W("r,rp").dot( + (unNormalized_Factor("r,n") * unNormalized_Factor("rp,n"))); // not sure why need to fence here, but hang periodically without it W.world().gop.fence(); return result; }; // compute the error in the loss function and find the fit const auto norm_cp = factor_norm(); // ||T_CP||_2 - const auto squared_norm_error = norm_reference * norm_reference + - norm_cp * norm_cp - + const auto squared_norm_error = norm_ref_sq + + norm_cp - 2.0 * ref_dot_cp; // ||T - T_CP||_2^2 // N.B. squared_norm_error is very noisy // TA_ASSERT(squared_norm_error >= - 1e-8); diff --git a/src/TiledArray/math/solvers/cp/cp_als.h b/src/TiledArray/math/solvers/cp/cp_als.h index bd5c5b9eee..4f44fa4aa1 100644 --- a/src/TiledArray/math/solvers/cp/cp_als.h +++ b/src/TiledArray/math/solvers/cp/cp_als.h @@ -69,6 +69,7 @@ class CP_ALS : public CP { first_gemm_dim_last.pop_back(); this->norm_reference = norm2(tref); + this->norm_ref_sq = this->norm_reference * this->norm_reference; } protected: diff --git a/src/TiledArray/math/solvers/cp/cp_thc_als.h b/src/TiledArray/math/solvers/cp/cp_thc_als.h new file mode 100644 index 0000000000..8aa919aa3d --- /dev/null +++ b/src/TiledArray/math/solvers/cp/cp_thc_als.h @@ -0,0 +1,315 @@ +/* +* This file is a part of TiledArray. +* Copyright (C) 2023 Virginia Tech +* +* This program is free software: you can redistribute it and/or modify +* it under the terms of the GNU General Public License as published by +* the Free Software Foundation, either version 3 of the License, or +* (at your option) any later version. +* +* This program is distributed in the hope that it will be useful, +* but WITHOUT ANY WARRANTY; without even the implied warranty of +* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +* GNU General Public License for more details. +* +* You should have received a copy of the GNU General Public License +* along with this program. If not, see . +* +* Karl Pierce +* Department of Chemistry, Virginia Tech +* +* cp.h +* April 17, 2022 +* +*/ + +#ifndef TILEDARRAY_MATH_SOLVERS_CP_CP_THC_ALS__H +#define TILEDARRAY_MATH_SOLVERS_CP_CP_THC_ALS__H + +#include +#include +#include + +namespace TiledArray::math::cp { + +/** +* This is a canonical polyadic (CP) optimization class which +* takes a reference order-N tensor that is expressed in the THC format +* and decomposes it into a set of order-2 tensors all coupled by +* a hyperdimension called the rank. These factors are optimized +* using an alternating least squares algorithm. +* +* @tparam Tile typing for the DistArray tiles +* @tparam Policy policy of the DistArray +**/ +template +class CP_THC_ALS : public CP { +public: + using CP::ndim; + using CP::cp_factors; + + /// Default CP_ALS constructor + CP_THC_ALS() = default; + + /// CP_ALS constructor function + /// takes, as a constant reference, the tensor to be decomposed + /// \param[in] tref A constant reference to the tensor to be decomposed. + // for now I am going to assume an order-4 THC but later this will be used for + // arbitrary order. + CP_THC_ALS(const DistArray& tref1, const DistArray& tref2, const DistArray& tref3) + : CP(2 * rank(tref3)), ref_orb_a(tref1), ref_orb_b(tref2), ref_core(tref3), world(tref1.world()) { + + DistArray pr, pq; + pr("r,rp") = (ref_orb_a("a,r") * ref_orb_a("a,rp")) * (ref_orb_b("i,r") * ref_orb_b("i,rp")); + pq("p,q") = ref_core("p,r") * pr("r,rp") * ref_core("q,rp"); + this->norm_ref_sq = pq("r,rp").dot(pr("r,rp")).get(); + this->norm_reference = sqrt(this->norm_ref_sq); + } + +protected: + const DistArray& ref_orb_a, ref_orb_b, ref_core; + madness::World& world; + std::vector lambda; + std::vector> THC_times_CPD; + TiledRange1 rank_trange1; + size_t size_of_core; + + /// This function constructs the initial CP factor matrices + /// stores them in CP::cp_factors vector. + /// In general the initial guess is constructed using quasi-random numbers + /// generated between [-1, 1] + /// \param[in] rank rank of the CP approximation + /// \param[in] rank_trange TiledRange1 of the rank dimension. + void build_guess(const size_t rank, const TiledRange1 rank_trange) override { + rank_trange1 = rank_trange; + if (cp_factors.size() == 0) { + cp_factors.emplace_back(this->construct_random_factor( + world, rank, ref_orb_a.trange().elements_range().extent(0), + rank_trange, ref_orb_a.trange().data()[0])); + cp_factors.emplace_back(this->construct_random_factor( + world, rank, ref_orb_b.trange().elements_range().extent(0), + rank_trange, ref_orb_b.trange().data()[0])); + cp_factors.emplace_back(this->construct_random_factor( + world, rank, ref_orb_a.trange().elements_range().extent(0), + rank_trange, ref_orb_a.trange().data()[0])); + cp_factors.emplace_back(this->construct_random_factor( + world, rank, ref_orb_b.trange().elements_range().extent(0), + rank_trange, ref_orb_b.trange().data()[0])); + } else { + TA_EXCEPTION("Currently no implementation to increase or change rank"); + } + + return; + } + + /// This function is specified by the CP solver + /// optimizes the rank @c rank CP approximation + /// stored in cp_factors. + /// \param[in] rank rank of the CP approximation + /// \param[in] max_iter max number of ALS iterations + /// \param[in] verbose Should ALS print fit information while running? + void ALS(size_t rank, size_t max_iter, bool verbose = false) override { + size_t iter = 0; + bool converged = false; + auto nthc = TA::rank(ref_core); + // initialize partial grammians + { + auto ptr = this->partial_grammian.begin(); + for (auto& i : cp_factors) { + (*ptr)("r,rp") = i("r,n") * i("rp, n"); + ++ptr; + } + DistArray pq; + pq("p,q") = ref_orb_a("a,p") * cp_factors[2]("q,a"); + THC_times_CPD.emplace_back(pq); + pq("p,q") *= ref_orb_b("b,p") * cp_factors[3]("q,b"); + pq.truncate(); + THC_times_CPD.emplace_back(pq); + + } +// auto factor_begin = cp_factors.data(), +// gram_begin = this->partial_grammian.data(); +// DistArray abr, tref; +// abr = einsum(ref_orb_a("a,r"), ref_orb_b("b,r"), "a,b,r"); +// tref("a,b,c,d") = abr("a,b,r") * ref_core("r,rp") * abr("c,d,rp"); +// +// std::cout << "Norm2 : " << norm2(tref) << std::endl; +// std::cout << "set: " << this->norm_reference << std::endl; + do { + update_factors_left(); + update_factors_right(); + converged = this->check_fit(verbose); + // for (auto i = 0; i < nthc; ++i) { + // update_factor(i, rank); + // } + ++iter; + } while (iter < max_iter && !converged); + } + + void update_factor(size_t mode, size_t rank){ + + size_t pos = 2 * mode, pos_plus_one = pos + 1; + // going through the core 0 is associated with factors 0 and 1 + // core 1 associated with factors 2 and 3 ... + // First we need to take other side of the problem and contract it with the core + // if core is greater than 2 I need to contract all the centers but + // here there's only one so just do one contract + size_t other_side = (mode + 1) % 2; + DistArray env, pq, W_env, W; + { + DistArray An; + env("p,q") = ref_core("p,r") * THC_times_CPD[other_side]("r,q"); + + pq("p,q") = ref_orb_b("b,p") * cp_factors[pos_plus_one]("q,b"); + An("q,a") = (pq("p,q") * env("p,q")) * ref_orb_a("a,p"); + + // TODO check to see if the Cholesky will fail. If it does + // use SVD + DistArray W; + other_side *= 2; + W_env("p,q") = this->partial_grammian[other_side]("p,q") * + this->partial_grammian[other_side + 1]("p,q"); + W("p,q") = this->partial_grammian[pos_plus_one]("p,q") * W_env("p,q"); + + this->cholesky_inverse(An, W); + world.gop.fence(); // N.B. seems to deadlock without this + + this->normalize_factor(An); + cp_factors[pos] = An; + auto& gram = this->partial_grammian[pos]; + gram("r,rp") = An("r,n") * An("rp,n"); + pq("p,q") = ref_orb_a("a,p") * cp_factors[pos]("q,a"); + THC_times_CPD[mode] = pq; + } + + // Finished with the first factor in THC. + // Starting the second factor + { + DistArray Bn; + Bn = DistArray(); + Bn("q,b") = einsum(pq("p,q"), env("p,q"), "p,q")("p,q") * ref_orb_b("b,p"); + + this->MTtKRP = Bn; + + // TODO check to see if the Cholesky will fail. If it does + // use SVD + W("p,q") = this->partial_grammian[pos]("p,q") * W_env("p,q"); + this->cholesky_inverse(Bn, W); + world.gop.fence(); // N.B. seems to deadlock without this + + this->unNormalized_Factor = Bn.clone(); + this->normalize_factor(Bn); + cp_factors[pos_plus_one] = Bn; + auto& gram = this->partial_grammian[pos_plus_one]; + gram("r,rp") = Bn("r,n") * Bn("rp,n"); + THC_times_CPD[mode]("p,q") *= ref_orb_b("b,p") * cp_factors[pos_plus_one]("q,b"); + } + } + void update_factors_left(){ + + // going through the core 0 is associated with factors 0 and 1 + // core 1 associated with factors 2 and 3 ... + // First we need to take other side of the problem and contract it with the core + // if core is greater than 2 I need to contract all the centers but + // here there's only one so just do one contract + DistArray env, pq, W_env, W; + { + DistArray An; + env("p,q") = ref_core("p,r") * THC_times_CPD[1]("r,q"); + + pq("p,q") = ref_orb_b("b,p") * cp_factors[1]("q,b"); + An("q,a") = (pq("p,q") * env("p,q")) * ref_orb_a("a,p"); + + // TODO check to see if the Cholesky will fail. If it does + // use SVD + W_env("p,q") = this->partial_grammian[2]("p,q") * + this->partial_grammian[3]("p,q"); + W("p,q") = this->partial_grammian[1]("p,q") * W_env("p,q"); + + this->cholesky_inverse(An, W); + world.gop.fence(); // N.B. seems to deadlock without this + + this->normalize_factor(An); + cp_factors[0] = An; + this->partial_grammian[0]("r,rp") = An("r,n") * An("rp,n"); + pq("p,q") = ref_orb_a("a,p") * An("q,a"); + THC_times_CPD[0] = pq; + } + + // Finished with the first factor in THC. + // Starting the second factor + { + DistArray Bn; + Bn = DistArray(); + Bn("q,b") = (pq("p,q") * env("p,q")) * ref_orb_b("b,p"); + + // TODO check to see if the Cholesky will fail. If it does + // use SVD + W("p,q") = this->partial_grammian[0]("p,q") * W_env("p,q"); + this->cholesky_inverse(Bn, W); + world.gop.fence(); // N.B. seems to deadlock without this + + this->normalize_factor(Bn); + cp_factors[1] = Bn; + this->partial_grammian[1]("r,rp") = Bn("r,n") * Bn("rp,n"); + THC_times_CPD[0]("p,q") *= ref_orb_b("b,p") * Bn("q,b"); + } + } + void update_factors_right(){ + + // going through the core 0 is associated with factors 0 and 1 + // core 1 associated with factors 2 and 3 ... + // First we need to take other side of the problem and contract it with the core + // if core is greater than 2 I need to contract all the centers but + // here there's only one so just do one contract + DistArray env, pq, W_env, W; + { + DistArray An; + env("p,q") = ref_core("r,p") * THC_times_CPD[0]("r,q"); + + pq("p,q") = ref_orb_b("b,p") * cp_factors[3]("q,b"); + An("q,a") = (pq("p,q") * env("p,q")) * ref_orb_a("a,p"); + + // TODO check to see if the Cholesky will fail. If it does + // use SVD + W_env("p,q") = this->partial_grammian[0]("p,q") * + this->partial_grammian[1]("p,q"); + W("p,q") = this->partial_grammian[3]("p,q") * W_env("p,q"); + + this->cholesky_inverse(An, W); + world.gop.fence(); // N.B. seems to deadlock without this + + this->normalize_factor(An); + cp_factors[2] = An; + this->partial_grammian[2]("r,rp") = An("r,n") * An("rp,n"); + pq("p,q") = ref_orb_a("a,p") * An("q,a"); + THC_times_CPD[1] = pq; + } + + // Finished with the first factor in THC. + // Starting the second factor + { + DistArray Bn; + Bn = DistArray(); + Bn("q,b") = (pq("p,q") * env("p,q")) * ref_orb_b("b,p"); + + this->MTtKRP = Bn; + + // TODO check to see if the Cholesky will fail. If it does + // use SVD + W("p,q") = this->partial_grammian[2]("p,q") * W_env("p,q"); + this->cholesky_inverse(Bn, W); + world.gop.fence(); // N.B. seems to deadlock without this + + this->unNormalized_Factor = Bn.clone(); + this->normalize_factor(Bn); + cp_factors[3] = Bn; + this->partial_grammian[3]("r,rp") = Bn("r,n") * Bn("rp,n"); + THC_times_CPD[1]("p,q") *= ref_orb_b("b,p") * Bn("q,b"); + } + } +}; + +} // namespace TiledArray::math::cp + +#endif // TILEDARRAY_MATH_SOLVERS_CP_CP_THC_ALS__H