Skip to content

Commit

Permalink
Update doc, use log_sum_exp_signed
Browse files Browse the repository at this point in the history
  • Loading branch information
andrjohns committed Oct 21, 2022
1 parent 53a02e3 commit 864f65f
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 6 deletions.
13 changes: 8 additions & 5 deletions stan/math/prim/fun/hypergeometric_3F2.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <stan/math/prim/fun/sum.hpp>
#include <stan/math/prim/fun/sign.hpp>
#include <stan/math/prim/fun/value_of_rec.hpp>
#include <stan/math/prim/fun/log_sum_exp_signed.hpp>

namespace stan {
namespace math {
Expand All @@ -30,7 +31,7 @@ T_return hypergeometric_3F2_infsum(const Ta& a, const Tb& b, const Tz& z,
check_3F2_converges("hypergeometric_3F2", a_array[0], a_array[1], a_array[2],
b_array[0], b_array[1], z);

T_return t_acc = 1.0;
T_return t_acc = 0.0;
T_return log_t = 0.0;
T_return log_z = log(fabs(z));
Eigen::ArrayXi a_signs = sign(value_of_rec(a_array));
Expand All @@ -39,7 +40,7 @@ T_return hypergeometric_3F2_infsum(const Ta& a, const Tb& b, const Tz& z,
plain_type_t<decltype(b_array)> bpk = b_array;
int z_sign = sign(value_of_rec(z));
int t_sign = z_sign * a_signs.prod() * b_signs.prod();

int acc_sign = 1;
int k = 0;
while (k <= max_steps && log_t >= log(precision)) {
// Replace zero values with 1 prior to taking the log so that we accumulate
Expand All @@ -52,7 +53,8 @@ T_return hypergeometric_3F2_infsum(const Ta& a, const Tb& b, const Tz& z,
}

log_t += p + log_z;
t_acc += t_sign * exp(log_t);
std::forward_as_tuple(t_acc, acc_sign)
= log_sum_exp_signed(t_acc, acc_sign, log_t, t_sign);

if (is_inf(t_acc)) {
throw_domain_error("hypergeometric_3F2", "sum (output)", t_acc,
Expand All @@ -70,7 +72,7 @@ T_return hypergeometric_3F2_infsum(const Ta& a, const Tb& b, const Tz& z,
"exceeded iterations, hypergeometric function did not ",
"converge.");
}
return t_acc;
return acc_sign * exp(t_acc);
}
} // namespace internal

Expand Down Expand Up @@ -109,7 +111,8 @@ T_return hypergeometric_3F2_infsum(const Ta& a, const Tb& b, const Tz& z,
* @param[in] z z (is always called with 1 from beta binomial cdfs)
* @param[in] precision precision of the infinite sum. defaults to 1e-6
* @param[in] max_steps number of steps to take. defaults to 1e5
* @return Generalized hypergeometric function applied to the inputs
* The 3F2 generalized hypergeometric function applied to the
* arguments {a1, a2, a3}, {b1, b2}
*/
template <typename Ta, typename Tb, typename Tz,
require_all_vector_t<Ta, Tb>* = nullptr,
Expand Down
3 changes: 2 additions & 1 deletion stan/math/prim/fun/log_sum_exp_signed.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
#include <stan/math/prim/meta.hpp>
#include <stan/math/prim/fun/constants.hpp>
#include <stan/math/prim/fun/Eigen.hpp>
#include <stan/math/prim/fun/log1p_exp.hpp>
#include <stan/math/prim/fun/log_diff_exp.hpp>
#include <stan/math/prim/fun/log_sum_exp.hpp>
#include <cmath>
#include <vector>

Expand Down

0 comments on commit 864f65f

Please sign in to comment.