Stan  2.5.0
probability, sampling & optimization
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
log_sum_exp.hpp
Go to the documentation of this file.
1 #ifndef STAN__AGRAD__REV__MATRIX__LOG_SUM_EXP_HPP
2 #define STAN__AGRAD__REV__MATRIX__LOG_SUM_EXP_HPP
3 
4 #include <stan/agrad/rev/var.hpp>
14 
15 namespace stan {
16  namespace agrad {
17 
18  namespace {
19 
20  // these function and the following class just translate
21  // log_sum_exp for std::vector for Eigen::Matrix
22 
23  template <int R, int C>
24  double log_sum_exp_as_double(const Eigen::Matrix<var,R,C>& x) {
25  using std::numeric_limits;
26  using std::exp;
27  using std::log;
28  double max = -numeric_limits<double>::infinity();
29  for (int i = 0; i < x.size(); ++i)
30  if (x(i) > max)
31  max = x(i).val();
32  double sum = 0.0;
33  for (int i = 0; i < x.size(); ++i)
34  if (x(i) != -numeric_limits<double>::infinity())
35  sum += exp(x(i).val() - max);
36  return max + log(sum);
37  }
38 
39  class log_sum_exp_matrix_vari : public op_matrix_vari {
40  public:
41  template <int R, int C>
42  log_sum_exp_matrix_vari(const Eigen::Matrix<var,R,C>& x) :
43  op_matrix_vari(log_sum_exp_as_double(x), x) {
44  }
45  void chain() {
46  for (size_t i = 0; i < size_; ++i) {
47  vis_[i]->adj_ += adj_ * calculate_chain(vis_[i]->val_, val_);
48  }
49  }
50  };
51  }
52 
58  template <int R, int C>
59  inline var log_sum_exp(const Eigen::Matrix<var,R,C>& x) {
60  return var(new log_sum_exp_matrix_vari(x));
61  }
62 
63  }
64 }
65 #endif
fvar< T > log_sum_exp(const fvar< T > &x1, const fvar< T > &x2)
Definition: log_sum_exp.hpp:15
double max(const double a, const double b)
Definition: max.hpp:7
fvar< T > sum(const Eigen::Matrix< fvar< T >, R, C > &m)
Definition: sum.hpp:14
double calculate_chain(const double &x, const double &val)
Independent (input) and dependent (output) variables for gradients.
Definition: var.hpp:27
size_t size_
Definition: dot_self.hpp:18
fvar< T > log(const fvar< T > &x)
Definition: log.hpp:15
fvar< T > exp(const fvar< T > &x)
Definition: exp.hpp:16

     [ Stan Home Page ] © 2011–2014, Stan Development Team.