/*
   Copyright 2010 Aaron J. Radke

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

       http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License.
*/
package cc.drx

//TODO link to test file for common docs

//TODO rename Filter to a more unique toplevel name, like keywords: StateSpace, TransferFunction, DiffEq, DynamicSystem, Model, System

//TODO try a way to generalize the discrete update without a time step
abstract class Filter{ //TODO def use type parameters to allow this output and input to be of generic types
   //update the state input and output the current output calculation
   def update(u:Doublet:Time):Double
   //initialize with value and output the current output calculation just like the state update
   def init(y0:Double):Double
   def output:Double

   //---concrete implementations
   /*convolve a filter on a sequence of data with a function that extracts time and the input from the data */
   final def convolve[A](gapDetect:Option[Time]=None)(tu:A=>Tuple2[Time,Double])(xs:Iterator[A]):Iterator[(Time,Double)] = {
      // val (t0,y0) = xs.headOption map tu getOrElse Tuple2(Time(0.0),0.0)
      var t_last = Time(0) //this is adjusted later
      var count:Long = 0
      for(x <- xs) yield {
         //--transform with dt
         val (t,u) = tu(x)
         //--init check
         if(count == 0){
           init(u)
           t_last = t //t0
         }
         count += 1

         val Δt:Time =  (t - t_last)
         //save time
         t_last = t
         //---discretization, update and output
         def isGap = gapDetect.isDefined && (Δt > gapDetect.get)
         val y = if(Δt == 0.s || isGap) output else update(u,Δt)
         t -> y
      }
   }

   /*convolve a filter on a double data seq with a fixed time step*/
   final def convolve[A](dt:Time)(xs:Iterator[Double]):Iterator[Double] = {
      var i = 0
      for(x <- xs) yield {
        i += 1
        if(i==1) init(x) else update(x,dt) //---discretization, update and output
      }
   }
}
abstract class FilterDiscrete[A] extends Filter{
   //== required fields
   //--The discrete time dynamic state model used for updates
   def discrete(dt:Time):A
   //--mutate the internal state update should only be used by internal filter functions (this dangerously requires the discrete model to be update before being called
   protected def update(u:Double):Unit

   //== implementation fields
   //mutably store the update model
   //store the discrete model from the last timestep to provide a check if fixed then it does not need to be re-calculated
   private var dtLast:Time = Time(-1)
   protected var d:A = _ //discrete model

   final def update(u:Doublet:Time):Double = {
      require(Δt > 0.s, s"state updates require Δt=$Δt > 0 ")
      //---model update
      if(dtLast != Δt){ d = discrete(Δt); dtLast = Δt}
      update(u)
      output
   }
}
object Filter{
   import Matrix._

   //---integrator
   def int(t_r:Time):SS2 = {val ω = 4/t_r.s; TF(ω)(1,ω,0).ss.asInstanceOf[SS2]}  //1/s * w/(s+w) =>  w/(s^2 + sw + 0)
   def int:SS1             = {TF(1)(1,0).ss.asInstanceOf[SS1]}
   //---dіfferentiator
   def diff(t_r:Time):SS2 = diff2Angular(4/t_r.s)
   //---lpf
   def lpf(t_r:Time):SS1 = lpf1Angular(4/t_r.s)
   def lpf(f:Frequency)(implicit d:DummyImplicit):SS1 = lpf1Angular(f.omega)
   //---ω parameterized filters
   //ω = 2πf = 4/tr
   //tr = 4/ω  = 4/(2πf)
   private def diff2Angular(ω:Double):SS2 = TF(ω*ω,0)(1,2*0.72*ω,ω*ω).ss.asInstanceOf[SS2]
   private def lpf2Angular(ω:Double):SS2  = TF(ω*ω)(1, 2*0.72*ω,  ω*ω).ss.asInstanceOf[SS2] //2ζω where ζ=0.72 for better physically focused butterworth matched BW concept
   private def lpf1Angular(ω:Double):SS1  = TF(ω)(1,ω).ss.asInstanceOf[SS1]
   // def lpf(f:Frequency):SS1 = {val ω = tau*f.hz; TF(ω)(1,ω).ss.asInstanceOf[SS1]}
   def lpf(t_r:Time,order:Int):Filter = order match {
      case 2 => lpf2(t_r)
      case _ => lpf(t_r)
   }
   def lpf2(t_r:Time):SS2 = lpf2Angular(4/t_r.s)
   def lpf2(f:Frequency)(implicit d:DummyImplicit):SS2 = lpf2Angular(f.omega)

   //---angle filters
   def lpfAngle(t_r:Time):LPF1Angle = new LPF1Angle(t_r)
   def lpfAngle2(t_r:Time):LPF2Angle = new LPF2Angle(t_r)

   def lpfAngle(t_r:Time,order:Int):Filter = order match {
      case 2 => new LPF2Angle(t_r)
      case _ => new LPF1Angle(t_r)
   }

