Skip to content

Commit

Permalink
Merge pull request #2797 from andrjohns/hyper3f2-expose
Browse files Browse the repository at this point in the history
Expose `hypergeometric_3F2` function
  • Loading branch information
andrjohns authored Oct 23, 2022
2 parents 4d2b936 + 4575585 commit c4b0717
Show file tree
Hide file tree
Showing 15 changed files with 289 additions and 185 deletions.
14 changes: 8 additions & 6 deletions stan/math/fwd/fun/inv_inc_beta.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
#include <stan/math/prim/fun/lbeta.hpp>
#include <stan/math/prim/fun/lgamma.hpp>
#include <stan/math/prim/fun/digamma.hpp>
#include <stan/math/prim/fun/F32.hpp>
#include <stan/math/prim/fun/hypergeometric_3F2.hpp>

namespace stan {
namespace math {
Expand Down Expand Up @@ -53,19 +53,21 @@ inline fvar<partials_return_t<T1, T2, T3>> inv_inc_beta(const T1& a,
T_return inv_d_(0);

if (is_fvar<T1>::value) {
std::vector<T_return> da_a{a_val, a_val, one_m_b};
std::vector<T_return> da_b{ap1, ap1};
auto da1 = exp(one_m_b * log1m_w + one_m_a * log_w);
auto da2
= exp(a_val * log_w + 2 * lgamma(a_val)
+ log(F32(a_val, a_val, one_m_b, ap1, ap1, w)) - 2 * lgamma(ap1));
auto da2 = exp(a_val * log_w + 2 * lgamma(a_val)
+ log(hypergeometric_3F2(da_a, da_b, w)) - 2 * lgamma(ap1));
auto da3 = inc_beta(a_val, b_val, w) * exp(lbeta_ab)
* (log_w - digamma(a_val) + digamma_apb);
inv_d_ += forward_as<fvar<T_return>>(a).d_ * da1 * (da2 - da3);
}

if (is_fvar<T2>::value) {
std::vector<T_return> db_a{b_val, b_val, one_m_a};
std::vector<T_return> db_b{bp1, bp1};
auto db1 = (w - 1) * exp(-b_val * log1m_w + one_m_a * log_w);
auto db2 = 2 * lgamma(b_val)
+ log(F32(b_val, b_val, one_m_a, bp1, bp1, one_m_w))
auto db2 = 2 * lgamma(b_val) + log(hypergeometric_3F2(db_a, db_b, one_m_w))
- 2 * lgamma(bp1) + b_val * log1m_w;

auto db3 = inc_beta(b_val, a_val, one_m_w) * exp(lbeta_ab)
Expand Down
2 changes: 1 addition & 1 deletion stan/math/prim/fun.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,6 @@
#include <stan/math/prim/fun/exp.hpp>
#include <stan/math/prim/fun/exp2.hpp>
#include <stan/math/prim/fun/expm1.hpp>
#include <stan/math/prim/fun/F32.hpp>
#include <stan/math/prim/fun/fabs.hpp>
#include <stan/math/prim/fun/factor_U.hpp>
#include <stan/math/prim/fun/factor_cov_matrix.hpp>
Expand Down Expand Up @@ -132,6 +131,7 @@
#include <stan/math/prim/fun/head.hpp>
#include <stan/math/prim/fun/hypergeometric_2F1.hpp>
#include <stan/math/prim/fun/hypergeometric_2F2.hpp>
#include <stan/math/prim/fun/hypergeometric_3F2.hpp>
#include <stan/math/prim/fun/hypergeometric_pFq.hpp>
#include <stan/math/prim/fun/hypot.hpp>
#include <stan/math/prim/fun/identity_constrain.hpp>
Expand Down
99 changes: 0 additions & 99 deletions stan/math/prim/fun/F32.hpp

This file was deleted.

22 changes: 14 additions & 8 deletions stan/math/prim/fun/grad_pFq.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -234,12 +234,18 @@ void grad_pFq_impl(TupleT&& grad_tuple, const Ta& a, const Tb& b, const Tz& z,
log_phammer_1n += log1p(n);
log_phammer_2_mpn += log(2 + m + n);

log_phammer_ap1_n += log(stan::math::fabs(ap1n));
log_phammer_bp1_n += log(stan::math::fabs(bp1n));
log_phammer_an += log(stan::math::fabs(an));
log_phammer_bn += log(stan::math::fabs(bn));
log_phammer_ap1_mpn += log(stan::math::fabs(ap1mn));
log_phammer_bp1_mpn += log(stan::math::fabs(bp1mn));
log_phammer_ap1_n.array()
+= log(math::fabs((ap1n.array() == 0).select(1.0, ap1n.array())));
log_phammer_bp1_n.array()
+= log(math::fabs((bp1n.array() == 0).select(1.0, bp1n.array())));
log_phammer_an.array()
+= log(math::fabs((an.array() == 0).select(1.0, an.array())));
log_phammer_bn.array()
+= log(math::fabs((bn.array() == 0).select(1.0, bn.array())));
log_phammer_ap1_mpn.array()
+= log(math::fabs((ap1mn.array() == 0).select(1.0, ap1mn.array())));
log_phammer_bp1_mpn.array()
+= log(math::fabs((bp1mn.array() == 0).select(1.0, bp1mn.array())));

z_pow_mn_sign *= z_sign;
log_phammer_ap1n_sign.array() *= sign(value_of_rec(ap1n)).array();
Expand All @@ -266,9 +272,9 @@ void grad_pFq_impl(TupleT&& grad_tuple, const Ta& a, const Tb& b, const Tz& z,
log_z_m += log_z;
log_phammer_1m += log1p(m);
log_phammer_2m += log(2 + m);
log_phammer_ap1_m += log(stan::math::fabs(ap1m));
log_phammer_ap1_m += log(math::fabs(ap1m));
log_phammer_ap1m_sign.array() *= sign(value_of_rec(ap1m)).array();
log_phammer_bp1_m += log(stan::math::fabs(bp1m));
log_phammer_bp1_m += log(math::fabs(bp1m));
log_phammer_bp1m_sign.array() *= sign(value_of_rec(bp1m)).array();

m += 1;
Expand Down
153 changes: 153 additions & 0 deletions stan/math/prim/fun/hypergeometric_3F2.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
#ifndef STAN_MATH_PRIM_FUN_HYPERGEOMETRIC_3F2_HPP
#define STAN_MATH_PRIM_FUN_HYPERGEOMETRIC_3F2_HPP

#include <stan/math/prim/meta.hpp>
#include <stan/math/prim/err.hpp>
#include <stan/math/prim/fun/append_row.hpp>
#include <stan/math/prim/fun/as_array_or_scalar.hpp>
#include <stan/math/prim/fun/to_vector.hpp>
#include <stan/math/prim/fun/constants.hpp>
#include <stan/math/prim/fun/fabs.hpp>
#include <stan/math/prim/fun/hypergeometric_pFq.hpp>
#include <stan/math/prim/fun/sum.hpp>
#include <stan/math/prim/fun/sign.hpp>
#include <stan/math/prim/fun/value_of_rec.hpp>

namespace stan {
namespace math {
namespace internal {
template <typename Ta, typename Tb, typename Tz,
typename T_return = return_type_t<Ta, Tb, Tz>,
typename ArrayAT = Eigen::Array<scalar_type_t<Ta>, 3, 1>,
typename ArrayBT = Eigen::Array<scalar_type_t<Ta>, 3, 1>,
require_all_vector_t<Ta, Tb>* = nullptr,
require_stan_scalar_t<Tz>* = nullptr>
T_return hypergeometric_3F2_infsum(const Ta& a, const Tb& b, const Tz& z,
double precision = 1e-6,
int max_steps = 1e5) {
ArrayAT a_array = as_array_or_scalar(a);
ArrayBT b_array = append_row(as_array_or_scalar(b), 1.0);
check_3F2_converges("hypergeometric_3F2", a_array[0], a_array[1], a_array[2],
b_array[0], b_array[1], z);

T_return t_acc = 1.0;
T_return log_t = 0.0;
T_return log_z = log(fabs(z));
Eigen::ArrayXi a_signs = sign(value_of_rec(a_array));
Eigen::ArrayXi b_signs = sign(value_of_rec(b_array));
plain_type_t<decltype(a_array)> apk = a_array;
plain_type_t<decltype(b_array)> bpk = b_array;
int z_sign = sign(value_of_rec(z));
int t_sign = z_sign * a_signs.prod() * b_signs.prod();

int k = 0;
while (k <= max_steps && log_t >= log(precision)) {
// Replace zero values with 1 prior to taking the log so that we accumulate
// 0.0 rather than -inf
const auto& abs_apk = math::fabs((apk == 0).select(1.0, apk));
const auto& abs_bpk = math::fabs((bpk == 0).select(1.0, bpk));
T_return p = sum(log(abs_apk)) - sum(log(abs_bpk));
if (p == NEGATIVE_INFTY) {
return t_acc;
}

log_t += p + log_z;
t_acc += t_sign * exp(log_t);

if (is_inf(t_acc)) {
throw_domain_error("hypergeometric_3F2", "sum (output)", t_acc,
"overflow hypergeometric function did not converge.");
}
k++;
apk.array() += 1.0;
bpk.array() += 1.0;
a_signs = sign(value_of_rec(apk));
b_signs = sign(value_of_rec(bpk));
t_sign = a_signs.prod() * b_signs.prod() * t_sign;
}
if (k == max_steps) {
throw_domain_error("hypergeometric_3F2", "k (internal counter)", max_steps,
"exceeded iterations, hypergeometric function did not ",
"converge.");
}
return t_acc;
}
} // namespace internal

/**
* Hypergeometric function (3F2).
*
* Function reference: http://dlmf.nist.gov/16.2
*
* \f[
* _3F_2 \left(
* \begin{matrix}a_1 a_2 a[2] \\ b_1 b_2\end{matrix}; z
* \right) = \sum_k=0^\infty
* \frac{(a_1)_k(a_2)_k(a_3)_k}{(b_1)_k(b_2)_k}\frac{z^k}{k!} \f]
*
* Where $(a_1)_k$ is an upper shifted factorial.
*
* Calculate the hypergeometric function (3F2) as the power series
* directly to within <code>precision</code> or until
* <code>max_steps</code> terms.
*
* This function does not have a closed form but will converge if:
* - <code>|z|</code> is less than 1
* - <code>|z|</code> is equal to one and <code>b[0] + b[1] < a[0] + a[1] +
* a[2]</code> This function is a rational polynomial if
* - <code>a[0]</code>, <code>a[1]</code>, or <code>a[2]</code> is a
* non-positive integer
* This function can be treated as a rational polynomial if
* - <code>b[0]</code> or <code>b[1]</code> is a non-positive integer
* and the series is terminated prior to the final term.
*
* @tparam Ta type of Eigen/Std vector 'a' arguments
* @tparam Tb type of Eigen/Std vector 'b' arguments
* @tparam Tz type of z argument
* @param[in] a Always called with a[1] > 1, a[2] <= 0
* @param[in] b Always called with int b[0] < |a[2]|, <= 1)
* @param[in] z z (is always called with 1 from beta binomial cdfs)
* @param[in] precision precision of the infinite sum. defaults to 1e-6
* @param[in] max_steps number of steps to take. defaults to 1e5
* @return The 3F2 generalized hypergeometric function applied to the
* arguments {a1, a2, a3}, {b1, b2}
*/
template <typename Ta, typename Tb, typename Tz,
require_all_vector_t<Ta, Tb>* = nullptr,
require_stan_scalar_t<Tz>* = nullptr>
auto hypergeometric_3F2(const Ta& a, const Tb& b, const Tz& z) {
check_3F2_converges("hypergeometric_3F2", a[0], a[1], a[2], b[0], b[1], z);
// Boost's pFq throws convergence errors in some cases, fallback to naive
// infinite-sum approach (tests pass for these)
if (z == 1.0 && (sum(b) - sum(a)) < 0.0) {
return internal::hypergeometric_3F2_infsum(a, b, z);
}
return hypergeometric_pFq(to_vector(a), to_vector(b), z);
}

/**
* Hypergeometric function (3F2).
*
* Overload for initializer_list inputs
*
* @tparam Ta type of scalar 'a' arguments
* @tparam Tb type of scalar 'b' arguments
* @tparam Tz type of z argument
* @param[in] a Always called with a[1] > 1, a[2] <= 0
* @param[in] b Always called with int b[0] < |a[2]|, <= 1)
* @param[in] z z (is always called with 1 from beta binomial cdfs)
* @param[in] precision precision of the infinite sum. defaults to 1e-6
* @param[in] max_steps number of steps to take. defaults to 1e5
* @return The 3F2 generalized hypergeometric function applied to the
* arguments {a1, a2, a3}, {b1, b2}
*/
template <typename Ta, typename Tb, typename Tz,
require_all_stan_scalar_t<Ta, Tb, Tz>* = nullptr>
auto hypergeometric_3F2(const std::initializer_list<Ta>& a,
const std::initializer_list<Tb>& b, const Tz& z) {
return hypergeometric_3F2(std::vector<Ta>(a), std::vector<Tb>(b), z);
}

} // namespace math
} // namespace stan
#endif
3 changes: 2 additions & 1 deletion stan/math/prim/fun/log_sum_exp_signed.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
#include <stan/math/prim/meta.hpp>
#include <stan/math/prim/fun/constants.hpp>
#include <stan/math/prim/fun/Eigen.hpp>
#include <stan/math/prim/fun/log1p_exp.hpp>
#include <stan/math/prim/fun/log_diff_exp.hpp>
#include <stan/math/prim/fun/log_sum_exp.hpp>
#include <cmath>
#include <vector>

Expand Down
6 changes: 3 additions & 3 deletions stan/math/prim/prob/beta_binomial_cdf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#include <stan/math/prim/fun/constants.hpp>
#include <stan/math/prim/fun/digamma.hpp>
#include <stan/math/prim/fun/exp.hpp>
#include <stan/math/prim/fun/F32.hpp>
#include <stan/math/prim/fun/hypergeometric_3F2.hpp>
#include <stan/math/prim/fun/grad_F32.hpp>
#include <stan/math/prim/fun/lbeta.hpp>
#include <stan/math/prim/fun/max_size.hpp>
Expand Down Expand Up @@ -100,8 +100,8 @@ return_type_t<T_size1, T_size2> beta_binomial_cdf(const T_n& n, const T_N& N,
const T_partials_return nu = beta_dbl + N_minus_n - 1;
const T_partials_return one = 1;

const T_partials_return F
= F32(one, mu, 1 - N_minus_n, n_dbl + 2, 1 - nu, one);
const T_partials_return F = hypergeometric_3F2({one, mu, 1 - N_minus_n},
{n_dbl + 2, 1 - nu}, one);

T_partials_return C = lbeta(nu, mu) - lbeta(alpha_dbl, beta_dbl)
- lbeta(N_minus_n, n_dbl + 2);
Expand Down
6 changes: 3 additions & 3 deletions stan/math/prim/prob/beta_binomial_lccdf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#include <stan/math/prim/fun/constants.hpp>
#include <stan/math/prim/fun/digamma.hpp>
#include <stan/math/prim/fun/exp.hpp>
#include <stan/math/prim/fun/F32.hpp>
#include <stan/math/prim/fun/hypergeometric_3F2.hpp>
#include <stan/math/prim/fun/grad_F32.hpp>
#include <stan/math/prim/fun/lbeta.hpp>
#include <stan/math/prim/fun/log.hpp>
Expand Down Expand Up @@ -101,8 +101,8 @@ return_type_t<T_size1, T_size2> beta_binomial_lccdf(const T_n& n, const T_N& N,
const T_partials_return nu = beta_dbl + N_dbl - n_dbl - 1;
const T_partials_return one = 1;

const T_partials_return F
= F32(one, mu, -N_dbl + n_dbl + 1, n_dbl + 2, 1 - nu, one);
const T_partials_return F = hypergeometric_3F2(
{one, mu, -N_dbl + n_dbl + 1}, {n_dbl + 2, 1 - nu}, one);
T_partials_return C = lbeta(nu, mu) - lbeta(alpha_dbl, beta_dbl)
- lbeta(N_dbl - n_dbl, n_dbl + 2);
C = F * exp(C) / (N_dbl + 1);
Expand Down
Loading

0 comments on commit c4b0717

Please sign in to comment.