Stan Math Library  2.12.0
reverse mode automatic differentiation
log_sum_exp.hpp
Go to the documentation of this file.
1 #ifndef STAN_MATH_REV_MAT_FUN_LOG_SUM_EXP_HPP
2 #define STAN_MATH_REV_MAT_FUN_LOG_SUM_EXP_HPP
3 
4 #include <stan/math/rev/core.hpp>
8 #include <limits>
9 
10 namespace stan {
11  namespace math {
12 
13  namespace {
14 
15  // these function and the following class just translate
16  // log_sum_exp for std::vector for Eigen::Matrix
17 
18  template <int R, int C>
19  double log_sum_exp_as_double(const Eigen::Matrix<var, R, C>& x) {
20  using std::numeric_limits;
21  using std::exp;
22  using std::log;
23  double max = -numeric_limits<double>::infinity();
24  for (int i = 0; i < x.size(); ++i)
25  if (x(i) > max)
26  max = x(i).val();
27  double sum = 0.0;
28  for (int i = 0; i < x.size(); ++i)
29  if (x(i) != -numeric_limits<double>::infinity())
30  sum += exp(x(i).val() - max);
31  return max + log(sum);
32  }
33 
34  class log_sum_exp_matrix_vari : public op_matrix_vari {
35  public:
36  template <int R, int C>
37  explicit log_sum_exp_matrix_vari(const Eigen::Matrix<var, R, C>& x) :
38  op_matrix_vari(log_sum_exp_as_double(x), x) {
39  }
40  void chain() {
41  for (size_t i = 0; i < size_; ++i) {
42  vis_[i]->adj_ += adj_ * calculate_chain(vis_[i]->val_, val_);
43  }
44  }
45  };
46  }
47 
53  template <int R, int C>
54  inline var log_sum_exp(const Eigen::Matrix<var, R, C>& x) {
55  return var(new log_sum_exp_matrix_vari(x));
56  }
57 
58  }
59 }
60 #endif
fvar< T > sum(const std::vector< fvar< T > > &m)
Return the sum of the entries of the specified standard vector.
Definition: sum.hpp:20
fvar< T > log(const fvar< T > &x)
Definition: log.hpp:14
Independent (input) and dependent (output) variables for gradients.
Definition: var.hpp:30
fvar< T > log_sum_exp(const std::vector< fvar< T > > &v)
Definition: log_sum_exp.hpp:13
fvar< T > exp(const fvar< T > &x)
Definition: exp.hpp:10
size_t size_
Definition: dot_self.hpp:18
int max(const std::vector< int > &x)
Returns the maximum coefficient in the specified column vector.
Definition: max.hpp:22
double calculate_chain(double x, double val)

     [ Stan Home Page ] © 2011–2016, Stan Development Team.