odesolver.cpp
1 
10 #include <target/odesolver.hpp>
11 #include <target/utils.hpp>
12 
13 using namespace arma;
14 
15 namespace target {
16 
17 
18  arma::mat RK4::solve(const arma::mat &input,
19  arma::mat init,
20  arma::mat theta) {
21  unsigned n = input.n_rows;
22  unsigned p = init.n_elem;
23  mat res(n, p);
24  rowvec y = arma::conv_to<arma::rowvec>::from(init);
25  res.row(0) = y;
26  for (unsigned i=0; i < n-1; i++) {
27  rowvec dinput = input.row(i+1)-input.row(i);
28  double tau = dinput(0);
29  rowvec f1 = tau*F(input.row(i), y, theta);
30  rowvec f2 = tau*F(input.row(i) + dinput/2, y + f1/2, theta);
31  rowvec f3 = tau*F(input.row(i) + dinput/2, y + f2/2, theta);
32  rowvec f4 = tau*F(input.row(i) + dinput, y + f3, theta);
33  y += (f1+2*f2+2*f3+f4)/6;
34  res.row(i+1) = y;
35  }
36  return( res );
37  }
38 
39  arma::mat Solver::solveint(const arma::mat &input,
40  arma::mat init,
41  arma::mat theta,
42  double tau, bool reduce) {
43  mat newinput = interpolate(input, tau, true);
44  mat value = solve(newinput, init, theta);
45  if (reduce) {
46  uvec idx = target::fastapprox(newinput.col(0), input.col(0), false, 0);
47  value = value.rows(idx);
48  }
49  return( value );
50  }
51 
52 } // namespace target
Various utility functions and constants.
Classes for Ordinary Differential Equation Solvers.