Skip to content

Commit

Permalink
Worked on reduce_sum a bit (Issue stan-dev#2197)
Browse files Browse the repository at this point in the history
  • Loading branch information
bbbales2 committed Nov 16, 2020
1 parent c0eb7cd commit ed0c7ff
Show file tree
Hide file tree
Showing 4 changed files with 173 additions and 35 deletions.
40 changes: 28 additions & 12 deletions stan/math/prim/functor/reduce_sum.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,14 @@ struct reduce_sum_impl<ReduceFunction, require_arithmetic_t<ReturnType>,
struct recursive_reducer {
Vec vmapped_;
std::ostream* msgs_;
const ReduceFunction& f_;
std::tuple<Args...> args_tuple_;
return_type_t<Vec, Args...> sum_{0.0};

recursive_reducer(Vec&& vmapped, std::ostream* msgs, Args&&... args)
recursive_reducer(Vec&& vmapped, std::ostream* msgs,
const ReduceFunction& f, Args&&... args)
: vmapped_(std::forward<Vec>(vmapped)),
msgs_(msgs),
msgs_(msgs), f_(f),
args_tuple_(std::forward<Args>(args)...) {}

/**
Expand All @@ -61,7 +63,7 @@ struct reduce_sum_impl<ReduceFunction, require_arithmetic_t<ReturnType>,
*/
recursive_reducer(recursive_reducer& other, tbb::split)
: vmapped_(other.vmapped_),
msgs_(other.msgs_),
msgs_(other.msgs_), f_(other.f_),
args_tuple_(other.args_tuple_) {}

/**
Expand All @@ -85,8 +87,8 @@ struct reduce_sum_impl<ReduceFunction, require_arithmetic_t<ReturnType>,

sum_ += apply(
[&](auto&&... args) {
return ReduceFunction()(sub_slice, r.begin(), r.end() - 1, msgs_,
args...);
return f_(msgs_, sub_slice, r.begin(), r.end() - 1,
args...);
},
args_tuple_);
}
Expand Down Expand Up @@ -143,13 +145,14 @@ struct reduce_sum_impl<ReduceFunction, require_arithmetic_t<ReturnType>,
*/
inline ReturnType operator()(Vec&& vmapped, bool auto_partitioning,
int grainsize, std::ostream* msgs,
const ReduceFunction& f,
Args&&... args) const {
const std::size_t num_terms = vmapped.size();
if (vmapped.empty()) {
return 0.0;
}
recursive_reducer worker(std::forward<Vec>(vmapped), msgs,
std::forward<Args>(args)...);
f, std::forward<Args>(args)...);

if (auto_partitioning) {
tbb::parallel_reduce(
Expand Down Expand Up @@ -192,28 +195,41 @@ struct reduce_sum_impl<ReduceFunction, require_arithmetic_t<ReturnType>,
* @return Sum of terms
*/
template <typename ReduceFunction, typename Vec,
typename = require_vector_like_t<Vec>, typename... Args>
typename = require_vector_like_t<Vec>,
require_stan_closure_t<ReduceFunction>* = nullptr,
typename... Args>
inline auto reduce_sum(Vec&& vmapped, int grainsize, std::ostream* msgs,
Args&&... args) {
using return_type = return_type_t<Vec, Args...>;
const ReduceFunction& f, Args&&... args) {
using return_type = return_type_t<ReduceFunction, Vec, Args...>;

check_positive("reduce_sum", "grainsize", grainsize);

#ifdef STAN_THREADS
return internal::reduce_sum_impl<ReduceFunction, void, return_type, Vec,
Args...>()(std::forward<Vec>(vmapped), true,
grainsize, msgs,
grainsize, msgs, f,
std::forward<Args>(args)...);
#else
if (vmapped.empty()) {
return return_type(0.0);
}

return ReduceFunction()(std::forward<Vec>(vmapped), 0, vmapped.size() - 1,
msgs, std::forward<Args>(args)...);
return f(msgs, std::forward<Vec>(vmapped), 0, vmapped.size() - 1,
std::forward<Args>(args)...);
#endif
}

template <typename ReduceFunction, typename Vec,
typename = require_vector_like_t<Vec>,
require_not_stan_closure_t<ReduceFunction>* = nullptr,
typename... Args>
inline auto reduce_sum(Vec&& vmapped, int grainsize, std::ostream* msgs,
Args&&... args) {
ReduceFunction f;
closure_adapter<ReduceFunction> cl(f);
return reduce_sum(vmapped, grainsize, msgs, cl, args...);
}

} // namespace math
} // namespace stan

Expand Down
25 changes: 19 additions & 6 deletions stan/math/prim/functor/reduce_sum_static.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,28 +41,41 @@ namespace math {
* @return Sum of terms
*/
template <typename ReduceFunction, typename Vec,
typename = require_vector_like_t<Vec>, typename... Args>
typename = require_vector_like_t<Vec>,
require_stan_closure_t<ReduceFunction>* = nullptr,
typename... Args>
auto reduce_sum_static(Vec&& vmapped, int grainsize, std::ostream* msgs,
Args&&... args) {
using return_type = return_type_t<Vec, Args...>;
const ReduceFunction& f, Args&&... args) {
using return_type = return_type_t<ReduceFunction, Vec, Args...>;

check_positive("reduce_sum", "grainsize", grainsize);

#ifdef STAN_THREADS
return internal::reduce_sum_impl<ReduceFunction, void, return_type, Vec,
Args...>()(std::forward<Vec>(vmapped), false,
grainsize, msgs,
grainsize, msgs, f,
std::forward<Args>(args)...);
#else
if (vmapped.empty()) {
return return_type(0);
}

return ReduceFunction()(std::forward<Vec>(vmapped), 0, vmapped.size() - 1,
msgs, std::forward<Args>(args)...);
return f(std::forward<Vec>(vmapped), 0, vmapped.size() - 1,
msgs, std::forward<Args>(args)...);
#endif
}

template <typename ReduceFunction, typename Vec,
typename = require_vector_like_t<Vec>,
require_not_stan_closure_t<ReduceFunction>* = nullptr,
typename... Args>
auto reduce_sum_static(Vec&& vmapped, int grainsize, std::ostream* msgs,
Args&&... args) {
ReduceFunction f;
internal::ode_closure_adapter<ReduceFunction> cl(f);
return reduce_sum_static(vmapped, grainsize, msgs, cl, args...);
}

} // namespace math
} // namespace stan

Expand Down
55 changes: 38 additions & 17 deletions stan/math/rev/functor/reduce_sum.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,23 +38,28 @@ struct reduce_sum_impl<ReduceFunction, require_var_t<ReturnType>, ReturnType,
*/
struct recursive_reducer {
const size_t num_vars_per_term_;
const size_t num_vars_closure_; // Number of vars in the closure
const size_t num_vars_shared_terms_; // Number of vars in shared arguments
double* sliced_partials_; // Points to adjoints of the partial calculations
Vec vmapped_;
std::ostream* msgs_;
const ReduceFunction& f_;
std::tuple<Args...> args_tuple_;
double sum_{0.0};
Eigen::VectorXd args_adjoints_{0};

template <typename VecT, typename... ArgsT>
recursive_reducer(size_t num_vars_per_term, size_t num_vars_shared_terms,
recursive_reducer(size_t num_vars_per_term,
size_t num_vars_closure,
size_t num_vars_shared_terms,
double* sliced_partials, VecT&& vmapped,
std::ostream* msgs, ArgsT&&... args)
std::ostream* msgs, const ReduceFunction& f, ArgsT&&... args)
: num_vars_per_term_(num_vars_per_term),
num_vars_closure_(num_vars_closure),
num_vars_shared_terms_(num_vars_shared_terms),
sliced_partials_(sliced_partials),
vmapped_(std::forward<VecT>(vmapped)),
msgs_(msgs),
msgs_(msgs), f_(f),
args_tuple_(std::forward<ArgsT>(args)...) {}

/*
Expand All @@ -65,10 +70,11 @@ struct reduce_sum_impl<ReduceFunction, require_var_t<ReturnType>, ReturnType,
*/
recursive_reducer(recursive_reducer& other, tbb::split)
: num_vars_per_term_(other.num_vars_per_term_),
num_vars_closure_(other.num_vars_closure_),
num_vars_shared_terms_(other.num_vars_shared_terms_),
sliced_partials_(other.sliced_partials_),
vmapped_(other.vmapped_),
msgs_(other.msgs_),
msgs_(other.msgs_), f_(other.f_),
args_tuple_(other.args_tuple_) {}

/**
Expand All @@ -90,7 +96,8 @@ struct reduce_sum_impl<ReduceFunction, require_var_t<ReturnType>, ReturnType,
}

if (args_adjoints_.size() == 0) {
args_adjoints_ = Eigen::VectorXd::Zero(num_vars_shared_terms_);
args_adjoints_ = Eigen::VectorXd::Zero(num_vars_closure_ +
num_vars_shared_terms_);
}

// Initialize nested autodiff stack
Expand All @@ -104,6 +111,9 @@ struct reduce_sum_impl<ReduceFunction, require_var_t<ReturnType>, ReturnType,
local_sub_slice.emplace_back(deep_copy_vars(vmapped_[i]));
}

// Create a copy of the functor
auto f_local_copy = deep_copy_vars(f_);

// Create nested autodiff copies of all shared arguments that do not point
// back to main autodiff stack
auto args_tuple_local_copy = apply(
Expand All @@ -116,8 +126,8 @@ struct reduce_sum_impl<ReduceFunction, require_var_t<ReturnType>, ReturnType,
// Perform calculation
var sub_sum_v = apply(
[&](auto&&... args) {
return ReduceFunction()(local_sub_slice, r.begin(), r.end() - 1,
msgs_, args...);
return f_local_copy(msgs_, local_sub_slice, r.begin(), r.end() - 1,
args...);
},
args_tuple_local_copy);

Expand All @@ -131,10 +141,13 @@ struct reduce_sum_impl<ReduceFunction, require_var_t<ReturnType>, ReturnType,
accumulate_adjoints(sliced_partials_ + r.begin() * num_vars_per_term_,
std::move(local_sub_slice));

// Accumulate adjoints of closure arguments
accumulate_adjoints(args_adjoints_.data(), f_local_copy);

// Accumulate adjoints of shared_arguments
apply(
[&](auto&&... args) {
accumulate_adjoints(args_adjoints_.data(),
accumulate_adjoints(args_adjoints_.data() + num_vars_closure_,
std::forward<decltype(args)>(args)...);
},
std::move(args_tuple_local_copy));
Expand Down Expand Up @@ -197,7 +210,8 @@ struct reduce_sum_impl<ReduceFunction, require_var_t<ReturnType>, ReturnType,
* @return Summation of all terms
*/
inline var operator()(Vec&& vmapped, bool auto_partitioning, int grainsize,
std::ostream* msgs, Args&&... args) const {
std::ostream* msgs, const ReduceFunction& f,
Args&&... args) const {
const std::size_t num_terms = vmapped.size();

if (vmapped.empty()) {
Expand All @@ -206,22 +220,27 @@ struct reduce_sum_impl<ReduceFunction, require_var_t<ReturnType>, ReturnType,

const std::size_t num_vars_per_term = count_vars(vmapped[0]);
const std::size_t num_vars_sliced_terms = num_terms * num_vars_per_term;
const std::size_t num_vars_closure = count_vars(f);
const std::size_t num_vars_shared_terms = count_vars(args...);

vari** varis = ChainableStack::instance_->memalloc_.alloc_array<vari*>(
num_vars_sliced_terms + num_vars_shared_terms);
num_vars_sliced_terms + num_vars_closure + num_vars_shared_terms);
double* partials = ChainableStack::instance_->memalloc_.alloc_array<double>(
num_vars_sliced_terms + num_vars_shared_terms);
num_vars_sliced_terms + num_vars_closure + num_vars_shared_terms);

save_varis(varis, vmapped);
save_varis(varis + num_vars_sliced_terms, args...);
save_varis(varis + num_vars_sliced_terms, f);
save_varis(varis + num_vars_sliced_terms + num_vars_closure, args...);

for (size_t i = 0; i < num_vars_sliced_terms; ++i) {
partials[i] = 0.0;
}

recursive_reducer worker(num_vars_per_term, num_vars_shared_terms, partials,
std::forward<Vec>(vmapped), msgs,
recursive_reducer worker(num_vars_per_term,
num_vars_closure,
num_vars_shared_terms,
partials,
std::forward<Vec>(vmapped), msgs, f,
std::forward<Args>(args)...);

if (auto_partitioning) {
Expand All @@ -234,13 +253,15 @@ struct reduce_sum_impl<ReduceFunction, require_var_t<ReturnType>, ReturnType,
partitioner);
}

for (size_t i = 0; i < num_vars_shared_terms; ++i) {
for (size_t i = 0; i < num_vars_closure + num_vars_shared_terms; ++i) {
partials[num_vars_sliced_terms + i] = worker.args_adjoints_(i);
}

return var(new precomputed_gradients_vari(
worker.sum_, num_vars_sliced_terms + num_vars_shared_terms, varis,
partials));
worker.sum_, num_vars_sliced_terms +
num_vars_closure +
num_vars_shared_terms,
varis, partials));
}
};
} // namespace internal
Expand Down
88 changes: 88 additions & 0 deletions test/unit/math/rev/functor/reduce_sum_closure_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
#include <stan/math.hpp>
#include <test/unit/math/prim/functor/reduce_sum_util.hpp>
#include <gtest/gtest.h>
#include <algorithm>
#include <sstream>
#include <tuple>
#include <vector>
#include <set>

