1 #ifndef STAN_MATH_PRIM_SCAL_FUN_GRAD_F32_HPP 2 #define STAN_MATH_PRIM_SCAL_FUN_GRAD_F32_HPP 36 void grad_F32(T* g,
const T& a1,
const T& a2,
const T& a3,
const T& b1,
37 const T& b2,
const T& z,
const T& precision = 1
e-6,
38 int max_steps = 1e5) {
45 for (
int i = 0; i < 6; ++i)
49 for (
int i = 0; i < 6; ++i)
50 log_g_old[i] = -std::numeric_limits<double>::infinity();
57 double log_t_new_sign = 1.0;
58 double log_t_old_sign = 1.0;
59 double log_g_old_sign[6];
60 for (
int i = 0; i < 6; ++i)
61 log_g_old_sign[i] = 1.0;
63 for (
int k = 0; k <= max_steps; ++k) {
64 T p = (a1 + k) * (a2 + k) * (a3 + k) / ((b1 + k) * (b2 + k) * (1 + k));
68 log_t_new +=
log(
fabs(p)) + log_z;
69 log_t_new_sign = p >= 0.0 ? log_t_new_sign : -log_t_new_sign;
72 T term = log_g_old_sign[0] * log_t_old_sign *
73 exp(log_g_old[0] - log_t_old) +
inv(a1 + k);
74 log_g_old[0] = log_t_new +
log(
fabs(term));
75 log_g_old_sign[0] = term >= 0.0 ? log_t_new_sign : -log_t_new_sign;
78 term = log_g_old_sign[1] * log_t_old_sign *
79 exp(log_g_old[1] - log_t_old) +
inv(a2 + k);
80 log_g_old[1] = log_t_new +
log(
fabs(term));
81 log_g_old_sign[1] = term >= 0.0 ? log_t_new_sign : -log_t_new_sign;
84 term = log_g_old_sign[2] * log_t_old_sign *
85 exp(log_g_old[2] - log_t_old) +
inv(a3 + k);
86 log_g_old[2] = log_t_new +
log(
fabs(term));
87 log_g_old_sign[2] = term >= 0.0 ? log_t_new_sign : -log_t_new_sign;
90 term = log_g_old_sign[3] * log_t_old_sign *
91 exp(log_g_old[3] - log_t_old) -
inv(b1 + k);
92 log_g_old[3] = log_t_new +
log(
fabs(term));
93 log_g_old_sign[3] = term >= 0.0 ? log_t_new_sign : -log_t_new_sign;
96 term = log_g_old_sign[4] * log_t_old_sign *
97 exp(log_g_old[4] - log_t_old) -
inv(b2 + k);
98 log_g_old[4] = log_t_new +
log(
fabs(term));
99 log_g_old_sign[4] = term >= 0.0 ? log_t_new_sign : -log_t_new_sign;
102 term = log_g_old_sign[5] * log_t_old_sign *
103 exp(log_g_old[5] - log_t_old) +
inv(z);
104 log_g_old[5] = log_t_new +
log(
fabs(term));
105 log_g_old_sign[5] = term >= 0.0 ? log_t_new_sign : -log_t_new_sign;
107 for (
int i = 0; i < 6; ++i) {
108 g[i] += log_g_old_sign[i] *
exp(log_g_old[i]);
111 if (log_t_new <=
log(precision))
114 log_t_old = log_t_new;
115 log_t_old_sign = log_t_new_sign;
117 domain_error(
"grad_F32",
"k (internal counter)", max_steps,
118 "exceeded ",
" iterations, hypergeometric function gradient " 119 "did not converge.");
fvar< T > fabs(const fvar< T > &x)
void check_3F2_converges(const char *function, const T_a1 &a1, const T_a2 &a2, const T_a3 &a3, const T_b1 &b1, const T_b2 &b2, const T_z &z)
Check if the hypergeometric function (3F2) called with supplied arguments will converge, assuming arguments are finite values.
fvar< T > log(const fvar< T > &x)
fvar< T > exp(const fvar< T > &x)
void domain_error(const char *function, const char *name, const T &y, const char *msg1, const char *msg2)
Throw a domain error with a consistently formatted message.
double e()
Return the base of the natural logarithm.
void grad_F32(T *g, const T &a1, const T &a2, const T &a3, const T &b1, const T &b2, const T &z, const T &precision=1e-6, int max_steps=1e5)
Gradients of the hypergeometric function, 3F2.
fvar< T > inv(const fvar< T > &x)