/*
   Copyright 2010 Aaron J. Radke

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

       http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License.
*/
package cc.drx

//TODO Move state into a top level drx object and remove the State and RichState draw api
//TODO add vector types to possible states
//TODO add unit types to states
object Integrator{
   type Derivative = State => State
   private type X = Map[Symbol,Double]
   // private type Output = State => X
   //private val emptyX:X = Map()
   //private val emptyY:Output = (s:State) => Map()

   object State{
      def apply(kvs:(Symbol,Double)*):State = State(0.s, kvs.toMap)
      def apply(t:Time,kvs:(Symbol,Double)*):State = State(t, kvs.toMap)
   }
   case class State(t:Time, xs:X){
      def dottedOp(f:(Double,Double) => Double)(s:State):State = copy(xs = xs.map{case (k,v) => (k, f(v,s.xs(k)) )})
      def scalarOp(f:Double => Double):State = {
        val xsNew = xs.map{case (k,v) => (k,f(v))}.toMap
        State(t, xsNew)
      }
      def +(s:State):State  = dottedOp(_+_)(s)
      def *(s:State):State  = dottedOp(_*_)(s)
      def /(s:State):State  = dottedOp(_/_)(s)
      def -(s:State):State  = dottedOp(_-_)(s)

      def +(a:Double):State = scalarOp(_+a)
      def *(a:Double):State = scalarOp(_*a)
      def /(a:Double):State = scalarOp(_/a)
      def -(a:Double):State = scalarOp(_-a)
      def norm:Double = (this * this).xs.values.foldLeft(0.0){_ + _}.sqrt

      /**get a state value*/
      def apply(k:Symbol):Double = xs(k)
      /**convenience function to set the current time*/
      def apply(t:Time):State = copy(t = t)
      /**convenience function to update the time by step size*/
      def step(h:Time) = copy(t = t + h)
      /**convenience function to set any state value or input value*/
      def apply(kvs:(Symbol,Double)*):State = copy(xs = updateIfContains(xs,kvs))
      private def updateIfContains(m:X, kvs:Seq[(Symbol,Double)]):X = kvs.foldLeft(m){case (m,(k,v)) => if(m contains k) m.updated(k,v) else m}

      private def kson = xs map {case (k,v) => s"${k.name}:$v"} mkString " "
      override def toString = s"State(${t.format} $kson)"
   }
   /**rk4 integration*/
   /**rk4 integration*/
   def rk4(f:Derivative, h:Time)(x0:State):State = {
      val h2 = h/2.0
      val x2 = x0 step h2
      val x1 = x0 step h

      val k1 = f(x0)
      val k2 = f(x2 + k1*h2.s)
      val k3 = f(x2 + k2*h2.s)
      val k4 = f(x1 + k3*h.s)
      x1 + (k1 + k2*2 + k3*2 + k4)*h.s/6.0
   }
   /**integrate with fixed timesteps*/
   def fixed(f:Derivative, t:Time, N:Int)(x0:State):State = {
      val h = t/N
      @tailrec def _int(state:State, i:Int=0):State = if(i < N) _int(rk4(f,h)(state),i+1) else state
      _int(x0)
   }
   /**integrate with fixed timesteps (allowing integration method choice)*/
   def fixed(f:Derivative, t:Time, N:Int, method:Explicit)(x0:State):State = {
      val h = t/N
      @tailrec def _int(state:State, i:Int=0):State = if(i < N) _int(method(f,h)(state)._1,i+1) else state
      _int(x0)
   }
   /**integrate with adaptive variable timesteps sizes*/
   def adaptive(f:Derivative, t:Time, N:Int, method:Explicit=rk45)(x0:State):State = {
      val tol = 1E-5
      val h:Time = t/N
      val kp:Double = 10.0
      val hbound:Bound[Time] = Bound(h/kp, h*kp)
      val maxN = (N*kp).toInt
      val tend = x0.t + t
      @tailrec def _int(x0:State, h:Time, i:Int):State = {
         //println(s"i:$i")
         val (xn,e) = method(f,h)(x0) //always calculate the state and error (TODO: don't always calculate the state)
         val x = if(e < tol) xn else x0
         if(i > maxN) x
         else if(x.t >= tend) x
         else if (x.t + h > tend) _int(x, tend - x.t, i+1)
         else if (h < hbound.min) x
         else {
            val delta = 0.84*(tol/e**0.2).sat(0.1,4.0) //scalation error control
            val hnext:Time = {val hn = h*delta; if (hn > hbound.max) hbound.max else hn}
            _int(x,hnext, i+1)
         }
      }
      _int(x0, h, 0)
   }

