/*
   Copyright 2010 Aaron J. Radke

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

       http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License.
*/
package cc.drx

object Matrix{
   implicit object ParsableMatrix2 extends Parsable[Matrix.Matrix2]{ def apply(v:String):Matrix.Matrix2 = Matrix.Matrix2(v) }
   implicit object ParsableMatrix3 extends Parsable[Matrix.Matrix3]{ def apply(v:String):Matrix.Matrix3 = Matrix.Matrix3(v) }

   val I2 = Matrix2(Vec(1,0),Vec(0,1))
   val I3 = Matrix3(Vec(1,0,0),Vec(0,1,0),Vec(0,0,1))
   def apply(a:Double,b:Double,c:Double,d:Double):Matrix2 = Matrix2(Vec(a,b),Vec(c,d))
   def apply(a:Double,b:Double,c:Double,
             d:Double,e:Double,f:Double,
             g:Double,h:Double,i:Double):Matrix3 = Matrix3(Vec(a,b,c),Vec(d,e,f),Vec(g,h,i))
   def apply(x:Vec, y:Vec):Matrix2 = Matrix2(x,y)
   def apply(x:Vec, y:Vec, z:Vec):Matrix3 = Matrix3(x,y,z)

   def diag(x:Double, y:Double):Matrix2           = Matrix2(Vec(x,0),   Vec(0,y))
   def diag(x:Double, y:Double, z:Double):Matrix3 = Matrix3(Vec(x,0,0), Vec(0,y,0), Vec(0,0,z))
   def symetric(xx:Double, xy:Double, yy:Double):Matrix2 = Matrix2(Vec(xx,xy), Vec(xy,yy))
   def symetric(xx:Double, xy:Double, xz:Double,  yy:Double, yz:Double,  zz:Double):Matrix3 = Matrix3(Vec(xx,xy,xz), Vec(xy,yy,yz), Vec(xz,yz,zz))

