Stan  2.5.0
probability, sampling & optimization
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
trace_gen_quad_form.hpp
Go to the documentation of this file.
1 #ifndef STAN__AGRAD__REV__MATRIX__TRACE_GEN_QUAD_FORM_HPP
2 #define STAN__AGRAD__REV__MATRIX__TRACE_GEN_QUAD_FORM_HPP
3 
4 #include <boost/utility/enable_if.hpp>
5 #include <boost/type_traits.hpp>
9 #include <stan/agrad/rev/var.hpp>
10 #include <stan/agrad/rev/vari.hpp>
15 
16 namespace stan {
17  namespace agrad {
18  namespace {
19  template<typename TD,int RD,int CD,
20  typename TA,int RA,int CA,
21  typename TB,int RB,int CB>
22  class trace_gen_quad_form_vari_alloc : public chainable_alloc {
23  public:
24  trace_gen_quad_form_vari_alloc(const Eigen::Matrix<TD,RD,CD> &D,
25  const Eigen::Matrix<TA,RA,CA> &A,
26  const Eigen::Matrix<TB,RB,CB> &B)
27  : D_(D), A_(A), B_(B)
28  { }
29 
30  double compute() {
32  value_of(A_),
33  value_of(B_));
34  }
35 
36  Eigen::Matrix<TD,RD,CD> D_;
37  Eigen::Matrix<TA,RA,CA> A_;
38  Eigen::Matrix<TB,RB,CB> B_;
39  };
40 
41  template<typename TD,int RD,int CD,
42  typename TA,int RA,int CA,
43  typename TB,int RB,int CB>
44  class trace_gen_quad_form_vari : public vari {
45  protected:
46  static inline void computeAdjoints(const double &adj,
47  const Eigen::Matrix<double,RD,CD> &D,
48  const Eigen::Matrix<double,RA,CA> &A,
49  const Eigen::Matrix<double,RB,CB> &B,
50  Eigen::Matrix<var,RD,CD> *varD,
51  Eigen::Matrix<var,RA,CA> *varA,
52  Eigen::Matrix<var,RB,CB> *varB)
53  {
54  Eigen::Matrix<double,CA,CB> AtB;
55  Eigen::Matrix<double,RA,CB> BD;
56  if (varB || varA)
57  BD.noalias() = B*D;
58  if (varB || varD)
59  AtB.noalias() = A.transpose()*B;
60 
61  if (varB) {
62  Eigen::Matrix<double,RB,CB> adjB(adj*(A*BD + AtB*D.transpose()));
63  int i,j;
64  for (j = 0; j < B.cols(); j++)
65  for (i = 0; i < B.rows(); i++)
66  (*varB)(i,j).vi_->adj_ += adjB(i,j);
67  }
68  if (varA) {
69  Eigen::Matrix<double,RA,CA> adjA(adj*(B*BD.transpose()));
70  int i,j;
71  for (j = 0; j < A.cols(); j++)
72  for (i = 0; i < A.rows(); i++)
73  (*varA)(i,j).vi_->adj_ += adjA(i,j);
74  }
75  if (varD) {
76  Eigen::Matrix<double,RD,CD> adjD(adj*(B.transpose()*AtB));
77  int i,j;
78  for (j = 0; j < D.cols(); j++)
79  for (i = 0; i < D.rows(); i++)
80  (*varD)(i,j).vi_->adj_ += adjD(i,j);
81  }
82  }
83 
84 
85  public:
86  trace_gen_quad_form_vari(trace_gen_quad_form_vari_alloc<TD,RD,CD,TA,RA,CA,TB,RB,CB> *impl)
87  : vari(impl->compute()), _impl(impl) { }
88 
89  virtual void chain() {
90  computeAdjoints(adj_,
91  value_of(_impl->D_),
92  value_of(_impl->A_),
93  value_of(_impl->B_),
94  (Eigen::Matrix<var,RD,CD>*)(boost::is_same<TD,var>::value?(&_impl->D_):NULL),
95  (Eigen::Matrix<var,RA,CA>*)(boost::is_same<TA,var>::value?(&_impl->A_):NULL),
96  (Eigen::Matrix<var,RB,CB>*)(boost::is_same<TB,var>::value?(&_impl->B_):NULL));
97  }
98 
99  trace_gen_quad_form_vari_alloc<TD,RD,CD,TA,RA,CA,TB,RB,CB> *_impl;
100  };
101  }
102 
103  template<typename TD,int RD,int CD,
104  typename TA,int RA,int CA,
105  typename TB,int RB,int CB>
106  inline typename
107  boost::enable_if_c< boost::is_same<TD,var>::value ||
108  boost::is_same<TA,var>::value ||
109  boost::is_same<TB,var>::value,
110  var >::type
111  trace_gen_quad_form(const Eigen::Matrix<TD,RD,CD> &D,
112  const Eigen::Matrix<TA,RA,CA> &A,
113  const Eigen::Matrix<TB,RB,CB> &B)
114  {
115  stan::math::check_square("trace_gen_quad_form(%1%)",A,"A",(double*)0);
116  stan::math::check_square("trace_gen_quad_form(%1%)",D,"D",(double*)0);
117  stan::math::check_multiplicable("trace_gen_quad_form(%1%)",A,"A",
118  B,"B",(double*)0);
119  stan::math::check_multiplicable("trace_gen_quad_form(%1%)",B,"B",
120  D,"D",(double*)0);
121 
122  trace_gen_quad_form_vari_alloc<TD,RD,CD,TA,RA,CA,TB,RB,CB> *baseVari = new trace_gen_quad_form_vari_alloc<TD,RD,CD,TA,RA,CA,TB,RB,CB>(D,A,B);
123 
124  return var(new trace_gen_quad_form_vari<TD,RD,CD,TA,RA,CA,TB,RB,CB>(baseVari));
125  }
126  }
127 }
128 
129 #endif
fvar< T > trace_gen_quad_form(const Eigen::Matrix< fvar< T >, RD, CD > &D, const Eigen::Matrix< fvar< T >, RA, CA > &A, const Eigen::Matrix< fvar< T >, RB, CB > &B)
T value_of(const fvar< T > &v)
Return the value of the specified variable.
Definition: value_of.hpp:16
bool check_multiplicable(const char *function, const T1 &y1, const char *name1, const T2 &y2, const char *name2, T_result *result)
trace_gen_quad_form_vari_alloc< TD, RD, CD, TA, RA, CA, TB, RB, CB > * _impl
Eigen::Matrix< TB, RB, CB > B_
Eigen::Matrix< TD, RD, CD > D_
Eigen::Matrix< TA, RA, CA > A_
Independent (input) and dependent (output) variables for gradients.
Definition: var.hpp:27
double trace_gen_quad_form(const Eigen::Matrix< double, RD, CD > &D, const Eigen::Matrix< double, RA, CA > &A, const Eigen::Matrix< double, RB, CB > &B)
Compute trace(D B^T A B).
bool check_square(const char *function, const Eigen::Matrix< T_y, Eigen::Dynamic, Eigen::Dynamic > &y, const char *name, T_result *result)
Return true if the specified matrix is square.

     [ Stan Home Page ] © 2011–2014, Stan Development Team.