   /**explicit Integrator
      * https://en/wikipedia.org/wiki/Runge-Kutta_methods
      * c,a,b,bp are the Butcher table coefficients
      */
   trait Explicit{
      //required
      protected val c:Vector[Double]
      protected val a:Vector[Vector[Double]]
      protected val b:Vector[Double]
      protected val bp:Vector[Double]
      protected lazy val be = (b zipTo bp){_-_}

      def apply(f:Derivative, h:Time)(x0:State):(State,Double) = {
         //--k's
         val k0 = f(x0)
         val k = (c zip a).foldLeft(Vector(k0)){case (k,(c,a)) =>
            val yoffset = (k zipTo a){_*_}.reduce(_+_)*h.s
            val ki = f( x0.step(h*c) + yoffset )
            k :+ ki
         }
         val e = (k zipTo be){_.norm * _}.reduce{_+_}.abs         //error
         val x = x0.step(h) + ((k zipTo b){_*_}.reduce{_+_}*h.s)  //state
         (x,e)
      }
   }
   //**rk45 dormand-prince returns state and error*/
   object rk45 extends Explicit{
      protected val c = Vector(1.0/5, 3.0/10,  4.0/5,  8.0/9, 1.0,  1.0)
      protected val a = Vector(
         Vector( 1.0/5),
         Vector( 3.0/40,          9.0/40),
         Vector(44.0/45,         -56.0/15,         32.0/9),
         Vector(19372.0/6561,    -25360.0/2187,    64448.0/6561,    -212.0/729),
         Vector(9017.0/3168,     -355.0/33,        46732.0/5247,    49.0/176,           -5103.0/18656),
         Vector(35.0/384,        0.0,              500.0/1113,      125.0/192,          -2187.0/6784,     11.0/84)
      )
      protected val b  = Vector(35.0/384,        0.0,             500.0/1113,       125.0/192,       -2187.0/6784,       11.0/84,          0.0)
      protected val bp = Vector(5179.0/57600,    0.0,             7571.0/16695,     393.0/640,       -92097.0/339200,    187.0/2100,       1.0/40)
   }
   //**rk45a fehlberg returns state and error*/
   object fehlberg extends Explicit{
      protected val c = Vector(1.0/4, 3.0/8,  12.0/13,  1.0, 1.0/2)
      protected val a = Vector(
         Vector( 1.0/4),
         Vector( 3.0/32,          9.0/32),
         Vector(1932.0/2197,     -7200.0/2197,       7296.0/2197),
         Vector(439.0/216,       -8.0,               3680.0/513,    -845.0/4104),
         Vector(-8.0/27,          2.0,              -3544.0/2565,   1859.0/4104,   -11.0/40)
      )
      protected val b  = Vector(16.0/135,    0.0,    6656.0/12825,   28561.0/56430,  -9.0/50,  2.0/55)
      protected val bp = Vector(25.0/216,    0.0,    1408.0/2565,     2197.0/4104,   -1.0/5.0,  0.0)
   }
   object rk4e extends Explicit{
      protected val c = Vector(1.0/2, 1.0/2, 1.0)
      protected val a = Vector(
         Vector( 1.0/5),
         Vector( 0     ,    1.0/2),
         Vector( 0.0   ,    0.0,          1.0)
      )
      protected val b  = Vector(1.0/6,   1.0/3,    1.0/3,   1.0/6)
      protected val bp = Vector(0.0,       0.0,      0.0,     0.0)
   }

}