/*
   Copyright 2010 Aaron J. Radke

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

       http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License.
*/
package cc.drx

//-- aritmetic speed comparisons
//-- 3gHz i7
// op   +    *    /
// i16  344  340  124
// i32  280  280  122
// i64  166  160   35
//
// f32  260   93   70
// f64  160  138   67
//-- 2gHz arm-A17
// op   +    *    /
// i16  188  188  26
// i32  170  170  27
// i64   87   68   7
//
// f32  136  136   54
// f64   79   79   28
//-- 1gHz arm11
// op   +    *    /
// i16  37  37  15
// i32  23  23  13
// i64  13  13   3
//
// f32  15  6    5
// f64  11  5    3


// implicit class FixedPointRichLong(x:Long){
//   def bshift(n:Int):Long = if(n > 0) x << n else x >> -n  //efficient scaling
// }

class FixedPoint(val m:Int,val n:Int,val q:Long, val keepLeftArgSize:Boolean){

  import FixedPoint.Q

  // override def hashCode:Int = (31*(31*(31 + m) + n)+q.toInt)
  override def equals(that:Any):Boolean = that match {
    case that:Q => this.m == that.m && this.n == this.n && this.q == that.q && this.keepLeftArgSize == that.keepLeftArgSize
    case _ => false
  }

  val N = m+n //TODO need to count sign bit too
  require(N <= 64, s"$N bits is to large for a Long backend $this")
  // require(m >=0)
  // require(n >=0)
  import FixedPoint.b
  lazy val min:Double = -b(m-1)
  lazy val max:Double = -min - resolution
  lazy val resolution:Double = b(-n)

  lazy val minQ:Long = -(1L << (N-1))
  lazy val maxQ:Long = (1L << (N-1)) - 1

  def sat:Q = if(q > maxQ) new Q(m,n,maxQ,keepLeftArgSize) else if (q < minQ) new Q(m,n,minQ,keepLeftArgSize) else this
  // def sat:Q = if(value > max) Q(m,n)(max) else if (value < min) Q(m,n)(min) else this
  // lazy val bits:Int = m+n+1 //signed

  lazy val value:Double = q*b(-n)

  //just use the base as a constructor
  def apply(x:Double):Q = new Q(m,n,  math.round(x*b(n)).toLong, keepLeftArgSize).sat

  def as(q:Q):Q = asQ(q.m, q.n, q.keepLeftArgSize)
  def asQ(mNew:Int, nNew:Int, keepLeftArgSizeNew:Boolean=false):Q = {
    val dn = n - nNew
    val half = 1.bshift(dn-1) //round the half in
    new Q(mNew,nNew, (q+half).bshift(-dn), keepLeftArgSizeNew).sat
  }

  private def f(x:Double):String = "% 8.5f".format(x)
  override def toString:String = f"${f(value)} = Q$m.$n(0x$q%08X) bound:[${f(min)},${f(max)}] dq:${f(resolution)})"


  def unary_- = new Q(m,n,-q, keepLeftArgSize)
  private def keep(res:Q):Q = if(keepLeftArgSize) res as this else res
  def +(that:Q):Q = keep(Q.add(this, that))
  def -(that:Q):Q = keep(Q.add(this, -that))
  def *(that:Q):Q = keep(Q.mul(this, that))
  def /(that:Q):Q = keep(Q.div(this, that))

  def withRounding:Q = new Q(m,n,q, keepLeftArgSize=true)
}

//assumed signed format fixed point
object FixedPoint{
  type Q = FixedPoint
  val Q = FixedPoint

  implicit class FixedPointRichDouble(x:Double){
    def as(q:Q):Q = q(x)
    def asQ(m:Int,n:Int):Q = (Q(m,n)(x)).sat
    // def log2:Double = math.log(x)/math.log(2)
    // def abs:Double = math.abs(x)
    // def ceil:Double = math.ceil(x)
  }


  // private def b(n:Int):Double = math.pow(2,n)
  def b(n:Int):Double = math.pow(2,n)
  // def apply(m:Int, n:Int)(x:Double):Q = new Q(m, n, q = math.round(x*b(n)).toLong )
  def apply(m:Int, n:Int):Q = new Q(m, n, 0, false)

