Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closures for ODEs #2094

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions stan/math/prim/err/check_finite.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,9 @@ inline void check_finite(const char* function, const char* name,
}
}

template <typename T, require_stan_closure_t<T>* = nullptr>
inline void check_finite(const char* function, const char* name, const T& y) {}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess a closure has variables, so we'd want to check that the values of the variables in the closure are finite

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I guess. I was mostly fighting the compiler here.
I'm not sure why the ODEs try and check the inputs are finite. Infinite inputs don't necessarily cause infinite outputs.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I would like to get rid of these checks too.


} // namespace math
} // namespace stan

Expand Down
13 changes: 13 additions & 0 deletions stan/math/prim/fun/value_of.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,19 @@ inline auto value_of(EigMat&& M) {
std::forward<EigMat>(M));
}

/**
* Closures that capture non-arithmetic types have value_of__() method.
*
* @tparam F Input element type
* @param[in] f Input closure
* @return closure
**/
template <typename F, require_stan_closure_t<F>* = nullptr,
require_not_st_arithmetic<F>* = nullptr>
inline auto value_of(const F& f) {
return f.value_of__();
}

} // namespace math
} // namespace stan

Expand Down
1 change: 1 addition & 0 deletions stan/math/prim/functor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <stan/math/prim/functor/apply_scalar_binary.hpp>
#include <stan/math/prim/functor/apply_vector_unary.hpp>
#include <stan/math/prim/functor/coupled_ode_system.hpp>
#include <stan/math/prim/functor/closure_adapter.hpp>
#include <stan/math/prim/functor/finite_diff_gradient.hpp>
#include <stan/math/prim/functor/finite_diff_gradient_auto.hpp>
#include <stan/math/prim/functor/finite_diff_hessian.hpp>
Expand Down
128 changes: 128 additions & 0 deletions stan/math/prim/functor/closure_adapter.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
#ifndef STAN_MATH_PRIM_FUNCTOR_CLOSURE_ADAPTER_HPP
#define STAN_MATH_PRIM_FUNCTOR_CLOSURE_ADAPTER_HPP

#include <stan/math/prim/meta/return_type.hpp>
#include <ostream>

