Stan  2.5.0
probability, sampling & optimization
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
trace_inv_quad_form_ldlt.hpp
Go to the documentation of this file.
1 #ifndef STAN__AGRAD__REV__MATRIX__TRACE_INV_QUAD_FORM_LDLT_HPP
2 #define STAN__AGRAD__REV__MATRIX__TRACE_INV_QUAD_FORM_LDLT_HPP
3 
5 #include <stan/agrad/rev/var.hpp>
8 #include <boost/type_traits/is_same.hpp>
9 #include <boost/utility/enable_if.hpp>
11 
12 namespace stan {
13  namespace agrad {
14  namespace {
15  template <typename T2, int R2,int C2,typename T3,int R3,int C3>
16  class trace_inv_quad_form_ldlt_impl : public chainable_alloc {
17  protected:
18  inline void initializeB(const Eigen::Matrix<var,R3,C3> &B,bool haveD) {
19  Eigen::Matrix<double,R3,C3> Bd(B.rows(),B.cols());
20  _variB.resize(B.rows(),B.cols());
21  for (int j = 0; j < B.cols(); j++) {
22  for (int i = 0; i < B.rows(); i++) {
23  _variB(i,j) = B(i,j).vi_;
24  Bd(i,j) = B(i,j).val();
25  }
26  }
27  AinvB_ = _ldlt.solve(Bd);
28  if (haveD)
29  C_.noalias() = Bd.transpose()*AinvB_;
30  else
31  _value = (Bd.transpose()*AinvB_).trace();
32  }
33  inline void initializeB(const Eigen::Matrix<double,R3,C3> &B,bool haveD) {
34  AinvB_ = _ldlt.solve(B);
35  if (haveD)
36  C_.noalias() = B.transpose()*AinvB_;
37  else
38  _value = (B.transpose()*AinvB_).trace();
39  }
40 
41  template<int R1,int C1>
42  inline void initializeD(const Eigen::Matrix<var,R1,C1> &D) {
43  D_.resize(D.rows(),D.cols());
44  _variD.resize(D.rows(),D.cols());
45  for (int j = 0; j < D.cols(); j++) {
46  for (int i = 0; i < D.rows(); i++) {
47  _variD(i,j) = D(i,j).vi_;
48  D_(i,j) = D(i,j).val();
49  }
50  }
51  }
52  template<int R1,int C1>
53  inline void initializeD(const Eigen::Matrix<double,R1,C1> &D) {
54  D_ = D;
55  }
56 
57  public:
58  template<typename T1, int R1,int C1>
59  trace_inv_quad_form_ldlt_impl(const Eigen::Matrix<T1,R1,C1> &D,
61  const Eigen::Matrix<T3,R3,C3> &B)
62  : Dtype_(boost::is_same<T1,var>::value?1:0),
63  _ldlt(A)
64  {
65  initializeB(B,true);
66  initializeD(D);
67 
68  _value = (D_*C_).trace();
69  }
70 
71  trace_inv_quad_form_ldlt_impl(const stan::math::LDLT_factor<T2,R2,C2> &A,
72  const Eigen::Matrix<T3,R3,C3> &B)
73  : Dtype_(2),
74  _ldlt(A)
75  {
76  initializeB(B,false);
77  }
78 
79  const int Dtype_; // 0 = double, 1 = var, 2 = missing
81  Eigen::Matrix<double,Eigen::Dynamic,Eigen::Dynamic> D_;
82  Eigen::Matrix<vari*,Eigen::Dynamic,Eigen::Dynamic> _variD;
83  Eigen::Matrix<vari*,R3,C3> _variB;
84  Eigen::Matrix<double,R3,C3> AinvB_;
85  Eigen::Matrix<double,C3,C3> C_;
86  double _value;
87  };
88 
89  template <typename T2,int R2,int C2,typename T3,int R3,int C3>
90  class trace_inv_quad_form_ldlt_vari : public vari {
91  protected:
92  static inline void chainA(const double &adj,
93  trace_inv_quad_form_ldlt_impl<double,R2,C2,T3,R3,C3> *impl) {
94  }
95  static inline void chainB(const double &adj,
96  trace_inv_quad_form_ldlt_impl<T2,R2,C2,double,R3,C3> *impl) {
97  }
98 
99  static inline void chainA(const double &adj,
100  trace_inv_quad_form_ldlt_impl<var,R2,C2,T3,R3,C3> *impl) {
101  Eigen::Matrix<double,R2,C2> aA;
102 
103  if (impl->Dtype_ != 2)
104  aA.noalias() = -adj*(impl->AinvB_ * impl->D_.transpose() * impl->AinvB_.transpose());
105  else
106  aA.noalias() = -adj*(impl->AinvB_ * impl->AinvB_.transpose());
107 
108  for (int j = 0; j < aA.cols(); j++)
109  for (int i = 0; i < aA.rows(); i++)
110  impl->_ldlt._alloc->_variA(i,j)->adj_ += aA(i,j);
111  }
112  static inline void chainB(const double &adj,
113  trace_inv_quad_form_ldlt_impl<T2,R2,C2,var,R3,C3> *impl) {
114  Eigen::Matrix<double,R3,C3> aB;
115 
116  if (impl->Dtype_ != 2)
117  aB.noalias() = adj*impl->AinvB_*(impl->D_ + impl->D_.transpose());
118  else
119  aB.noalias() = 2*adj*impl->AinvB_;
120 
121  for (int j = 0; j < aB.cols(); j++)
122  for (int i = 0; i < aB.rows(); i++)
123  impl->_variB(i,j)->adj_ += aB(i,j);
124  }
125 
126  public:
127  trace_inv_quad_form_ldlt_vari(trace_inv_quad_form_ldlt_impl<T2,R2,C2,T3,R3,C3> *impl)
128  : vari(impl->_value), _impl(impl)
129  {}
130 
131  virtual void chain() {
132  // F = trace(D * B' * inv(A) * B)
133  // aA = -aF * inv(A') * B * D' * B' * inv(A')
134  // aB = aF*(inv(A) * B * D + inv(A') * B * D')
135  // aD = aF*(B' * inv(A) * B)
136  chainA(adj_, _impl);
137 
138  chainB(adj_, _impl);
139 
140  if (_impl->Dtype_ == 1) {
141  for (int j = 0; j < _impl->_variD.cols(); j++)
142  for (int i = 0; i < _impl->_variD.rows(); i++)
143  _impl->_variD(i,j)->adj_ += adj_*_impl->C_(i,j);
144  }
145  }
146 
147  trace_inv_quad_form_ldlt_impl<T2,R2,C2,T3,R3,C3> *_impl;
148  };
149 
150  }
151 
152 
158  template <typename T2,int R2,int C2,typename T3,int R3,int C3>
159  inline typename
160  boost::enable_if_c<boost::is_same<T2,var>::value ||
161  boost::is_same<T3,var>::value, var>::type
163  const Eigen::Matrix<T3,R3,C3> &B)
164  {
165  stan::math::check_multiplicable("trace_inv_quad_form_ldlt(%1%)",A,"A",
166  B,"B",(double*)0);
167 
168  trace_inv_quad_form_ldlt_impl<T2,R2,C2,T3,R3,C3> *_impl = new trace_inv_quad_form_ldlt_impl<T2,R2,C2,T3,R3,C3>(A,B);
169 
170  return var(new trace_inv_quad_form_ldlt_vari<T2,R2,C2,T3,R3,C3>(_impl));
171  }
172 
173  }
174 }
175 
176 #endif
const int Dtype_
trace_inv_quad_form_ldlt_impl< T2, R2, C2, T3, R3, C3 > * _impl
Eigen::Matrix< vari *, Eigen::Dynamic, Eigen::Dynamic > _variD
bool check_multiplicable(const char *function, const T1 &y1, const char *name1, const T2 &y2, const char *name2, T_result *result)
Eigen::Matrix< double, R3, C3 > AinvB_
Eigen::Matrix< vari *, R3, C3 > _variB
double _value
Independent (input) and dependent (output) variables for gradients.
Definition: var.hpp:27
Eigen::Matrix< double, Eigen::Dynamic, Eigen::Dynamic > D_
stan::math::LDLT_factor< T2, R2, C2 > _ldlt
T trace(const Eigen::Matrix< T, Eigen::Dynamic, Eigen::Dynamic > &m)
Returns the trace of the specified matrix.
Definition: trace.hpp:20
Eigen::Matrix< double, C3, C3 > C_
boost::enable_if_c< boost::is_same< T2, var >::value||boost::is_same< T3, var >::value, var >::type trace_inv_quad_form_ldlt(const stan::math::LDLT_factor< T2, R2, C2 > &A, const Eigen::Matrix< T3, R3, C3 > &B)
Compute the trace of an inverse quadratic form.

     [ Stan Home Page ] © 2011–2014, Stan Development Team.