
"""
ODE Solver: RK45 version 1.1.

Public domain, Connelly Barnes 2005.
"""

import math
import scipy
from scipy.linalg import norm
from scipy import array as vector

def solve(f, t0, tfinal, y0, tol = 1e-7):
  """
  Solve an ODE numerically using RK45.

  Solves dy/dt = f(t, y).  Returns a list of (t, y) tuples.
  Reference: http://www.library.cornell.edu/nr/bookcpdf/c16-2.pdf
  """

  def F(*args):
    return vector(f(*args))

  t = t0
  hmax = (tfinal - t0) / 128.0
  h = hmax / 4.0
  y = vector(y0)              # Column vector (nx1).
  out = [(t, list(y))]

  # Cash-Karp parameters
  a = [ 0.0, 0.2, 0.3, 0.6, 1.0, 0.875 ]
  b = [[],
       [0.2],
       [3.0/40.0, 9.0/40.0],
       [0.3, -0.9, 1.2],
       [-11.0/54.0, 2.5, -70.0/27.0, 35.0/27.0],
       [1631.0/55296.0, 175.0/512.0, 575.0/13824.0, 44275.0/110592.0, 253.0/4096.0]]
  c  = [37.0/378.0, 0.0, 250.0/621.0, 125.0/594.0, 0.0, 512.0/1771.0]
  dc = [c[0]-2825.0/27648.0, c[1]-0.0, c[2]-18575.0/48384.0,
        c[3]-13525.0/55296.0, c[4]-277.00/14336.0, c[5]-0.25]

  while t < tfinal:
    if t + h > tfinal:
      h = tfinal - t
    if t + h <= t:
      raise ValueError('Singularity in ODE')

    # Compute k[i] function values.
    k = [None] * 6
    k[0] = F(t, y)
    k[1] = F(t+a[1]*h, y+h*(k[0]*b[1][0]))
    k[2] = F(t+a[2]*h, y+h*(k[0]*b[2][0]+k[1]*b[2][1]))
    k[3] = F(t+a[3]*h, y+h*(k[0]*b[3][0]+k[1]*b[3][1]+k[2]*b[3][2]))
    k[4] = F(t+a[4]*h, y+h*(k[0]*b[4][0]+k[1]*b[4][1]+k[2]*b[4][2]+k[3]*b[4][3]))
    k[5] = F(t+a[5]*h, y+h*(k[0]*b[5][0]+k[1]*b[5][1]+k[2]*b[5][2]+k[3]*b[5][3]+k[4]*b[5][4]))

    # Estimate current error and current maximum error.
    E = norm(h*(k[0]*dc[0]+k[1]*dc[1]+k[2]*dc[2]+k[3]*dc[3]+k[4]*dc[4]+k[5]*dc[5]))
    Emax = tol*max(norm(y), 1.0)

    # Update solution if error is OK.
    if E < Emax:
      t += h
      y += h*(k[0]*c[0]+k[1]*c[1]+k[2]*c[2]+k[3]*c[3]+k[4]*c[4]+k[5]*c[5])
      out += [(t, list(y))]

    # Update step size
    if E > 0.0:
      h = min(hmax, 0.85*h*(Emax/E)**0.2)

  return out

def solve_func(f, t0, tfinal, y0, tol = 1e-5):
  """Returns y(t) function, uses linear interpolation."""
  L = solve(f, t0, tfinal, y0, tol)
  T = [t for (t, y) in L]
  Y = [y for (t, y) in L]
  interp_func = scipy.interpolate.interp1d(T, Y, axis=0)
  def func(t):
    return interp_func(t)[0]
  return func

def test_rk45(tol, max_error, max_evals):
  """
  Orbiting space shuttle test case.
  """

  y0 = vector([0.994, 0.0, 0.0, -2.00158510637908252240537862224])
  Period = 17.0652166
  fevals = [0]

  def yp(t, y):
    fevals[0] += 1
    mu = 0.012277471
    muhat = 1 - mu

    u1 = y[0]
    u1p = y[1]
    u2 = y[2]
    u2p = y[3]

    D1 = ((u1 + mu)**2 + u2**2)
    D1 = D1 * math.sqrt(D1)
    D2 = ((u1 - muhat)**2 + u2**2)
    D2 = D2 * math.sqrt(D2)

    return (u1p,
            u1 + 2 * u2p - muhat*(u1+mu) / D1 - mu*(u1-muhat) / D2,
            u2p,
            u2 - 2 * u1p - muhat * u2 / D1 - mu * u2 / D2)
  
  sol = solve(yp, 0.0, Period, y0, tol)
  ylast = sol[len(sol)-1][1]

#  print 'Fevals:', fevals[0]
#  print 'Error:', norm(ylast - y0)
  assert fevals[0] < max_evals
  assert norm(ylast - y0) < max_error

def test():
  """
  Unit tests.
  """

  print 'Testing:'
  test_rk45(tol=1e-6, max_error=0.01, max_evals=2000)
  test_rk45(tol=1e-9, max_error=3e-5, max_evals=5000)
  print '  rk45:        OK'

if __name__ == '__main__':
  test()
