Stan  2.5.0
probability, sampling & optimization
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
trace_quad_form.hpp
Go to the documentation of this file.
1 #ifndef STAN__AGRAD__REV__MATRIX__TRACE_QUAD_FORM_HPP
2 #define STAN__AGRAD__REV__MATRIX__TRACE_QUAD_FORM_HPP
3 
4 #include <boost/utility/enable_if.hpp>
5 #include <boost/type_traits.hpp>
9 #include <stan/agrad/rev/var.hpp>
10 #include <stan/agrad/rev/vari.hpp>
15 
16 namespace stan {
17  namespace agrad {
18  namespace {
19  template<typename TA,int RA,int CA,typename TB,int RB,int CB>
20  class trace_quad_form_vari_alloc : public chainable_alloc {
21  public:
22  trace_quad_form_vari_alloc(const Eigen::Matrix<TA,RA,CA> &A,
23  const Eigen::Matrix<TB,RB,CB> &B)
24  : A_(A), B_(B)
25  { }
26 
27  double compute() {
29  value_of(B_));
30  }
31 
32  Eigen::Matrix<TA,RA,CA> A_;
33  Eigen::Matrix<TB,RB,CB> B_;
34  };
35 
36  template<typename TA,int RA,int CA,typename TB,int RB,int CB>
37  class trace_quad_form_vari : public vari {
38  protected:
39  static inline void chainA(Eigen::Matrix<double,RA,CA> &A,
40  const Eigen::Matrix<double,RB,CB> &Bd,
41  const double &adjC) {}
42  static inline void chainB(Eigen::Matrix<double,RB,CB> &B,
43  const Eigen::Matrix<double,RA,CA> &Ad,
44  const Eigen::Matrix<double,RB,CB> &Bd,
45  const double &adjC) {}
46 
47  static inline void chainA(Eigen::Matrix<var,RA,CA> &A,
48  const Eigen::Matrix<double,RB,CB> &Bd,
49  const double &adjC)
50  {
51  int i,j;
52  Eigen::Matrix<double,RA,CA> adjA(adjC*Bd*Bd.transpose());
53  for (j = 0; j < A.cols(); j++)
54  for (i = 0; i < A.rows(); i++)
55  A(i,j).vi_->adj_ += adjA(i,j);
56  }
57  static inline void chainB(Eigen::Matrix<var,RB,CB> &B,
58  const Eigen::Matrix<double,RA,CA> &Ad,
59  const Eigen::Matrix<double,RB,CB> &Bd,
60  const double &adjC)
61  {
62  int i,j;
63  Eigen::Matrix<double,RA,CA> adjB(adjC*(Ad + Ad.transpose())*Bd);
64  for (j = 0; j < B.cols(); j++)
65  for (i = 0; i < B.rows(); i++)
66  B(i,j).vi_->adj_ += adjB(i,j);
67  }
68 
69  inline void chainAB(Eigen::Matrix<TA,RA,CA> &A,
70  Eigen::Matrix<TB,RB,CB> &B,
71  const Eigen::Matrix<double,RA,CA> &Ad,
72  const Eigen::Matrix<double,RB,CB> &Bd,
73  const double &adjC)
74  {
75  chainA(A,Bd,adjC);
76  chainB(B,Ad,Bd,adjC);
77  }
78 
79 
80  public:
81  trace_quad_form_vari(trace_quad_form_vari_alloc<TA,RA,CA,TB,RB,CB> *impl)
82  : vari(impl->compute()), _impl(impl) { }
83 
84  virtual void chain() {
85  chainAB(_impl->A_, _impl->B_,
86  value_of(_impl->A_), value_of(_impl->B_),
87  adj_);
88  };
89 
90  trace_quad_form_vari_alloc<TA,RA,CA,TB,RB,CB> *_impl;
91  };
92  }
93 
94  template<typename TA,int RA,int CA,typename TB,int RB,int CB>
95  inline typename
96  boost::enable_if_c< boost::is_same<TA,var>::value ||
97  boost::is_same<TB,var>::value,
98  var >::type
99  trace_quad_form(const Eigen::Matrix<TA,RA,CA> &A,
100  const Eigen::Matrix<TB,RB,CB> &B)
101  {
102  stan::math::check_square("trace_quad_form(%1%)",A,"A",(double*)0);
103  stan::math::check_multiplicable("trace_quad_form(%1%)",A,"A",
104  B,"B",(double*)0);
105 
106  trace_quad_form_vari_alloc<TA,RA,CA,TB,RB,CB> *baseVari = new trace_quad_form_vari_alloc<TA,RA,CA,TB,RB,CB>(A,B);
107 
108  return var(new trace_quad_form_vari<TA,RA,CA,TB,RB,CB>(baseVari));
109  }
110  }
111 }
112 
113 #endif
T value_of(const fvar< T > &v)
Return the value of the specified variable.
Definition: value_of.hpp:16
Eigen::Matrix< TB, RB, CB > B_
stan::agrad::fvar< T > trace_quad_form(const Eigen::Matrix< stan::agrad::fvar< T >, RA, CA > &A, const Eigen::Matrix< stan::agrad::fvar< T >, RB, CB > &B)
bool check_multiplicable(const char *function, const T1 &y1, const char *name1, const T2 &y2, const char *name2, T_result *result)
double trace_quad_form(const Eigen::Matrix< double, RA, CA > &A, const Eigen::Matrix< double, RB, CB > &B)
Compute trace(B^T A B).
trace_quad_form_vari_alloc< TA, RA, CA, TB, RB, CB > * _impl
Independent (input) and dependent (output) variables for gradients.
Definition: var.hpp:27
bool check_square(const char *function, const Eigen::Matrix< T_y, Eigen::Dynamic, Eigen::Dynamic > &y, const char *name, T_result *result)
Return true if the specified matrix is square.
Eigen::Matrix< TA, RA, CA > A_

     [ Stan Home Page ] © 2011–2014, Stan Development Team.