Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve numerical stability of normal quantile gradients #3139

Open
wants to merge 5 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions stan/math/rev/fun/inv_Phi.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@

#include <stan/math/rev/meta.hpp>
#include <stan/math/rev/core.hpp>
#include <stan/math/prim/fun/constants.hpp>
#include <stan/math/prim/fun/exp.hpp>
#include <stan/math/prim/fun/inv_Phi.hpp>
#include <stan/math/prim/prob/std_normal_lpdf.hpp>
#include <stan/math/prim/functor/apply_scalar_binary.hpp>
#include <cmath>

namespace stan {
Expand All @@ -19,8 +21,9 @@ namespace math {
* @return The unit normal inverse cdf evaluated at p
*/
inline var inv_Phi(const var& p) {
return make_callback_var(inv_Phi(p.val()), [p](auto& vi) mutable {
p.adj() += vi.adj() * SQRT_TWO_PI / std::exp(-0.5 * vi.val() * vi.val());
double val = inv_Phi(p.val());
return make_callback_var(val, [p, val](auto& vi) mutable {
p.adj() += vi.adj() * exp(-std_normal_lpdf(val));
});
}

Expand All @@ -33,9 +36,12 @@ inline var inv_Phi(const var& p) {
*/
template <typename T, require_var_matrix_t<T>* = nullptr>
inline auto inv_Phi(const T& p) {
return make_callback_var(inv_Phi(p.val()), [p](auto& vi) mutable {
p.adj().array() += vi.adj().array() * SQRT_TWO_PI
/ (-0.5 * vi.val().array().square()).exp();
const auto& arena_rtn = to_arena(inv_Phi(p.val()));
return make_callback_var(arena_rtn, [p, arena_rtn](auto& vi) mutable {
p.adj() += apply_scalar_binary(
vi.adj(), arena_rtn.val(), [](const double adj, const double rtn_val) {
return adj * exp(-std_normal_lpdf(rtn_val));
});
});
}

Expand Down
22 changes: 10 additions & 12 deletions stan/math/rev/prob/std_normal_log_qf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@

#include <stan/math/rev/meta.hpp>
#include <stan/math/rev/core.hpp>
#include <stan/math/prim/fun/constants.hpp>
#include <stan/math/prim/fun/sign.hpp>
#include <stan/math/prim/prob/std_normal_log_qf.hpp>
#include <stan/math/prim/functor/apply_scalar_binary.hpp>
#include <stan/math/prim/fun/elt_multiply.hpp>
#include <cmath>

namespace stan {
Expand All @@ -19,16 +19,14 @@ namespace math {
*/
template <typename T, require_stan_scalar_or_eigen_t<T>* = nullptr>
inline auto std_normal_log_qf(const var_value<T>& log_p) {
return make_callback_var(
std_normal_log_qf(log_p.val()), [log_p](auto& vi) mutable {
auto vi_array = as_array_or_scalar(vi.val());
auto vi_sign = sign(as_array_or_scalar(vi.adj()));

const auto& deriv = as_array_or_scalar(log_p).val()
+ log(as_array_or_scalar(vi.adj()) * vi_sign)
- NEG_LOG_SQRT_TWO_PI + 0.5 * square(vi_array);
as_array_or_scalar(log_p).adj() += vi_sign * exp(deriv);
});
const auto& arena_rtn = to_arena(std_normal_log_qf(log_p.val()));
return make_callback_var(arena_rtn, [log_p, arena_rtn](auto& vi) mutable {
auto deriv = apply_scalar_binary(
log_p.val(), arena_rtn, [](const auto& logp_val, const auto& rtn_val) {
return exp(logp_val - std_normal_lpdf(rtn_val));
});
Comment on lines +24 to +27
Copy link
Collaborator

@SteveBronder SteveBronder Jan 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of apply_scalar_binary here how do you feel about just doing the std_normal_lpdf on the return val? Since then the rest can use standard vectorization from Eigen.

Suggested change
auto deriv = apply_scalar_binary(
log_p.val(), arena_rtn, [](const auto& logp_val, const auto& rtn_val) {
return exp(logp_val - std_normal_lpdf(rtn_val));
});
if constexpr (is_eigen<decltype(arena_rtn)>::value) {
auto derive = exp(log_p.val() - arena_rtn.unaryExpr([](auto x) {
return std_normal_lpdf(x);}));
log_p.adj() += elt_multiply(vi.adj(), deriv);
} else {
auto derive = exp(log_p.val() - std_normal_lpdf(arena_rtn));
log_p.adj() += vi.adj() * deriv;
}

Same thing for the other change.

log_p.adj() += elt_multiply(vi.adj(), deriv);
});
}

} // namespace math
Expand Down
Loading