1 #ifndef STAN__GM__AST_HPP
2 #define STAN__GM__AST_HPP
10 #include <boost/variant/recursive_variant.hpp>
112 bool is_user_defined(
const std::pair<std::string,function_signature_t>&
114 void add(
const std::string& name,
116 const std::vector<expr_type>& arg_types);
117 void add(
const std::string& name,
119 void add(
const std::string& name,
122 void add(
const std::string& name,
126 void add(
const std::string& name,
131 void add(
const std::string& name,
137 void add(
const std::string& name,
144 void add(
const std::string& name,
152 void add(
const std::string& name,
162 void add_unary(const::std::string& name);
167 const std::vector<expr_type>& sig_args);
169 const std::vector<expr_type>& args,
170 std::ostream& error_msgs);
172 const std::vector<expr_type>& args,
176 std::set<std::string>
key_set()
const;
180 std::map<std::string, std::vector<function_signature_t> > sigs_map_;
181 std::set<std::pair<std::string,function_signature_t> > user_defined_set_;
189 statements(
const std::vector<var_decl>& local_decl,
190 const std::vector<statement>& stmts);
217 typedef boost::variant<boost::recursive_wrapper<nil>,
218 boost::recursive_wrapper<int_literal>,
219 boost::recursive_wrapper<double_literal>,
220 boost::recursive_wrapper<array_literal>,
221 boost::recursive_wrapper<variable>,
222 boost::recursive_wrapper<integrate_ode>,
223 boost::recursive_wrapper<fun>,
224 boost::recursive_wrapper<index_op>,
225 boost::recursive_wrapper<binary_op>,
226 boost::recursive_wrapper<unary_op> >
257 typedef boost::variant<boost::recursive_wrapper<std::string>,
258 boost::recursive_wrapper<expression> >
293 std::vector<expression>
const&
dims);
355 fun(std::string
const& name,
356 std::vector<expression>
const& args);
360 size_t total_dims(
const std::vector<std::vector<expression> >& dimss);
363 size_t num_expr_dims,
364 size_t num_index_dims);
367 size_t num_index_dims);
372 std::vector<std::vector<expression> >
dimss_;
377 const std::vector<std::vector<expression> >& dimss);
389 const std::string&
op,
435 const std::vector<expression>&
dims,
440 typedef std::pair<base_var_decl,var_origin>
range_t;
442 bool exists(
const std::string& name)
const;
447 void add(
const std::string& name,
450 void remove(
const std::string& name);
452 std::map<std::string, range_t>
map_;
459 std::string
const& name,
460 std::vector<expression>
const&
dims);
468 std::string
const& name,
469 std::vector<expression>
const&
dims);
476 std::string
const& name,
477 std::vector<expression>
const&
dims);
484 std::string
const& name,
485 std::vector<expression>
const&
dims);
492 std::string
const& name,
493 std::vector<expression>
const&
dims);
500 std::string
const& name,
501 std::vector<expression>
const&
dims);
510 std::string
const& name,
511 std::vector<expression>
const&
dims);
520 std::string
const& name,
521 std::vector<expression>
const&
dims);
532 std::string
const& name,
533 std::vector<expression>
const&
dims);
542 std::string
const& name,
543 std::vector<expression>
const&
dims);
550 const std::string& name,
551 const std::vector<expression>&
dims);
558 std::string
const& name,
559 std::vector<expression>
const&
dims);
567 std::string
const& name,
568 std::vector<expression>
const&
dims);
571 struct name_vis :
public boost::static_visitor<std::string> {
593 typedef boost::variant<boost::recursive_wrapper<nil>,
594 boost::recursive_wrapper<int_var_decl>,
595 boost::recursive_wrapper<double_var_decl>,
596 boost::recursive_wrapper<vector_var_decl>,
597 boost::recursive_wrapper<row_vector_var_decl>,
598 boost::recursive_wrapper<matrix_var_decl>,
599 boost::recursive_wrapper<simplex_var_decl>,
600 boost::recursive_wrapper<unit_vector_var_decl>,
601 boost::recursive_wrapper<ordered_var_decl>,
602 boost::recursive_wrapper<positive_ordered_var_decl>,
603 boost::recursive_wrapper<cholesky_factor_var_decl>,
604 boost::recursive_wrapper<cholesky_corr_var_decl>,
605 boost::recursive_wrapper<cov_matrix_var_decl>,
606 boost::recursive_wrapper<corr_matrix_var_decl> >
631 std::string
name()
const;
635 typedef boost::variant<boost::recursive_wrapper<nil>,
636 boost::recursive_wrapper<assignment>,
637 boost::recursive_wrapper<sample>,
638 boost::recursive_wrapper<increment_log_prob_statement>,
639 boost::recursive_wrapper<expression>,
640 boost::recursive_wrapper<statements>,
641 boost::recursive_wrapper<for_statement>,
642 boost::recursive_wrapper<conditional_statement>,
643 boost::recursive_wrapper<while_statement>,
644 boost::recursive_wrapper<print_statement>,
645 boost::recursive_wrapper<reject_statement>,
646 boost::recursive_wrapper<return_statement>,
647 boost::recursive_wrapper<no_op_statement> >
692 std::ostream& error_msgs);
710 std::ostream& error_msgs);
772 const std::string& name);
779 const std::string& name,
780 const std::vector<arg_decl>& arg_decls,
798 std::pair<std::vector<var_decl>,std::vector<statement> >
801 std::pair<std::vector<var_decl>,std::vector<statement> >
807 const std::vector<var_decl>& data_decl,
808 const std::pair<std::vector<var_decl>,
809 std::vector<statement> >& derived_data_decl,
810 const std::vector<var_decl>& parameter_decl,
811 const std::pair<std::vector<var_decl>,
812 std::vector<statement> >& derived_decl,
814 const std::pair<std::vector<var_decl>,
815 std::vector<statement> >& generated_decl);
844 const std::vector<expression>& args);
847 const std::vector<expression>& params);
889 const std::string& failure_message,
890 std::ostream& error_msgs);
893 bool ends_with(
const std::string& suffix,
894 const std::string& s);
bool operator==(const expr_type &et) const
Placeholder struct for boost::variant default ctors.
expression & operator/=(const expression &rhs)
void generate_expression(const expression &e, std::ostream &o)
bool is_user_defined(const std::pair< std::string, function_signature_t > &name_sig)
bool is_no_op_statement() const
void add_unary(const ::std::string &name)
bool ends_with(const std::string &suffix, const std::string &s)
std::pair< std::vector< var_decl >, std::vector< statement > > derived_data_decl_
std::vector< expression > conditions_
void add(const std::string &name, const expr_type &result_type, const std::vector< expr_type > &arg_types)
std::vector< expression > dims_
base_expr_type base_type_
void add_nullary(const ::std::string &name)
increment_log_prob_statement()
const int parameter_origin
void add(const std::string &name, const base_var_decl &base_decl, const var_origin &vo)
bool is_user_defined_prob_function(const std::string &name, const expression &variate, const std::vector< expression > ¶ms)
bool is_defined(const std::string &name, const function_signature_t &sig)
static function_signatures & instance()
std::map< std::string, range_t > map_
boost::variant< boost::recursive_wrapper< std::string >, boost::recursive_wrapper< expression > > printable_t
std::vector< std::vector< expression > > dimss_
bool operator>=(const expr_type &et) const
bool is_primitive_int() const
void set_user_defined(const std::pair< std::string, function_signature_t > &name_sig)
bool is_primitive_double() const
std::pair< std::vector< var_decl >, std::vector< statement > > generated_decl_
std::vector< var_decl > data_decl_
void print_var_origin(std::ostream &o, const var_origin &vo)
const int function_argument_origin
bool is_user_defined(const std::string &name, const std::vector< expression > &args)
expression & operator-=(const expression &rhs)
Metaprogram to calculate the base scalar return type resulting from promoting all the scalar types of...
int get_signature_matches(const std::string &name, const std::vector< expr_type > &args, function_signature_t &signature)
expression & operator*=(const expression &rhs)
size_t total_dims(const std::vector< std::vector< expression > > &dimss)
const int void_function_argument_origin_rng
const int function_argument_origin_rng
bool has_rng_suffix(const std::string &s)
std::vector< expression > args_
void add_binary(const ::std::string &name)
bool operator()(const nil &e) const
bool is_ill_formed() const
size_t get_num_dims(const std::string &name) const
std::vector< expression > args_
base_expr_type base_type_
std::ostream & error_msgs_
std::vector< var_decl > parameter_decl_
base_expr_type type() const
bool has_non_param_var(const expression &e, const variable_map &var_map)
std::ostream & operator<<(std::ostream &o, const expr_type &et)
std::vector< statement > bodies_
int num_promotions(const std::vector< expr_type > &call_args, const std::vector< expr_type > &sig_args)
std::pair< std::vector< var_decl >, std::vector< statement > > derived_decl_
expr_type promote_primitive(const expr_type &et)
void dims(const T &x, std::vector< int > &result)
returns_type_vis(const expr_type &return_type, std::ostream &error_msgs)
boost::variant< boost::recursive_wrapper< nil >, boost::recursive_wrapper< int_literal >, boost::recursive_wrapper< double_literal >, boost::recursive_wrapper< array_literal >, boost::recursive_wrapper< variable >, boost::recursive_wrapper< integrate_ode >, boost::recursive_wrapper< fun >, boost::recursive_wrapper< index_op >, boost::recursive_wrapper< binary_op >, boost::recursive_wrapper< unary_op > > expression_t
std::vector< expression > dims_
contains_nonparam_var(const variable_map &var_map)
const int void_function_argument_origin_lp
bool has_var(const expression &e, const variable_map &var_map)
std::vector< arg_decl > arg_decls_
const int void_function_argument_origin
void add_quaternary(const ::std::string &name)
const int model_name_origin
bool has_lp_suffix(const std::string &s)
boost::variant< boost::recursive_wrapper< nil >, boost::recursive_wrapper< assignment >, boost::recursive_wrapper< sample >, boost::recursive_wrapper< increment_log_prob_statement >, boost::recursive_wrapper< expression >, boost::recursive_wrapper< statements >, boost::recursive_wrapper< for_statement >, boost::recursive_wrapper< conditional_statement >, boost::recursive_wrapper< while_statement >, boost::recursive_wrapper< print_statement >, boost::recursive_wrapper< reject_statement >, boost::recursive_wrapper< return_statement >, boost::recursive_wrapper< no_op_statement > > statement_t
array_literal & operator=(const array_literal &al)
std::set< std::string > key_set() const
std::string system_function_name_
positive_ordered_var_decl()
contains_var(const variable_map &var_map)
base_expr_type get_base_type(const std::string &name) const
bool operator()(const nil &e) const
const int transformed_data_origin
bool is_nil(const expression &e)
double e()
Return the base of the natural logarithm.
bool operator!=(const expr_type &et) const
std::vector< function_decl_def > function_decl_defs_
std::vector< printable > printables_
bool is_ill_formed() const
base_var_decl base_variable_declaration()
bool is_assignable(const expr_type &l_type, const expr_type &r_type, const std::string &failure_message, std::ostream &error_msgs)
expr_type infer_type_indexing(const base_expr_type &expr_base_type, size_t num_expr_dims, size_t num_index_dims)
bool operator()(const nil &x) const
const int function_argument_origin_lp
bool operator>(const expr_type &et) const
unary_op(char op, expression const &subject)
expression & operator+=(const expression &rhs)
expr_type operator()(const nil &e) const
std::vector< function_decl_def > decl_defs_
const int transformed_parameter_origin
bool operator<(const expr_type &et) const
cholesky_factor_var_decl()
std::vector< expression > args_
bool operator<=(const expr_type &et) const
double_literal & operator=(const double_literal &dl)
std::pair< base_var_decl, var_origin > range_t
std::string operator()(const nil &x) const
expr_type get_result_type(const std::string &name, const std::vector< expr_type > &args, std::ostream &error_msgs)
double dist(const std::vector< double > &x, const std::vector< double > &y)
bool exists(const std::string &name) const
std::ostream & write_base_expr_type(std::ostream &o, base_expr_type type)
std::vector< var_decl > local_decl_
std::vector< printable > printables_
expr_type expression_type() const
boost::variant< boost::recursive_wrapper< nil >, boost::recursive_wrapper< int_var_decl >, boost::recursive_wrapper< double_var_decl >, boost::recursive_wrapper< vector_var_decl >, boost::recursive_wrapper< row_vector_var_decl >, boost::recursive_wrapper< matrix_var_decl >, boost::recursive_wrapper< simplex_var_decl >, boost::recursive_wrapper< unit_vector_var_decl >, boost::recursive_wrapper< ordered_var_decl >, boost::recursive_wrapper< positive_ordered_var_decl >, boost::recursive_wrapper< cholesky_factor_var_decl >, boost::recursive_wrapper< cholesky_corr_var_decl >, boost::recursive_wrapper< cov_matrix_var_decl >, boost::recursive_wrapper< corr_matrix_var_decl > > var_decl_t
bool returns_type(const expr_type &return_type, const statement &statement, std::ostream &error_msgs)
void set_type(const base_expr_type &base_type, size_t num_dims)
const variable_map & var_map_
int_literal & operator=(const int_literal &il)
bool operator()(const nil &st) const
void add_ternary(const ::std::string &name)
std::vector< statement > statements_
bool is_primitive() const
var_origin get_origin(const std::string &name) const
const variable_map & var_map_
bool operator()(const nil &st) const
std::pair< expr_type, std::vector< expr_type > > function_signature_t