1 #ifndef __STAN__GM__AST_DEF_HPP__
2 #define __STAN__GM__AST_DEF_HPP__
4 #include <boost/variant/apply_visitor.hpp>
5 #include <boost/variant/recursive_variant.hpp>
58 : base_type_(base_type),
63 : base_type_(base_type),
71 return !(*
this == et);
139 if (sigs_ == 0)
return;
151 std::pair<std::string,
154 user_defined_set_.insert(name_sig);
160 return user_defined_set_.find(name_sig) != user_defined_set_.end();
165 const std::vector<function_signature_t> sigs = sigs_map_[name];
166 for (
size_t i = 0; i < sigs.size(); ++i)
167 if (sig.second == sigs[i].second)
173 const std::vector<expr_type>& arg_types) {
179 std::vector<expr_type> arg_types;
180 add(name,result_type,arg_types);
185 std::vector<expr_type> arg_types;
186 arg_types.push_back(arg_type);
187 add(name,result_type,arg_types);
193 std::vector<expr_type> arg_types;
194 arg_types.push_back(arg_type1);
195 arg_types.push_back(arg_type2);
196 add(name,result_type,arg_types);
203 std::vector<expr_type> arg_types;
204 arg_types.push_back(arg_type1);
205 arg_types.push_back(arg_type2);
206 arg_types.push_back(arg_type3);
207 add(name,result_type,arg_types);
215 std::vector<expr_type> arg_types;
216 arg_types.push_back(arg_type1);
217 arg_types.push_back(arg_type2);
218 arg_types.push_back(arg_type3);
219 arg_types.push_back(arg_type4);
220 add(name,result_type,arg_types);
229 std::vector<expr_type> arg_types;
230 arg_types.push_back(arg_type1);
231 arg_types.push_back(arg_type2);
232 arg_types.push_back(arg_type3);
233 arg_types.push_back(arg_type4);
234 arg_types.push_back(arg_type5);
235 add(name,result_type,arg_types);
245 std::vector<expr_type> arg_types;
246 arg_types.push_back(arg_type1);
247 arg_types.push_back(arg_type2);
248 arg_types.push_back(arg_type3);
249 arg_types.push_back(arg_type4);
250 arg_types.push_back(arg_type5);
251 arg_types.push_back(arg_type6);
252 add(name,result_type,arg_types);
263 std::vector<expr_type> arg_types;
264 arg_types.push_back(arg_type1);
265 arg_types.push_back(arg_type2);
266 arg_types.push_back(arg_type3);
267 arg_types.push_back(arg_type4);
268 arg_types.push_back(arg_type5);
269 arg_types.push_back(arg_type6);
270 arg_types.push_back(arg_type7);
271 add(name,result_type,arg_types);
289 const std::vector<expr_type>& call_args,
290 const std::vector<expr_type>& sig_args) {
291 if (call_args.size() != sig_args.size()) {
295 for (
size_t i = 0; i < call_args.size(); ++i) {
296 if (call_args[i] == sig_args[i]) {
298 }
else if (call_args[i].is_primitive_int()
299 && sig_args[i].is_primitive_double()) {
308 const std::vector<expr_type>& args,
311 std::vector<function_signature_t> signatures = sigs_map_[name];
313 size_t num_matches = 0;
314 for (
size_t i = 0; i < signatures.size(); ++i) {
315 signature = signatures[i];
317 if (promotions < 0)
continue;
318 size_t promotions_ui =
static_cast<size_t>(promotions);
319 if (promotions_ui < min_promotions) {
320 min_promotions = promotions_ui;
322 }
else if (promotions_ui == min_promotions) {
330 const std::string& name,
331 const std::vector<expr_type>& args,
332 std::ostream& error_msgs) {
333 std::vector<function_signature_t> signatures = sigs_map_[name];
334 size_t match_index = 0;
336 size_t num_matches = 0;
338 for (
size_t i = 0; i < signatures.size(); ++i) {
340 if (promotions < 0)
continue;
341 size_t promotions_ui =
static_cast<size_t>(promotions);
342 if (promotions_ui < min_promotions) {
343 min_promotions = promotions_ui;
346 }
else if (promotions_ui == min_promotions) {
351 if (num_matches == 1) {
352 return signatures[match_index].first;
353 }
else if (num_matches == 0) {
354 error_msgs <<
"no matches for function name=\"" << name <<
"\""
357 error_msgs << num_matches <<
" matches with "
358 << min_promotions <<
" integer promotions "
359 <<
"for function name=\"" << name <<
"\"" << std::endl;
361 for (
size_t i = 0; i < args.size(); ++i)
362 error_msgs <<
" arg " << i <<
" type=" << args[i] << std::endl;
364 error_msgs <<
"available function signatures for "
365 << name <<
":" << std::endl;
366 for (
size_t i = 0; i < signatures.size(); ++i) {
367 error_msgs << i <<
". " << name <<
"(";
368 for (
size_t j = 0; j < signatures[i].second.size(); ++j) {
369 if (j > 0) error_msgs <<
", ";
370 error_msgs << signatures[i].second[j];
372 error_msgs <<
") : " << signatures[i].first << std::endl;
376 function_signatures::function_signatures() {
377 #include <stan/gm/function_signatures.h>
379 std::set<std::string>
386 for (map<
string,vector<function_signature_t> >::const_iterator
387 it = sigs_map_.begin();
388 it != sigs_map_.end();
390 result.insert(it->first);
399 const std::string& name)
400 : arg_type_(arg_type),
404 std::vector<expression>
dims;
412 const std::string& name,
413 const std::vector<arg_decl>& arg_decls,
416 : return_type_(return_type),
418 arg_decls_(arg_decls),
425 : decl_defs_(decl_defs) {
429 std::ostream& error_msgs)
430 : return_type_(return_type),
431 error_msgs_(error_msgs) {
434 error_msgs_ <<
"Expecting return, found nil statement."
439 error_msgs_ <<
"Expecting return, found assignment statement."
444 error_msgs_ <<
"Expecting return, found sampling statement."
450 error_msgs_ <<
"Expecting return, found increment_log_prob statement."
455 error_msgs_ <<
"Expecting return, found increment_log_prob statement."
460 error_msgs_ <<
"Expecting return, found print statement."
465 error_msgs_ <<
"Expecting return, found reject statement."
470 error_msgs_ <<
"Expecting return, found no_op statement."
479 " statement sequence with empty body." )
497 error_msgs_ << (
"Expecting return, found conditional"
498 " without final else.")
502 for (
size_t i = 0; i < st.
bodies_.size(); ++i)
511 "Returned expression does not match return type",
517 std::ostream& error_msgs) {
518 if (return_type ==
VOID_T)
521 return boost::apply_visitor(vis,statement.
statement_);
528 const std::vector<statement>& stmts)
529 : local_decl_(local_decl),
572 return boost::apply_visitor(vis,
expr_);
593 : printable_(printable) { }
595 : printable_(printable.printable_) { }
598 : var_map_(var_map) {
610 for (
size_t i = 0; i < e.
args_.size(); ++i)
611 if (boost::apply_visitor(*
this,e.
args_[i].expr_))
622 for (
size_t i = 0; i < e.
args_.size(); ++i)
623 if (boost::apply_visitor(*
this,e.
args_[i].expr_))
629 return boost::apply_visitor(*
this, e.
y0_.
expr_)
634 return boost::apply_visitor(*
this,e.
expr_.
expr_);
637 return boost::apply_visitor(*
this,e.
left.
expr_)
647 || name ==
"append_col"
650 || name ==
"diagonal"
653 || name ==
"negative_infinity"
654 || name ==
"not_a_number"
655 || name ==
"append_row"
656 || name ==
"rep_matrix"
657 || name ==
"rep_row_vector"
658 || name ==
"rep_vector"
661 || name ==
"positive_infinity"
663 || name ==
"subtract"
666 || name ==
"to_vector"
667 || name ==
"to_row_vector"
668 || name ==
"to_matrix"
669 || name ==
"to_array_1d"
670 || name ==
"to_array_2d"
671 || name ==
"transpose"
678 return boost::apply_visitor(vis,e.
expr_);
682 : var_map_(var_map) {
694 for (
size_t i = 0; i < e.
args_.size(); ++i)
695 if (boost::apply_visitor(*
this,e.
args_[i].expr_))
706 return boost::apply_visitor(*
this, e.
y0_.
expr_)
712 for (
size_t i = 0; i < e.
args_.size(); ++i)
713 if (boost::apply_visitor(*
this,e.
args_[i].expr_))
717 for (
size_t i = 0; i < e.
args_.size(); ++i)
724 return boost::apply_visitor(*
this,e.
expr_.
expr_);
739 if (e.
op ==
"*" || e.
op ==
"/")
751 return boost::apply_visitor(vis,e.
expr_);
770 return boost::apply_visitor(ino,e.
expr_);
775 std::vector<expression>
const&
dims)
843 : system_function_name_(system_function_name),
855 std::vector<expression>
const& args)
865 size_t total_dims(
const std::vector<std::vector<expression> >& dimss) {
867 for (
size_t i = 0; i < dimss.size(); ++i)
868 total += dimss[i].
size();
874 size_t num_expr_dims,
875 size_t num_index_dims) {
876 if (num_index_dims <= num_expr_dims)
877 return expr_type(expr_base_type,num_expr_dims - num_index_dims);
878 if (num_index_dims == (num_expr_dims + 1)) {
884 if (num_index_dims == (num_expr_dims + 2))
893 size_t num_index_dims) {
902 const std::vector<std::vector<expression> >& dimss)
913 const std::string& op,
919 right.expression_type())) {
950 o <<
"transformed data";
954 o <<
"transformed parameter";
956 o <<
"generated quantities";
960 o <<
"function argument";
962 o <<
"function argument '_lp' suffixed";
964 o <<
"function argument '_rng' suffixed";
966 o <<
"void function argument";
968 o <<
"void function argument '_lp' suffixed";
970 o <<
"void function argument '_rng' suffixed";
972 o <<
"UNKNOWN ORIGIN=" << vo;
978 : base_type_(base_type) {
981 const std::vector<expression>&
dims,
985 base_type_(base_type) {
989 return map_.find(name) !=
map_.end();
993 throw std::invalid_argument(
"variable does not exist");
994 return map_.find(name)->second.first;
997 return get(name).base_type_;
1000 return get(name).dims_.size();
1004 throw std::invalid_argument(
"variable does not exist");
1005 return map_.find(name)->second.second;
1021 std::string
const& name,
1022 std::vector<expression>
const&
dims)
1034 std::string
const& name,
1035 std::vector<expression>
const&
dims)
1045 std::string
const& name,
1046 std::vector<expression>
const&
dims)
1056 std::string
const& name,
1057 std::vector<expression>
const&
dims)
1067 std::string
const& name,
1068 std::vector<expression>
const&
dims)
1078 std::string
const& name,
1079 std::vector<expression>
const&
dims)
1088 std::string
const& name,
1089 std::vector<expression>
const&
dims)
1098 std::string
const& name,
1099 std::vector<expression>
const&
dims)
1109 std::string
const& name,
1110 std::vector<expression>
const&
dims)
1123 std::string
const& name,
1124 std::vector<expression>
const&
dims)
1134 std::string
const& name,
1135 std::vector<expression>
const&
dims)
1143 std::string
const& name,
1144 std::vector<expression>
const&
dims)
1151 std::string
const& name,
1152 std::vector<expression>
const&
dims)
1294 : log_prob_(log_prob) {
1302 : variable_(variable),
1311 : condition_(condition),
1319 const std::vector<statement>& bodies)
1320 : conditions_(conditions),
1326 : return_value_(expr) {
1332 : printables_(printables) {
1338 : printables_(printables) {
1343 const std::vector<var_decl>& data_decl,
1344 const std::pair<std::vector<var_decl>,
1345 std::vector<statement> >& derived_data_decl,
1346 const std::vector<var_decl>& parameter_decl,
1347 const std::pair<std::vector<var_decl>,
1348 std::vector<statement> >& derived_decl,
1350 const std::pair<std::vector<var_decl>,
1351 std::vector<statement> >& generated_decl)
1352 : function_decl_defs_(function_decl_defs),
1353 data_decl_(data_decl),
1354 derived_data_decl_(derived_data_decl),
1355 parameter_decl_(parameter_decl),
1356 derived_decl_(derived_decl),
1358 generated_decl_(generated_decl) {
1382 : var_dims_(var_dims),
1424 const std::vector<expression>& args) {
1425 std::vector<expr_type> arg_types;
1426 for (
size_t i = 0; i < args.size(); ++i)
1427 arg_types.push_back(args[i].expression_type());
1434 std::pair<std::string, function_signature_t>
1435 name_sig(name, sig);
1441 const std::vector<expression>& params) {
1442 std::vector<expression> variate_params;
1443 variate_params.push_back(variate);
1444 for (
size_t i = 0; i < params.size(); ++i)
1445 variate_params.push_back(params[i]);
1455 const std::string& failure_message,
1456 std::ostream& error_msgs) {
1457 bool assignable =
true;
1460 error_msgs <<
"Mismatched array dimensions.";
1465 error_msgs <<
"Base type mismatch. ";
1468 error_msgs << failure_message
1470 <<
" LHS type = " << l_type
1471 <<
"; RHS type = " << r_type
1478 const std::string& s) {
1479 size_t idx = s.rfind(suffix);
1480 return idx != std::string::npos
1481 && idx == (s.size() - suffix.size());
bool operator==(const expr_type &et) const
Placeholder struct for boost::variant default ctors.
expression & operator/=(const expression &rhs)
bool is_user_defined(const std::pair< std::string, function_signature_t > &name_sig)
bool is_no_op_statement() const
void add_unary(const ::std::string &name)
bool ends_with(const std::string &suffix, const std::string &s)
std::vector< expression > conditions_
void add(const std::string &name, const expr_type &result_type, const std::vector< expr_type > &arg_types)
base_expr_type base_type_
void add_nullary(const ::std::string &name)
increment_log_prob_statement()
const int parameter_origin
void add(const std::string &name, const base_var_decl &base_decl, const var_origin &vo)
bool is_user_defined_prob_function(const std::string &name, const expression &variate, const std::vector< expression > ¶ms)
bool is_defined(const std::string &name, const function_signature_t &sig)
static function_signatures & instance()
std::map< std::string, range_t > map_
boost::variant< boost::recursive_wrapper< std::string >, boost::recursive_wrapper< expression > > printable_t
std::vector< std::vector< expression > > dimss_
bool operator>=(const expr_type &et) const
bool is_primitive_int() const
void set_user_defined(const std::pair< std::string, function_signature_t > &name_sig)
bool is_primitive_double() const
void print_var_origin(std::ostream &o, const var_origin &vo)
const int function_argument_origin
bool is_user_defined(const std::string &name, const std::vector< expression > &args)
expression & operator-=(const expression &rhs)
double max(const double a, const double b)
Metaprogram to calculate the base scalar return type resulting from promoting all the scalar types of...
int get_signature_matches(const std::string &name, const std::vector< expr_type > &args, function_signature_t &signature)
expression & operator*=(const expression &rhs)
size_t total_dims(const std::vector< std::vector< expression > > &dimss)
const int void_function_argument_origin_rng
const int function_argument_origin_rng
bool has_rng_suffix(const std::string &s)
std::vector< expression > args_
void add_binary(const ::std::string &name)
bool operator()(const nil &e) const
bool is_ill_formed() const
size_t get_num_dims(const std::string &name) const
std::vector< expression > args_
std::ostream & error_msgs_
base_expr_type type() const
bool has_non_param_var(const expression &e, const variable_map &var_map)
std::ostream & operator<<(std::ostream &o, const expr_type &et)
std::vector< statement > bodies_
bool is_linear_function(const std::string &name)
int num_promotions(const std::vector< expr_type > &call_args, const std::vector< expr_type > &sig_args)
expr_type promote_primitive(const expr_type &et)
void dims(const T &x, std::vector< int > &result)
returns_type_vis(const expr_type &return_type, std::ostream &error_msgs)
boost::variant< boost::recursive_wrapper< nil >, boost::recursive_wrapper< int_literal >, boost::recursive_wrapper< double_literal >, boost::recursive_wrapper< array_literal >, boost::recursive_wrapper< variable >, boost::recursive_wrapper< integrate_ode >, boost::recursive_wrapper< fun >, boost::recursive_wrapper< index_op >, boost::recursive_wrapper< binary_op >, boost::recursive_wrapper< unary_op > > expression_t
contains_nonparam_var(const variable_map &var_map)
const int void_function_argument_origin_lp
bool has_var(const expression &e, const variable_map &var_map)
const int void_function_argument_origin
void add_quaternary(const ::std::string &name)
const int model_name_origin
bool has_lp_suffix(const std::string &s)
boost::variant< boost::recursive_wrapper< nil >, boost::recursive_wrapper< assignment >, boost::recursive_wrapper< sample >, boost::recursive_wrapper< increment_log_prob_statement >, boost::recursive_wrapper< expression >, boost::recursive_wrapper< statements >, boost::recursive_wrapper< for_statement >, boost::recursive_wrapper< conditional_statement >, boost::recursive_wrapper< while_statement >, boost::recursive_wrapper< print_statement >, boost::recursive_wrapper< reject_statement >, boost::recursive_wrapper< return_statement >, boost::recursive_wrapper< no_op_statement > > statement_t
array_literal & operator=(const array_literal &al)
std::set< std::string > key_set() const
positive_ordered_var_decl()
contains_var(const variable_map &var_map)
base_expr_type get_base_type(const std::string &name) const
bool operator()(const nil &e) const
const int transformed_data_origin
bool is_nil(const expression &e)
double e()
Return the base of the natural logarithm.
bool operator!=(const expr_type &et) const
int size(const std::vector< T > &x)
bool is_ill_formed() const
void remove(const std::string &name)
base_var_decl base_variable_declaration()
bool is_assignable(const expr_type &l_type, const expr_type &r_type, const std::string &failure_message, std::ostream &error_msgs)
expr_type infer_type_indexing(const base_expr_type &expr_base_type, size_t num_expr_dims, size_t num_index_dims)
bool operator()(const nil &x) const
const int function_argument_origin_lp
bool operator>(const expr_type &et) const
unary_op(char op, expression const &subject)
expression & operator+=(const expression &rhs)
expr_type operator()(const nil &e) const
const int transformed_parameter_origin
bool operator<(const expr_type &et) const
cholesky_factor_var_decl()
bool operator<=(const expr_type &et) const
double_literal & operator=(const double_literal &dl)
std::pair< base_var_decl, var_origin > range_t
std::string operator()(const nil &x) const
expr_type get_result_type(const std::string &name, const std::vector< expr_type > &args, std::ostream &error_msgs)
double dist(const std::vector< double > &x, const std::vector< double > &y)
bool exists(const std::string &name) const
std::ostream & write_base_expr_type(std::ostream &o, base_expr_type type)
expr_type expression_type() const
boost::variant< boost::recursive_wrapper< nil >, boost::recursive_wrapper< int_var_decl >, boost::recursive_wrapper< double_var_decl >, boost::recursive_wrapper< vector_var_decl >, boost::recursive_wrapper< row_vector_var_decl >, boost::recursive_wrapper< matrix_var_decl >, boost::recursive_wrapper< simplex_var_decl >, boost::recursive_wrapper< unit_vector_var_decl >, boost::recursive_wrapper< ordered_var_decl >, boost::recursive_wrapper< positive_ordered_var_decl >, boost::recursive_wrapper< cholesky_factor_var_decl >, boost::recursive_wrapper< cholesky_corr_var_decl >, boost::recursive_wrapper< cov_matrix_var_decl >, boost::recursive_wrapper< corr_matrix_var_decl > > var_decl_t
base_var_decl get(const std::string &name) const
bool returns_type(const expr_type &return_type, const statement &statement, std::ostream &error_msgs)
void set_type(const base_expr_type &base_type, size_t num_dims)
const variable_map & var_map_
int_literal & operator=(const int_literal &il)
bool operator()(const nil &st) const
void add_ternary(const ::std::string &name)
std::vector< statement > statements_
bool is_primitive() const
var_origin get_origin(const std::string &name) const
const variable_map & var_map_
bool operator()(const nil &st) const
std::pair< expr_type, std::vector< expr_type > > function_signature_t