Skip to content

Commit

Permalink
fix: update to libslope 0.3.0
Browse files Browse the repository at this point in the history
  • Loading branch information
jolars committed Dec 11, 2023
1 parent a09ae51 commit 905424f
Show file tree
Hide file tree
Showing 10 changed files with 719 additions and 502 deletions.
33 changes: 17 additions & 16 deletions src/main.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
#include "slope/parameters.h"
#include "slope/slope.h"
#include <Eigen/Core>
#include <Eigen/LU>
Expand All @@ -18,24 +17,26 @@ fit_slope(const T& x,
const Eigen::ArrayXd& alpha,
const py::dict& args)
{
slope::SlopeParameters params;
slope::Slope model;

params.intercept = args["intercept"].cast<bool>();
params.standardize = args["standardize"].cast<bool>();
params.update_clusters = args["update_clusters"].cast<bool>();
params.alpha_min_ratio = args["alpha_min_ratio"].cast<double>();
params.objective = args["objective"].cast<std::string>();
params.path_length = args["path_length"].cast<int>();
params.pgd_freq = args["pgd_freq"].cast<int>();
params.tol = args["tol"].cast<double>();
params.max_it = args["max_it"].cast<int>();
params.max_it_outer = args["max_it_outer"].cast<int>();
params.print_level = args["print_level"].cast<int>();
model.setIntercept(args["intercept"].cast<bool>());
model.setStandardize(args["standardize"].cast<bool>());
model.setUpdateClusters(args["update_clusters"].cast<bool>());
model.setAlphaMinRatio(args["alpha_min_ratio"].cast<double>());
model.setObjective(args["objective"].cast<std::string>());
model.setPathLength(args["path_length"].cast<int>());
model.setPgdFreq(args["pgd_freq"].cast<int>());
model.setTol(args["tol"].cast<double>());
model.setMaxIt(args["max_it"].cast<int>());
model.setMaxItOuter(args["max_it_outer"].cast<int>());
model.setPrintLevel(args["print_level"].cast<int>());

auto result = slope::slope(x, y, alpha, lambda, params);
model.fit(x, y, alpha, lambda);

return py::make_tuple(
result.beta0s, result.betas, result.lambda, result.alpha);
return py::make_tuple(model.getIntercepts(),
model.getCoefs(),
model.getLambda(),
model.getAlpha());
}

pybind11::tuple
Expand Down
28 changes: 19 additions & 9 deletions src/slope/cd.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
#pragma once

#include "clusters.h"
#include "parameters.h"
#include "slope_threshold.h"
#include "sorted_l1_norm.h"
#include <Eigen/Core>
Expand All @@ -26,14 +25,22 @@ namespace slope {
* @param beta0 The intercept
* @param beta The coefficients
* @param residual The residual vector
* @param clusters The cluster information
* @param clusters The cluster information, stored in a Cluster object.
* @param x The design matrix
* @param w The weight vector
* @param z The response vector
* @param sl1_norm The sorted L1 norm object
* @param x_centers The center values of the data matrix columns
* @param x_scales The scale values of the data matrix columns
* @param params The SLOPE parameters
* @param intercept Shuold an intervept be fit?
* @param standardize Flag indicating whether to standardize the data matrix
* columns
* @param update_clusters Flag indicating whether to update the cluster
* information
* @param print_level The level of verbosity for printing debug information
*
* @see Clusters
* @see SortedL1Norm
*/
template<typename T>
void
Expand All @@ -47,7 +54,10 @@ coordinateDescent(double& beta0,
const SortedL1Norm& sl1_norm,
const Eigen::VectorXd& x_centers,
const Eigen::VectorXd& x_scales,
const SlopeParameters& params)
const bool intercept,
const bool standardize,
const bool update_clusters,
const int print_level)
{
using namespace Eigen;

Expand Down Expand Up @@ -75,7 +85,7 @@ coordinateDescent(double& beta0,
double s_k = sign(beta(k));
s.emplace_back(s_k);

if (params.standardize) {
if (standardize) {
gradient_j = -s_k *
(x.col(k).cwiseProduct(w).dot(residual) -
w.dot(residual) * x_centers(k)) /
Expand All @@ -102,7 +112,7 @@ coordinateDescent(double& beta0,
double s_k = sign(beta(k));
s.emplace_back(s_k);

if (params.standardize) {
if (standardize) {
x_s += x.col(k) * (s_k / x_scales(k));
x_s.array() -= x_centers(k) * s_k / x_scales(k);
} else {
Expand Down Expand Up @@ -132,7 +142,7 @@ coordinateDescent(double& beta0,
if (cluster_size == 1) {
int k = *clusters.cbegin(j);

if (params.standardize) {
if (standardize) {
residual += x.col(k) * (s[0] * c_diff / x_scales(k));
residual.array() -= x_centers(k) * s[0] * c_diff / x_scales(k);
} else {
Expand All @@ -143,13 +153,13 @@ coordinateDescent(double& beta0,
}
}

if (params.update_clusters) {
if (update_clusters) {
clusters.update(j, new_index, std::abs(c_tilde));
} else {
clusters.setCoeff(j, std::abs(c_tilde));
}

if (params.intercept) {
if (intercept) {
double beta0_update = residual.dot(w) / w.sum();
residual.array() -= beta0_update;
beta0 += beta0_update;
Expand Down
Loading

0 comments on commit 905424f

Please sign in to comment.