namespace stan {
namespace math {

template <typename F>
struct closure_adapter {
using captured_scalar_t__ = double;
using ValueOf__ = closure_adapter<F>;
static const size_t vars_count__ = 0;
F f_;

explicit closure_adapter(const F& f) : f_(f) {}

template <typename... Args>
auto operator()(std::ostream* msgs, Args... args) const {
return f_(args..., msgs);
}
auto value_of__() const { return closure_adapter<F>(f_); }
auto deep_copy_vars__() const { return closure_adapter<F>(f_); }
void zero_adjoints__() const {}
double* accumulate_adjoints__(double* dest) const { return dest; }
template <typename Vari>
Vari** save_varis(Vari** dest) const {
return dest;
}
};

template <typename F, typename T>
struct simple_closure {
using captured_scalar_t__ = return_type_t<T>;
using ValueOf__ = simple_closure<F, decltype(value_of(std::declval<T>()))>;
const size_t vars_count__;
F f_;
T s_;

explicit simple_closure(const F& f, T s)
: f_(f), s_(s), vars_count__(count_vars(s)) {}

template <typename... Args>
auto operator()(std::ostream* msgs, Args... args) const {
return f_(s_, args..., msgs);
}
auto value_of__() const { return ValueOf__(f_, value_of(s_)); }
auto deep_copy_vars__() const {
return simple_closure<F, T>(f_, deep_copy_vars(s_));
}
void zero_adjoints__() { zero_adjoints(s_); }
double* accumulate_adjoints__(double* dest) const {
return accumulate_adjoints(dest, s_);
}
template <typename Vari>
Vari** save_varis__(Vari** dest) const {
return save_varis(dest, s_);
}
};

template <typename F>
auto from_lambda(F f) {
return closure_adapter<F>(f);
}

template <typename F, typename T>
auto from_lambda(F f, T a) {
return simple_closure<F, T>(f, a);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to be the key to how all this works. This accepts a lambda and an argument. This binds a to the first argument of the function f, and this object is treated as a closure (which can be passed around wherever).

THe closure itself, because it can contain vars, is treated like one in all the ODE code (so there are specializations for save_varis, deep_copy_vars, etc.).

This sound right?

Presumably we'd need to expand from_lamba to take a variable list of arguments and modify all the other higher order functions to work with this?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that's right.

Actually, from_lambda isn't necessary because my plan was to leave implementing closure structs to the stanc codegen. require_stan_closure_t<...> recognizes closure simply by the presence of ::captured_scalar_t__; they are effectively duck-typed. simple_closure is one example of how to make one.
Of course, an alternative (complementary?) path forward is expanding simple_closure so that the template type T is a parameter pack.

Other higher-order functions just need support for variadic arguments. Shortly after submitting this PR I realized you could revert all the changes to ODEs and instead have an adapter like this

struct closure_adapter_ode {
  template<typename F, typename T0, typename T1, typename... Args>
  auto operator()(const T0& t, const Eigen::Matrix<T1, Eigen::Dynamic, 1>& y,
                  std::ostream* msgs, F& f, Args... args) {
    return f(msgs, t, y, args...);
  }
}

template <typename F, typename T_y0, typename T_t0, typename T_ts,
          typename... Args, require_stan_closure_t<F>* = nullptr>
std::vector<Eigen::Matrix<stan::return_type_t<T_y0, T_t0, T_ts, Args...>,
                          Eigen::Dynamic, 1>>
ode_rk45_tol_impl(const char* function_name, const F& f,
                  const Eigen::Matrix<T_y0, Eigen::Dynamic, 1>& y0_arg, T_t0 t0,
                  const std::vector<T_ts>& ts, double relative_tolerance,
                  double absolute_tolerance,
                  long int max_num_steps,
                  std::ostream* msgs, const Args&... args) {
  closure_adapter_ode adapter;
  return ode_rk45_tol_impl(function_name, adapter, y0_arg, t0, ts,
                           relative_tolerance, absolute_tolerance,
                           max_num_steps, msgs, f, args...);
}

Same goes for reduce_sum and others.

}

namespace internal {

template <typename F>
struct ode_closure_adapter {
using captured_scalar_t__ = double;
using ValueOf__ = ode_closure_adapter<F>;
static const size_t vars_count__ = 0;
const F f_;

explicit ode_closure_adapter(const F& f) : f_(f) {}

template <typename T0, typename T1, typename... Args>
auto operator()(std::ostream* msgs, const T0& t,
const Eigen::Matrix<T1, Eigen::Dynamic, 1>& y,
Args... args) const {
return f_(t, y, msgs, args...);
}
auto value_of__() const { return ode_closure_adapter<F>(f_); }
auto deep_copy_vars__() const { return ode_closure_adapter<F>(f_); }
void zero_adjoints__() const {}
double* accumulate_adjoints__(double* dest) const { return dest; }
template <typename Vari>
Vari** save_varis(Vari** dest) const {
return dest;
}
};

template <typename F>
struct reduce_sum_closure_adapter {
using captured_scalar_t__ = double;
using ValueOf__ = reduce_sum_closure_adapter<F>;
static const size_t vars_count__ = 0;
const F f_;

explicit reduce_sum_closure_adapter(const F& f) : f_(f) {}

template <typename T, typename... Args>
auto operator()(std::ostream* msgs, const std::vector<T>& sub_slice,
std::size_t start, std::size_t end, Args... args) const {
return f_(sub_slice, start, end, msgs, args...);
}
auto value_of__() const { return reduce_sum_closure_adapter<F>(f_); }
auto deep_copy_vars__() const { return reduce_sum_closure_adapter<F>(f_); }
void zero_adjoints__() const {}
double* accumulate_adjoints__(double* dest) const { return dest; }
template <typename Vari>
Vari** save_varis(Vari** dest) const {
return dest;
}
};

} // namespace internal

} // namespace math
} // namespace stan

