Stan Math Library  2.11.0
reverse mode automatic differentiation
OperandsAndPartials.hpp
Go to the documentation of this file.
1 #ifndef STAN_MATH_REV_SCAL_META_OPERANDSANDPARTIALS_HPP
2 #define STAN_MATH_REV_SCAL_META_OPERANDSANDPARTIALS_HPP
3 
8 #include <stan/math/rev/core.hpp>
9 
10 namespace stan {
11  namespace math {
12 
13  // These are helpers to the OperandsAndPartials specialization for
14  // stan::math::var
15  namespace {
16  class partials_vari : public vari {
17  private:
18  const size_t N_;
19  vari** operands_;
20  double* partials_;
21  public:
22  partials_vari(double value,
23  size_t N,
24  vari** operands, double* partials)
25  : vari(value),
26  N_(N),
27  operands_(operands),
28  partials_(partials) { }
29  void chain() {
30  for (size_t n = 0; n < N_; ++n)
31  operands_[n]->adj_ += adj_ * partials_[n];
32  }
33  };
34 
35  var partials_to_var(double logp, size_t nvaris,
36  vari** all_varis,
37  double* all_partials) {
38  return var(new partials_vari(logp, nvaris, all_varis,
39  all_partials));
40  }
41 
42  template<typename T,
43  bool is_vec = is_vector<T>::value,
44  bool is_const = is_constant_struct<T>::value>
45  struct set_varis {
46  inline size_t set(vari** /*varis*/, const T& /*x*/) {
47  return 0U;
48  }
49  };
50  template<typename T>
51  struct set_varis<T, true, false> {
52  inline size_t set(vari** varis, const T& x) {
53  for (size_t n = 0; n < length(x); n++)
54  varis[n] = x[n].vi_;
55  return length(x);
56  }
57  };
58  template<>
59  struct set_varis<var, false, false> {
60  inline size_t set(vari** varis, const var& x) {
61  varis[0] = x.vi_;
62  return (1);
63  }
64  };
65  }
66 
89  template<typename T1, typename T2, typename T3,
90  typename T4, typename T5, typename T6>
91  struct OperandsAndPartials<T1, T2, T3, T4, T5, T6, stan::math::var> {
92  size_t nvaris;
94  double* all_partials;
95 
96  VectorView<double,
99  VectorView<double,
102  VectorView<double,
105  VectorView<double,
108  VectorView<double,
111  VectorView<double,
114 
125  OperandsAndPartials(const T1& x1 = 0, const T2& x2 = 0, const T3& x3 = 0,
126  const T4& x4 = 0, const T5& x5 = 0, const T6& x6 = 0)
127  : nvaris(!is_constant_struct<T1>::value * length(x1) +
128  !is_constant_struct<T2>::value * length(x2) +
129  !is_constant_struct<T3>::value * length(x3) +
130  !is_constant_struct<T4>::value * length(x4) +
131  !is_constant_struct<T5>::value * length(x5) +
132  !is_constant_struct<T6>::value * length(x6)),
133  // TODO(carpenter): replace with array allocation fun
134  all_varis(static_cast<vari**>
135  (vari::operator new
136  (sizeof(vari*) * nvaris))),
137  all_partials(static_cast<double*>
138  (vari::operator new
139  (sizeof(double) * nvaris))),
140  d_x1(all_partials),
141  d_x2(all_partials
142  + (!is_constant_struct<T1>::value) * length(x1)),
143  d_x3(all_partials
144  + (!is_constant_struct<T1>::value) * length(x1)
145  + (!is_constant_struct<T2>::value) * length(x2)),
146  d_x4(all_partials
147  + (!is_constant_struct<T1>::value) * length(x1)
148  + (!is_constant_struct<T2>::value) * length(x2)
149  + (!is_constant_struct<T3>::value) * length(x3)),
150  d_x5(all_partials
151  + (!is_constant_struct<T1>::value) * length(x1)
152  + (!is_constant_struct<T2>::value) * length(x2)
153  + (!is_constant_struct<T3>::value) * length(x3)
154  + (!is_constant_struct<T4>::value) * length(x4)),
155  d_x6(all_partials
156  + (!is_constant_struct<T1>::value) * length(x1)
157  + (!is_constant_struct<T2>::value) * length(x2)
158  + (!is_constant_struct<T3>::value) * length(x3)
159  + (!is_constant_struct<T4>::value) * length(x4)
160  + (!is_constant_struct<T5>::value) * length(x5)) {
161  size_t base = 0;
163  base += set_varis<T1>().set(&all_varis[base], x1);
165  base += set_varis<T2>().set(&all_varis[base], x2);
167  base += set_varis<T3>().set(&all_varis[base], x3);
169  base += set_varis<T4>().set(&all_varis[base], x4);
171  base += set_varis<T5>().set(&all_varis[base], x5);
173  set_varis<T6>().set(&all_varis[base], x6);
174  std::fill(all_partials, all_partials+nvaris, 0);
175  }
176 
186  return partials_to_var(value, nvaris, all_varis,
187  all_partials);
188  }
189  };
190 
191  }
192 }
193 #endif
T_return_type value(double value)
Returns a T_return_type with the value specified with the partial derivatves.
VectorView< double, is_vector< T4 >::value, is_constant_struct< T4 >::value > d_x4
The variable implementation base class.
Definition: vari.hpp:30
size_t length(const std::vector< T > &x)
Definition: length.hpp:10
stan::math::var value(double value)
Returns a T_return_type with the value specified with the partial derivatves.
Independent (input) and dependent (output) variables for gradients.
Definition: var.hpp:31
vari ** operands_
Metaprogram to determine if a type has a base scalar type that can be assigned to type double...
double * partials_
VectorView< double, is_vector< T1 >::value, is_constant_struct< T1 >::value > d_x1
const size_t N_
This class builds partial derivatives with respect to a set of operands.
OperandsAndPartials(const T1 &x1=0, const T2 &x2=0, const T3 &x3=0, const T4 &x4=0, const T5 &x5=0, const T6 &x6=0)
Constructor.
VectorView< double, is_vector< T5 >::value, is_constant_struct< T5 >::value > d_x5
VectorView< double, is_vector< T2 >::value, is_constant_struct< T2 >::value > d_x2
void fill(std::vector< T > &x, const S &y)
Fill the specified container with the specified value.
Definition: fill.hpp:22
VectorView< double, is_vector< T3 >::value, is_constant_struct< T3 >::value > d_x3
VectorView is a template expression that is constructed with a container or scalar, which it then allows to be used as an array using operator[].
Definition: VectorView.hpp:48
VectorView< double, is_vector< T6 >::value, is_constant_struct< T6 >::value > d_x6

     [ Stan Home Page ] © 2011–2016, Stan Development Team.