Stan  2.5.0
probability, sampling & optimization
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
log_sum_exp.hpp
Go to the documentation of this file.
1 #ifndef STAN__AGRAD__REV__FUNCTIONS__LOG_SUM_EXP_HPP
2 #define STAN__AGRAD__REV__FUNCTIONS__LOG_SUM_EXP_HPP
3 
4 #include <stan/agrad/rev/var.hpp>
13 
14 namespace stan {
15  namespace agrad {
16 
17  namespace {
18  double log_sum_exp_as_double(const std::vector<var>& x) {
19  using std::numeric_limits;
20  using std::exp;
21  using std::log;
22  double max = -numeric_limits<double>::infinity();
23  for (size_t i = 0; i < x.size(); ++i)
24  if (x[i] > max)
25  max = x[i].val();
26  double sum = 0.0;
27  for (size_t i = 0; i < x.size(); ++i)
28  if (x[i] != -numeric_limits<double>::infinity())
29  sum += exp(x[i].val() - max);
30  return max + log(sum);
31  }
32 
33  class log_sum_exp_vv_vari : public op_vv_vari {
34  public:
35  log_sum_exp_vv_vari(vari* avi, vari* bvi) :
36  op_vv_vari(stan::math::log_sum_exp(avi->val_, bvi->val_),
37  avi, bvi) {
38  }
39  void chain() {
40  avi_->adj_ += adj_ * calculate_chain(avi_->val_, val_);
41  bvi_->adj_ += adj_ * calculate_chain(bvi_->val_, val_);
42  }
43  };
44  class log_sum_exp_vd_vari : public op_vd_vari {
45  public:
46  log_sum_exp_vd_vari(vari* avi, double b) :
47  op_vd_vari(stan::math::log_sum_exp(avi->val_, b),
48  avi, b) {
49  }
50  void chain() {
51  avi_->adj_ += adj_ * calculate_chain(avi_->val_, val_);
52  }
53  };
54  class log_sum_exp_dv_vari : public op_dv_vari {
55  public:
56  log_sum_exp_dv_vari(double a, vari* bvi) :
57  op_dv_vari(stan::math::log_sum_exp(a, bvi->val_),
58  a, bvi) {
59  }
60  void chain() {
61  bvi_->adj_ += adj_ * calculate_chain(bvi_->val_, val_);
62  }
63  };
64 
65  class log_sum_exp_vector_vari : public op_vector_vari {
66  public:
67  log_sum_exp_vector_vari(const std::vector<var>& x) :
68  op_vector_vari(log_sum_exp_as_double(x), x) {
69  }
70  void chain() {
71  for (size_t i = 0; i < size_; ++i) {
72  vis_[i]->adj_ += adj_ * calculate_chain(vis_[i]->val_, val_);
73  }
74  }
75  };
76  }
77 
81  inline var log_sum_exp(const stan::agrad::var& a,
82  const stan::agrad::var& b) {
83  return var(new log_sum_exp_vv_vari(a.vi_, b.vi_));
84  }
88  inline var log_sum_exp(const stan::agrad::var& a,
89  const double& b) {
90  return var(new log_sum_exp_vd_vari(a.vi_, b));
91  }
95  inline var log_sum_exp(const double& a,
96  const stan::agrad::var& b) {
97  return var(new log_sum_exp_dv_vari(a, b.vi_));
98  }
102  inline var log_sum_exp(const std::vector<var>& x) {
103  return var(new log_sum_exp_vector_vari(x));
104  }
105 
106  }
107 }
108 #endif
fvar< T > log_sum_exp(const fvar< T > &x1, const fvar< T > &x2)
Definition: log_sum_exp.hpp:15
double max(const double a, const double b)
Definition: max.hpp:7
fvar< T > sum(const Eigen::Matrix< fvar< T >, R, C > &m)
Definition: sum.hpp:14
vari * vi_
Pointer to the implementation of this variable.
Definition: var.hpp:40
double calculate_chain(const double &x, const double &val)
Independent (input) and dependent (output) variables for gradients.
Definition: var.hpp:27
size_t size_
Definition: dot_self.hpp:18
fvar< T > log(const fvar< T > &x)
Definition: log.hpp:15
fvar< T > exp(const fvar< T > &x)
Definition: exp.hpp:16

     [ Stan Home Page ] © 2011–2014, Stan Development Team.