1 #ifndef STAN__AGRAD__REV__MATRIX__QUAD_FORM_HPP
2 #define STAN__AGRAD__REV__MATRIX__QUAD_FORM_HPP
4 #include <boost/utility/enable_if.hpp>
5 #include <boost/type_traits.hpp>
19 template<
typename TA,
int RA,
int CA,
typename TB,
int RB,
int CB>
20 class quad_form_vari_alloc :
public chainable_alloc {
22 inline void compute(
const Eigen::Matrix<double,RA,CA> &A,
23 const Eigen::Matrix<double,RB,CB> &B)
26 Eigen::Matrix<double,CB,CB> Cd(B.transpose()*A*B);
27 for (j = 0; j <
C_.cols(); j++) {
28 for (i = 0; i <
C_.rows(); i++) {
30 C_(i,j) = var(
new vari(0.5*(Cd(i,j) + Cd(j,i)),
false));
33 C_(i,j) = var(
new vari(Cd(i,j),
false));
40 quad_form_vari_alloc(
const Eigen::Matrix<TA,RA,CA> &A,
41 const Eigen::Matrix<TB,RB,CB> &B,
42 bool symmetric =
false)
48 Eigen::Matrix<TA,RA,CA>
A_;
49 Eigen::Matrix<TB,RB,CB>
B_;
50 Eigen::Matrix<var,CB,CB>
C_;
54 template<
typename TA,
int RA,
int CA,
typename TB,
int RB,
int CB>
55 class quad_form_vari :
public vari {
57 inline void chainA(Eigen::Matrix<double,RA,CA> &A,
58 const Eigen::Matrix<double,RB,CB> &Bd,
59 const Eigen::Matrix<double,CB,CB> &adjC) {}
60 inline void chainB(Eigen::Matrix<double,RB,CB> &B,
61 const Eigen::Matrix<double,RA,CA> &Ad,
62 const Eigen::Matrix<double,RB,CB> &Bd,
63 const Eigen::Matrix<double,CB,CB> &adjC) {}
65 inline void chainA(Eigen::Matrix<var,RA,CA> &A,
66 const Eigen::Matrix<double,RB,CB> &Bd,
67 const Eigen::Matrix<double,CB,CB> &adjC)
70 Eigen::Matrix<double,RA,CA> adjA(Bd*adjC*Bd.transpose());
71 for (j = 0; j < A.cols(); j++) {
72 for (i = 0; i < A.rows(); i++) {
73 A(i,j).vi_->adj_ += adjA(i,j);
77 inline void chainB(Eigen::Matrix<var,RB,CB> &B,
78 const Eigen::Matrix<double,RA,CA> &Ad,
79 const Eigen::Matrix<double,RB,CB> &Bd,
80 const Eigen::Matrix<double,CB,CB> &adjC)
83 Eigen::Matrix<double,RA,CA> adjB(Ad*Bd*adjC.transpose() + Ad.transpose()*Bd*adjC);
84 for (j = 0; j < B.cols(); j++)
85 for (i = 0; i < B.rows(); i++)
86 B(i,j).vi_->adj_ += adjB(i,j);
89 inline void chainAB(Eigen::Matrix<TA,RA,CA> &A,
90 Eigen::Matrix<TB,RB,CB> &B,
91 const Eigen::Matrix<double,RA,CA> &Ad,
92 const Eigen::Matrix<double,RB,CB> &Bd,
93 const Eigen::Matrix<double,CB,CB> &adjC)
100 quad_form_vari(
const Eigen::Matrix<TA,RA,CA> &A,
101 const Eigen::Matrix<TB,RB,CB> &B,
102 bool symmetric =
false)
104 _impl =
new quad_form_vari_alloc<TA,RA,CA,TB,RB,CB>(A,B,symmetric);
107 virtual void chain() {
109 Eigen::Matrix<double,CB,CB> adjC(
_impl->C_.rows(),
_impl->C_.cols());
111 for (j = 0; j <
_impl->C_.cols(); j++)
112 for (i = 0; i <
_impl->C_.rows(); i++)
113 adjC(i,j) =
_impl->C_(i,j).vi_->adj_;
120 quad_form_vari_alloc<TA,RA,CA,TB,RB,CB> *
_impl;
124 template<
typename TA,
int RA,
int CA,
typename TB,
int RB,
int CB>
126 boost::enable_if_c< boost::is_same<TA,var>::value ||
127 boost::is_same<TB,var>::value,
128 Eigen::Matrix<var,CB,CB> >::type
130 const Eigen::Matrix<TB,RB,CB> &B)
136 quad_form_vari<TA,RA,CA,TB,RB,CB> *baseVari =
new quad_form_vari<TA,RA,CA,TB,RB,CB>(A,B);
138 return baseVari->_impl->C_;
140 template<
typename TA,
int RA,
int CA,
typename TB,
int RB>
142 boost::enable_if_c< boost::is_same<TA,var>::value ||
143 boost::is_same<TB,var>::value,
146 const Eigen::Matrix<TB,RB,1> &B)
152 quad_form_vari<TA,RA,CA,TB,RB,1> *baseVari =
new quad_form_vari<TA,RA,CA,TB,RB,1>(A,B);
154 return baseVari->_impl->C_(0,0);
157 template<
typename TA,
int RA,
int CA,
typename TB,
int RB,
int CB>
159 boost::enable_if_c< boost::is_same<TA,var>::value ||
160 boost::is_same<TB,var>::value,
161 Eigen::Matrix<var,CB,CB> >::type
163 const Eigen::Matrix<TB,RB,CB> &B)
170 quad_form_vari<TA,RA,CA,TB,RB,CB> *baseVari =
new quad_form_vari<TA,RA,CA,TB,RB,CB>(A,B,
true);
172 return baseVari->_impl->C_;
174 template<
typename TA,
int RA,
int CA,
typename TB,
int RB>
176 boost::enable_if_c< boost::is_same<TA,var>::value ||
177 boost::is_same<TB,var>::value,
180 const Eigen::Matrix<TB,RB,1> &B)
187 quad_form_vari<TA,RA,CA,TB,RB,1> *baseVari =
new quad_form_vari<TA,RA,CA,TB,RB,1>(A,B,
true);
189 return baseVari->_impl->C_(0,0);
T value_of(const fvar< T > &v)
Return the value of the specified variable.
bool check_symmetric(const char *function, const Eigen::Matrix< T_y, Eigen::Dynamic, Eigen::Dynamic > &y, const char *name, T_result *result)
Return true if the specified matrix is symmetric.
bool check_multiplicable(const char *function, const T1 &y1, const char *name1, const T2 &y2, const char *name2, T_result *result)
size_t cols(const Eigen::Matrix< T, R, C > &m)
boost::enable_if_c< boost::is_same< TA, var >::value||boost::is_same< TB, var >::value, Eigen::Matrix< var, CB, CB > >::type quad_form(const Eigen::Matrix< TA, RA, CA > &A, const Eigen::Matrix< TB, RB, CB > &B)
boost::enable_if_c< boost::is_same< TA, var >::value||boost::is_same< TB, var >::value, Eigen::Matrix< var, CB, CB > >::type quad_form_sym(const Eigen::Matrix< TA, RA, CA > &A, const Eigen::Matrix< TB, RB, CB > &B)
bool check_square(const char *function, const Eigen::Matrix< T_y, Eigen::Dynamic, Eigen::Dynamic > &y, const char *name, T_result *result)
Return true if the specified matrix is square.