-
-
Notifications
You must be signed in to change notification settings - Fork 189
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closures for ODEs #2094
Closures for ODEs #2094
Changes from all commits
bafd17f
c0eb7cd
ed0c7ff
be52b7b
a012957
a89a6eb
d6f3dd4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
#ifndef STAN_MATH_PRIM_FUNCTOR_CLOSURE_ADAPTER_HPP | ||
#define STAN_MATH_PRIM_FUNCTOR_CLOSURE_ADAPTER_HPP | ||
|
||
#include <stan/math/prim/meta/return_type.hpp> | ||
#include <ostream> | ||
|
||
namespace stan { | ||
namespace math { | ||
|
||
template <typename F> | ||
struct closure_adapter { | ||
using captured_scalar_t__ = double; | ||
using ValueOf__ = closure_adapter<F>; | ||
static const size_t vars_count__ = 0; | ||
F f_; | ||
|
||
explicit closure_adapter(const F& f) : f_(f) {} | ||
|
||
template <typename... Args> | ||
auto operator()(std::ostream* msgs, Args... args) const { | ||
return f_(args..., msgs); | ||
} | ||
auto value_of__() const { return closure_adapter<F>(f_); } | ||
auto deep_copy_vars__() const { return closure_adapter<F>(f_); } | ||
void zero_adjoints__() const {} | ||
double* accumulate_adjoints__(double* dest) const { return dest; } | ||
template <typename Vari> | ||
Vari** save_varis(Vari** dest) const { | ||
return dest; | ||
} | ||
}; | ||
|
||
template <typename F, typename T> | ||
struct simple_closure { | ||
using captured_scalar_t__ = return_type_t<T>; | ||
using ValueOf__ = simple_closure<F, decltype(value_of(std::declval<T>()))>; | ||
const size_t vars_count__; | ||
F f_; | ||
T s_; | ||
|
||
explicit simple_closure(const F& f, T s) | ||
: f_(f), s_(s), vars_count__(count_vars(s)) {} | ||
|
||
template <typename... Args> | ||
auto operator()(std::ostream* msgs, Args... args) const { | ||
return f_(s_, args..., msgs); | ||
} | ||
auto value_of__() const { return ValueOf__(f_, value_of(s_)); } | ||
auto deep_copy_vars__() const { | ||
return simple_closure<F, T>(f_, deep_copy_vars(s_)); | ||
} | ||
void zero_adjoints__() { zero_adjoints(s_); } | ||
double* accumulate_adjoints__(double* dest) const { | ||
return accumulate_adjoints(dest, s_); | ||
} | ||
template <typename Vari> | ||
Vari** save_varis__(Vari** dest) const { | ||
return save_varis(dest, s_); | ||
} | ||
}; | ||
|
||
template <typename F> | ||
auto from_lambda(F f) { | ||
return closure_adapter<F>(f); | ||
} | ||
|
||
template <typename F, typename T> | ||
auto from_lambda(F f, T a) { | ||
return simple_closure<F, T>(f, a); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This seems to be the key to how all this works. This accepts a lambda and an argument. This binds THe closure itself, because it can contain vars, is treated like one in all the ODE code (so there are specializations for save_varis, deep_copy_vars, etc.). This sound right? Presumably we'd need to expand There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, that's right. Actually, Other higher-order functions just need support for variadic arguments. Shortly after submitting this PR I realized you could revert all the changes to ODEs and instead have an adapter like this struct closure_adapter_ode {
template<typename F, typename T0, typename T1, typename... Args>
auto operator()(const T0& t, const Eigen::Matrix<T1, Eigen::Dynamic, 1>& y,
std::ostream* msgs, F& f, Args... args) {
return f(msgs, t, y, args...);
}
}
template <typename F, typename T_y0, typename T_t0, typename T_ts,
typename... Args, require_stan_closure_t<F>* = nullptr>
std::vector<Eigen::Matrix<stan::return_type_t<T_y0, T_t0, T_ts, Args...>,
Eigen::Dynamic, 1>>
ode_rk45_tol_impl(const char* function_name, const F& f,
const Eigen::Matrix<T_y0, Eigen::Dynamic, 1>& y0_arg, T_t0 t0,
const std::vector<T_ts>& ts, double relative_tolerance,
double absolute_tolerance,
long int max_num_steps,
std::ostream* msgs, const Args&... args) {
closure_adapter_ode adapter;
return ode_rk45_tol_impl(function_name, adapter, y0_arg, t0, ts,
relative_tolerance, absolute_tolerance,
max_num_steps, msgs, f, args...);
} Same goes for |
||
} | ||
|
||
namespace internal { | ||
|
||
template <typename F> | ||
struct ode_closure_adapter { | ||
using captured_scalar_t__ = double; | ||
using ValueOf__ = ode_closure_adapter<F>; | ||
static const size_t vars_count__ = 0; | ||
const F f_; | ||
|
||
explicit ode_closure_adapter(const F& f) : f_(f) {} | ||
|
||
template <typename T0, typename T1, typename... Args> | ||
auto operator()(std::ostream* msgs, const T0& t, | ||
const Eigen::Matrix<T1, Eigen::Dynamic, 1>& y, | ||
Args... args) const { | ||
return f_(t, y, msgs, args...); | ||
} | ||
auto value_of__() const { return ode_closure_adapter<F>(f_); } | ||
auto deep_copy_vars__() const { return ode_closure_adapter<F>(f_); } | ||
void zero_adjoints__() const {} | ||
double* accumulate_adjoints__(double* dest) const { return dest; } | ||
template <typename Vari> | ||
Vari** save_varis(Vari** dest) const { | ||
return dest; | ||
} | ||
}; | ||
|
||
template <typename F> | ||
struct reduce_sum_closure_adapter { | ||
using captured_scalar_t__ = double; | ||
using ValueOf__ = reduce_sum_closure_adapter<F>; | ||
static const size_t vars_count__ = 0; | ||
const F f_; | ||
|
||
explicit reduce_sum_closure_adapter(const F& f) : f_(f) {} | ||
|
||
template <typename T, typename... Args> | ||
auto operator()(std::ostream* msgs, const std::vector<T>& sub_slice, | ||
std::size_t start, std::size_t end, Args... args) const { | ||
return f_(sub_slice, start, end, msgs, args...); | ||
} | ||
auto value_of__() const { return reduce_sum_closure_adapter<F>(f_); } | ||
auto deep_copy_vars__() const { return reduce_sum_closure_adapter<F>(f_); } | ||
void zero_adjoints__() const {} | ||
double* accumulate_adjoints__(double* dest) const { return dest; } | ||
template <typename Vari> | ||
Vari** save_varis(Vari** dest) const { | ||
return dest; | ||
} | ||
}; | ||
|
||
} // namespace internal | ||
|
||
} // namespace math | ||
} // namespace stan | ||
|
||
#endif |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess a closure has variables, so we'd want to check that the values of the variables in the closure are finite
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I guess. I was mostly fighting the compiler here.
I'm not sure why the ODEs try and check the inputs are finite. Infinite inputs don't necessarily cause infinite outputs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I would like to get rid of these checks too.