1 #ifndef STAN_MATH_REV_MAT_FUN_QUAD_FORM_HPP 2 #define STAN_MATH_REV_MAT_FUN_QUAD_FORM_HPP 4 #include <boost/utility/enable_if.hpp> 5 #include <boost/type_traits.hpp> 20 template <
typename TA,
int RA,
int CA,
typename TB,
int RB,
int CB>
21 class quad_form_vari_alloc :
public chainable_alloc {
23 inline void compute(
const Eigen::Matrix<double, RA, CA>& A,
24 const Eigen::Matrix<double, RB, CB>& B) {
25 Eigen::Matrix<double, CB, CB> Cd(B.transpose()*A*B);
26 for (
int j = 0; j <
C_.cols(); j++) {
27 for (
int i = 0; i <
C_.rows(); i++) {
29 C_(i, j) = var(
new vari(0.5*(Cd(i, j) + Cd(j, i)),
false));
31 C_(i, j) = var(
new vari(Cd(i, j),
false));
38 quad_form_vari_alloc(
const Eigen::Matrix<TA, RA, CA>& A,
39 const Eigen::Matrix<TB, RB, CB>& B,
40 bool symmetric =
false)
45 Eigen::Matrix<TA, RA, CA>
A_;
46 Eigen::Matrix<TB, RB, CB>
B_;
47 Eigen::Matrix<var, CB, CB>
C_;
51 template <
typename TA,
int RA,
int CA,
typename TB,
int RB,
int CB>
52 class quad_form_vari :
public vari {
54 inline void chainA(Eigen::Matrix<double, RA, CA>& A,
55 const Eigen::Matrix<double, RB, CB>& Bd,
56 const Eigen::Matrix<double, CB, CB>& adjC) {}
57 inline void chainB(Eigen::Matrix<double, RB, CB>& B,
58 const Eigen::Matrix<double, RA, CA>& Ad,
59 const Eigen::Matrix<double, RB, CB>& Bd,
60 const Eigen::Matrix<double, CB, CB>& adjC) {}
62 inline void chainA(Eigen::Matrix<var, RA, CA>& A,
63 const Eigen::Matrix<double, RB, CB>& Bd,
64 const Eigen::Matrix<double, CB, CB>& adjC) {
65 Eigen::Matrix<double, RA, CA> adjA(Bd*adjC*Bd.transpose());
66 for (
int j = 0; j < A.cols(); j++) {
67 for (
int i = 0; i < A.rows(); i++) {
68 A(i, j).vi_->adj_ += adjA(i, j);
72 inline void chainB(Eigen::Matrix<var, RB, CB>& B,
73 const Eigen::Matrix<double, RA, CA>& Ad,
74 const Eigen::Matrix<double, RB, CB>& Bd,
75 const Eigen::Matrix<double, CB, CB>& adjC) {
76 Eigen::Matrix<double, RA, CA> adjB(Ad * Bd * adjC.transpose()
77 + Ad.transpose()*Bd*adjC);
78 for (
int j = 0; j < B.cols(); j++)
79 for (
int i = 0; i < B.rows(); i++)
80 B(i, j).vi_->adj_ += adjB(i, j);
83 inline void chainAB(Eigen::Matrix<TA, RA, CA>& A,
84 Eigen::Matrix<TB, RB, CB>& B,
85 const Eigen::Matrix<double, RA, CA>& Ad,
86 const Eigen::Matrix<double, RB, CB>& Bd,
87 const Eigen::Matrix<double, CB, CB>& adjC) {
89 chainB(B, Ad, Bd, adjC);
93 quad_form_vari(
const Eigen::Matrix<TA, RA, CA>& A,
94 const Eigen::Matrix<TB, RB, CB>& B,
95 bool symmetric =
false)
98 =
new quad_form_vari_alloc<TA, RA, CA, TB, RB, CB>(A, B, symmetric);
101 virtual void chain() {
102 Eigen::Matrix<double, CB, CB> adjC(
impl_->C_.rows(),
105 for (
int j = 0; j <
impl_->C_.cols(); j++)
106 for (
int i = 0; i <
impl_->C_.rows(); i++)
107 adjC(i, j) =
impl_->C_(i, j).vi_->adj_;
114 quad_form_vari_alloc<TA, RA, CA, TB, RB, CB> *
impl_;
118 template <
typename TA,
int RA,
int CA,
typename TB,
int RB,
int CB>
120 boost::enable_if_c< boost::is_same<TA, var>::value ||
121 boost::is_same<TB, var>::value,
122 Eigen::Matrix<var, CB, CB> >::type
124 const Eigen::Matrix<TB, RB, CB>& B) {
130 quad_form_vari<TA, RA, CA, TB, RB, CB> *baseVari
131 =
new quad_form_vari<TA, RA, CA, TB, RB, CB>(A, B);
133 return baseVari->impl_->C_;
135 template <
typename TA,
int RA,
int CA,
typename TB,
int RB>
137 boost::enable_if_c< boost::is_same<TA, var>::value ||
138 boost::is_same<TB, var>::value,
141 const Eigen::Matrix<TB, RB, 1>& B) {
147 quad_form_vari<TA, RA, CA, TB, RB, 1> *baseVari
148 =
new quad_form_vari<TA, RA, CA, TB, RB, 1>(A, B);
150 return baseVari->impl_->C_(0, 0);
T value_of(const fvar< T > &v)
Return the value of the specified variable.
Independent (input) and dependent (output) variables for gradients.
int cols(const Eigen::Matrix< T, R, C > &m)
Return the number of columns in the specified matrix, vector, or row vector.
void check_multiplicable(const char *function, const char *name1, const T1 &y1, const char *name2, const T2 &y2)
Check if the matrices can be multiplied.
void check_square(const char *function, const char *name, const Eigen::Matrix< T_y, Eigen::Dynamic, Eigen::Dynamic > &y)
Check if the specified matrix is square.
Eigen::Matrix< T, CB, CB > quad_form(const Eigen::Matrix< T, RA, CA > &A, const Eigen::Matrix< T, RB, CB > &B)
Compute B^T A B.