Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose hypergeometric_3F2 function #2797

Merged
merged 22 commits into from
Oct 23, 2022
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions stan/math/fwd/fun/inv_inc_beta.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
#include <stan/math/prim/fun/lbeta.hpp>
#include <stan/math/prim/fun/lgamma.hpp>
#include <stan/math/prim/fun/digamma.hpp>
#include <stan/math/prim/fun/F32.hpp>
#include <stan/math/prim/fun/hypergeometric_3F2.hpp>

namespace stan {
namespace math {
Expand Down Expand Up @@ -53,19 +53,21 @@ inline fvar<partials_return_t<T1, T2, T3>> inv_inc_beta(const T1& a,
T_return inv_d_(0);

if (is_fvar<T1>::value) {
std::vector<T_return> da_a{a_val, a_val, one_m_b};
std::vector<T_return> da_b{ap1, ap1};
auto da1 = exp(one_m_b * log1m_w + one_m_a * log_w);
auto da2
= exp(a_val * log_w + 2 * lgamma(a_val)
+ log(F32(a_val, a_val, one_m_b, ap1, ap1, w)) - 2 * lgamma(ap1));
auto da2 = exp(a_val * log_w + 2 * lgamma(a_val)
+ log(hypergeometric_3F2(da_a, da_b, w)) - 2 * lgamma(ap1));
auto da3 = inc_beta(a_val, b_val, w) * exp(lbeta_ab)
* (log_w - digamma(a_val) + digamma_apb);
inv_d_ += forward_as<fvar<T_return>>(a).d_ * da1 * (da2 - da3);
}

if (is_fvar<T2>::value) {
std::vector<T_return> db_a{b_val, b_val, one_m_a};
std::vector<T_return> db_b{bp1, bp1};
auto db1 = (w - 1) * exp(-b_val * log1m_w + one_m_a * log_w);
auto db2 = 2 * lgamma(b_val)
+ log(F32(b_val, b_val, one_m_a, bp1, bp1, one_m_w))
auto db2 = 2 * lgamma(b_val) + log(hypergeometric_3F2(db_a, db_b, one_m_w))
- 2 * lgamma(bp1) + b_val * log1m_w;

auto db3 = inc_beta(b_val, a_val, one_m_w) * exp(lbeta_ab)
Expand Down
2 changes: 1 addition & 1 deletion stan/math/prim/fun.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,6 @@
#include <stan/math/prim/fun/exp.hpp>
#include <stan/math/prim/fun/exp2.hpp>
#include <stan/math/prim/fun/expm1.hpp>
#include <stan/math/prim/fun/F32.hpp>
#include <stan/math/prim/fun/fabs.hpp>
#include <stan/math/prim/fun/factor_U.hpp>
#include <stan/math/prim/fun/factor_cov_matrix.hpp>
Expand Down Expand Up @@ -132,6 +131,7 @@
#include <stan/math/prim/fun/head.hpp>
#include <stan/math/prim/fun/hypergeometric_2F1.hpp>
#include <stan/math/prim/fun/hypergeometric_2F2.hpp>
#include <stan/math/prim/fun/hypergeometric_3F2.hpp>
#include <stan/math/prim/fun/hypergeometric_pFq.hpp>
#include <stan/math/prim/fun/hypot.hpp>
#include <stan/math/prim/fun/identity_constrain.hpp>
Expand Down
99 changes: 0 additions & 99 deletions stan/math/prim/fun/F32.hpp

This file was deleted.

22 changes: 14 additions & 8 deletions stan/math/prim/fun/grad_pFq.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -234,12 +234,18 @@ void grad_pFq_impl(TupleT&& grad_tuple, const Ta& a, const Tb& b, const Tz& z,
log_phammer_1n += log1p(n);
log_phammer_2_mpn += log(2 + m + n);

log_phammer_ap1_n += log(stan::math::fabs(ap1n));
log_phammer_bp1_n += log(stan::math::fabs(bp1n));
log_phammer_an += log(stan::math::fabs(an));
log_phammer_bn += log(stan::math::fabs(bn));
log_phammer_ap1_mpn += log(stan::math::fabs(ap1mn));
log_phammer_bp1_mpn += log(stan::math::fabs(bp1mn));
log_phammer_ap1_n.array()
+= log(math::fabs((ap1n.array() == 0).select(1.0, ap1n.array())));
log_phammer_bp1_n.array()
+= log(math::fabs((bp1n.array() == 0).select(1.0, bp1n.array())));
log_phammer_an.array()
+= log(math::fabs((an.array() == 0).select(1.0, an.array())));
log_phammer_bn.array()
+= log(math::fabs((bn.array() == 0).select(1.0, bn.array())));
log_phammer_ap1_mpn.array()
+= log(math::fabs((ap1mn.array() == 0).select(1.0, ap1mn.array())));
log_phammer_bp1_mpn.array()
+= log(math::fabs((bp1mn.array() == 0).select(1.0, bp1mn.array())));

z_pow_mn_sign *= z_sign;
log_phammer_ap1n_sign.array() *= sign(value_of_rec(ap1n)).array();
Expand All @@ -266,9 +272,9 @@ void grad_pFq_impl(TupleT&& grad_tuple, const Ta& a, const Tb& b, const Tz& z,
log_z_m += log_z;
log_phammer_1m += log1p(m);
log_phammer_2m += log(2 + m);
log_phammer_ap1_m += log(stan::math::fabs(ap1m));
log_phammer_ap1_m += log(math::fabs(ap1m));
log_phammer_ap1m_sign.array() *= sign(value_of_rec(ap1m)).array();
log_phammer_bp1_m += log(stan::math::fabs(bp1m));
log_phammer_bp1_m += log(math::fabs(bp1m));
log_phammer_bp1m_sign.array() *= sign(value_of_rec(bp1m)).array();

m += 1;
Expand Down
152 changes: 152 additions & 0 deletions stan/math/prim/fun/hypergeometric_3F2.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
#ifndef STAN_MATH_PRIM_FUN_HYPERGEOMETRIC_3F2_HPP
#define STAN_MATH_PRIM_FUN_HYPERGEOMETRIC_3F2_HPP

#include <stan/math/prim/meta.hpp>
#include <stan/math/prim/err.hpp>
#include <stan/math/prim/fun/append_row.hpp>
#include <stan/math/prim/fun/as_array_or_scalar.hpp>
#include <stan/math/prim/fun/to_vector.hpp>
#include <stan/math/prim/fun/constants.hpp>
#include <stan/math/prim/fun/fabs.hpp>
#include <stan/math/prim/fun/hypergeometric_pFq.hpp>
#include <stan/math/prim/fun/sum.hpp>
#include <stan/math/prim/fun/sign.hpp>
#include <stan/math/prim/fun/value_of_rec.hpp>

namespace stan {
namespace math {
namespace internal {
template <typename Ta, typename Tb, typename Tz,
typename T_return = return_type_t<Ta, Tb, Tz>,
typename ArrayAT = Eigen::Array<scalar_type_t<Ta>, 3, 1>,
typename ArrayBT = Eigen::Array<scalar_type_t<Ta>, 3, 1>,
require_all_vector_t<Ta, Tb>* = nullptr,
require_stan_scalar_t<Tz>* = nullptr>
T_return hypergeometric_3F2_infsum(const Ta& a, const Tb& b, const Tz& z,
double precision = 1e-6,
int max_steps = 1e5) {
ArrayAT a_array = as_array_or_scalar(a);
ArrayBT b_array = append_row(as_array_or_scalar(b), 1.0);
check_3F2_converges("hypergeometric_3F2", a_array[0], a_array[1], a_array[2],
b_array[0], b_array[1], z);

T_return t_acc = 1.0;
T_return log_t = 0.0;
T_return log_z = log(fabs(z));
Eigen::ArrayXi a_signs = sign(value_of_rec(a_array));
Eigen::ArrayXi b_signs = sign(value_of_rec(b_array));
plain_type_t<decltype(a_array)> apk = a_array;
plain_type_t<decltype(b_array)> bpk = b_array;
int z_sign = sign(value_of_rec(z));
int t_sign = z_sign * a_signs.prod() * b_signs.prod();

int k = 0;
while (k <= max_steps && log_t >= log(precision)) {
// Replace zero values with 1 prior to taking the log so that we accumulate
// 0.0 rather than -inf
const auto& abs_apk = math::fabs((apk == 0).select(1.0, apk));
const auto& abs_bpk = math::fabs((bpk == 0).select(1.0, bpk));
T_return p = sum(log(abs_apk)) - sum(log(abs_bpk));
if (p == NEGATIVE_INFTY) {
return t_acc;
}

log_t += p + log_z;
t_acc += t_sign * exp(log_t);

if (is_inf(t_acc)) {
throw_domain_error("hypergeometric_3F2", "sum (output)", t_acc,
"overflow hypergeometric function did not converge.");
}
k++;
apk.array() += 1.0;
bpk.array() += 1.0;
a_signs = sign(value_of_rec(apk));
b_signs = sign(value_of_rec(bpk));
t_sign = a_signs.prod() * b_signs.prod() * t_sign;
}
if (k == max_steps) {
throw_domain_error("hypergeometric_3F2", "k (internal counter)", max_steps,
"exceeded iterations, hypergeometric function did not ",
"converge.");
}
return t_acc;
}
} // namespace internal

/**
* Hypergeometric function (3F2).
*
* Function reference: http://dlmf.nist.gov/16.2
*
* \f[
* _3F_2 \left(
* \begin{matrix}a_1 a_2 a[2] \\ b_1 b_2\end{matrix}; z
* \right) = \sum_k=0^\infty
* \frac{(a_1)_k(a_2)_k(a_3)_k}{(b_1)_k(b_2)_k}\frac{z^k}{k!} \f]
*
* Where $(a_1)_k$ is an upper shifted factorial.
*
* Calculate the hypergeometric function (3F2) as the power series
* directly to within <code>precision</code> or until
* <code>max_steps</code> terms.
*
* This function does not have a closed form but will converge if:
* - <code>|z|</code> is less than 1
* - <code>|z|</code> is equal to one and <code>b[0] + b[1] < a[0] + a[1] +
* a[2]</code> This function is a rational polynomial if
* - <code>a[0]</code>, <code>a[1]</code>, or <code>a[2]</code> is a
* non-positive integer
* This function can be treated as a rational polynomial if
* - <code>b[0]</code> or <code>b[1]</code> is a non-positive integer
* and the series is terminated prior to the final term.
*
* @tparam Ta type of Eigen/Std vector 'a' arguments
* @tparam Tb type of Eigen/Std vector 'b' arguments
* @tparam Tz type of z argument
* @param[in] a Always called with a[1] > 1, a[2] <= 0
* @param[in] b Always called with int b[0] < |a[2]|, <= 1)
* @param[in] z z (is always called with 1 from beta binomial cdfs)
* @param[in] precision precision of the infinite sum. defaults to 1e-6
* @param[in] max_steps number of steps to take. defaults to 1e5
* The 3F2 generalized hypergeometric function applied to the
andrjohns marked this conversation as resolved.
Show resolved Hide resolved
andrjohns marked this conversation as resolved.
Show resolved Hide resolved
* arguments {a1, a2, a3}, {b1, b2}
*/
template <typename Ta, typename Tb, typename Tz,
require_all_vector_t<Ta, Tb>* = nullptr,
require_stan_scalar_t<Tz>* = nullptr>
auto hypergeometric_3F2(const Ta& a, const Tb& b, const Tz& z) {
check_3F2_converges("hypergeometric_3F2", a[0], a[1], a[2], b[0], b[1], z);
// Boost's pFq throws convergence errors in some cases, fallback to naive
// infinite-sum approach (tests pass for these)
if (z == 1.0 && (sum(b) - sum(a)) < 0.0) {
return internal::hypergeometric_3F2_infsum(a, b, z);
}
return hypergeometric_pFq(to_vector(a), to_vector(b), z);
}

/**
* Hypergeometric function (3F2).
*
* Overload for initializer_list inputs
*
* @tparam Ta type of scalar 'a' arguments
* @tparam Tb type of scalar 'b' arguments
* @tparam Tz type of z argument
* @param[in] a Always called with a[1] > 1, a[2] <= 0
* @param[in] b Always called with int b[0] < |a[2]|, <= 1)
* @param[in] z z (is always called with 1 from beta binomial cdfs)
* @param[in] precision precision of the infinite sum. defaults to 1e-6
* @param[in] max_steps number of steps to take. defaults to 1e5
* @return Generalized hypergeometric function applied to the inputs
andrjohns marked this conversation as resolved.
Show resolved Hide resolved
*/
template <typename Ta, typename Tb, typename Tz,
require_all_stan_scalar_t<Ta, Tb, Tz>* = nullptr>
auto hypergeometric_3F2(const std::initializer_list<Ta>& a,
const std::initializer_list<Tb>& b, const Tz& z) {
return hypergeometric_3F2(std::vector<Ta>(a), std::vector<Tb>(b), z);
}

} // namespace math
} // namespace stan
#endif
3 changes: 2 additions & 1 deletion stan/math/prim/fun/log_sum_exp_signed.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
#include <stan/math/prim/meta.hpp>
#include <stan/math/prim/fun/constants.hpp>
#include <stan/math/prim/fun/Eigen.hpp>
#include <stan/math/prim/fun/log1p_exp.hpp>
#include <stan/math/prim/fun/log_diff_exp.hpp>
#include <stan/math/prim/fun/log_sum_exp.hpp>
#include <cmath>
#include <vector>

Expand Down
6 changes: 3 additions & 3 deletions stan/math/prim/prob/beta_binomial_cdf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#include <stan/math/prim/fun/constants.hpp>
#include <stan/math/prim/fun/digamma.hpp>
#include <stan/math/prim/fun/exp.hpp>
#include <stan/math/prim/fun/F32.hpp>
#include <stan/math/prim/fun/hypergeometric_3F2.hpp>
#include <stan/math/prim/fun/grad_F32.hpp>
#include <stan/math/prim/fun/lbeta.hpp>
#include <stan/math/prim/fun/max_size.hpp>
Expand Down Expand Up @@ -100,8 +100,8 @@ return_type_t<T_size1, T_size2> beta_binomial_cdf(const T_n& n, const T_N& N,
const T_partials_return nu = beta_dbl + N_minus_n - 1;
const T_partials_return one = 1;

const T_partials_return F
= F32(one, mu, 1 - N_minus_n, n_dbl + 2, 1 - nu, one);
const T_partials_return F = hypergeometric_3F2({one, mu, 1 - N_minus_n},
{n_dbl + 2, 1 - nu}, one);

T_partials_return C = lbeta(nu, mu) - lbeta(alpha_dbl, beta_dbl)
- lbeta(N_minus_n, n_dbl + 2);
Expand Down
6 changes: 3 additions & 3 deletions stan/math/prim/prob/beta_binomial_lccdf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#include <stan/math/prim/fun/constants.hpp>
#include <stan/math/prim/fun/digamma.hpp>
#include <stan/math/prim/fun/exp.hpp>
#include <stan/math/prim/fun/F32.hpp>
#include <stan/math/prim/fun/hypergeometric_3F2.hpp>
#include <stan/math/prim/fun/grad_F32.hpp>
#include <stan/math/prim/fun/lbeta.hpp>
#include <stan/math/prim/fun/log.hpp>
Expand Down Expand Up @@ -101,8 +101,8 @@ return_type_t<T_size1, T_size2> beta_binomial_lccdf(const T_n& n, const T_N& N,
const T_partials_return nu = beta_dbl + N_dbl - n_dbl - 1;
const T_partials_return one = 1;

const T_partials_return F
= F32(one, mu, -N_dbl + n_dbl + 1, n_dbl + 2, 1 - nu, one);
const T_partials_return F = hypergeometric_3F2(
{one, mu, -N_dbl + n_dbl + 1}, {n_dbl + 2, 1 - nu}, one);
T_partials_return C = lbeta(nu, mu) - lbeta(alpha_dbl, beta_dbl)
- lbeta(N_dbl - n_dbl, n_dbl + 2);
C = F * exp(C) / (N_dbl + 1);
Expand Down
Loading