package cc.drx

//-- aritmetic speed comparisons
//-- 3gHz i7
// op   +    *    /
// i16  344  340  124
// i32  280  280  122
// i64  166  160   35
// f32  260   93   70
// f64  160  138   67
//-- 2gHz arm-A17
// op   +    *    /
// i16  188  188  26
// i32  170  170  27
// i64   87   68   7
// f32  136  136   54
// f64   79   79   28
//-- 1gHz arm11
// op   +    *    /
// i16  37  37  15
// i32  23  23  13
// i64  13  13   3
// f32  15  6    5
// f64  11  5    3

// implicit class FixedPointRichLong(x:Long){
//   def bshift(n:Int):Long = if(n > 0) x << n else x >> -n  //efficient scaling
// }

class FixedPoint(val m:Int,val n:Int,val q:Long, val keepLeftArgSize:Boolean){

  import FixedPoint.Q

  // override def hashCode:Int = (31*(31*(31 + m) + n)+q.toInt)
  override def equals(that:Any):Boolean = that match {
    case that:Q => this.m == that.m && this.n == this.n && this.q == that.q && this.keepLeftArgSize == that.keepLeftArgSize
    case _ => false

  val N = m+n //TODO need to count sign bit too
  require(N <= 64, s"$N bits is to large for a Long backend $this")
  // require(m >=0)
  // require(n >=0)
  import FixedPoint.b
  lazy val min:Double = -b(m-1)
  lazy val max:Double = -min - resolution
  lazy val resolution:Double = b(-n)

  lazy val minQ:Long = -(1L << (N-1))
  lazy val maxQ:Long = (1L << (N-1)) - 1

  def sat:Q = if(q > maxQ) new Q(m,n,maxQ,keepLeftArgSize) else if (q < minQ) new Q(m,n,minQ,keepLeftArgSize) else this
  // def sat:Q = if(value > max) Q(m,n)(max) else if (value < min) Q(m,n)(min) else this
  // lazy val bits:Int = m+n+1 //signed

  lazy val value:Double = q*b(-n)

  //just use the base as a constructor
  def apply(x:Double):Q = new Q(m,n,  math.round(x*b(n)).toLong, keepLeftArgSize).sat

  def as(q:Q):Q = asQ(q.m, q.n, q.keepLeftArgSize)
  def asQ(mNew:Int, nNew:Int, keepLeftArgSizeNew:Boolean=false):Q = {
    val dn = n - nNew
    val half = 1.bshift(dn-1) //round the half in
    new Q(mNew,nNew, (q+half).bshift(-dn), keepLeftArgSizeNew).sat

  private def f(x:Double):String = "% 8.5f".format(x)
  override def toString:String = f"${f(value)} = Q$m.$n(0x$q%08X) bound:[${f(min)},${f(max)}] dq:${f(resolution)})"

  def unary_- = new Q(m,n,-q, keepLeftArgSize)
  private def keep(res:Q):Q = if(keepLeftArgSize) res as this else res
  def +(that:Q):Q = keep(Q.add(this, that))
  def -(that:Q):Q = keep(Q.add(this, -that))
  def *(that:Q):Q = keep(Q.mul(this, that))
  def /(that:Q):Q = keep(Q.div(this, that))

  def withRounding:Q = new Q(m,n,q, keepLeftArgSize=true)

//assumed signed format fixed point
object FixedPoint{
  type Q = FixedPoint
  val Q = FixedPoint

  implicit class FixedPointRichDouble(x:Double){
    def as(q:Q):Q = q(x)
    def asQ(m:Int,n:Int):Q = (Q(m,n)(x)).sat
    // def log2:Double = math.log(x)/math.log(2)
    // def abs:Double = math.abs(x)
    // def ceil:Double = math.ceil(x)

  // private def b(n:Int):Double = math.pow(2,n)
  def b(n:Int):Double = math.pow(2,n)
  // def apply(m:Int, n:Int)(x:Double):Q = new Q(m, n, q = math.round(x*b(n)).toLong )
  def apply(m:Int, n:Int):Q = new Q(m, n, 0, false)

  def bits(max:Double, N:Int):Q = {
    // min = -2^(m-1)         m+n = N
    val m = max.abs.log2.ceil.toInt + 1
    val q = new Q(m, N-m, 0, false)(max)
    if( q.max < max) new Q(q.m+1, q.N-q.m-1, 0, false)(max) else q //not quite enough bits, more are needed for 'm'
  def resolution(max:Double, dq:Double):Q = {
    val m = max.abs.log2.ceil.toInt + 1
    def n = - dq.log2.ceil.toInt
    val q = new Q(m,n,0,false)(max)
    if( q.max < max) new Q(q.m+1, n, 0, false)(max) else q //not quite enough bits, more are needed for 'm'

  // -- add (keep A)
  // cA = aA + bB
  // c  = a + bB/A
  private def addN(a:Q,b:Q):Q = {
    val m = if(a.m > b.m) a.m else b.m
    new Q(m+1, a.n, a.q + b.q.bshift(a.n-b.n), false)
  //normalized so 'a' has more resolution than 'b'
  private def add(a:Q,b:Q):Q = if(a.n > b.n) addN(a,b) else addN(b,a) //normalized so 'a' has more resolution than 'b'
  // --mul case (new C)
  // cC = aA * bB
  private def mul(a:Q,b:Q):Q = new Q(a.m+b.m, a.n+b.n, a.q*b.q, false) //add 1/2 for better rounding down
  // --div keep A
  // cA  = aA / bB
  // c   = a  / bB
  private def div(a:Q,b:Q):Q = new Q(a.m, a.n, a.q.bshift(b.n)/b.q, false) //TODO m? TODO norm if smaller then shift?

  //--TODO move this test logic to the test suite
  def main(args:Array[String]):Unit = {
  private def test():Unit = {
    // val a = 2.2.asQ(8,2)
    // val b = 1.3.asQ(10,3)
    val a = 2.2.asQ(8,2)
    val b = 1.3.asQ(10,3)

    def same(a:Any, b:Any, label:String):Unit = if(a != b) println(s"✗ Error $label:\n $a !=\n $b") else println(s"✓ Success $label:\n $a")

    same(a.asQ(10,3), 2.25.asQ(10,3), "upshift")
    same(b.asQ(8,2), 1.25.asQ(8,2),   "downshift")
    same(21.asQ(8,-2), 20.asQ(8,-2),  "negative fractional")

    //best Q
    same(Q.bits(90,16),  90.0.asQ(8,8), "max")
    same(Q.bits(128,16), 128.0.asQ(9,7), "max edge")
    same(Q.resolution(90,0.01), 90.0.asQ(8,6), "resolution")
    same(Q.resolution(90,5), 90.0.asQ(8,-3), "resolution under")

    same(a+b  ,  3.5.asQ(11,3)  , "add")
    same(a-b  ,  1.0.asQ(11,3)  , "sub")
    same(a*b  ,  2.813.asQ(18,5), "mul")
    same(a/b  ,  1.75.asQ(8,2),   "div")

    val q88 = Q.bits(120,16)(100)
    val qover = q88 + q88
    same(qover         , 200.asQ(9,8)  , "add over")
    same(qover.asQ(8,8), 200.asQ(8,8)  , "add sat")

      val q = Q.bits(100,32).withRounding

      val u = 1.0 as q
      val dt = 1d/80 as q
      val f = u/dt
      same(f, 80 as q, "div dt")
      val x = 0 as q

      // same(u,   1.0.asQ(5,11), "bit set")
      same(u-x,  1.0 as q, "sub withRounding")
      same(dt,   0.0125 as q, "dt size")

      @annotation.tailrec def fix[A](x0:A,maxN:Int=Int.MaxValue)(f: A => A):A = {
        val x = f(x0)
        if(x == x0) x else fix(x)(f)

      // val xs = Iterator.iterate(x,7*80){x => 
      val sim = fix(x){x =>
        val y = x
        // println(s" y: $y")

        val e = u - y
        x + e*dt

      same(sim, 0.98096 as q, "lowpass sim")
      // xs foreach println
