Stan  2.5.0
probability, sampling & optimization
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
stan_csv_reader.hpp
Go to the documentation of this file.
1 #ifndef STAN__IO__STAN_CSV_READER_HPP
2 #define STAN__IO__STAN_CSV_READER_HPP
3 
4 #include <istream>
5 #include <iostream>
6 #include <sstream>
7 #include <string>
8 #include <boost/algorithm/string.hpp>
9 #include <boost/lexical_cast.hpp>
10 #include <stan/math/matrix.hpp>
11 
12 namespace stan {
13  namespace io {
14 
15  // FIXME: should consolidate with the options from the command line in stan::gm
20 
21  std::string model;
22  std::string data;
23  std::string init;
24  size_t chain_id;
25  size_t seed;
27  size_t num_samples;
28  size_t num_warmup;
30  size_t thin;
32  std::string algorithm;
33  std::string engine;
34 
37  model(""), data(""), init(""),
38  chain_id(1), seed(0), random_seed(false),
39  num_samples(0), num_warmup(0), save_warmup(false), thin(0),
40  append_samples(false),
41  algorithm(""), engine("") {}
42  };
43 
45  double step_size;
46  Eigen::MatrixXd metric;
47 
49  : step_size(0), metric(0,0) {}
50  };
51 
52  struct stan_csv_timing {
53  double warmup;
54  double sampling;
55 
57  : warmup(0), sampling(0) {}
58  };
59 
60  struct stan_csv {
62  Eigen::Matrix<std::string, Eigen::Dynamic, 1> header;
64  Eigen::MatrixXd samples;
66  };
67 
72 
73  public:
74 
77 
78  static bool read_metadata(std::istream& in, stan_csv_metadata& metadata) {
79 
80  std::stringstream ss;
81  std::string line;
82 
83  if (in.peek() != '#')
84  return false;
85  while (in.peek() == '#') {
86  std::getline(in, line);
87  ss << line << '\n';
88  }
89  ss.seekg(std::ios_base::beg);
90 
91  char comment;
92  std::string lhs;
93 
94  std::string name;
95  std::string value;
96 
97  while (ss.good()) {
98 
99  ss >> comment;
100  std::getline(ss, lhs);
101 
102  size_t equal = lhs.find("=");
103  if (equal != std::string::npos) {
104  name = lhs.substr(0, equal);
105  boost::trim(name);
106  value = lhs.substr(equal + 2, lhs.size());
107  boost::replace_first(value, " (Default)", "");
108  } else {
109  continue;
110  }
111 
112  if (name.compare("stan_version_major") == 0) {
113  metadata.stan_version_major = boost::lexical_cast<int>(value);
114  } else if (name.compare("stan_version_minor") == 0) {
115  metadata.stan_version_minor = boost::lexical_cast<int>(value);
116  } else if (name.compare("stan_version_patch") == 0) {
117  metadata.stan_version_patch = boost::lexical_cast<int>(value);
118  } else if (name.compare("model") == 0) {
119  metadata.model = value;
120  } else if (name.compare("num_samples") == 0) {
121  metadata.num_samples = boost::lexical_cast<int>(value);
122  } else if (name.compare("num_warmup") == 0) {
123  metadata.num_warmup = boost::lexical_cast<int>(value);
124  } else if (name.compare("save_warmup") == 0) {
125  metadata.save_warmup = boost::lexical_cast<bool>(value);
126  } else if (name.compare("thin") == 0) {
127  metadata.thin = boost::lexical_cast<int>(value);
128  } else if (name.compare("chain_id") == 0) {
129  metadata.chain_id = boost::lexical_cast<int>(value);
130  } else if (name.compare("data") == 0) {
131  metadata.data = value;
132  } else if (name.compare("init") == 0) {
133  metadata.init = value;
134  boost::trim(metadata.init);
135  } else if (name.compare("seed") == 0) {
136  metadata.seed = boost::lexical_cast<unsigned int>(value);
137  metadata.random_seed = false;
138  } else if (name.compare("append_samples") == 0) {
139  metadata.append_samples = boost::lexical_cast<bool>(value);
140  } else if (name.compare("algorithm") == 0) {
141  metadata.algorithm = value;
142  } else if (name.compare("engine") == 0) {
143  metadata.engine = value;
144  }
145 
146  }
147 
148  if (ss.good() == true)
149  return false;
150 
151  return true;
152 
153  } // read_metadata
154 
155  static bool read_header(std::istream& in, Eigen::Matrix<std::string, Eigen::Dynamic, 1>& header) {
156 
157  std::string line;
158 
159  if (in.peek() != 'l')
160  return false;
161 
162  std::getline(in, line);
163  std::stringstream ss(line);
164 
165  header.resize(std::count(line.begin(), line.end(), ',') + 1);
166  int idx = 0;
167  while (ss.good()) {
168  std::string token;
169  std::getline(ss, token, ',');
170  boost::trim(token);
171 
172  int pos = token.find('.');
173  if (pos > 0) {
174  token.replace(pos, 1, "[");
175  std::replace(token.begin(), token.end(), '.', ',');
176  token += "]";
177  }
178  header(idx++) = token;
179  }
180  return true;
181  }
182 
183  static bool read_adaptation(std::istream& in, stan_csv_adaptation& adaptation) {
184 
185  std::stringstream ss;
186  std::string line;
187  int lines = 0;
188 
189  if (in.peek() != '#' || in.good() == false)
190  return false;
191 
192  while (in.peek() == '#') {
193  std::getline(in, line);
194  ss << line << std::endl;
195  lines++;
196  }
197  ss.seekg(std::ios_base::beg);
198 
199  char comment; // Buffer for comment indicator, #
200 
201  // Skip first two lines
202  std::getline(ss, line);
203 
204  // Stepsize
205  std::getline(ss, line, '=');
206  boost::trim(line);
207  ss >> adaptation.step_size;
208 
209  // Metric parameters
210  std::getline(ss, line);
211  std::getline(ss, line);
212  std::getline(ss, line);
213 
214  int rows = lines - 3;
215  int cols = std::count(line.begin(), line.end(), ',') + 1;
216  adaptation.metric.resize(rows, cols);
217 
218  for (int row = 0; row < rows; row++) {
219 
220  std::stringstream line_ss;
221  line_ss.str(line);
222  line_ss >> comment;
223 
224  for (int col = 0; col < cols; col++) {
225  std::string token;
226  std::getline(line_ss, token, ',');
227  boost::trim(token);
228  adaptation.metric(row, col) = boost::lexical_cast<double>(token);
229  }
230 
231  std::getline(ss, line); // Read in next line
232 
233  }
234 
235  if (ss.good())
236  return false;
237  else
238  return true;
239 
240  }
241 
242  static bool read_samples(std::istream& in, Eigen::MatrixXd& samples, stan_csv_timing& timing) {
243 
244  std::stringstream ss;
245  std::string line;
246 
247  int rows = 0;
248  int cols = -1;
249 
250  if (in.peek() == '#' || in.good() == false)
251  return false;
252 
253  while (in.good()) {
254 
255  bool comment_line = (in.peek() == '#');
256  bool empty_line = (in.peek() == '\n');
257 
258  std::getline(in, line);
259 
260  if (empty_line) continue;
261  if (!line.length()) break;
262 
263  if (comment_line) {
264 
265  if (line.find("(Warm-up)") != std::string::npos) {
266  int left = 17;
267  int right = line.find(" seconds");
268  timing.warmup += boost::lexical_cast<double>(line.substr(left, right - left));
269  } else if (line.find("(Sampling)") != std::string::npos) {
270  int left = 17;
271  int right = line.find(" seconds");
272  timing.sampling += boost::lexical_cast<double>(line.substr(left, right - left));
273  }
274 
275  }
276  else {
277 
278  ss << line << '\n';
279 
280  int current_cols = std::count(line.begin(), line.end(), ',') + 1;
281  if (cols == -1) {
282  cols = current_cols;
283  } else if (cols != current_cols) {
284  std::cout << "Error: expected " << cols << " columns, but found "
285  << current_cols << " instead for row " << rows + 1 << std::endl;
286  return false;
287  }
288  rows++;
289 
290  }
291 
292  in.peek();
293 
294  }
295 
296  ss.seekg(std::ios_base::beg);
297 
298  if (rows > 0) {
299  samples.resize(rows, cols);
300  for (int row = 0; row < rows; row++) {
301  std::getline(ss, line);
302  std::stringstream ls(line);
303  for (int col = 0; col < cols; col++) {
304  std::getline(ls, line, ',');
305  boost::trim(line);
306  //std::cout << "line: !" << line << "@" << std::endl;
307  samples(row, col) = boost::lexical_cast<double>(line);
308  //std::cout << "after" << std::endl << std::endl;
309  }
310  }
311  }
312  return true;
313  }
314 
319  static stan_csv parse(std::istream& in) {
320 
321  stan_csv data;
322 
323  if (!read_metadata(in, data.metadata)) {
324  std::cout << "Warning: non-fatal error reading metadata" << std::endl;
325  }
326 
327  if (!read_header(in, data.header)) {
328  std::cout << "Error: error reading header" << std::endl;
329  throw std::invalid_argument("Error with header of input file in parse");
330  }
331 
332  if (!read_adaptation(in, data.adaptation)) {
333  std::cout << "Warning: non-fatal error reading adapation data" << std::endl;
334  }
335 
336  data.timing.warmup = 0;
337  data.timing.sampling = 0;
338 
339  if (!read_samples(in, data.samples, data.timing)) {
340  std::cout << "Warning: non-fatal error reading samples" << std::endl;
341  }
342 
343  return data;
344 
345  }
346 
347  };
348 
349  } // io
350 
351 } // stan
352 
353 #endif
Eigen::MatrixXd samples
static bool read_samples(std::istream &in, Eigen::MatrixXd &samples, stan_csv_timing &timing)
Eigen::Matrix< T, 1, Eigen::Dynamic > row(const Eigen::Matrix< T, Eigen::Dynamic, Eigen::Dynamic > &m, size_t i)
Return the specified row of the specified matrix, using start-at-1 indexing.
Definition: row.hpp:26
static bool read_adaptation(std::istream &in, stan_csv_adaptation &adaptation)
Eigen::Matrix< T, Eigen::Dynamic, 1 > col(const Eigen::Matrix< T, Eigen::Dynamic, Eigen::Dynamic > &m, size_t j)
Return the specified column of the specified matrix using start-at-1 indexing.
Definition: col.hpp:24
static bool read_header(std::istream &in, Eigen::Matrix< std::string, Eigen::Dynamic, 1 > &header)
stan_csv_adaptation adaptation
size_t rows(const Eigen::Matrix< T, R, C > &m)
Definition: rows.hpp:12
size_t cols(const Eigen::Matrix< T, R, C > &m)
Definition: cols.hpp:12
static bool read_metadata(std::istream &in, stan_csv_metadata &metadata)
Reads from a Stan output csv file.
static stan_csv parse(std::istream &in)
Parses the file.
stan_csv_metadata metadata
Eigen::Matrix< std::string, Eigen::Dynamic, 1 > header
stan_csv_timing timing

     [ Stan Home Page ] © 2011–2014, Stan Development Team.