-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refactor: use welford's algorithm for stddev computation
Also update the functions to use c++17 return value sugar
- Loading branch information
Showing
5 changed files
with
142 additions
and
79 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,20 +1,60 @@ | ||
#pragma once | ||
|
||
#include "parameters.h" | ||
#include <Eigen/Sparse> | ||
|
||
namespace slope { | ||
|
||
std::pair<Eigen::VectorXd, Eigen::VectorXd> | ||
computeMeanAndStdDev(const Eigen::MatrixXd& x); | ||
/** | ||
* Standardizes the given matrix column-wise. | ||
* | ||
* This function uses Welford's algorithm to compute the means and standard | ||
* deviation. | ||
* | ||
* @tparam T The type of the input matrix. | ||
* @param x The input matrix. | ||
* @param standardize Flag indicating whether to standardize the matrix. | ||
* @return A tuple containing the means and standard deviations of the columns. | ||
*/ | ||
template<typename T> | ||
std::tuple<Eigen::VectorXd, Eigen::VectorXd> | ||
standardize(const T& x, const bool standardize) | ||
{ | ||
const int n = x.rows(); | ||
const int p = x.cols(); | ||
|
||
std::pair<Eigen::VectorXd, Eigen::VectorXd> | ||
computeMeanAndStdDev(const Eigen::SparseMatrix<double>& x); | ||
Eigen::VectorXd x_means(p); | ||
Eigen::VectorXd x_stddevs(p); | ||
|
||
std::pair<double, Eigen::VectorXd> | ||
unstandardizeCoefficients(double beta0, | ||
Eigen::VectorXd beta, | ||
const Eigen::VectorXd& x_means, | ||
const Eigen::VectorXd& x_stddevs, | ||
const bool intercept); | ||
for (int j = 0; j < p; ++j) { | ||
double mean = 0.0; | ||
double m2 = 0.0; | ||
int count = 0; | ||
|
||
for (typename T::InnerIterator it(x, j); it; ++it) { | ||
double delta = it.value() - mean; | ||
mean += delta / (++count); | ||
m2 += delta * (it.value() - mean); | ||
} | ||
|
||
// Account for zeros in the column | ||
double delta = -mean; | ||
while (count < n) { | ||
count++; | ||
mean += delta / count; | ||
m2 -= mean * delta; | ||
} | ||
|
||
x_means(j) = mean; | ||
x_stddevs(j) = std::sqrt(m2 / n); | ||
} | ||
return { x_means, x_stddevs }; | ||
} | ||
|
||
std::tuple<double, Eigen::VectorXd> | ||
rescaleCoefficients(double beta0, | ||
Eigen::VectorXd beta, | ||
const Eigen::VectorXd& x_centers, | ||
const Eigen::VectorXd& x_scales, | ||
const SlopeParameters& params); | ||
} // namespace slope |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
#include "test_helpers.hpp" | ||
#include <Eigen/Core> | ||
#include <catch2/catch_test_macros.hpp> | ||
#include <catch2/matchers/catch_matchers.hpp> | ||
#include <catch2/matchers/catch_matchers_floating_point.hpp> | ||
#include <slope/standardize.h> | ||
|
||
std::tuple<Eigen::VectorXd, Eigen::VectorXd> | ||
computeMeanAndStdDev(const Eigen::MatrixXd& x) | ||
{ | ||
const int n = x.rows(); | ||
const int p = x.cols(); | ||
|
||
Eigen::VectorXd x_means = x.colwise().mean(); | ||
Eigen::VectorXd x_stddevs(p); | ||
|
||
for (int j = 0; j < p; ++j) { | ||
x_stddevs(j) = | ||
std::sqrt((x.col(j).array() - x_means(j)).square().sum() / n); | ||
} | ||
|
||
return { x_means, x_stddevs }; | ||
} | ||
|
||
std::tuple<Eigen::VectorXd, Eigen::VectorXd> | ||
computeMeanAndStdDev(const Eigen::SparseMatrix<double>& x) | ||
{ | ||
const int n = x.rows(); | ||
const int p = x.cols(); | ||
|
||
Eigen::VectorXd x_means(p); | ||
Eigen::VectorXd x_stddevs(p); | ||
|
||
for (int j = 0; j < p; ++j) { | ||
x_means(j) = x.col(j).sum() / n; | ||
// TODO: Reconsider this implementation since it might overflow. | ||
x_stddevs(j) = | ||
std::sqrt(x.col(j).squaredNorm() / n - std::pow(x_means(j), 2)); | ||
} | ||
|
||
return { x_means, x_stddevs }; | ||
} | ||
|
||
TEST_CASE("Check that standardization algorithm works", | ||
"[utils][standardization]") | ||
{ | ||
using Catch::Matchers::WithinAbs; | ||
|
||
Eigen::SparseMatrix<double> x(3, 3); | ||
|
||
x.coeffRef(0, 0) = 1; | ||
x.coeffRef(1, 0) = 98.2; | ||
x.coeffRef(2, 0) = -1007; | ||
x.coeffRef(0, 2) = 1000; | ||
x.coeffRef(1, 2) = 34; | ||
|
||
Eigen::MatrixXd x_dense = x; | ||
|
||
auto [x_centers_ref, x_scales_ref] = computeMeanAndStdDev(x); | ||
|
||
auto [x_centers, x_scales] = slope::standardize(x, true); | ||
auto [x_centers_dense, x_scales_dense] = slope::standardize(x_dense, true); | ||
|
||
REQUIRE_THAT(x_centers, VectorApproxEqual(x_centers_ref)); | ||
REQUIRE_THAT(x_scales, VectorApproxEqual(x_scales_ref)); | ||
|
||
REQUIRE_THAT(x_centers_dense, VectorApproxEqual(x_centers_ref)); | ||
REQUIRE_THAT(x_scales, VectorApproxEqual(x_scales_ref)); | ||
} |