   //private def numbers(str:String) = str.trim.split("""[^\dE\-\.\+]+""").map(Parse[Double].apply) //anything that is a number
   private def numbers(str:String) = str.trim.split("""[\s\,\;\_\~\:]+""").map(Parse[Double](_)) //anything that is a number
   object Matrix2{
      def apply(str:String):Matrix2 = {
         val ns = numbers(str)
         assert(ns.size == 4)
         Matrix2( Vec(ns(0), ns(1)), Vec(ns(2), ns(3)) )
      }
   }
   object Matrix3{
      def apply(str:String):Matrix3 = {
         val ns = numbers(str)
         assert(ns.size == 9)
         Matrix3( Vec(ns(0), ns(1), ns(2)), Vec(ns(3), ns(4), ns(5)), Vec(ns(6),ns(7),ns(8)) )
      }
   }
   /**eigen value and vector pairs*/
   case class Eigen(value:Double, vec:Vec)//it is not really a vector but rather a basis
   case class Matrix2(x:Vec, y:Vec){
      lazy val det:Double = x.x*y.y - x.y*y.x
      lazy val inv = Matrix2(Vec(y.y,-x.y),Vec(-y.x,x.x))/det
      lazy val T = Matrix2(Vec(x.x,y.x),Vec(x.y,y.y))
      lazy val tr:Double = x.x+y.y
      def isSymetric = x.y == y.x
      /**eigen values and eigenvectors from largest mag to smallest*/
      lazy val eigen:(Eigen,Eigen) = {
         val Vec(a,b,_) = x
         val Vec(c,d,_) = y
         //eigenvals
         if(b == 0 && c == 0) if(a.abs > d.abs) (Eigen(a,x),Eigen(d,y)) else (Eigen(a,x),Eigen(d,y))
         else {
           // TODO this general case needs some work because it doesn't always work...
           val amd = a-d
           val apd = a+d
           val q = math.sqrt(amd*amd + b*c*4) // a*a - 2.0*a*d + 4.0*b*c + d*d// val apd = a+d
           val l1 = (apd+q)/2 //l1 >= l2 since q is real and q >= 0
           val l2 = (apd-q)/2
           val c2 = c*2
           val v1 = Vec(amd+q,c2) //l1 >= l2 since q is real and q >= 0
           val v2 = Vec(amd-q,c2)
           (Eigen(l1,v1),Eigen(l2,v2))
         }
      }
      /**eigen ellipse (this assumes a symetric matrix */
      lazy val ellipse:Ellipse = Ellipse(Vec.zero,Vec(eigen._1.value.abs, eigen._2.value.abs), eigen._1.vec.heading)
      /**Matrix Exponential e^A**/
      lazy val exp:Matrix2 = {
         val Vec(a,b,_) = x
         val Vec(c,d,_) = y
         //--special eigenvalue case for faster calculations
         if(c == 0.0){
            val ea = math.exp(a)
            if(a == d){
               //println(s"exp: simple c=0 & a == d A:$this")
               Matrix(ea, ea*b, 0.0, ea)
            }
            else{
               //println(s"exp: simple c=0 & a != d A:$this")
               val ed = math.exp(d)
               val pb = b*(ea-ed)/(a-d)
               Matrix(ea,pb, 0.0, ed)
            }
         }
         //--full form Matrix2 solution from:
         //   IEE automatic control "some explicit formulats for the Matrix exponential" Beernstein and So 1993
         else{
            val condition = (a-d)*(a-d) + 4.0*b*c
            val m = (a-d)/2.0
            val p = (a+d)/2.0
            val ep = math.exp(p)
            if(condition == 0.0){
               //println(s"exp: full form condition == 0 A:$this")
               Matrix(1.0+m, b, c, 1.0-m)*ep
            }
            else if(condition > 0.0){
               //println(s"exp: full form condition > 0 A:$this")
               val del = 0.5*math.sqrt(condition)
               val sd = math.sinh(del)/del
               val cd = math.cosh(del)
               Matrix(cd+m*sd, b*sd, c*sd, cd-m*sd)*ep
            }
            else{ //if(condition < 0.0){
               //println(s"exp: full form condition < 0 A:$this")
               val del = 0.5*math.sqrt(math.abs(condition))
               val sd = math.sin(del)/del
               val cd = math.cos(del)
               Matrix(cd+m*sd, b*sd, c*sd, cd-m*sd)*ep
            }
         }
      }
      //def expOld(N:Int):Matrix2 = (Matrix.I2 /: (1 to N)){case (a,n) => a + (a.pow(n)/n.factorial)}  //room for optimization for rolling pows
      // e^x = \sum(ck*mk) where ck = c(k-1)/i   mk = m(k-1)*x
      //def exp(N:Int):Matrix2 = List.iterate((1.0,1.0,Matrix.I2),N){case (i,c,m) => (i+1.0, c/i, m*this)}.map{case(i,c,m) => m*c}.reduce(_*_)

      def pow(n:Int):Matrix2 = Seq.fill(n)(this) reduce (_*_)

      def unary_- :Matrix2 = Matrix2(-x,-y)

      def *(s:Double) = Matrix2(x*s, y*s)
      def /(s:Double) = Matrix2(x/s, y/s)
      def -(that:Matrix2) = Matrix2(this.x-that.x, this.y-that.y)
      def +(that:Matrix2) = Matrix2(this.x+that.x, this.y+that.y)

      def *(that:Matrix2) = Matrix2(
         Vec( x*that.T.x,  x*that.T.y),
         Vec( y*that.T.x,  y*that.T.y)
      )
      def *(v:Vec):Vec = Vec(x*v, y*v)

      override def toString =  s"Matrix2(${x.x}  ${x.y} ; ${y.x} ${y.y})"
      def apply(i:Int,j:Int) = (i,j) match {
         case (1,1) => x.x
         case (1,2) => x.y
         case (2,1) => y.x
         case (2,2) => y.y
         case _ => ???
      }
   }
   /**3x3 specialized matrix*/
   case class Matrix3(x:Vec, y:Vec, z:Vec){
      def apply(i:Int,j:Int) = (i,j) match {
         case (1,1) => x.x
         case (1,2) => x.y
         case (1,3) => x.z
         case (2,1) => y.x
         case (2,2) => y.y
         case (2,3) => y.z
         case (3,1) => z.x
         case (3,2) => z.y
         case (3,3) => z.z
         case _ => ???
      }

      //wikipedia.org/wiki/Invertable_matrix

      lazy val det:Double = {
         val A =  (y.y*z.z - y.z*z.y)
         val B = -(y.x*z.z - y.z*z.x)
         val C =  (y.x*z.y - y.y*z.x)

         x.x*A + x.y*B + x.z*C
      }
      lazy val inv = {
         val A =  (y.y*z.z - y.z*z.y)
         val B = -(y.x*z.z - y.z*z.x)
         val C =  (y.x*z.y - y.y*z.x)

         val det = x.x*A + x.y*B + x.z*C

         val D = -(x.y*z.z - x.z*z.y)
         val E =  (x.x*z.z - x.z*z.x)
         val F = -(x.x*z.y - x.y*z.x)

         val G =  (x.y*y.z - x.z*y.y)
         val H = -(x.x*y.z - x.z*y.x)
         val I =  (x.x*y.y - x.y*y.x)

         Matrix3(Vec(A,D,G), Vec(B,E,H), Vec(C,F,I))/det

      }

      lazy val T = Matrix3(Vec(x.x,y.x,z.x),Vec(x.y,y.y,z.y), Vec(x.z,y.z,z.z))

      lazy val exp:Matrix3 = ??? //Matrix3(x map math.exp, y map math.exp, z map math.exp)
      lazy val tr:Double = x.x+y.y+z.z

      def pow(n:Int):Matrix3 = Seq.fill(n)(this) reduce (_*_)

      def unary_- :Matrix3 = Matrix3(-x,-y,-z)

      def *(s:Double) = Matrix3(x*s, y*s, z*s)
      def /(s:Double) = Matrix3(x/s, y/s, z/s)
      def -(that:Matrix3) = Matrix3(this.x-that.x, this.y-that.y, this.z-that.z)
      def +(that:Matrix3) = Matrix3(this.x+that.x, this.y+that.y, this.z-that.z)

      def *(that:Matrix3) = Matrix3(
         Vec( x*that.T.x,  x*that.T.y, x*that.T.z),
         Vec( y*that.T.x,  y*that.T.y, y*that.T.z),
         Vec( z*that.T.x,  z*that.T.y, z*that.T.z)
      )
      def *(v:Vec):Vec = Vec(x*v, y*v, z*v)

      override def toString =  s"Matrix3(${x.x} ${x.y} ${x.z} ; ${y.x} ${y.y} ${y.z} ; ${z.x} ${z.y} ${z.z})"
   }
}