Stan  2.5.0
probability, sampling & optimization
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
partials_vari.hpp
Go to the documentation of this file.
1 #ifndef STAN__AGRAD__PARTIALS_VARI_HPP
2 #define STAN__AGRAD__PARTIALS_VARI_HPP
3 
4 #include <stan/meta/traits.hpp>
5 #include <stan/agrad/rev/var.hpp>
7 
8 namespace stan {
9  namespace agrad {
10 
11  class partials_vari : public vari {
12  private:
13  const size_t N_;
14  vari** operands_;
15  double* partials_;
16  public:
17  partials_vari(double value,
18  size_t N,
19  vari** operands, double* partials)
20  : vari(value),
21  N_(N),
22  operands_(operands),
23  partials_(partials) { }
24  void chain() {
25  for (size_t n = 0; n < N_; ++n)
26  operands_[n]->adj_ += adj_ * partials_[n];
27  }
28  };
29 
30  namespace {
31  template<typename T>
32  T partials_to_var(double logp, size_t /* nvaris */,
33  agrad::vari** /* all_varis */,
34  double* /* all_partials */) {
35  return logp;
36  }
37  template<>
38  var partials_to_var<var>(double logp, size_t nvaris,
39  agrad::vari** all_varis,
40  double* all_partials) {
41  return var(new agrad::partials_vari(logp, nvaris, all_varis, all_partials));
42  }
43 
44  template<typename T,
45  bool is_vec = is_vector<T>::value,
46  bool is_const = is_constant_struct<T>::value>
47  struct set_varis {
48  inline size_t set(agrad::vari** /*varis*/, const T& /*x*/) {
49  return 0U;
50  }
51  };
52  template<typename T>
53  struct set_varis <T,true,false>{
54  inline size_t set(agrad::vari** varis, const T& x) {
55  for (size_t n = 0; n < length(x); n++)
56  varis[n] = x[n].vi_;
57  return length(x);
58  }
59  };
60  template<>
61  struct set_varis<agrad::var, false, false> {
62  inline size_t set(agrad::vari** varis, const agrad::var& x) {
63  varis[0] = x.vi_;
64  return (1);
65  }
66  };
67  }
68 
73  template<typename T1=double, typename T2=double, typename T3=double,
74  typename T4=double, typename T5=double, typename T6=double,
75  typename T_return_type=typename return_type<T1,T2,T3,T4,T5,T6>::type>
78  size_t nvaris;
80  double* all_partials;
81 
88 
89  OperandsAndPartials(const T1& x1=0, const T2& x2=0, const T3& x3=0,
90  const T4& x4=0, const T5& x5=0, const T6& x6=0)
91  : nvaris(!is_constant_struct<T1>::value * length(x1) +
92  !is_constant_struct<T2>::value * length(x2) +
93  !is_constant_struct<T3>::value * length(x3) +
94  !is_constant_struct<T4>::value * length(x4) +
95  !is_constant_struct<T5>::value * length(x5) +
96  !is_constant_struct<T6>::value * length(x6)),
97  all_varis((agrad::vari**)agrad::chainable::operator new(sizeof(agrad::vari*) * nvaris)),
98  all_partials((double*)agrad::chainable::operator new(sizeof(double) * nvaris)),
101  + (!is_constant_struct<T1>::value) * length(x1)),
103  + (!is_constant_struct<T1>::value) * length(x1)
104  + (!is_constant_struct<T2>::value) * length(x2)),
106  + (!is_constant_struct<T1>::value) * length(x1)
107  + (!is_constant_struct<T2>::value) * length(x2)
108  + (!is_constant_struct<T3>::value) * length(x3)),
110  + (!is_constant_struct<T1>::value) * length(x1)
111  + (!is_constant_struct<T2>::value) * length(x2)
112  + (!is_constant_struct<T3>::value) * length(x3)
113  + (!is_constant_struct<T4>::value) * length(x4)),
115  + (!is_constant_struct<T1>::value) * length(x1)
116  + (!is_constant_struct<T2>::value) * length(x2)
117  + (!is_constant_struct<T3>::value) * length(x3)
118  + (!is_constant_struct<T4>::value) * length(x4)
119  + (!is_constant_struct<T5>::value) * length(x5))
120  {
121  size_t base = 0;
123  base += set_varis<T1>().set(&all_varis[base], x1);
125  base += set_varis<T2>().set(&all_varis[base], x2);
127  base += set_varis<T3>().set(&all_varis[base], x3);
129  base += set_varis<T4>().set(&all_varis[base], x4);
131  base += set_varis<T5>().set(&all_varis[base], x5);
133  set_varis<T6>().set(&all_varis[base], x6);
134  std::fill(all_partials, all_partials+nvaris, 0);
135  }
136 
137  T_return_type
138  to_var(double logp) {
139  return partials_to_var<T_return_type>(logp, nvaris, all_varis, all_partials);
140  }
141  };
142 
143  }
144 }
145 
146 
147 #endif
Metaprogramming struct to detect whether a given type is constant in the mathematical sense (not the ...
Definition: traits.hpp:43
T_return_type to_var(double logp)
Abstract base class for variable implementations that handles memory management and applying the chai...
Definition: chainable.hpp:14
size_t length(const T &)
Definition: traits.hpp:159
OperandsAndPartials(const T1 &x1=0, const T2 &x2=0, const T3 &x3=0, const T4 &x4=0, const T5 &x5=0, const T6 &x6=0)
A variable implementation that stores operands and derivatives with respect to the variable...
Metaprogram to determine if a type has a base scalar type that can be assigned to type double...
Definition: traits.hpp:57
VectorView< double *, is_vector< T2 >::value, is_constant_struct< T2 >::value > d_x2
void chain()
Apply the chain rule to this variable based on the variables on which it depends. ...
void fill(T &x, const S &y)
Fill the specified container with the specified value.
Definition: fill.hpp:22
The variable implementation base class.
Definition: vari.hpp:28
VectorView< double *, is_vector< T6 >::value, is_constant_struct< T6 >::value > d_x6
VectorView< double *, is_vector< T4 >::value, is_constant_struct< T4 >::value > d_x4
VectorView< double *, is_vector< T5 >::value, is_constant_struct< T5 >::value > d_x5
VectorView< double *, is_vector< T1 >::value, is_constant_struct< T1 >::value > d_x1
VectorView< double *, is_vector< T3 >::value, is_constant_struct< T3 >::value > d_x3
partials_vari(double value, size_t N, vari **operands, double *partials)
VectorView is a template metaprogram that takes its argument and allows it to be used like a vector...
Definition: traits.hpp:275
double adj_
The adjoint of this variable, which is the partial derivative of this variable with respect to the ro...
Definition: vari.hpp:43

     [ Stan Home Page ] © 2011–2014, Stan Development Team.