Stan Math Library  2.15.0
reverse mode automatic differentiation
trace_inv_quad_form_ldlt.hpp
Go to the documentation of this file.
1 #ifndef STAN_MATH_REV_MAT_FUN_TRACE_INV_QUAD_FORM_LDLT_HPP
2 #define STAN_MATH_REV_MAT_FUN_TRACE_INV_QUAD_FORM_LDLT_HPP
3 
5 #include <stan/math/rev/core.hpp>
8 #include <boost/utility/enable_if.hpp>
11 
12 namespace stan {
13  namespace math {
14 
15  namespace {
16  template <typename T2, int R2, int C2, typename T3, int R3, int C3>
17  class trace_inv_quad_form_ldlt_impl : public chainable_alloc {
18  protected:
19  inline void initializeB(const Eigen::Matrix<var, R3, C3> &B,
20  bool haveD) {
21  Eigen::Matrix<double, R3, C3> Bd(B.rows(), B.cols());
22  variB_.resize(B.rows(), B.cols());
23  for (int j = 0; j < B.cols(); j++) {
24  for (int i = 0; i < B.rows(); i++) {
25  variB_(i, j) = B(i, j).vi_;
26  Bd(i, j) = B(i, j).val();
27  }
28  }
29  AinvB_ = ldlt_.solve(Bd);
30  if (haveD)
31  C_.noalias() = Bd.transpose()*AinvB_;
32  else
33  value_ = (Bd.transpose()*AinvB_).trace();
34  }
35  inline void initializeB(const Eigen::Matrix<double, R3, C3> &B,
36  bool haveD) {
37  AinvB_ = ldlt_.solve(B);
38  if (haveD)
39  C_.noalias() = B.transpose()*AinvB_;
40  else
41  value_ = (B.transpose()*AinvB_).trace();
42  }
43 
44  template<int R1, int C1>
45  inline void initializeD(const Eigen::Matrix<var, R1, C1> &D) {
46  D_.resize(D.rows(), D.cols());
47  variD_.resize(D.rows(), D.cols());
48  for (int j = 0; j < D.cols(); j++) {
49  for (int i = 0; i < D.rows(); i++) {
50  variD_(i, j) = D(i, j).vi_;
51  D_(i, j) = D(i, j).val();
52  }
53  }
54  }
55  template<int R1, int C1>
56  inline void initializeD(const Eigen::Matrix<double, R1, C1> &D) {
57  D_ = D;
58  }
59 
60  public:
61  template<typename T1, int R1, int C1>
62  trace_inv_quad_form_ldlt_impl(const Eigen::Matrix<T1, R1, C1> &D,
63  const LDLT_factor<T2, R2, C2>
64  &A,
65  const Eigen::Matrix<T3, R3, C3> &B)
66  : Dtype_(stan::is_var<T1>::value),
67  ldlt_(A) {
68  initializeB(B, true);
69  initializeD(D);
70 
71  value_ = (D_*C_).trace();
72  }
73 
74  trace_inv_quad_form_ldlt_impl(const LDLT_factor<T2, R2, C2>
75  &A,
76  const Eigen::Matrix<T3, R3, C3> &B)
77  : Dtype_(2),
78  ldlt_(A) {
79  initializeB(B, false);
80  }
81 
82  const int Dtype_; // 0 = double, 1 = var, 2 = missing
83  LDLT_factor<T2, R2, C2> ldlt_;
84  Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic> D_;
85  Eigen::Matrix<vari*, Eigen::Dynamic, Eigen::Dynamic> variD_;
86  Eigen::Matrix<vari*, R3, C3> variB_;
87  Eigen::Matrix<double, R3, C3> AinvB_;
88  Eigen::Matrix<double, C3, C3> C_;
89  double value_;
90  };
91 
92  template <typename T2, int R2, int C2, typename T3, int R3, int C3>
93  class trace_inv_quad_form_ldlt_vari : public vari {
94  protected:
95  static inline
96  void
97  chainA(double adj,
98  trace_inv_quad_form_ldlt_impl<double, R2, C2, T3, R3, C3>
99  *impl) {
100  }
101  static inline
102  void
103  chainB(double adj,
104  trace_inv_quad_form_ldlt_impl<T2, R2, C2, double, R3, C3>
105  *impl) {
106  }
107 
108  static inline
109  void
110  chainA(double adj,
111  trace_inv_quad_form_ldlt_impl<var, R2, C2, T3, R3, C3> *impl) {
112  Eigen::Matrix<double, R2, C2> aA;
113 
114  if (impl->Dtype_ != 2)
115  aA.noalias() = -adj * (impl->AinvB_ * impl->D_.transpose()
116  * impl->AinvB_.transpose());
117  else
118  aA.noalias() = -adj*(impl->AinvB_ * impl->AinvB_.transpose());
119 
120  for (int j = 0; j < aA.cols(); j++)
121  for (int i = 0; i < aA.rows(); i++)
122  impl->ldlt_.alloc_->variA_(i, j)->adj_ += aA(i, j);
123  }
124  static inline
125  void
126  chainB(double adj,
127  trace_inv_quad_form_ldlt_impl<T2, R2, C2, var, R3, C3> *impl) {
128  Eigen::Matrix<double, R3, C3> aB;
129 
130  if (impl->Dtype_ != 2)
131  aB.noalias() = adj*impl->AinvB_*(impl->D_ + impl->D_.transpose());
132  else
133  aB.noalias() = 2*adj*impl->AinvB_;
134 
135  for (int j = 0; j < aB.cols(); j++)
136  for (int i = 0; i < aB.rows(); i++)
137  impl->variB_(i, j)->adj_ += aB(i, j);
138  }
139 
140  public:
141  explicit trace_inv_quad_form_ldlt_vari
142  (trace_inv_quad_form_ldlt_impl<T2, R2, C2, T3, R3, C3> *impl)
143  : vari(impl->value_), impl_(impl)
144  { }
145 
146  virtual void chain() {
147  // F = trace(D * B' * inv(A) * B)
148  // aA = -aF * inv(A') * B * D' * B' * inv(A')
149  // aB = aF*(inv(A) * B * D + inv(A') * B * D')
150  // aD = aF*(B' * inv(A) * B)
151  chainA(adj_, impl_);
152 
153  chainB(adj_, impl_);
154 
155  if (impl_->Dtype_ == 1) {
156  for (int j = 0; j < impl_->variD_.cols(); j++)
157  for (int i = 0; i < impl_->variD_.rows(); i++)
158  impl_->variD_(i, j)->adj_ += adj_*impl_->C_(i, j);
159  }
160  }
161 
162  trace_inv_quad_form_ldlt_impl<T2, R2, C2, T3, R3, C3> *impl_;
163  };
164 
165  }
166 
172  template <typename T2, int R2, int C2, typename T3, int R3, int C3>
173  inline typename
174  boost::enable_if_c<stan::is_var<T2>::value ||
176  var>::type
178  const Eigen::Matrix<T3, R3, C3> &B) {
179  check_multiplicable("trace_inv_quad_form_ldlt",
180  "A", A,
181  "B", B);
182 
183  trace_inv_quad_form_ldlt_impl<T2, R2, C2, T3, R3, C3> *impl_
184  = new trace_inv_quad_form_ldlt_impl<T2, R2, C2, T3, R3, C3>(A, B);
185 
186  return var(new trace_inv_quad_form_ldlt_vari<T2, R2, C2, T3, R3, C3>
187  (impl_));
188  }
189 
190  }
191 }
192 #endif
boost::enable_if_c<!stan::is_var< T1 >::value &&!stan::is_var< T2 >::value, typename boost::math::tools::promote_args< T1, T2 >::type >::type trace_inv_quad_form_ldlt(const LDLT_factor< T1, R2, C2 > &A, const Eigen::Matrix< T2, R3, C3 > &B)
const int Dtype_
Eigen::Matrix< double, C3, C3 > C_
Independent (input) and dependent (output) variables for gradients.
Definition: var.hpp:30
double value_
Eigen::Matrix< double, Eigen::Dynamic, Eigen::Dynamic > D_
Eigen::Matrix< vari *, Eigen::Dynamic, Eigen::Dynamic > variD_
LDLT_factor is a thin wrapper on Eigen::LDLT to allow for reusing factorizations and efficient autodi...
Definition: LDLT_factor.hpp:63
void check_multiplicable(const char *function, const char *name1, const T1 &y1, const char *name2, const T2 &y2)
Check if the matrices can be multiplied.
LDLT_factor< T2, R2, C2 > ldlt_
Eigen::Matrix< double, R3, C3 > AinvB_
Eigen::Matrix< vari *, R3, C3 > variB_
T trace(const Eigen::Matrix< T, Eigen::Dynamic, Eigen::Dynamic > &m)
Returns the trace of the specified matrix.
Definition: trace.hpp:19
trace_inv_quad_form_ldlt_impl< T2, R2, C2, T3, R3, C3 > * impl_

     [ Stan Home Page ] © 2011–2016, Stan Development Team.