1 #ifndef STAN__AGRAD__REV__MATRIX__DOT_PRODUCT_HPP
2 #define STAN__AGRAD__REV__MATRIX__DOT_PRODUCT_HPP
5 #include <boost/utility/enable_if.hpp>
6 #include <boost/type_traits.hpp>
22 struct dot_product_store_type;
25 struct dot_product_store_type<var> {
30 struct dot_product_store_type<double> {
34 template<
typename T1,
typename T2>
35 class dot_product_vari :
public vari {
37 typename dot_product_store_type<T1>::type
v1_;
38 typename dot_product_store_type<T2>::type
v2_;
41 inline static double var_dot(vari** v1, vari** v2,
43 Eigen::VectorXd vd1(length), vd2(length);
44 for (
size_t i = 0; i <
length; i++) {
51 inline static double var_dot(
const T1* v1,
const T2* v2,
54 Eigen::VectorXd vd1(length), vd2(length);
55 for (
size_t i = 0; i <
length; i++) {
62 template<
typename Derived1,
typename Derived2>
63 inline static double var_dot(
const Eigen::DenseBase<Derived1> &v1,
64 const Eigen::DenseBase<Derived2> &v2) {
67 Eigen::VectorXd vd1(v1.size()), vd2(v1.size());
68 for (
int i = 0; i < v1.size(); i++) {
74 inline void chain(vari** v1, vari** v2) {
75 for (
size_t i = 0; i <
length_; i++) {
76 v1[i]->adj_ += adj_ *
v2_[i]->val_;
77 v2[i]->adj_ += adj_ *
v1_[i]->val_;
80 inline void chain(
double* v1, vari** v2) {
81 for (
size_t i = 0; i <
length_; i++) {
82 v2[i]->adj_ += adj_ *
v1_[i];
85 inline void chain(vari** v1,
double* v2) {
86 for (
size_t i = 0; i <
length_; i++) {
87 v1[i]->adj_ += adj_ *
v2_[i];
90 inline void initialize(vari** &mem_v,
const var *
inv, vari **shared = NULL) {
93 for (
size_t i = 0; i <
length_; i++)
94 mem_v[i] = inv[i].vi_;
100 template<
typename Derived>
101 inline void initialize(vari** &mem_v,
const Eigen::DenseBase<Derived> &inv, vari **shared = NULL) {
102 if (shared == NULL) {
104 for (
size_t i = 0; i <
length_; i++)
105 mem_v[i] =
inv(i).vi_;
112 inline void initialize(
double* &mem_d,
const double *ind,
double *shared = NULL) {
113 if (shared == NULL) {
115 for (
size_t i = 0; i <
length_; i++)
122 template<
typename Derived>
123 inline void initialize(
double* &mem_d,
const Eigen::DenseBase<Derived> &ind,
double *shared = NULL) {
124 if (shared == NULL) {
126 for (
size_t i = 0; i <
length_; i++)
135 dot_product_vari(
typename dot_product_store_type<T1>::type v1,
136 typename dot_product_store_type<T2>::type v2,
138 : vari(var_dot(v1,v2,length)),
v1_(v1),
v2_(v2),
length_(length) {}
140 dot_product_vari(
const T1* v1,
const T2* v2,
size_t length,
141 dot_product_vari<T1,T2>* shared_v1 = NULL,
142 dot_product_vari<T1,T2>* shared_v2 = NULL) :
143 vari(var_dot(v1, v2, length)),
length_(length) {
144 if (shared_v1 == NULL) {
150 if (shared_v2 == NULL) {
157 template<
typename Derived1,
typename Derived2>
158 dot_product_vari(
const Eigen::DenseBase<Derived1> &v1,
159 const Eigen::DenseBase<Derived2> &v2,
160 dot_product_vari<T1,T2>* shared_v1 = NULL,
161 dot_product_vari<T1,T2>* shared_v2 = NULL) :
163 if (shared_v1 == NULL) {
169 if (shared_v2 == NULL) {
176 template<
int R1,
int C1,
int R2,
int C2>
177 dot_product_vari(
const Eigen::Matrix<T1,R1,C1> &v1,
178 const Eigen::Matrix<T2,R2,C2> &v2,
179 dot_product_vari<T1,T2>* shared_v1 = NULL,
180 dot_product_vari<T1,T2>* shared_v2 = NULL) :
182 if (shared_v1 == NULL) {
188 if (shared_v2 == NULL) {
195 virtual void chain() {
209 template<
typename T1,
int R1,
int C1,
typename T2,
int R2,
int C2>
211 typename boost::enable_if_c<boost::is_same<T1,var>::value ||
212 boost::is_same<T2,var>::value, var>::type
214 const Eigen::Matrix<T2, R2, C2>& v2) {
219 return var(
new dot_product_vari<T1,T2>(v1,v2));
229 template<
typename T1,
typename T2>
231 typename boost::enable_if_c<boost::is_same<T1,var>::value ||
232 boost::is_same<T2,var>::value, var>::type
234 return var(
new dot_product_vari<T1,T2>(v1, v2, length));
245 template<
typename T1,
typename T2>
247 typename boost::enable_if_c<boost::is_same<T1,var>::value ||
248 boost::is_same<T2,var>::value, var>::type
250 const std::vector<T2>& v2) {
253 return var(
new dot_product_vari<T1,T2>(&v1[0], &v2[0], v1.size()));
256 template<
typename T1,
int R1,
int C1,
typename T2,
int R2,
int C2>
258 typename boost::enable_if_c<boost::is_same<T1,var>::value ||
259 boost::is_same<T2,var>::value,
260 Eigen::Matrix<var, 1, C1> >::type
262 const Eigen::Matrix<T2, R2, C2>& v2) {
265 Eigen::Matrix<var, 1, C1> ret(1,v1.cols());
266 for (
size_type j = 0; j < v1.cols(); ++j) {
267 ret(j) =
var(
new dot_product_vari<T1,T2>(v1.col(j),v2.col(j)));
272 template<
typename T1,
int R1,
int C1,
typename T2,
int R2,
int C2>
274 typename boost::enable_if_c<boost::is_same<T1,var>::value ||
275 boost::is_same<T2,var>::value,
276 Eigen::Matrix<var, R1, 1> >::type
278 const Eigen::Matrix<T2, R2, C2>& v2) {
281 Eigen::Matrix<var, R1, 1> ret(v1.rows(),1);
282 for (
size_type j = 0; j < v1.rows(); ++j) {
283 ret(j) =
var(
new dot_product_vari<T1,T2>(v1.row(j),v2.row(j)));
memory::stack_alloc memalloc_
bool check_vector(const char *function, const Eigen::Matrix< T, R, C > &x, const char *name, T_result *result)
T value_of(const fvar< T > &v)
Return the value of the specified variable.
void initialize(T &x, const T &v)
fvar< T > inv(const fvar< T > &x)
double value_of(const T x)
Return the value of the specified scalar argument converted to a double value.
dot_product_store_type< T2 >::type v2_
Eigen::Matrix< double, Eigen::Dynamic, Eigen::Dynamic >::Index size_type
fvar< T > dot_product(const Eigen::Matrix< fvar< T >, R1, C1 > &v1, const Eigen::Matrix< fvar< T >, R2, C2 > &v2)
void * alloc(size_t len)
Return a newly allocated block of memory of the appropriate size managed by the stack allocator...
dot_product_store_type< T1 >::type v1_
Independent (input) and dependent (output) variables for gradients.
int size(const std::vector< T > &x)
Eigen::Matrix< fvar< T >, 1, C1 > columns_dot_product(const Eigen::Matrix< fvar< T >, R1, C1 > &v1, const Eigen::Matrix< fvar< T >, R2, C2 > &v2)
bool check_matching_sizes(const char *function, const T_y1 &y1, const char *name1, const T_y2 &y2, const char *name2, T_result *result)
Eigen::Matrix< fvar< T >, R1, 1 > rows_dot_product(const Eigen::Matrix< fvar< T >, R1, C1 > &v1, const Eigen::Matrix< fvar< T >, R2, C2 > &v2)