Stan Math Library  2.11.0
reverse mode automatic differentiation
trace_gen_quad_form.hpp
Go to the documentation of this file.
1 #ifndef STAN_MATH_REV_MAT_FUN_TRACE_GEN_QUAD_FORM_HPP
2 #define STAN_MATH_REV_MAT_FUN_TRACE_GEN_QUAD_FORM_HPP
3 
4 #include <boost/utility/enable_if.hpp>
5 #include <boost/type_traits.hpp>
8 #include <stan/math/rev/core.hpp>
15 
16 namespace stan {
17  namespace math {
18  namespace {
19  template <typename TD, int RD, int CD,
20  typename TA, int RA, int CA,
21  typename TB, int RB, int CB>
22  class trace_gen_quad_form_vari_alloc : public chainable_alloc {
23  public:
24  trace_gen_quad_form_vari_alloc(const Eigen::Matrix<TD, RD, CD>& D,
25  const Eigen::Matrix<TA, RA, CA>& A,
26  const Eigen::Matrix<TB, RB, CB>& B)
27  : D_(D), A_(A), B_(B)
28  { }
29 
30  double compute() {
33  value_of(A_),
34  value_of(B_));
35  }
36 
37  Eigen::Matrix<TD, RD, CD> D_;
38  Eigen::Matrix<TA, RA, CA> A_;
39  Eigen::Matrix<TB, RB, CB> B_;
40  };
41 
42  template <typename TD, int RD, int CD,
43  typename TA, int RA, int CA,
44  typename TB, int RB, int CB>
45  class trace_gen_quad_form_vari : public vari {
46  protected:
47  static inline void
48  computeAdjoints(const double& adj,
49  const Eigen::Matrix<double, RD, CD>& D,
50  const Eigen::Matrix<double, RA, CA>& A,
51  const Eigen::Matrix<double, RB, CB>& B,
52  Eigen::Matrix<var, RD, CD> *varD,
53  Eigen::Matrix<var, RA, CA> *varA,
54  Eigen::Matrix<var, RB, CB> *varB) {
55  Eigen::Matrix<double, CA, CB> AtB;
56  Eigen::Matrix<double, RA, CB> BD;
57  if (varB || varA)
58  BD.noalias() = B*D;
59  if (varB || varD)
60  AtB.noalias() = A.transpose()*B;
61 
62  if (varB) {
63  Eigen::Matrix<double, RB, CB> adjB(adj*(A*BD + AtB*D.transpose()));
64  for (int j = 0; j < B.cols(); j++)
65  for (int i = 0; i < B.rows(); i++)
66  (*varB)(i, j).vi_->adj_ += adjB(i, j);
67  }
68  if (varA) {
69  Eigen::Matrix<double, RA, CA> adjA(adj*(B*BD.transpose()));
70  for (int j = 0; j < A.cols(); j++)
71  for (int i = 0; i < A.rows(); i++)
72  (*varA)(i, j).vi_->adj_ += adjA(i, j);
73  }
74  if (varD) {
75  Eigen::Matrix<double, RD, CD> adjD(adj*(B.transpose()*AtB));
76  for (int j = 0; j < D.cols(); j++)
77  for (int i = 0; i < D.rows(); i++)
78  (*varD)(i, j).vi_->adj_ += adjD(i, j);
79  }
80  }
81 
82 
83  public:
84  explicit
85  trace_gen_quad_form_vari(trace_gen_quad_form_vari_alloc
86  <TD, RD, CD, TA, RA, CA, TB, RB, CB> *impl)
87  : vari(impl->compute()), _impl(impl) { }
88 
89  virtual void chain() {
91  computeAdjoints(adj_,
92  value_of(_impl->D_),
93  value_of(_impl->A_),
94  value_of(_impl->B_),
95  reinterpret_cast<Eigen::Matrix<var, RD, CD> *>
96  (boost::is_same<TD, var>::value?(&_impl->D_):NULL),
97  reinterpret_cast<Eigen::Matrix<var, RA, CA> *>
98  (boost::is_same<TA, var>::value?(&_impl->A_):NULL),
99  reinterpret_cast<Eigen::Matrix<var, RB, CB> *>
100  (boost::is_same<TB, var>::value?(&_impl->B_):NULL));
101  }
102 
103  trace_gen_quad_form_vari_alloc<TD, RD, CD, TA, RA, CA, TB, RB, CB>
105  };
106  }
107 
108  template <typename TD, int RD, int CD,
109  typename TA, int RA, int CA,
110  typename TB, int RB, int CB>
111  inline typename
112  boost::enable_if_c< boost::is_same<TD, var>::value ||
113  boost::is_same<TA, var>::value ||
114  boost::is_same<TB, var>::value,
115  var >::type
116  trace_gen_quad_form(const Eigen::Matrix<TD, RD, CD>& D,
117  const Eigen::Matrix<TA, RA, CA>& A,
118  const Eigen::Matrix<TB, RB, CB>& B) {
119  stan::math::check_square("trace_gen_quad_form", "A", A);
120  stan::math::check_square("trace_gen_quad_form", "D", D);
121  stan::math::check_multiplicable("trace_gen_quad_form",
122  "A", A,
123  "B", B);
124  stan::math::check_multiplicable("trace_gen_quad_form",
125  "B", B,
126  "D", D);
127 
128  trace_gen_quad_form_vari_alloc<TD, RD, CD, TA, RA, CA, TB, RB, CB>
129  *baseVari
130  = new trace_gen_quad_form_vari_alloc<TD, RD, CD, TA, RA, CA, TB, RB, CB>
131  (D, A, B);
132 
133  return var(new trace_gen_quad_form_vari
134  <TD, RD, CD, TA, RA, CA, TB, RB, CB>(baseVari));
135  }
136  }
137 }
138 
139 #endif
Eigen::Matrix< TB, RB, CB > B_
Eigen::Matrix< TA, RA, CA > A_
T value_of(const fvar< T > &v)
Return the value of the specified variable.
Definition: value_of.hpp:16
trace_gen_quad_form_vari_alloc< TD, RD, CD, TA, RA, CA, TB, RB, CB > * _impl
Eigen::Matrix< TD, RD, CD > D_
Independent (input) and dependent (output) variables for gradients.
Definition: var.hpp:31
fvar< T > trace_gen_quad_form(const Eigen::Matrix< fvar< T >, RD, CD > &D, const Eigen::Matrix< fvar< T >, RA, CA > &A, const Eigen::Matrix< fvar< T >, RB, CB > &B)
bool check_multiplicable(const char *function, const char *name1, const T1 &y1, const char *name2, const T2 &y2)
Return true if the matrices can be multiplied.
bool check_square(const char *function, const char *name, const Eigen::Matrix< T_y, Eigen::Dynamic, Eigen::Dynamic > &y)
Return true if the specified matrix is square.

     [ Stan Home Page ] © 2011–2016, Stan Development Team.