1 #ifndef STAN__AGRAD__REV__MATRIX__TRACE_QUAD_FORM_HPP
2 #define STAN__AGRAD__REV__MATRIX__TRACE_QUAD_FORM_HPP
4 #include <boost/utility/enable_if.hpp>
5 #include <boost/type_traits.hpp>
19 template<
typename TA,
int RA,
int CA,
typename TB,
int RB,
int CB>
20 class trace_quad_form_vari_alloc :
public chainable_alloc {
22 trace_quad_form_vari_alloc(
const Eigen::Matrix<TA,RA,CA> &A,
23 const Eigen::Matrix<TB,RB,CB> &B)
32 Eigen::Matrix<TA,RA,CA>
A_;
33 Eigen::Matrix<TB,RB,CB>
B_;
36 template<
typename TA,
int RA,
int CA,
typename TB,
int RB,
int CB>
37 class trace_quad_form_vari :
public vari {
39 static inline void chainA(Eigen::Matrix<double,RA,CA> &A,
40 const Eigen::Matrix<double,RB,CB> &Bd,
41 const double &adjC) {}
42 static inline void chainB(Eigen::Matrix<double,RB,CB> &B,
43 const Eigen::Matrix<double,RA,CA> &Ad,
44 const Eigen::Matrix<double,RB,CB> &Bd,
45 const double &adjC) {}
47 static inline void chainA(Eigen::Matrix<var,RA,CA> &A,
48 const Eigen::Matrix<double,RB,CB> &Bd,
52 Eigen::Matrix<double,RA,CA> adjA(adjC*Bd*Bd.transpose());
53 for (j = 0; j < A.cols(); j++)
54 for (i = 0; i < A.rows(); i++)
55 A(i,j).vi_->adj_ += adjA(i,j);
57 static inline void chainB(Eigen::Matrix<var,RB,CB> &B,
58 const Eigen::Matrix<double,RA,CA> &Ad,
59 const Eigen::Matrix<double,RB,CB> &Bd,
63 Eigen::Matrix<double,RA,CA> adjB(adjC*(Ad + Ad.transpose())*Bd);
64 for (j = 0; j < B.cols(); j++)
65 for (i = 0; i < B.rows(); i++)
66 B(i,j).vi_->adj_ += adjB(i,j);
69 inline void chainAB(Eigen::Matrix<TA,RA,CA> &A,
70 Eigen::Matrix<TB,RB,CB> &B,
71 const Eigen::Matrix<double,RA,CA> &Ad,
72 const Eigen::Matrix<double,RB,CB> &Bd,
81 trace_quad_form_vari(trace_quad_form_vari_alloc<TA,RA,CA,TB,RB,CB> *impl)
82 : vari(impl->compute()),
_impl(impl) { }
84 virtual void chain() {
90 trace_quad_form_vari_alloc<TA,RA,CA,TB,RB,CB> *
_impl;
94 template<
typename TA,
int RA,
int CA,
typename TB,
int RB,
int CB>
96 boost::enable_if_c< boost::is_same<TA,var>::value ||
97 boost::is_same<TB,var>::value,
100 const Eigen::Matrix<TB,RB,CB> &B)
106 trace_quad_form_vari_alloc<TA,RA,CA,TB,RB,CB> *baseVari =
new trace_quad_form_vari_alloc<TA,RA,CA,TB,RB,CB>(A,B);
108 return var(
new trace_quad_form_vari<TA,RA,CA,TB,RB,CB>(baseVari));
T value_of(const fvar< T > &v)
Return the value of the specified variable.
stan::agrad::fvar< T > trace_quad_form(const Eigen::Matrix< stan::agrad::fvar< T >, RA, CA > &A, const Eigen::Matrix< stan::agrad::fvar< T >, RB, CB > &B)
bool check_multiplicable(const char *function, const T1 &y1, const char *name1, const T2 &y2, const char *name2, T_result *result)
double trace_quad_form(const Eigen::Matrix< double, RA, CA > &A, const Eigen::Matrix< double, RB, CB > &B)
Compute trace(B^T A B).
Independent (input) and dependent (output) variables for gradients.
bool check_square(const char *function, const Eigen::Matrix< T_y, Eigen::Dynamic, Eigen::Dynamic > &y, const char *name, T_result *result)
Return true if the specified matrix is square.