   class SS0(val D:Double=1.0) extends FilterDiscrete[SS0]{
      private var uLast = 0.0
      def init(y0:Double):Double = output
      def discretet:Time):SS0 = this
      protected def update(u:Double):Unit = uLast = u
      def output:Double = D*uLast
      override def toString = s"y = ${D} u"
   }
   class SS1(val A:Double,val B:Double,val C:Double,val D:Double=0) extends FilterDiscrete[SS1]{
      private var x = 0.0
      def init(y0:Double):Double = {x = y0; output}
      def discretet:Time):SS1 = {
         //---discretization
         if(A==0){
            val Ad = 1.0
            val Bd = B*Δt.s
            new SS1(Ad,Bd,C)
         }else{
            val Ad = math.exp(A*Δt.s)
            val Bd = 1/A*(Ad-1)*B
            new SS1(Ad,Bd,C)
         }
      }
      protected def update(u:Double):Unit = x = (d.A*x) + (d.B*u)
      def output = C*x
      override def toString = s"xn = $A x + $B u;  y = $C x + $D u"
   }
   class SS2(val A:Matrix2,val B:Vec,val C:Vec,val D:Double=0) extends FilterDiscrete[SS2]{
      private var x = Vec(0,0)
      def init(y0:Double) = {x = Vec(y0,0); output}
      def discretet:Time):SS2 = {
         if(A.det==0){
            val Ad = Matrix.I2 + A*Δt.s
            val Bd = B*Δt.s
            new SS2(Ad,Bd,C)
         }else{
            val Ad = (A*Δt.s).exp
            val Bd = A.inv*(Ad-Matrix.I2)*B
            new SS2(Ad,Bd,C)
         }
      }
      def update(u:Double):Unit = x = (d.A*x) + (d.B*u)
      def output = C*x
      override def toString = s"xn = $A x + $B u;  y = $C x + $D u"
   }
   class SS3(val A:Matrix3,val B:Vec,val C:Vec,val D:Double=0) extends FilterDiscrete[SS3]{
      private var x = Vec(0,0,0)
      def init(y0:Double) = {x = Vec(y0,0,0); output}
      def discretet:Time):SS3 = {
         if(A.det==0){
            val Ad = Matrix.I3 + A*Δt.s
            val Bd = B*Δt.s
            new SS3(Ad,Bd,C)
         }else{
            val Ad = (A*Δt.s).exp
            val Bd = A.inv*(Ad-Matrix.I3)*B
            new SS3(Ad,Bd,C)
         }
      }
      protected def update(u:Double):Unit = x = (d.A*x) + (d.B*u)
      def output = C*x
      override def toString = s"xn = $A x + $B u;  y = $C x + $D u"
   }
   object TF{
      def apply(num:Double*)(den:Double*) = new TF(num,den)
   }
   class TF(num:Seq[Double],den:Seq[Double]){
      val cden = den.dropWhile(_ == 0).map{c => c/den.head}
      val order = cden.size - 1
      val cnum = {
         val num_tmp = num.dropWhile(_ == 0).map{c => c/den.head}
         Seq.fill(order - num_tmp.size)(0.0) ++ num_tmp //pad with zeros
      }

      val strictlyProper = cnum.size < cden.size

     //wikipedia/wiki/State_space_representation -> observable cononical form
      val ss:Filter = order match {
         case 0 => new SS0(cnum(0)/cden(0))
         case 1 => {
            val A = -cden(1)
            val B = cnum(0)
            val C = 1.0
            new SS1(A,B,C,0)
         }
         case 2 => {
            val A = Matrix2(
               Vec(-cden(1), 1),
               Vec(-cden(2), 0)
            )
            val B = Vec(cnum(0),cnum(1))
            val C = Vec(1,0)
            new SS2(A,B,C,0)
         }
         case 3 => {
            val A = Matrix3(
               Vec(-cden(1), 1, 0),
               Vec(-cden(2), 0, 1),
               Vec(-cden(3), 0, 0)
            )
            val B = Vec(cnum(0),cnum(1), cnum(2))
            val C = Vec(1,0,0)
            new SS3(A,B,C,0)
         }
         case _ => new SS0(cnum(0)/cden(0))
      }

      //override def toString = s"TF[${num mkString " "}]/[${den mkString " "}]"
      override def toString = s"TF[${cnum mkString " "}]/[${cden mkString " "}]"
   }

   class LPF1Angle(t_r:Time) extends Filter{
      private var x = Vec.angle(0.rad)

      //val ω = 4/t_r.s

      def init(y0:Double) = {x = Vec.angle(y0.rad); output}

      val lpf = Filter.lpf(t_r)

      def update(u:Doublet:Time):Double= {
         //---discretization
         val d = lpf discrete Δt
         x = x.unit*d.A + Vec.angle(u.rad)*d.B
         //--output
         output
      }
      def output = x.heading.rad  //TODO def use type parameters, or base value types to allow this output and input to be of Angle type 
   }

   class LPF2Angle(t_r:Time,ζ:Double=0.71) extends Filter{
      //val ω = 4/t_r.s

      private var cosx = Vec(0,0)
      private var sinx = Vec(0,0)

      def init(y0:Double) = {
         cosx = Vec(math.cos(y0),0)
         sinx = Vec(math.sin(y0),0)
         output
      }

      val lpf = Filter.lpf2(t_r)

      def update(u:Doublet:Time):Double ={
         //---discretization
         val d = lpf discrete Δt
         //---update
         cosx = (d.A*cosx) + (d.B*u.cos)
         sinx = (d.A*sinx) + (d.B*u.sin)
         //--output
         output
      }
      def output = Vec(lpf.C*cosx, lpf.C*sinx).unit.heading.rad
   }
}