Stan Math Library  2.11.0
reverse mode automatic differentiation
trace_inv_quad_form_ldlt.hpp
Go to the documentation of this file.
1 #ifndef STAN_MATH_REV_MAT_FUN_TRACE_INV_QUAD_FORM_LDLT_HPP
2 #define STAN_MATH_REV_MAT_FUN_TRACE_INV_QUAD_FORM_LDLT_HPP
3 
5 #include <stan/math/rev/core.hpp>
8 #include <boost/utility/enable_if.hpp>
11 
12 namespace stan {
13  namespace math {
14  namespace {
15  template <typename T2, int R2, int C2, typename T3, int R3, int C3>
16  class trace_inv_quad_form_ldlt_impl : public chainable_alloc {
17  protected:
18  inline void initializeB(const Eigen::Matrix<var, R3, C3> &B,
19  bool haveD) {
20  Eigen::Matrix<double, R3, C3> Bd(B.rows(), B.cols());
21  _variB.resize(B.rows(), B.cols());
22  for (int j = 0; j < B.cols(); j++) {
23  for (int i = 0; i < B.rows(); i++) {
24  _variB(i, j) = B(i, j).vi_;
25  Bd(i, j) = B(i, j).val();
26  }
27  }
28  AinvB_ = _ldlt.solve(Bd);
29  if (haveD)
30  C_.noalias() = Bd.transpose()*AinvB_;
31  else
32  _value = (Bd.transpose()*AinvB_).trace();
33  }
34  inline void initializeB(const Eigen::Matrix<double, R3, C3> &B,
35  bool haveD) {
36  AinvB_ = _ldlt.solve(B);
37  if (haveD)
38  C_.noalias() = B.transpose()*AinvB_;
39  else
40  _value = (B.transpose()*AinvB_).trace();
41  }
42 
43  template<int R1, int C1>
44  inline void initializeD(const Eigen::Matrix<var, R1, C1> &D) {
45  D_.resize(D.rows(), D.cols());
46  _variD.resize(D.rows(), D.cols());
47  for (int j = 0; j < D.cols(); j++) {
48  for (int i = 0; i < D.rows(); i++) {
49  _variD(i, j) = D(i, j).vi_;
50  D_(i, j) = D(i, j).val();
51  }
52  }
53  }
54  template<int R1, int C1>
55  inline void initializeD(const Eigen::Matrix<double, R1, C1> &D) {
56  D_ = D;
57  }
58 
59  public:
60  template<typename T1, int R1, int C1>
61  trace_inv_quad_form_ldlt_impl(const Eigen::Matrix<T1, R1, C1> &D,
63  &A,
64  const Eigen::Matrix<T3, R3, C3> &B)
65  : Dtype_(stan::is_var<T1>::value),
66  _ldlt(A) {
67  initializeB(B, true);
68  initializeD(D);
69 
70  _value = (D_*C_).trace();
71  }
72 
73  trace_inv_quad_form_ldlt_impl(const stan::math::LDLT_factor<T2, R2, C2>
74  &A,
75  const Eigen::Matrix<T3, R3, C3> &B)
76  : Dtype_(2),
77  _ldlt(A) {
78  initializeB(B, false);
79  }
80 
81  const int Dtype_; // 0 = double, 1 = var, 2 = missing
83  Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic> D_;
84  Eigen::Matrix<vari*, Eigen::Dynamic, Eigen::Dynamic> _variD;
85  Eigen::Matrix<vari*, R3, C3> _variB;
86  Eigen::Matrix<double, R3, C3> AinvB_;
87  Eigen::Matrix<double, C3, C3> C_;
88  double _value;
89  };
90 
91  template <typename T2, int R2, int C2, typename T3, int R3, int C3>
92  class trace_inv_quad_form_ldlt_vari : public vari {
93  protected:
94  static inline
95  void
96  chainA(const double &adj,
97  trace_inv_quad_form_ldlt_impl<double, R2, C2, T3, R3, C3>
98  *impl) {
99  }
100  static inline
101  void
102  chainB(const double &adj,
103  trace_inv_quad_form_ldlt_impl<T2, R2, C2, double, R3, C3>
104  *impl) {
105  }
106 
107  static inline
108  void
109  chainA(const double &adj,
110  trace_inv_quad_form_ldlt_impl<var, R2, C2, T3, R3, C3> *impl) {
111  Eigen::Matrix<double, R2, C2> aA;
112 
113  if (impl->Dtype_ != 2)
114  aA.noalias() = -adj * (impl->AinvB_ * impl->D_.transpose()
115  * impl->AinvB_.transpose());
116  else
117  aA.noalias() = -adj*(impl->AinvB_ * impl->AinvB_.transpose());
118 
119  for (int j = 0; j < aA.cols(); j++)
120  for (int i = 0; i < aA.rows(); i++)
121  impl->_ldlt._alloc->_variA(i, j)->adj_ += aA(i, j);
122  }
123  static inline
124  void
125  chainB(const double &adj,
126  trace_inv_quad_form_ldlt_impl<T2, R2, C2, var, R3, C3> *impl) {
127  Eigen::Matrix<double, R3, C3> aB;
128 
129  if (impl->Dtype_ != 2)
130  aB.noalias() = adj*impl->AinvB_*(impl->D_ + impl->D_.transpose());
131  else
132  aB.noalias() = 2*adj*impl->AinvB_;
133 
134  for (int j = 0; j < aB.cols(); j++)
135  for (int i = 0; i < aB.rows(); i++)
136  impl->_variB(i, j)->adj_ += aB(i, j);
137  }
138 
139  public:
140  explicit trace_inv_quad_form_ldlt_vari
141  (trace_inv_quad_form_ldlt_impl<T2, R2, C2, T3, R3, C3> *impl)
142  : vari(impl->_value), _impl(impl)
143  { }
144 
145  virtual void chain() {
146  // F = trace(D * B' * inv(A) * B)
147  // aA = -aF * inv(A') * B * D' * B' * inv(A')
148  // aB = aF*(inv(A) * B * D + inv(A') * B * D')
149  // aD = aF*(B' * inv(A) * B)
150  chainA(adj_, _impl);
151 
152  chainB(adj_, _impl);
153 
154  if (_impl->Dtype_ == 1) {
155  for (int j = 0; j < _impl->_variD.cols(); j++)
156  for (int i = 0; i < _impl->_variD.rows(); i++)
157  _impl->_variD(i, j)->adj_ += adj_*_impl->C_(i, j);
158  }
159  }
160 
161  trace_inv_quad_form_ldlt_impl<T2, R2, C2, T3, R3, C3> *_impl;
162  };
163 
164  }
165 
166 
172  template <typename T2, int R2, int C2, typename T3, int R3, int C3>
173  inline typename
174  boost::enable_if_c<stan::is_var<T2>::value ||
176  var>::type
178  const Eigen::Matrix<T3, R3, C3> &B) {
179  stan::math::check_multiplicable("trace_inv_quad_form_ldlt",
180  "A", A,
181  "B", B);
182 
183  trace_inv_quad_form_ldlt_impl<T2, R2, C2, T3, R3, C3> *_impl
184  = new trace_inv_quad_form_ldlt_impl<T2, R2, C2, T3, R3, C3>(A, B);
185 
186  return var(new trace_inv_quad_form_ldlt_vari<T2, R2, C2, T3, R3, C3>
187  (_impl));
188  }
189 
190  }
191 }
192 
193 #endif
Eigen::Matrix< vari *, Eigen::Dynamic, Eigen::Dynamic > _variD
const int Dtype_
Eigen::Matrix< double, C3, C3 > C_
double _value
boost::enable_if_c<!stan::is_var< T1 >::value &&!stan::is_var< T2 >::value, typename boost::math::tools::promote_args< T1, T2 >::type >::type trace_inv_quad_form_ldlt(const stan::math::LDLT_factor< T1, R2, C2 > &A, const Eigen::Matrix< T2, R3, C3 > &B)
Independent (input) and dependent (output) variables for gradients.
Definition: var.hpp:31
Eigen::Matrix< double, Eigen::Dynamic, Eigen::Dynamic > D_
bool check_multiplicable(const char *function, const char *name1, const T1 &y1, const char *name2, const T2 &y2)
Return true if the matrices can be multiplied.
stan::math::LDLT_factor< T2, R2, C2 > _ldlt
Eigen::Matrix< double, R3, C3 > AinvB_
T trace(const Eigen::Matrix< T, Eigen::Dynamic, Eigen::Dynamic > &m)
Returns the trace of the specified matrix.
Definition: trace.hpp:20
Eigen::Matrix< vari *, R3, C3 > _variB
trace_inv_quad_form_ldlt_impl< T2, R2, C2, T3, R3, C3 > * _impl

     [ Stan Home Page ] © 2011–2016, Stan Development Team.