#endif
6 changes: 3 additions & 3 deletions stan/math/prim/functor/coupled_ode_system.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ struct coupled_ode_system_impl<true, F, T_y0, Args...> {
dz_dt.resize(y.size());

Eigen::VectorXd f_y_t
= apply([&](const Args&... args) { return f_(t, y, msgs_, args...); },
= apply([&](const Args&... args) { return f_(msgs_, t, y, args...); },
args_tuple_);

check_size_match("coupled_ode_system", "dy_dt", f_y_t.size(), "states",
Expand Down Expand Up @@ -104,13 +104,13 @@ struct coupled_ode_system_impl<true, F, T_y0, Args...> {
template <typename F, typename T_y0, typename... Args>
struct coupled_ode_system
: public coupled_ode_system_impl<
std::is_arithmetic<return_type_t<T_y0, Args...>>::value, F, T_y0,
std::is_arithmetic<return_type_t<F, T_y0, Args...>>::value, F, T_y0,
Args...> {
coupled_ode_system(const F& f,
const Eigen::Matrix<T_y0, Eigen::Dynamic, 1>& y0,
std::ostream* msgs, const Args&... args)
: coupled_ode_system_impl<
std::is_arithmetic<return_type_t<T_y0, Args...>>::value, F, T_y0,
std::is_arithmetic<return_type_t<F, T_y0, Args...>>::value, F, T_y0,
Args...>(f, y0, msgs, args...) {}
};

Expand Down
18 changes: 16 additions & 2 deletions stan/math/prim/functor/integrate_ode_rk45.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ namespace math {
* @deprecated use <code>ode_rk45</code>
*/
template <typename F, typename T_y0, typename T_param, typename T_t0,
typename T_ts>
typename T_ts, require_stan_closure_t<F>* = nullptr>
inline auto integrate_ode_rk45(
const F& f, const std::vector<T_y0>& y0, const T_t0& t0,
const std::vector<T_ts>& ts, const std::vector<T_param>& theta,
Expand All @@ -26,7 +26,7 @@ inline auto integrate_ode_rk45(
ts, relative_tolerance, absolute_tolerance,
max_num_steps, msgs, theta, x, x_int);

std::vector<std::vector<return_type_t<T_y0, T_param, T_t0, T_ts>>>
std::vector<std::vector<fn_return_type_t<F, T_y0, T_param, T_t0, T_ts>>>
y_converted;
y_converted.reserve(y.size());
for (size_t i = 0; i < y.size(); ++i)
Expand All @@ -35,6 +35,20 @@ inline auto integrate_ode_rk45(
return y_converted;
}

template <typename F, typename T_y0, typename T_param, typename T_t0,
typename T_ts, require_not_stan_closure_t<F>* = nullptr>
inline auto integrate_ode_rk45(
const F& f, const std::vector<T_y0>& y0, const T_t0& t0,
const std::vector<T_ts>& ts, const std::vector<T_param>& theta,
const std::vector<double>& x, const std::vector<int>& x_int,
std::ostream* msgs = nullptr, double relative_tolerance = 1e-6,
double absolute_tolerance = 1e-6, int max_num_steps = 1e6) {
closure_adapter<F> cl(f);
return integrate_ode_rk45(cl, y0, t0, ts, theta, x, x_int, msgs,
relative_tolerance, absolute_tolerance,
max_num_steps);
}

} // namespace math
} // namespace stan

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,34 @@ namespace internal {
*/
template <typename F>
struct integrate_ode_std_vector_interface_adapter {
using captured_scalar_t__ = typename F::captured_scalar_t__;
using ValueOf__
= integrate_ode_std_vector_interface_adapter<typename F::ValueOf__>;
const int vars_count__;
const F f_;

explicit integrate_ode_std_vector_interface_adapter(const F& f) : f_(f) {}
explicit integrate_ode_std_vector_interface_adapter(const F& f)
: vars_count__(f.vars_count__), f_(f) {}

template <typename T0, typename T1, typename T2>
auto operator()(const T0& t, const Eigen::Matrix<T1, Eigen::Dynamic, 1>& y,
std::ostream* msgs, const std::vector<T2>& theta,
const std::vector<double>& x,
auto operator()(std::ostream* msgs, const T0& t,
const Eigen::Matrix<T1, Eigen::Dynamic, 1>& y,
const std::vector<T2>& theta, const std::vector<double>& x,
const std::vector<int>& x_int) const {
return to_vector(f_(t, to_array_1d(y), theta, x, x_int, msgs));
return to_vector(f_(msgs, t, to_array_1d(y), theta, x, x_int));
}

auto value_of__() const { return ValueOf__(f_.value_of__()); }
auto deep_copy_vars__() const {
return integrate_ode_std_vector_interface_adapter<F>(f_.deep_copy_vars__());
}
void zero_adjoints__() const { f_.zero_adjoints__(); }
double* accumulate_adjoints__(double* dest) const {
return f_.accumulate_adjoints__(dest);
}
template <typename Vari>
Vari** save_varis__(Vari** dest) const {
return f_.save_varis__(dest);
}
};

Expand Down
28 changes: 23 additions & 5 deletions stan/math/prim/functor/ode_rk45.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define STAN_MATH_PRIM_FUNCTOR_ODE_RK45_HPP

#include <stan/math/prim/meta.hpp>
#include <stan/math/prim/functor/closure_adapter.hpp>
#include <stan/math/prim/functor/apply.hpp>
#include <stan/math/prim/functor/coupled_ode_system.hpp>
#include <stan/math/prim/functor/ode_store_sensitivities.hpp>
Expand Down Expand Up @@ -53,8 +54,9 @@ namespace math {
* @return Solution to ODE at times \p ts
*/
template <typename F, typename T_y0, typename T_t0, typename T_ts,
typename... Args, require_eigen_vector_t<T_y0>* = nullptr>
std::vector<Eigen::Matrix<stan::return_type_t<T_y0, T_t0, T_ts, Args...>,
typename... Args, require_eigen_vector_t<T_y0>* = nullptr,
require_stan_closure_t<F>* = nullptr>
std::vector<Eigen::Matrix<stan::return_type_t<F, T_y0, T_t0, T_ts, Args...>,
Eigen::Dynamic, 1>>
ode_rk45_tol_impl(const char* function_name, const F& f, const T_y0& y0_arg,
T_t0 t0, const std::vector<T_ts>& ts,
Expand Down Expand Up @@ -100,7 +102,7 @@ ode_rk45_tol_impl(const char* function_name, const F& f, const T_y0& y0_arg,
absolute_tolerance);
check_positive(function_name, "max_num_steps", max_num_steps);

using return_t = return_type_t<T_y0, T_t0, T_ts, Args...>;
using return_t = return_type_t<F, T_y0, T_t0, T_ts, Args...>;
// creates basic or coupled system by template specializations
auto&& coupled_system = apply(
[&](const auto&... args_ref) {
Expand Down Expand Up @@ -158,6 +160,22 @@ ode_rk45_tol_impl(const char* function_name, const F& f, const T_y0& y0_arg,
return y;
}

template <typename F, typename T_y0, typename T_t0, typename T_ts,
typename... Args, require_not_stan_closure_t<F>* = nullptr>
std::vector<Eigen::Matrix<stan::return_type_t<T_y0, T_t0, T_ts, Args...>,
Eigen::Dynamic, 1>>
ode_rk45_tol_impl(const char* function_name, const F& f,
const Eigen::Matrix<T_y0, Eigen::Dynamic, 1>& y0_arg, T_t0 t0,
const std::vector<T_ts>& ts, double relative_tolerance,
double absolute_tolerance,
long int max_num_steps, // NOLINT(runtime/int)
std::ostream* msgs, const Args&... args) {
internal::ode_closure_adapter<F> cl(f);
return ode_rk45_tol_impl(function_name, cl, y0_arg, t0, ts,
relative_tolerance, absolute_tolerance,
max_num_steps, msgs, args...);
}

/**
* Solve the ODE initial value problem y' = f(t, y), y(t0) = y0 at a set of
* times, { t1, t2, t3, ... } using the non-stiff Runge-Kutta 45 solver in
Expand Down Expand Up @@ -196,7 +214,7 @@ ode_rk45_tol_impl(const char* function_name, const F& f, const T_y0& y0_arg,
*/
template <typename F, typename T_y0, typename T_t0, typename T_ts,
typename... Args, require_eigen_vector_t<T_y0>* = nullptr>
std::vector<Eigen::Matrix<stan::return_type_t<T_y0, T_t0, T_ts, Args...>,
std::vector<Eigen::Matrix<stan::fn_return_type_t<F, T_y0, T_t0, T_ts, Args...>,
Eigen::Dynamic, 1>>
ode_rk45_tol(const F& f, const T_y0& y0_arg, T_t0 t0,
const std::vector<T_ts>& ts, double relative_tolerance,
Expand Down Expand Up @@ -242,7 +260,7 @@ ode_rk45_tol(const F& f, const T_y0& y0_arg, T_t0 t0,
*/
template <typename F, typename T_y0, typename T_t0, typename T_ts,
typename... Args, require_eigen_vector_t<T_y0>* = nullptr>
std::vector<Eigen::Matrix<stan::return_type_t<T_y0, T_t0, T_ts, Args...>,
std::vector<Eigen::Matrix<stan::fn_return_type_t<F, T_y0, T_t0, T_ts, Args...>,
Eigen::Dynamic, 1>>
ode_rk45(const F& f, const T_y0& y0, T_t0 t0, const std::vector<T_ts>& ts,
std::ostream* msgs, const Args&... args) {
Expand Down
8 changes: 4 additions & 4 deletions stan/math/prim/functor/ode_store_sensitivities.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@ namespace math {
* @param args Extra arguments passed unmodified through to ODE right hand side
* @return ODE state
*/
template <
typename F, typename T_y0_t0, typename T_t0, typename T_t, typename... Args,
typename
= require_all_arithmetic_t<T_y0_t0, T_t0, T_t, scalar_type_t<Args>...>>
template <typename F, typename T_y0_t0, typename T_t0, typename T_t,
typename... Args,
typename = require_all_arithmetic_t<scalar_type_t<F>, T_y0_t0, T_t0,
T_t, scalar_type_t<Args>...>>
Eigen::VectorXd ode_store_sensitivities(
const F& f, const std::vector<double>& coupled_state,
const Eigen::Matrix<T_y0_t0, Eigen::Dynamic, 1>& y0, T_t0 t0, T_t t,
Expand Down
Loading