1 #ifndef STAN_MATH_REV_MAT_FUN_TRACE_INV_QUAD_FORM_LDLT_HPP 2 #define STAN_MATH_REV_MAT_FUN_TRACE_INV_QUAD_FORM_LDLT_HPP 8 #include <boost/utility/enable_if.hpp> 16 template <
typename T2,
int R2,
int C2,
typename T3,
int R3,
int C3>
17 class trace_inv_quad_form_ldlt_impl :
public chainable_alloc {
19 inline void initializeB(
const Eigen::Matrix<var, R3, C3> &B,
21 Eigen::Matrix<double, R3, C3> Bd(B.rows(), B.cols());
22 variB_.resize(B.rows(), B.cols());
23 for (
int j = 0; j < B.cols(); j++) {
24 for (
int i = 0; i < B.rows(); i++) {
25 variB_(i, j) = B(i, j).vi_;
26 Bd(i, j) = B(i, j).val();
31 C_.noalias() = Bd.transpose()*
AinvB_;
35 inline void initializeB(
const Eigen::Matrix<double, R3, C3> &B,
44 template<
int R1,
int C1>
45 inline void initializeD(
const Eigen::Matrix<var, R1, C1> &D) {
46 D_.resize(D.rows(), D.cols());
47 variD_.resize(D.rows(), D.cols());
48 for (
int j = 0; j < D.cols(); j++) {
49 for (
int i = 0; i < D.rows(); i++) {
50 variD_(i, j) = D(i, j).vi_;
51 D_(i, j) = D(i, j).val();
55 template<
int R1,
int C1>
56 inline void initializeD(
const Eigen::Matrix<double, R1, C1> &D) {
61 template<
typename T1,
int R1,
int C1>
62 trace_inv_quad_form_ldlt_impl(
const Eigen::Matrix<T1, R1, C1> &D,
63 const LDLT_factor<T2, R2, C2>
65 const Eigen::Matrix<T3, R3, C3> &B)
74 trace_inv_quad_form_ldlt_impl(
const LDLT_factor<T2, R2, C2>
76 const Eigen::Matrix<T3, R3, C3> &B)
79 initializeB(B,
false);
84 Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic>
D_;
85 Eigen::Matrix<vari*, Eigen::Dynamic, Eigen::Dynamic>
variD_;
87 Eigen::Matrix<double, R3, C3>
AinvB_;
88 Eigen::Matrix<double, C3, C3>
C_;
92 template <
typename T2,
int R2,
int C2,
typename T3,
int R3,
int C3>
93 class trace_inv_quad_form_ldlt_vari :
public vari {
98 trace_inv_quad_form_ldlt_impl<double, R2, C2, T3, R3, C3>
104 trace_inv_quad_form_ldlt_impl<T2, R2, C2, double, R3, C3>
111 trace_inv_quad_form_ldlt_impl<var, R2, C2, T3, R3, C3> *impl) {
112 Eigen::Matrix<double, R2, C2> aA;
114 if (impl->Dtype_ != 2)
115 aA.noalias() = -adj * (impl->AinvB_ * impl->D_.transpose()
116 * impl->AinvB_.transpose());
118 aA.noalias() = -adj*(impl->AinvB_ * impl->AinvB_.transpose());
120 for (
int j = 0; j < aA.cols(); j++)
121 for (
int i = 0; i < aA.rows(); i++)
122 impl->ldlt_.alloc_->variA_(i, j)->adj_ += aA(i, j);
127 trace_inv_quad_form_ldlt_impl<T2, R2, C2, var, R3, C3> *impl) {
128 Eigen::Matrix<double, R3, C3> aB;
130 if (impl->Dtype_ != 2)
131 aB.noalias() = adj*impl->AinvB_*(impl->D_ + impl->D_.transpose());
133 aB.noalias() = 2*adj*impl->AinvB_;
135 for (
int j = 0; j < aB.cols(); j++)
136 for (
int i = 0; i < aB.rows(); i++)
137 impl->variB_(i, j)->adj_ += aB(i, j);
141 explicit trace_inv_quad_form_ldlt_vari
142 (trace_inv_quad_form_ldlt_impl<T2, R2, C2, T3, R3, C3> *impl)
143 : vari(impl->value_),
impl_(impl)
146 virtual void chain() {
155 if (
impl_->Dtype_ == 1) {
156 for (
int j = 0; j <
impl_->variD_.cols(); j++)
157 for (
int i = 0; i <
impl_->variD_.rows(); i++)
158 impl_->variD_(i, j)->adj_ += adj_*
impl_->C_(i, j);
162 trace_inv_quad_form_ldlt_impl<T2, R2, C2, T3, R3, C3> *
impl_;
172 template <
typename T2,
int R2,
int C2,
typename T3,
int R3,
int C3>
174 boost::enable_if_c<stan::is_var<T2>::value ||
178 const Eigen::Matrix<T3, R3, C3> &B) {
183 trace_inv_quad_form_ldlt_impl<T2, R2, C2, T3, R3, C3> *
impl_ 184 =
new trace_inv_quad_form_ldlt_impl<T2, R2, C2, T3, R3, C3>(A, B);
186 return var(
new trace_inv_quad_form_ldlt_vari<T2, R2, C2, T3, R3, C3>
boost::enable_if_c<!stan::is_var< T1 >::value &&!stan::is_var< T2 >::value, typename boost::math::tools::promote_args< T1, T2 >::type >::type trace_inv_quad_form_ldlt(const LDLT_factor< T1, R2, C2 > &A, const Eigen::Matrix< T2, R3, C3 > &B)
Independent (input) and dependent (output) variables for gradients.
LDLT_factor is a thin wrapper on Eigen::LDLT to allow for reusing factorizations and efficient autodi...
void check_multiplicable(const char *function, const char *name1, const T1 &y1, const char *name2, const T2 &y2)
Check if the matrices can be multiplied.
T trace(const Eigen::Matrix< T, Eigen::Dynamic, Eigen::Dynamic > &m)
Returns the trace of the specified matrix.