Stan Math Library  2.15.0
reverse mode automatic differentiation
falling_factorial.hpp
Go to the documentation of this file.
1 #ifndef STAN_MATH_REV_SCAL_FUN_FALLING_FACTORIAL_HPP
2 #define STAN_MATH_REV_SCAL_FUN_FALLING_FACTORIAL_HPP
3 
4 #include <stan/math/rev/core.hpp>
7 
8 namespace stan {
9  namespace math {
10 
11  namespace {
12 
13  class falling_factorial_vv_vari : public op_vv_vari {
14  public:
15  falling_factorial_vv_vari(vari* avi, vari* bvi) :
16  op_vv_vari(falling_factorial(avi->val_, bvi->val_),
17  avi, bvi) {
18  }
19  void chain() {
20  avi_->adj_ += adj_
21  * val_
22  * (digamma(avi_->val_ + 1)
23  - digamma(avi_->val_ - bvi_->val_ + 1));
24  bvi_->adj_ += adj_
25  * val_
26  * digamma(avi_->val_ - bvi_->val_ + 1);
27  }
28  };
29 
30  class falling_factorial_vd_vari : public op_vd_vari {
31  public:
32  falling_factorial_vd_vari(vari* avi, double b) :
33  op_vd_vari(falling_factorial(avi->val_, b), avi, b) {
34  }
35  void chain() {
36  avi_->adj_ += adj_
37  * val_
38  * (digamma(avi_->val_ + 1)
39  - digamma(avi_->val_ - bd_ + 1));
40  }
41  };
42 
43  class falling_factorial_dv_vari : public op_dv_vari {
44  public:
45  falling_factorial_dv_vari(double a, vari* bvi) :
46  op_dv_vari(falling_factorial(a, bvi->val_), a, bvi) {
47  }
48  void chain() {
49  bvi_->adj_ += adj_
50  * val_
51  * digamma(ad_ - bvi_->val_ + 1);
52  }
53  };
54  }
55 
56  inline var falling_factorial(const var& a,
57  double b) {
58  return var(new falling_factorial_vd_vari(a.vi_, b));
59  }
60 
61  inline var falling_factorial(const var& a,
62  const var& b) {
63  return var(new falling_factorial_vv_vari(a.vi_, b.vi_));
64  }
65 
66  inline var falling_factorial(double a,
67  const var& b) {
68  return var(new falling_factorial_dv_vari(a, b.vi_));
69  }
70 
71  }
72 }
73 #endif
Independent (input) and dependent (output) variables for gradients.
Definition: var.hpp:30
vari * vi_
Pointer to the implementation of this variable.
Definition: var.hpp:42
fvar< T > falling_factorial(const fvar< T > &x, const fvar< T > &n)
fvar< T > digamma(const fvar< T > &x)
Return the derivative of the log gamma function at the specified argument.
Definition: digamma.hpp:22

     [ Stan Home Page ] © 2011–2016, Stan Development Team.