-
-
Notifications
You must be signed in to change notification settings - Fork 188
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closures #2384
Closures #2384
Changes from all commits
01f0a03
7382327
ee21600
a595b43
29c165f
5dce92a
bbabc92
a4032be
76a8991
a55ddb3
990c070
cbf48fa
977f54d
c043511
e72e6f4
e120064
1245fa6
9c25817
a749f61
6990768
2e3180a
610af2d
880e270
4f1f6eb
f883e42
75f6d30
0a609db
e0f6145
9eca190
9436a18
2b2bee2
ecec96a
03ab504
9fc569f
9fb3740
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,220 @@ | ||
#ifndef STAN_MATH_PRIM_FUNCTOR_CLOSURE_ADAPTER_HPP | ||
#define STAN_MATH_PRIM_FUNCTOR_CLOSURE_ADAPTER_HPP | ||
|
||
#include <stan/math/prim/meta.hpp> | ||
#include <stan/math/prim/functor/apply.hpp> | ||
#include <ostream> | ||
|
||
namespace stan { | ||
namespace math { | ||
namespace internal { | ||
|
||
/** | ||
* A closure that wraps a C++ lambda and captures values. | ||
* | ||
* @tparam Ref if true values are captured by reference | ||
* @tparam F the lambda functor type | ||
* @tparam Ts types of the captured values | ||
*/ | ||
template <bool Ref, typename F, typename... Ts> | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This all needs docs for template parameters etc. |
||
struct base_closure { | ||
using return_scalar_t_ = return_type_t<Ts...>; | ||
/*The base closure with `Ts` as the non-expression partials of `Ts`*/ | ||
using partials_closure_t_ | ||
= base_closure<false, F, decltype(eval(value_of(std::declval<Ts>())))...>; | ||
using Base_ = base_closure<false, F, Ts...>; | ||
std::decay_t<F> f_; | ||
std::tuple<closure_return_type_t<Ts, Ref>...> captures_; | ||
template <typename FF, require_same_t<FF, F>* = nullptr, typename... Args> | ||
explicit base_closure(FF&& f, Args&&... args) | ||
: f_(std::forward<FF>(f)), captures_(std::forward<Args>(args)...) {} | ||
|
||
template <typename... Args> | ||
auto operator()(std::ostream* msgs, const Args&... args) const { | ||
return apply( | ||
[this, msgs, &args...](const auto&... s) { | ||
return this->f_(s..., args..., msgs); | ||
}, | ||
captures_); | ||
} | ||
}; | ||
|
||
/** | ||
* A closure that takes rng argument. | ||
* | ||
* @tparam Ref if true values are captured by reference | ||
* @tparam F the lambda functor type | ||
* @tparam Ts types of the captured values | ||
*/ | ||
template <bool Ref, typename F, typename... Ts> | ||
struct closure_rng { | ||
using return_scalar_t_ = double; | ||
using partials_closure_t_ = closure_rng<false, F, Ts...>; | ||
using Base_ = closure_rng<false, F, Ts...>; | ||
std::decay_t<F> f_; | ||
std::tuple<closure_return_type_t<Ts, Ref>...> captures_; | ||
|
||
template <typename FF, require_same_t<FF, F>* = nullptr, typename... Args> | ||
explicit closure_rng(FF&& f, Args&&... args) | ||
: f_(std::forward<FF>(f)), captures_(std::forward<Args>(args)...) {} | ||
|
||
template <typename Rng, typename... Args> | ||
auto operator()(Rng& rng, std::ostream* msgs, const Args&... args) const { | ||
return apply( | ||
[this, &rng, msgs, &args...](const auto&... s) { | ||
return this->f_(s..., args..., rng, msgs); | ||
}, | ||
captures_); | ||
} | ||
}; | ||
|
||
/** | ||
* A closure that may compute an unnormalized propability density. | ||
* | ||
* @tparam Propto if true the function is unnormalized | ||
* @tparam Ref if true values are captured by reference | ||
* @tparam F the lambda functor type | ||
* @tparam Ts types of the captured values | ||
*/ | ||
template <bool Propto, bool Ref, typename F, typename... Ts> | ||
struct closure_lpdf { | ||
using return_scalar_t_ = return_type_t<Ts...>; | ||
using partials_closure_t_ = closure_lpdf<Propto, false, F, Ts...>; | ||
using Base_ = closure_lpdf<Propto, false, F, Ts...>; | ||
std::decay_t<F> f_; | ||
std::tuple<closure_return_type_t<Ts, Ref>...> captures_; | ||
|
||
template <typename FF, require_same_t<FF, F>* = nullptr, typename... Args> | ||
explicit closure_lpdf(FF&& f, Args&&... args) | ||
: f_(std::forward<FF>(f)), captures_(std::forward<Args>(args)...) {} | ||
|
||
template <bool propto> | ||
auto with_propto() { | ||
Comment on lines
+91
to
+92
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We use camelcase for template parameters. How is this propto different from There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this for like lupdf or something? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, it's an parameters {
real y[100];
}
model {
function
real higher_lpdf(real[] x, real(real[], int, int) f_lpdf) {
real lp = 0;
lp += reduce_sum(f_lpdf, x, 1); // <-- A
lp += reduce_sum(f_lupdf, x, 1); // <-- B
return lp;
}
function
real partial_lpdf(real[] x, int s, int e) {
return std_normal_lupdf(x|);
}
target += higher_lpdf( y| partial_lpdf); // <-- 1
target += higher_lupdf(y| partial_lupdf); // <-- 2
} Using The above compiles to C++ that looks something like auto higher_lpdf = from_lambda([&](auto f_lpdf) {
var lp = 0;
lp += reduce_sum(f_lpdf.with_propto<false>(), x, 1); // <-- A
lp += reduce_sum(f_lpdf.with_propto<true>(), x, 1); // <-- B
return lp;
});
auto partial_lpdf = from_lambda([]<bool propto>(auto x, int s, int e) {
return std_normal_lpdf<propto>(x);
});
lp_accum__.add(higher_lpdf(y, partial_lpdf.with_propto<false>()); // <-- 1
lp_accum__.add(higher_lpdf(y, partial_lpdf.with_propto<true>()); // <-- 2 Every time the closure object is passed to a higher-order function |
||
return apply( | ||
[this](const auto&... args) { | ||
return closure_lpdf < Propto && propto, true, F, | ||
Ts... > (this->f_, args...); | ||
}, | ||
captures_); | ||
} | ||
|
||
template <bool propto = false, typename... Args> | ||
auto operator()(std::ostream* msgs, const Args&... args) const { | ||
return apply( | ||
[this, msgs, &args...](const auto&... s) { | ||
return this->f_.template operator()<propto>(s..., args..., msgs); | ||
}, | ||
captures_); | ||
} | ||
}; | ||
|
||
/** | ||
* A closure that accesses logprob accumulator. | ||
* | ||
* @tparam Propto if true the logprob is unnormalized | ||
* @tparam Ref if true values are captured by reference | ||
* @tparam F the lambda functor type | ||
* @tparam Ts types of the captured values | ||
*/ | ||
template <bool Propto, bool Ref, typename F, typename... Ts> | ||
struct closure_lp { | ||
using return_scalar_t_ = return_type_t<Ts...>; | ||
using partials_closure_t_ = closure_lp<Propto, true, F, Ts...>; | ||
using Base_ = closure_lp<Propto, true, F, Ts...>; | ||
std::decay_t<F> f_; | ||
std::tuple<closure_return_type_t<Ts, Ref>...> captures_; | ||
|
||
template <typename FF, require_same_t<FF, F>* = nullptr, typename... Args> | ||
explicit closure_lp(FF&& f, Args&&... args) | ||
: f_(std::forward<FF>(f)), captures_(std::forward<Args>(args)...) {} | ||
|
||
template <bool propto = false, typename T_lp, typename T_lp_accum, | ||
typename... Args> | ||
auto operator()(T_lp& lp, T_lp_accum& lp_accum, std::ostream* msgs, | ||
const Args&... args) const { | ||
return apply( | ||
[this, &lp, &lp_accum, msgs, &args...](const auto&... s) { | ||
return this->f_.template operator()<propto>(s..., args..., lp, | ||
lp_accum, msgs); | ||
}, | ||
captures_); | ||
} | ||
}; | ||
|
||
} // namespace internal | ||
|
||
/** | ||
* Higher-order functor suitable for calling a closure inside variadic ODE | ||
* solvers. | ||
*/ | ||
struct ode_closure_adapter { | ||
template <typename F, typename T0, typename T1, typename... Args> | ||
auto operator()(const T0& t, const T1& y, std::ostream* msgs, F&& f, | ||
Args&&... args) const { | ||
return std::forward<F>(f)(msgs, t, y, std::forward<Args>(args)...); | ||
} | ||
}; | ||
|
||
struct integrate_ode_closure_adapter { | ||
template <typename F, typename T0, typename T1, typename... Args> | ||
auto operator()(const T0& t, const T1& y, std::ostream* msgs, F&& f, | ||
Args&&... args) const { | ||
return to_vector(std::forward<F>(f)(msgs, t, to_array_1d(y), | ||
std::forward<Args>(args)...)); | ||
} | ||
}; | ||
|
||
/** | ||
* Create a closure from a C++ lambda and captures. | ||
*/ | ||
template <typename F, typename... Args> | ||
auto from_lambda(F&& f, Args&&... args) { | ||
return internal::base_closure<true, F, Args...>(std::forward<F>(f), | ||
std::forward<Args>(args)...); | ||
} | ||
|
||
/** | ||
* Create a closure from an rng functor. | ||
*/ | ||
template <typename F, typename... Args> | ||
auto rng_from_lambda(F&& f, Args&&... args) { | ||
return internal::closure_rng<true, F, Args...>(std::forward<F>(f), | ||
std::forward<Args>(args)...); | ||
} | ||
|
||
/** | ||
* Create a closure from an lpdf functor. | ||
*/ | ||
template <bool propto, typename F, typename... Args> | ||
auto lpdf_from_lambda(F&& f, Args&&... args) { | ||
return internal::closure_lpdf<propto, true, F, Args...>( | ||
std::forward<F>(f), std::forward<Args>(args)...); | ||
} | ||
|
||
/** | ||
* Create a closure from a functor that needs access to logprob accumulator. | ||
*/ | ||
template <bool Propto, typename F, typename... Args> | ||
auto lp_from_lambda(F&& f, Args&&... args) { | ||
return internal::closure_lp<Propto, true, F, Args...>( | ||
std::forward<F>(f), std::forward<Args>(args)...); | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So for each kind of function we might have in Stan, we can also have a closure version of that. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Each function kind follows a different calling convention so each kind needs its own adapter closure. These aren't used in math library but stanc3 allows userdefined higher order functions that might need them. |
||
|
||
/** | ||
* Higher-order functor that invokes a closure inside a reduce_sum call. | ||
*/ | ||
struct reduce_sum_closure_adapter { | ||
template <typename F, typename T, typename... Args> | ||
auto operator()(const std::vector<T>& sub_slice, std::size_t start, | ||
std::size_t end, std::ostream* msgs, F&& f, | ||
Args&&... args) const { | ||
return std::forward<F>(f)(msgs, sub_slice, start + error_index::value, | ||
end + error_index::value, | ||
std::forward<Args>(args)...); | ||
} | ||
}; | ||
|
||
} // namespace math | ||
} // namespace stan | ||
|
||
#endif |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -236,7 +236,7 @@ inline double integrate_1d_impl(const F& f, double a, double b, | |
* @param relative_tolerance tolerance passed to Boost quadrature | ||
* @return numeric integral of function f | ||
*/ | ||
template <typename F> | ||
template <typename F, require_not_stan_closure_t<F>* = nullptr> | ||
inline double integrate_1d(const F& f, double a, double b, | ||
const std::vector<double>& theta, | ||
const std::vector<double>& x_r, | ||
|
@@ -247,6 +247,18 @@ inline double integrate_1d(const F& f, double a, double b, | |
msgs, theta, x_r, x_i); | ||
} | ||
|
||
template <typename F, require_stan_closure_t<F>* = nullptr, | ||
require_arithmetic_t<return_type_t<F>>* = nullptr> | ||
inline double integrate_1d(const F& f, double a, double b, | ||
const std::vector<double>& theta, | ||
const std::vector<double>& x_r, | ||
const std::vector<int>& x_i, std::ostream* msgs, | ||
const double relative_tolerance | ||
= std::sqrt(EPSILON)) { | ||
return integrate_1d_impl(integrate_1d_closure_adapter(), a, b, | ||
relative_tolerance, msgs, f, theta, x_r, x_i); | ||
} | ||
Comment on lines
+250
to
+260
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [Q] Think I'm just missing some context, why is There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
||
} // namespace math | ||
} // namespace stan | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch. We'll need to implement these.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, thinking more, I don't think there's a big advantage to implementing this.
We could expose the variables captured by a closure to checks, but the Math checks wouldn't know in what order its getting them, and then depending on which function was accepting closures it would need to decide which checks to do on which inputs.
I think instead in the ODE solves we check only the arguments passed in explicitly (which this is effectively doing) or we get rid of the infinity checks on the inputs to the ODE solves. I'll make an issue and see if getting rid of the checks altogether is an option. (Edit: Issue #2406)