struct closure_adapter {
template<typename F, typename T_slice, typename... Args>
auto operator()(const T_slice& subslice, std::size_t start,
std::size_t end, std::ostream* msgs,
const F& f, Args... args) {
return f(msgs, subslice, start, end, args...);
}
};

TEST(StanMathRev_reduce_sum, grouped_gradient_closure) {
using stan::math::var;
using stan::math::from_lambda;
using stan::math::test::get_new_msg;

double lambda_d = 10.0;
const std::size_t groups = 10;
const std::size_t elems_per_group = 1000;
const std::size_t elems = groups * elems_per_group;

std::vector<int> data(elems);
std::vector<int> gidx(elems);

for (std::size_t i = 0; i != elems; ++i) {
data[i] = i;
gidx[i] = i / elems_per_group;
}

std::vector<var> vlambda_v;

for (std::size_t i = 0; i != groups; ++i)
vlambda_v.push_back(i + 0.2);

var lambda_v = vlambda_v[0];

auto functor = from_lambda(
[](auto& lambda, auto& slice, std::size_t start, std::size_t end, auto& gidx, std::ostream * msgs) {
const std::size_t num_terms = end - start + 1;
std::decay_t<decltype(lambda)> lambda_slice(num_terms);
for (std::size_t i = 0; i != num_terms; ++i)
lambda_slice[i] = lambda[gidx[start + i]];
return stan::math::poisson_lpmf(slice, lambda_slice);
}, vlambda_v);

var poisson_lpdf = stan::math::reduce_sum(
data, 5, get_new_msg(), functor, gidx);

std::vector<var> vref_lambda_v;
for (std::size_t i = 0; i != elems; ++i) {
vref_lambda_v.push_back(vlambda_v[gidx[i]]);
}
var lambda_ref = vlambda_v[0];
var poisson_lpdf_ref = stan::math::poisson_lpmf(data, vref_lambda_v);

EXPECT_FLOAT_EQ(value_of(poisson_lpdf), value_of(poisson_lpdf_ref));

stan::math::grad(poisson_lpdf_ref.vi_);
const double lambda_ref_adj = lambda_ref.adj();

stan::math::set_zero_all_adjoints();
stan::math::grad(poisson_lpdf.vi_);
const double lambda_adj = lambda_v.adj();

EXPECT_FLOAT_EQ(lambda_adj, lambda_ref_adj)
<< "ref value of poisson lpdf : " << poisson_lpdf_ref.val() << std::endl
<< "ref gradient wrt to lambda: " << lambda_ref_adj << std::endl
<< "value of poisson lpdf : " << poisson_lpdf.val() << std::endl
<< "gradient wrt to lambda: " << lambda_adj << std::endl;

var poisson_lpdf_static
= stan::math::reduce_sum_static(data, 5, get_new_msg(), functor, gidx);

stan::math::set_zero_all_adjoints();
stan::math::grad(poisson_lpdf_static.vi_);
const double lambda_adj_static = lambda_v.adj();
EXPECT_FLOAT_EQ(lambda_adj_static, lambda_ref_adj);
stan::math::recover_memory();

stan::math::recover_memory();
}

0 comments on commit ed0c7ff

Please sign in to comment.