1 #ifndef STAN_MATH_PRIM_SCAL_FUN_GRAD_2F1_HPP 2 #define STAN_MATH_PRIM_SCAL_FUN_GRAD_2F1_HPP 35 void grad_2F1(T& g_a1, T& g_b1,
const T& a1,
const T& a2,
const T& b1,
36 const T& z,
const T& precision = 1
e-10,
int max_steps = 1e5) {
47 for (
int i = 0; i < 2; ++i)
48 log_g_old[i] = -std::numeric_limits<T>::infinity();
55 double log_t_new_sign = 1.0;
56 double log_t_old_sign = 1.0;
57 double log_g_old_sign[2];
58 for (
int i = 0; i < 2; ++i)
59 log_g_old_sign[i] = 1.0;
61 for (
int k = 0; k <= max_steps; ++k) {
62 T p = (a1 + k) * (a2 + k) / ((b1 + k) * (1 + k));
66 log_t_new +=
log(
fabs(p)) + log_z;
67 log_t_new_sign = p >= 0.0 ? log_t_new_sign : -log_t_new_sign;
69 T term = log_g_old_sign[0] * log_t_old_sign *
70 exp(log_g_old[0] - log_t_old) + 1 / (a1 + k);
71 log_g_old[0] = log_t_new +
log(
fabs(term));
72 log_g_old_sign[0] = term >= 0.0 ? log_t_new_sign : -log_t_new_sign;
74 term = log_g_old_sign[1] * log_t_old_sign *
75 exp(log_g_old[1] - log_t_old) - 1 / (b1 + k);
76 log_g_old[1] = log_t_new +
log(
fabs(term));
77 log_g_old_sign[1] = term >= 0.0 ? log_t_new_sign : -log_t_new_sign;
79 g_a1 += log_g_old_sign[0] > 0 ?
exp(log_g_old[0]) : -
exp(log_g_old[0]);
80 g_b1 += log_g_old_sign[1] > 0 ?
exp(log_g_old[1]) : -
exp(log_g_old[1]);
82 if (log_t_new <=
log(precision))
85 log_t_old = log_t_new;
86 log_t_old_sign = log_t_new_sign;
88 domain_error(
"grad_2F1",
"k (internal counter)", max_steps,
89 "exceeded ",
" iterations, hypergeometric function gradient " fvar< T > fabs(const fvar< T > &x)
fvar< T > log(const fvar< T > &x)
void grad_2F1(T &g_a1, T &g_b1, const T &a1, const T &a2, const T &b1, const T &z, const T &precision=1e-10, int max_steps=1e5)
Gradients of the hypergeometric function, 2F1.
fvar< T > exp(const fvar< T > &x)
void domain_error(const char *function, const char *name, const T &y, const char *msg1, const char *msg2)
Throw a domain error with a consistently formatted message.
double e()
Return the base of the natural logarithm.
void check_2F1_converges(const char *function, const T_a1 &a1, const T_a2 &a2, const T_b1 &b1, const T_z &z)
Check if the hypergeometric function (2F1) called with supplied arguments will converge, assuming arguments are finite values.