  def bits(max:Double, N:Int):Q = {
    // min = -2^(m-1)         m+n = N
    val m = max.abs.log2.ceil.toInt + 1
    val q = new Q(m, N-m, 0, false)(max)
    if( q.max < max) new Q(q.m+1, q.N-q.m-1, 0, false)(max) else q //not quite enough bits, more are needed for 'm'
  }
  def resolution(max:Double, dq:Double):Q = {
    val m = max.abs.log2.ceil.toInt + 1
    def n = - dq.log2.ceil.toInt
    val q = new Q(m,n,0,false)(max)
    if( q.max < max) new Q(q.m+1, n, 0, false)(max) else q //not quite enough bits, more are needed for 'm'
  }

  // -- add (keep A)
  // cA = aA + bB
  // c  = a + bB/A
  private def addN(a:Q,b:Q):Q = {
    val m = if(a.m > b.m) a.m else b.m
    new Q(m+1, a.n, a.q + b.q.bshift(a.n-b.n), false)
  }
  //normalized so 'a' has more resolution than 'b'
  private def add(a:Q,b:Q):Q = if(a.n > b.n) addN(a,b) else addN(b,a) //normalized so 'a' has more resolution than 'b'
  // --mul case (new C)
  // cC = aA * bB
  private def mul(a:Q,b:Q):Q = new Q(a.m+b.m, a.n+b.n, a.q*b.q, false) //add 1/2 for better rounding down
  // --div keep A
  // cA  = aA / bB
  // c   = a  / bB
  private def div(a:Q,b:Q):Q = new Q(a.m, a.n, a.q.bshift(b.n)/b.q, false) //TODO m? TODO norm if smaller then shift?


  //--TODO move this test logic to the test suite
  def main(args:Array[String]):Unit = {
    test()
  }
  private def test():Unit = {
    // val a = 2.2.asQ(8,2)
    // val b = 1.3.asQ(10,3)
    val a = 2.2.asQ(8,2)
    val b = 1.3.asQ(10,3)


    def same(a:Any, b:Any, label:String):Unit = if(a != b) println(s"✗ Error $label:\n $a !=\n $b") else println(s"✓ Success $label:\n $a")

    //initialize
    same(a.asQ(10,3), 2.25.asQ(10,3), "upshift")
    same(b.asQ(8,2), 1.25.asQ(8,2),   "downshift")
    same(21.asQ(8,-2), 20.asQ(8,-2),  "negative fractional")

    //best Q
    same(Q.bits(90,16),  90.0.asQ(8,8), "max")
    same(Q.bits(128,16), 128.0.asQ(9,7), "max edge")
    same(Q.resolution(90,0.01), 90.0.asQ(8,6), "resolution")
    same(Q.resolution(90,5), 90.0.asQ(8,-3), "resolution under")

    //math
    same(a+b  ,  3.5.asQ(11,3)  , "add")
    same(a-b  ,  1.0.asQ(11,3)  , "sub")
    same(a*b  ,  2.813.asQ(18,5), "mul")
    same(a/b  ,  1.75.asQ(8,2),   "div")

    //sat
    val q88 = Q.bits(120,16)(100)
    val qover = q88 + q88
    same(qover         , 200.asQ(9,8)  , "add over")
    same(qover.asQ(8,8), 200.asQ(8,8)  , "add sat")

    {
      val q = Q.bits(100,32).withRounding

      val u = 1.0 as q
      val dt = 1d/80 as q
      val f = u/dt
      same(f, 80 as q, "div dt")
      val x = 0 as q

      // same(u,   1.0.asQ(5,11), "bit set")
      same(u-x,  1.0 as q, "sub withRounding")
      same(dt,   0.0125 as q, "dt size")

      @annotation.tailrec def fix[A](x0:A,maxN:Int=Int.MaxValue)(f: A => A):A = {
        val x = f(x0)
        if(x == x0) x else fix(x)(f)
      }

      // val xs = Iterator.iterate(x,7*80){x => 
      val sim = fix(x){x =>
        val y = x
        // println(s" y: $y")

        val e = u - y
        x + e*dt
      }

      same(sim, 0.98096 as q, "lowpass sim")
      // xs foreach println

    }
  }
}