Stan  2.5.0
probability, sampling & optimization
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
quad_form.hpp
Go to the documentation of this file.
1 #ifndef STAN__AGRAD__REV__MATRIX__QUAD_FORM_HPP
2 #define STAN__AGRAD__REV__MATRIX__QUAD_FORM_HPP
3 
4 #include <boost/utility/enable_if.hpp>
5 #include <boost/type_traits.hpp>
8 #include <stan/agrad/rev.hpp>
15 
16 namespace stan {
17  namespace agrad {
18  namespace {
19  template<typename TA,int RA,int CA,typename TB,int RB,int CB>
20  class quad_form_vari_alloc : public chainable_alloc {
21  private:
22  inline void compute(const Eigen::Matrix<double,RA,CA> &A,
23  const Eigen::Matrix<double,RB,CB> &B)
24  {
25  int i,j;
26  Eigen::Matrix<double,CB,CB> Cd(B.transpose()*A*B);
27  for (j = 0; j < C_.cols(); j++) {
28  for (i = 0; i < C_.rows(); i++) {
29  if (_sym) {
30  C_(i,j) = var(new vari(0.5*(Cd(i,j) + Cd(j,i)),false));
31  }
32  else {
33  C_(i,j) = var(new vari(Cd(i,j),false));
34  }
35  }
36  }
37  }
38 
39  public:
40  quad_form_vari_alloc(const Eigen::Matrix<TA,RA,CA> &A,
41  const Eigen::Matrix<TB,RB,CB> &B,
42  bool symmetric = false)
43  : A_(A), B_(B), C_(B_.cols(),B_.cols()), _sym(symmetric)
44  {
45  compute(value_of(A),value_of(B));
46  }
47 
48  Eigen::Matrix<TA,RA,CA> A_;
49  Eigen::Matrix<TB,RB,CB> B_;
50  Eigen::Matrix<var,CB,CB> C_;
51  bool _sym;
52  };
53 
54  template<typename TA,int RA,int CA,typename TB,int RB,int CB>
55  class quad_form_vari : public vari {
56  protected:
57  inline void chainA(Eigen::Matrix<double,RA,CA> &A,
58  const Eigen::Matrix<double,RB,CB> &Bd,
59  const Eigen::Matrix<double,CB,CB> &adjC) {}
60  inline void chainB(Eigen::Matrix<double,RB,CB> &B,
61  const Eigen::Matrix<double,RA,CA> &Ad,
62  const Eigen::Matrix<double,RB,CB> &Bd,
63  const Eigen::Matrix<double,CB,CB> &adjC) {}
64 
65  inline void chainA(Eigen::Matrix<var,RA,CA> &A,
66  const Eigen::Matrix<double,RB,CB> &Bd,
67  const Eigen::Matrix<double,CB,CB> &adjC)
68  {
69  int i,j;
70  Eigen::Matrix<double,RA,CA> adjA(Bd*adjC*Bd.transpose());
71  for (j = 0; j < A.cols(); j++) {
72  for (i = 0; i < A.rows(); i++) {
73  A(i,j).vi_->adj_ += adjA(i,j);
74  }
75  }
76  }
77  inline void chainB(Eigen::Matrix<var,RB,CB> &B,
78  const Eigen::Matrix<double,RA,CA> &Ad,
79  const Eigen::Matrix<double,RB,CB> &Bd,
80  const Eigen::Matrix<double,CB,CB> &adjC)
81  {
82  int i,j;
83  Eigen::Matrix<double,RA,CA> adjB(Ad*Bd*adjC.transpose() + Ad.transpose()*Bd*adjC);
84  for (j = 0; j < B.cols(); j++)
85  for (i = 0; i < B.rows(); i++)
86  B(i,j).vi_->adj_ += adjB(i,j);
87  }
88 
89  inline void chainAB(Eigen::Matrix<TA,RA,CA> &A,
90  Eigen::Matrix<TB,RB,CB> &B,
91  const Eigen::Matrix<double,RA,CA> &Ad,
92  const Eigen::Matrix<double,RB,CB> &Bd,
93  const Eigen::Matrix<double,CB,CB> &adjC)
94  {
95  chainA(A,Bd,adjC);
96  chainB(B,Ad,Bd,adjC);
97  }
98 
99  public:
100  quad_form_vari(const Eigen::Matrix<TA,RA,CA> &A,
101  const Eigen::Matrix<TB,RB,CB> &B,
102  bool symmetric = false)
103  : vari(0.0) {
104  _impl = new quad_form_vari_alloc<TA,RA,CA,TB,RB,CB>(A,B,symmetric);
105  }
106 
107  virtual void chain() {
108  int i,j;
109  Eigen::Matrix<double,CB,CB> adjC(_impl->C_.rows(),_impl->C_.cols());
110 
111  for (j = 0; j < _impl->C_.cols(); j++)
112  for (i = 0; i < _impl->C_.rows(); i++)
113  adjC(i,j) = _impl->C_(i,j).vi_->adj_;
114 
115  chainAB(_impl->A_, _impl->B_,
116  value_of(_impl->A_), value_of(_impl->B_),
117  adjC);
118  };
119 
120  quad_form_vari_alloc<TA,RA,CA,TB,RB,CB> *_impl;
121  };
122  }
123 
124  template<typename TA,int RA,int CA,typename TB,int RB,int CB>
125  inline typename
126  boost::enable_if_c< boost::is_same<TA,var>::value ||
127  boost::is_same<TB,var>::value,
128  Eigen::Matrix<var,CB,CB> >::type
129  quad_form(const Eigen::Matrix<TA,RA,CA> &A,
130  const Eigen::Matrix<TB,RB,CB> &B)
131  {
132  stan::math::check_square("quad_form(%1%)",A,"A",(double*)0);
133  stan::math::check_multiplicable("quad_form(%1%)",A,"A",
134  B,"B",(double*)0);
135 
136  quad_form_vari<TA,RA,CA,TB,RB,CB> *baseVari = new quad_form_vari<TA,RA,CA,TB,RB,CB>(A,B);
137 
138  return baseVari->_impl->C_;
139  }
140  template<typename TA,int RA,int CA,typename TB,int RB>
141  inline typename
142  boost::enable_if_c< boost::is_same<TA,var>::value ||
143  boost::is_same<TB,var>::value,
144  var >::type
145  quad_form(const Eigen::Matrix<TA,RA,CA> &A,
146  const Eigen::Matrix<TB,RB,1> &B)
147  {
148  stan::math::check_square("quad_form(%1%)",A,"A",(double*)0);
149  stan::math::check_multiplicable("quad_form(%1%)",A,"A",
150  B,"B",(double*)0);
151 
152  quad_form_vari<TA,RA,CA,TB,RB,1> *baseVari = new quad_form_vari<TA,RA,CA,TB,RB,1>(A,B);
153 
154  return baseVari->_impl->C_(0,0);
155  }
156 
157  template<typename TA,int RA,int CA,typename TB,int RB,int CB>
158  inline typename
159  boost::enable_if_c< boost::is_same<TA,var>::value ||
160  boost::is_same<TB,var>::value,
161  Eigen::Matrix<var,CB,CB> >::type
162  quad_form_sym(const Eigen::Matrix<TA,RA,CA> &A,
163  const Eigen::Matrix<TB,RB,CB> &B)
164  {
165  stan::math::check_square("quad_form(%1%)",A,"A",(double*)0);
166  stan::math::check_symmetric("quad_form_sym(%1%)",A,"A",(double*)0);
167  stan::math::check_multiplicable("quad_form_sym(%1%)",A,"A",
168  B,"B",(double*)0);
169 
170  quad_form_vari<TA,RA,CA,TB,RB,CB> *baseVari = new quad_form_vari<TA,RA,CA,TB,RB,CB>(A,B,true);
171 
172  return baseVari->_impl->C_;
173  }
174  template<typename TA,int RA,int CA,typename TB,int RB>
175  inline typename
176  boost::enable_if_c< boost::is_same<TA,var>::value ||
177  boost::is_same<TB,var>::value,
178  var >::type
179  quad_form_sym(const Eigen::Matrix<TA,RA,CA> &A,
180  const Eigen::Matrix<TB,RB,1> &B)
181  {
182  stan::math::check_square("quad_form(%1%)",A,"A",(double*)0);
183  stan::math::check_symmetric("quad_form_sym(%1%)",A,"A",(double*)0);
184  stan::math::check_multiplicable("quad_form_sym(%1%)",A,"A",
185  B,"B",(double*)0);
186 
187  quad_form_vari<TA,RA,CA,TB,RB,1> *baseVari = new quad_form_vari<TA,RA,CA,TB,RB,1>(A,B,true);
188 
189  return baseVari->_impl->C_(0,0);
190  }
191  }
192 }
193 
194 #endif
quad_form_vari_alloc< TA, RA, CA, TB, RB, CB > * _impl
Definition: quad_form.hpp:118
bool _sym
Definition: quad_form.hpp:51
Eigen::Matrix< var, CB, CB > C_
Definition: quad_form.hpp:50
T value_of(const fvar< T > &v)
Return the value of the specified variable.
Definition: value_of.hpp:16
bool check_symmetric(const char *function, const Eigen::Matrix< T_y, Eigen::Dynamic, Eigen::Dynamic > &y, const char *name, T_result *result)
Return true if the specified matrix is symmetric.
bool check_multiplicable(const char *function, const T1 &y1, const char *name1, const T2 &y2, const char *name2, T_result *result)
Eigen::Matrix< TB, RB, CB > B_
Definition: quad_form.hpp:49
Eigen::Matrix< TA, RA, CA > A_
Definition: quad_form.hpp:48
size_t cols(const Eigen::Matrix< T, R, C > &m)
Definition: cols.hpp:12
boost::enable_if_c< boost::is_same< TA, var >::value||boost::is_same< TB, var >::value, Eigen::Matrix< var, CB, CB > >::type quad_form(const Eigen::Matrix< TA, RA, CA > &A, const Eigen::Matrix< TB, RB, CB > &B)
Definition: quad_form.hpp:129
boost::enable_if_c< boost::is_same< TA, var >::value||boost::is_same< TB, var >::value, Eigen::Matrix< var, CB, CB > >::type quad_form_sym(const Eigen::Matrix< TA, RA, CA > &A, const Eigen::Matrix< TB, RB, CB > &B)
Definition: quad_form.hpp:162
bool check_square(const char *function, const Eigen::Matrix< T_y, Eigen::Dynamic, Eigen::Dynamic > &y, const char *name, T_result *result)
Return true if the specified matrix is square.

     [ Stan Home Page ] © 2011–2014, Stan Development Team.