Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update integrate_1d to use variadic autodiff stuff internally in preparation for closures #2397

Merged
merged 28 commits into from
Mar 31, 2021
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
6e0f680
Saving work
bbbales2 Nov 22, 2020
340e62c
integrate_1d_new working (variadic integrate_1d) (Issue #2110)
bbbales2 Nov 23, 2020
767ebb6
Merge remote-tracking branch 'origin/develop' into feature/variadic-i…
bbbales2 Feb 24, 2021
f84f7eb
Bit of cleanup and added adapter file (Issue #2197)
bbbales2 Feb 24, 2021
5019637
Renamed tests (Issue #2197)
bbbales2 Feb 24, 2021
403987f
Turned on reverse mode tests (Issue #2197)
bbbales2 Feb 24, 2021
7f53b90
Merge remote-tracking branch 'origin/develop' into feature/variadic-i…
bbbales2 Feb 27, 2021
132cd2c
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot Feb 27, 2021
59d1099
Switch binds to lambdas
bbbales2 Mar 25, 2021
989cb49
Merge commit 'a426eea0ec9d9a7547061bc776a08e509d3406f3' into HEAD
yashikno Mar 25, 2021
a6cebfc
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot Mar 25, 2021
abaa602
use double nested reverse pass to save making N tuple copies
SteveBronder Mar 25, 2021
82624e8
Update stan/math/rev/functor/integrate_1d.hpp
bbbales2 Mar 29, 2021
3469fa1
Update stan/math/prim/functor/integrate_1d.hpp
bbbales2 Mar 29, 2021
eaaeeb2
Updated docs
bbbales2 Mar 29, 2021
95b4032
Update stan/math/rev/functor/integrate_1d.hpp
bbbales2 Mar 29, 2021
485e089
Merge branch 'feature/variadic-integrate-1d' of github.com:stan-dev/m…
bbbales2 Mar 29, 2021
c4540c5
Reordered if
bbbales2 Mar 29, 2021
e4eb66f
Merge remote-tracking branch 'origin/review/variadic-integrate-1d' in…
bbbales2 Mar 29, 2021
aba98b6
Put error checks into functions
bbbales2 Mar 29, 2021
231b0e9
Merge commit 'f390d823261266c2a0999eeaedcd5ac216f857b3' into HEAD
yashikno Mar 29, 2021
77164cc
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot Mar 29, 2021
4e882f2
small cleanups and use get the adjoint in each integration by a looku…
SteveBronder Mar 29, 2021
a2e8e57
Merge remote-tracking branch 'origin/develop' into feature/variadic-i…
bbbales2 Mar 30, 2021
36847d9
Merge remote-tracking branch 'origin/review/integrate-1d-variadic-2' …
bbbales2 Mar 30, 2021
cd1b0cb
Use varis to get nth gradient
bbbales2 Mar 30, 2021
9360084
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot Mar 30, 2021
49ba955
remove changes to math::get() now that we don't need it in integrate1d
SteveBronder Mar 30, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions stan/math/prim/functor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <stan/math/prim/functor/finite_diff_gradient_auto.hpp>
#include <stan/math/prim/functor/for_each.hpp>
#include <stan/math/prim/functor/integrate_1d.hpp>
#include <stan/math/prim/functor/integrate_1d_adapter.hpp>
#include <stan/math/prim/functor/integrate_ode_rk45.hpp>
#include <stan/math/prim/functor/integrate_ode_std_vector_interface_adapter.hpp>
#include <stan/math/prim/functor/ode_ckrk.hpp>
Expand Down
111 changes: 71 additions & 40 deletions stan/math/prim/functor/integrate_1d.hpp
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
#ifndef STAN_MATH_PRIM_FUNCTOR_integrate_1d_HPP
#define STAN_MATH_PRIM_FUNCTOR_integrate_1d_HPP
#ifndef STAN_MATH_PRIM_FUNCTOR_INTEGRATE_1D_HPP
#define STAN_MATH_PRIM_FUNCTOR_INTEGRATE_1D_HPP

#include <stan/math/prim/meta.hpp>
#include <stan/math/prim/err.hpp>
#include <stan/math/prim/fun/constants.hpp>
#include <stan/math/prim/functor/integrate_1d_adapter.hpp>
#include <boost/math/quadrature/exp_sinh.hpp>
#include <boost/math/quadrature/sinh_sinh.hpp>
#include <boost/math/quadrature/tanh_sinh.hpp>
Expand Down Expand Up @@ -50,17 +51,44 @@ namespace math {
template <typename F>
inline double integrate(const F& f, double a, double b,
double relative_tolerance) {
static constexpr const char* function = "integrate";
double error1 = 0.0;
double error2 = 0.0;
double L1 = 0.0;
double L2 = 0.0;
bool used_two_integrals = false;
size_t levels;
double Q = 0.0;

auto one_integral_convergence_check = [&]() {
if (error1 > relative_tolerance * L1) {
throw_domain_error(
function, "error estimate of integral", error1, "",
" exceeds the given relative tolerance times norm of integral");
}
};

auto two_integral_convergence_check = [&]() {
if (error1 > relative_tolerance * L1) {
throw_domain_error(function, "error estimate of integral below zero",
error1, "",
" exceeds the given relative tolerance times norm of "
"integral below zero");
}
if (error2 > relative_tolerance * L2) {
throw_domain_error(function, "error estimate of integral above zero",
error2, "",
" exceeds the given relative tolerance times norm of "
"integral above zero");
}
};

// if a or b is infinite, set xc argument to NaN (see docs above for user
// function for xc info)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[optional] We can do this some other time but it would be nice to make the errors in the if (used_two_integrals) into a function that's just called in the places we use two integrals

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I split these into lambda functions. That look better or is used_two_integrals clearer?

auto f_wrap = [&](double x) { return f(x, NOT_A_NUMBER); };
if (std::isinf(a) && std::isinf(b)) {
boost::math::quadrature::sinh_sinh<double> integrator;
Q = integrator.integrate(f_wrap, relative_tolerance, &error1, &L1, &levels);
one_integral_convergence_check();
} else if (std::isinf(a)) {
boost::math::quadrature::exp_sinh<double> integrator;
/**
Expand All @@ -71,26 +99,28 @@ inline double integrate(const F& f, double a, double b,
if (b <= 0.0) {
Q = integrator.integrate(f_wrap, a, b, relative_tolerance, &error1, &L1,
&levels);
one_integral_convergence_check();
} else {
boost::math::quadrature::tanh_sinh<double> integrator_right;
Q = integrator.integrate(f_wrap, a, 0.0, relative_tolerance, &error1, &L1,
&levels)
+ integrator_right.integrate(f_wrap, 0.0, b, relative_tolerance,
&error2, &L2, &levels);
used_two_integrals = true;
two_integral_convergence_check();
}
} else if (std::isinf(b)) {
boost::math::quadrature::exp_sinh<double> integrator;
if (a >= 0.0) {
Q = integrator.integrate(f_wrap, a, b, relative_tolerance, &error1, &L1,
&levels);
one_integral_convergence_check();
} else {
boost::math::quadrature::tanh_sinh<double> integrator_left;
Q = integrator_left.integrate(f_wrap, a, 0, relative_tolerance, &error1,
&L1, &levels)
+ integrator.integrate(f_wrap, relative_tolerance, &error2, &L2,
&levels);
used_two_integrals = true;
two_integral_convergence_check();
}
} else {
auto f_wrap = [&](double x, double xc) { return f(x, xc); };
Expand All @@ -100,35 +130,48 @@ inline double integrate(const F& f, double a, double b,
&levels)
+ integrator.integrate(f_wrap, 0.0, b, relative_tolerance, &error2,
&L2, &levels);
used_two_integrals = true;
two_integral_convergence_check();
} else {
Q = integrator.integrate(f_wrap, a, b, relative_tolerance, &error1, &L1,
&levels);
one_integral_convergence_check();
}
}

static const char* function = "integrate";
if (used_two_integrals) {
if (error1 > relative_tolerance * L1) {
throw_domain_error(function, "error estimate of integral below zero",
error1, "",
" exceeds the given relative tolerance times norm of "
"integral below zero");
}
if (error2 > relative_tolerance * L2) {
throw_domain_error(function, "error estimate of integral above zero",
error2, "",
" exceeds the given relative tolerance times norm of "
"integral above zero");
return Q;
}

/**
* Compute the integral of the single variable function f from a to b to within
* a specified relative tolerance. a and b can be finite or infinite.
*
* @tparam T Type of f
* @param f the function to be integrated
* @param a lower limit of integration
* @param b upper limit of integration
* @param relative_tolerance tolerance passed to Boost quadrature
* @param[in, out] msgs the print stream for warning messages
* @param args additional arguments passed to f
* @return numeric integral of function f
*/
template <typename F, typename... Args,
require_all_not_st_var<Args...>* = nullptr>
inline double integrate_1d_impl(const F& f, double a, double b,
double relative_tolerance, std::ostream* msgs,
const Args&... args) {
static constexpr const char* function = "integrate_1d";
check_less_or_equal(function, "lower limit", a, b);
bbbales2 marked this conversation as resolved.
Show resolved Hide resolved

if (a == b) {
if (std::isinf(a)) {
throw_domain_error(function, "Integration endpoints are both", a, "", "");
}
return 0.0;
} else {
if (error1 > relative_tolerance * L1) {
throw_domain_error(
function, "error estimate of integral", error1, "",
" exceeds the given relative tolerance times norm of integral");
}
return integrate(
[&](const auto& x, const auto& xc) { return f(x, xc, msgs, args...); },
a, b, relative_tolerance);
}
return Q;
}

/**
Expand Down Expand Up @@ -178,26 +221,14 @@ inline double integrate(const F& f, double a, double b,
* @return numeric integral of function f
*/
template <typename F>
inline double integrate_1d(const F& f, const double a, const double b,
inline double integrate_1d(const F& f, double a, double b,
const std::vector<double>& theta,
const std::vector<double>& x_r,
const std::vector<int>& x_i, std::ostream* msgs,
const double relative_tolerance
= std::sqrt(EPSILON)) {
static const char* function = "integrate_1d";
check_less_or_equal(function, "lower limit", a, b);

if (a == b) {
if (std::isinf(a)) {
throw_domain_error(function, "Integration endpoints are both", a, "", "");
}
return 0.0;
} else {
return integrate(
std::bind<double>(f, std::placeholders::_1, std::placeholders::_2,
theta, x_r, x_i, msgs),
a, b, relative_tolerance);
}
return integrate_1d_impl(integrate_1d_adapter<F>(f), a, b, relative_tolerance,
msgs, theta, x_r, x_i);
}

} // namespace math
Expand Down
28 changes: 28 additions & 0 deletions stan/math/prim/functor/integrate_1d_adapter.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#ifndef STAN_MATH_PRIM_FUNCTOR_INTEGRATE_1D_ADAPTER_HPP
#define STAN_MATH_PRIM_FUNCTOR_INTEGRATE_1D_ADAPTER_HPP

#include <ostream>
#include <vector>

/**
* Adapt the non-variadic integrate_1d arguments to the variadic
* integrate_1d_impl interface
*
* @tparam F type of function to adapt
*/
template <typename F>
struct integrate_1d_adapter {
const F& f_;

explicit integrate_1d_adapter(const F& f) : f_(f) {}

template <typename T_a, typename T_b, typename T_theta>
auto operator()(const T_a& x, const T_b& xc, std::ostream* msgs,
const std::vector<T_theta>& theta,
const std::vector<double>& x_r,
const std::vector<int>& x_i) const {
return f_(x, xc, theta, x_r, x_i, msgs);
}
};

#endif
Loading