Stan  2.5.0
probability, sampling & optimization
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
base_nuts.hpp
Go to the documentation of this file.
1 #ifndef STAN__MCMC__BASE__NUTS__BETA
2 #define STAN__MCMC__BASE__NUTS__BETA
3 
4 #include <math.h>
5 #include <boost/math/special_functions/fpclassify.hpp>
9 
10 namespace stan {
11 
12  namespace mcmc {
13 
14  struct nuts_util
15  {
16  // Constants through each recursion
17  double log_u;
18  double H0;
19  int sign;
20 
21  // Aggregators through each recursion
22  int n_tree;
23  double sum_prob;
24  bool criterion;
25 
26  };
27 
28  // The No-U-Turn Sampler (NUTS).
29 
30  template <class M, class P, template<class, class> class H,
31  template<class, class> class I, class BaseRNG>
32  class base_nuts: public base_hmc<M, P, H, I, BaseRNG>
33  {
34 
35  public:
36 
37  base_nuts(M &m, BaseRNG& rng, std::ostream* o, std::ostream* e):
38  base_hmc<M, P, H, I, BaseRNG>(m, rng, o, e),
40  {};
41 
42  ~base_nuts() {};
43 
44  void set_max_depth(const int d) {
45  if(d > 0)
46  max_depth_ = d;
47  }
48 
49  void set_max_delta(const double d) {
50  max_delta_ = d;
51  }
52 
53  int get_max_depth() { return this->max_depth_; }
54  double get_max_delta() { return this->max_delta_; }
55 
56  sample transition(sample& init_sample) {
57 
58  // Initialize the algorithm
59  this->sample_stepsize();
60 
61  nuts_util util;
62 
63  this->seed(init_sample.cont_params());
64 
65  this->hamiltonian_.sample_p(this->z_, this->rand_int_);
66  this->hamiltonian_.init(this->z_);
67 
68  ps_point z_plus(this->z_);
69  ps_point z_minus(z_plus);
70 
71  ps_point z_sample(z_plus);
72  ps_point z_propose(z_plus);
73 
74  int n_cont = init_sample.cont_params().size();
75 
76  Eigen::VectorXd rho_init = this->z_.p;
77  Eigen::VectorXd rho_plus(n_cont); rho_plus.setZero();
78  Eigen::VectorXd rho_minus(n_cont); rho_minus.setZero();
79 
80  util.H0 = this->hamiltonian_.H(this->z_);
81 
82  // Sample the slice variable
83  util.log_u = std::log(this->rand_uniform_());
84 
85  // Build a balanced binary tree until the NUTS criterion fails
86  util.criterion = true;
87  int n_valid = 0;
88 
89  this->depth_ = 0;
90  this->n_divergent_ = 0;
91 
92  util.n_tree = 0;
93  util.sum_prob = 0;
94 
95  while (util.criterion && (this->depth_ <= this->max_depth_) ) {
96 
97  // Randomly sample a direction in time
98  ps_point* z = 0;
99  Eigen::VectorXd* rho = 0;
100 
101  if (this->rand_uniform_() > 0.5) {
102  z = &z_plus;
103  rho = &rho_plus;
104  util.sign = 1;
105  }
106  else {
107  z = &z_minus;
108  rho = &rho_minus;
109  util.sign = -1;
110  }
111 
112  // And build a new subtree in that direction
113  this->z_.ps_point::operator=(*z);
114 
115  int n_valid_subtree = build_tree(depth_, *rho, 0, z_propose, util);
116  ++(this->depth_);
117 
118  *z = this->z_;
119 
120  // Metropolis-Hastings sample the fresh subtree
121  if (!util.criterion)
122  break;
123 
124  double subtree_prob = 0;
125 
126  if (n_valid) {
127  subtree_prob = static_cast<double>(n_valid_subtree) /
128  static_cast<double>(n_valid);
129  } else {
130  subtree_prob = n_valid_subtree ? 1 : 0;
131  }
132 
133  if (this->rand_uniform_() < subtree_prob)
134  z_sample = z_propose;
135 
136  n_valid += n_valid_subtree;
137 
138  // Check validity of completed tree
139  this->z_.ps_point::operator=(z_plus);
140  Eigen::VectorXd delta_rho = rho_minus + rho_init + rho_plus;
141 
142  util.criterion = compute_criterion(z_minus, this->z_, delta_rho);
143 
144  }
145 
146  this->n_leapfrog_ = util.n_tree;
147 
148  double accept_prob = util.sum_prob / static_cast<double>(util.n_tree);
149 
150  this->z_.ps_point::operator=(z_sample);
151  return sample(this->z_.q, - this->z_.V, accept_prob);
152 
153  }
154 
155  void write_sampler_param_names(std::ostream& o) {
156  o << "stepsize__,treedepth__,n_leapfrog__,n_divergent__,";
157  }
158 
159  void write_sampler_params(std::ostream& o) {
160  o << this->epsilon_ << "," << this->depth_ << ","
161  << this->n_leapfrog_ << "," << this->n_divergent_ << ",";
162  }
163 
164  void get_sampler_param_names(std::vector<std::string>& names) {
165  names.push_back("stepsize__");
166  names.push_back("treedepth__");
167  names.push_back("n_leapfrog__");
168  names.push_back("n_divergent__");
169  }
170 
171  void get_sampler_params(std::vector<double>& values) {
172  values.push_back(this->epsilon_);
173  values.push_back(this->depth_);
174  values.push_back(this->n_leapfrog_);
175  values.push_back(this->n_divergent_);
176  }
177 
178  virtual bool compute_criterion(ps_point& start, P& finish, Eigen::VectorXd& rho) = 0;
179 
180  // Returns number of valid points in the completed subtree
181  int build_tree(int depth, Eigen::VectorXd& rho,
182  ps_point* z_init_parent, ps_point& z_propose,
183  nuts_util& util)
184  {
185 
186  // Base case
187  if (depth == 0)
188  {
189 
190  this->integrator_.evolve(this->z_, this->hamiltonian_,
191  util.sign * this->epsilon_);
192 
193  rho += this->z_.p;
194 
195  if (z_init_parent) *z_init_parent = this->z_;
196  z_propose = this->z_;
197 
198  double h = this->hamiltonian_.H(this->z_);
199  if (boost::math::isnan(h)) h = std::numeric_limits<double>::infinity();
200 
201  util.criterion = util.log_u + (h - util.H0) < this->max_delta_;
202  if (!util.criterion) ++(this->n_divergent_);
203 
204  util.sum_prob += stan::math::min(1, std::exp(util.H0 - h));
205  util.n_tree += 1;
206 
207  return (util.log_u + (h - util.H0) < 0);
208 
209  }
210  // General recursion
211  else
212  {
213 
214  Eigen::VectorXd left_subtree_rho(rho.size()); left_subtree_rho.setZero();
215  ps_point z_init(this->z_);
216 
217  int n1 = build_tree(depth - 1, left_subtree_rho, &z_init, z_propose, util);
218 
219  if (z_init_parent) *z_init_parent = z_init;
220 
221  if (!util.criterion) return 0;
222 
223  Eigen::VectorXd right_subtree_rho(rho.size()); right_subtree_rho.setZero();
224  ps_point z_propose_right(z_init);
225 
226  int n2 = build_tree(depth - 1, right_subtree_rho, 0, z_propose_right, util);
227 
228  double accept_prob = static_cast<double>(n2) /
229  static_cast<double>(n1 + n2);
230 
231  if ( util.criterion && (this->rand_uniform_() < accept_prob) )
232  z_propose = z_propose_right;
233 
234  Eigen::VectorXd& subtree_rho = left_subtree_rho;
235  subtree_rho += right_subtree_rho;
236 
237  rho += subtree_rho;
238 
239  util.criterion &= compute_criterion(z_init, this->z_, subtree_rho);
240 
241  return n1 + n2;
242 
243  }
244 
245  }
246 
247  int depth_;
249  double max_delta_;
250 
253 
254  };
255 
256  } // mcmc
257 
258 } // stan
259 
260 
261 #endif
bool isnan(const stan::agrad::var &v)
Checks if the given number is NaN.
H< M, BaseRNG > hamiltonian_
Definition: base_hmc.hpp:135
virtual bool compute_criterion(ps_point &start, P &finish, Eigen::VectorXd &rho)=0
sample transition(sample &init_sample)
Definition: base_nuts.hpp:56
void get_sampler_param_names(std::vector< std::string > &names)
Definition: base_nuts.hpp:164
void write_sampler_param_names(std::ostream &o)
Definition: base_nuts.hpp:155
double cont_params(int k) const
Definition: sample.hpp:30
base_nuts(M &m, BaseRNG &rng, std::ostream *o, std::ostream *e)
Definition: base_nuts.hpp:37
double e()
Return the base of the natural logarithm.
Definition: constants.hpp:86
void set_max_depth(const int d)
Definition: base_nuts.hpp:44
void get_sampler_params(std::vector< double > &values)
Definition: base_nuts.hpp:171
double min(const double a, const double b)
Definition: min.hpp:7
void seed(const Eigen::VectorXd &q)
Definition: base_hmc.hpp:51
boost::uniform_01< BaseRNG & > rand_uniform_
Definition: base_hmc.hpp:140
int build_tree(int depth, Eigen::VectorXd &rho, ps_point *z_init_parent, ps_point &z_propose, nuts_util &util)
Definition: base_nuts.hpp:181
fvar< T > log(const fvar< T > &x)
Definition: log.hpp:15
fvar< T > exp(const fvar< T > &x)
Definition: exp.hpp:16
void write_sampler_params(std::ostream &o)
Definition: base_nuts.hpp:159
void set_max_delta(const double d)
Definition: base_nuts.hpp:49
void sample(stan::mcmc::base_mcmc *sampler, int num_warmup, int num_samples, int num_thin, int refresh, bool save, stan::io::mcmc_writer< Model, SampleRecorder, DiagnosticRecorder, MessageRecorder > &writer, stan::mcmc::sample &init_s, Model &model, RNG &base_rng, const std::string &prefix, const std::string &suffix, std::ostream &o, StartTransitionCallback &callback)
Definition: sample.hpp:13
I< H< M, BaseRNG >, P > integrator_
Definition: base_hmc.hpp:134

     [ Stan Home Page ] © 2011–2014, Stan Development Team.