diff --git a/stan/math/prim/fun/hypergeometric_3F2.hpp b/stan/math/prim/fun/hypergeometric_3F2.hpp index 4508e3b4c84..5f5dbedea69 100644 --- a/stan/math/prim/fun/hypergeometric_3F2.hpp +++ b/stan/math/prim/fun/hypergeometric_3F2.hpp @@ -12,6 +12,7 @@ #include #include #include +#include namespace stan { namespace math { @@ -30,7 +31,7 @@ T_return hypergeometric_3F2_infsum(const Ta& a, const Tb& b, const Tz& z, check_3F2_converges("hypergeometric_3F2", a_array[0], a_array[1], a_array[2], b_array[0], b_array[1], z); - T_return t_acc = 1.0; + T_return t_acc = 0.0; T_return log_t = 0.0; T_return log_z = log(fabs(z)); Eigen::ArrayXi a_signs = sign(value_of_rec(a_array)); @@ -39,7 +40,7 @@ T_return hypergeometric_3F2_infsum(const Ta& a, const Tb& b, const Tz& z, plain_type_t bpk = b_array; int z_sign = sign(value_of_rec(z)); int t_sign = z_sign * a_signs.prod() * b_signs.prod(); - + int acc_sign = 1; int k = 0; while (k <= max_steps && log_t >= log(precision)) { // Replace zero values with 1 prior to taking the log so that we accumulate @@ -52,7 +53,8 @@ T_return hypergeometric_3F2_infsum(const Ta& a, const Tb& b, const Tz& z, } log_t += p + log_z; - t_acc += t_sign * exp(log_t); + std::forward_as_tuple(t_acc, acc_sign) + = log_sum_exp_signed(t_acc, acc_sign, log_t, t_sign); if (is_inf(t_acc)) { throw_domain_error("hypergeometric_3F2", "sum (output)", t_acc, @@ -70,7 +72,7 @@ T_return hypergeometric_3F2_infsum(const Ta& a, const Tb& b, const Tz& z, "exceeded iterations, hypergeometric function did not ", "converge."); } - return t_acc; + return acc_sign * exp(t_acc); } } // namespace internal @@ -109,7 +111,8 @@ T_return hypergeometric_3F2_infsum(const Ta& a, const Tb& b, const Tz& z, * @param[in] z z (is always called with 1 from beta binomial cdfs) * @param[in] precision precision of the infinite sum. defaults to 1e-6 * @param[in] max_steps number of steps to take. defaults to 1e5 - * @return Generalized hypergeometric function applied to the inputs + * The 3F2 generalized hypergeometric function applied to the + * arguments {a1, a2, a3}, {b1, b2} */ template * = nullptr, diff --git a/stan/math/prim/fun/log_sum_exp_signed.hpp b/stan/math/prim/fun/log_sum_exp_signed.hpp index 196093eb3c5..97a74ac4407 100644 --- a/stan/math/prim/fun/log_sum_exp_signed.hpp +++ b/stan/math/prim/fun/log_sum_exp_signed.hpp @@ -4,7 +4,8 @@ #include #include #include -#include +#include +#include #include #include