1 #ifndef STAN_MATH_REV_SCAL_FUN_FMA_HPP 2 #define STAN_MATH_REV_SCAL_FUN_FMA_HPP 15 class fma_vvv_vari :
public op_vvv_vari {
17 fma_vvv_vari(vari* avi, vari* bvi, vari* cvi) :
18 op_vvv_vari(
fma(avi->val_, bvi->val_, cvi->val_), avi, bvi, cvi) {
24 avi_->adj_ = std::numeric_limits<double>::quiet_NaN();
25 bvi_->adj_ = std::numeric_limits<double>::quiet_NaN();
26 cvi_->adj_ = std::numeric_limits<double>::quiet_NaN();
28 avi_->adj_ += adj_ * bvi_->val_;
29 bvi_->adj_ += adj_ * avi_->val_;
35 class fma_vvd_vari :
public op_vvd_vari {
37 fma_vvd_vari(vari* avi, vari* bvi,
double c) :
38 op_vvd_vari(
fma(avi->val_, bvi->val_, c), avi, bvi, c) {
44 avi_->adj_ = std::numeric_limits<double>::quiet_NaN();
45 bvi_->adj_ = std::numeric_limits<double>::quiet_NaN();
47 avi_->adj_ += adj_ * bvi_->val_;
48 bvi_->adj_ += adj_ * avi_->val_;
53 class fma_vdv_vari :
public op_vdv_vari {
55 fma_vdv_vari(vari* avi,
double b, vari* cvi) :
56 op_vdv_vari(
fma(avi->val_ , b, cvi->val_), avi, b, cvi) {
62 avi_->adj_ = std::numeric_limits<double>::quiet_NaN();
63 cvi_->adj_ = std::numeric_limits<double>::quiet_NaN();
65 avi_->adj_ += adj_ * bd_;
71 class fma_vdd_vari :
public op_vdd_vari {
73 fma_vdd_vari(vari* avi,
double b,
double c) :
74 op_vdd_vari(
fma(avi->val_ , b, c), avi, b, c) {
80 avi_->adj_ = std::numeric_limits<double>::quiet_NaN();
82 avi_->adj_ += adj_ * bd_;
86 class fma_ddv_vari :
public op_ddv_vari {
88 fma_ddv_vari(
double a,
double b, vari* cvi) :
89 op_ddv_vari(
fma(a, b, cvi->val_), a, b, cvi) {
95 cvi_->adj_ = std::numeric_limits<double>::quiet_NaN();
141 return var(
new fma_vvd_vari(a.
vi_, b.
vi_, c));
161 return var(
new fma_vdv_vari(a.
vi_, b, c.
vi_));
182 return var(
new fma_vdd_vari(a.
vi_, b, c));
200 return var(
new fma_vdd_vari(b.
vi_, a, c));
218 return var(
new fma_ddv_vari(a, b, c.
vi_));
238 return var(
new fma_vdv_vari(b.
vi_, a, c.
vi_));
Independent (input) and dependent (output) variables for gradients.
fvar< typename stan::return_type< T1, T2, T3 >::type > fma(const fvar< T1 > &x1, const fvar< T2 > &x2, const fvar< T3 > &x3)
The fused multiply-add operation (C99).
vari * vi_
Pointer to the implementation of this variable.
int is_nan(const fvar< T > &x)
Returns 1 if the input's value is NaN and 0 otherwise.