/*
   Copyright 2010 Aaron J. Radke

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

       http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License.
*/
package cc.drx

object FFT{
   def apply(in: Iterable[Complex],dt:Time):FFT = new FFT(in,dt)
   def apply(in: Iterable[Double], dt:Time)(implicit d: DummyImplicit):FFT = new FFT(in map {Complex(_,0)},dt)

   //--windowing functions
   /**Use the hann windowing function for better power signal SNR on impure signals 
    * as suggested by this excellent answer https://stackoverflow.com/a/15153614/622016
    * (wiki:Hann_function (1-cos(2πn/(N-1)))/2 ) */
   def hann(in:Iterable[Double],dt:Time):FFT = {
     val ds = in.toArray
     val N = ds.size
     val k = tau/N.prev
     val ws = ds.zipWithIndex.map{case (d,n) => Complex(d*(k*n).cos.not/2, 0)}
     new FFT(ws, dt)
   }

   case class Bin(freq:Frequency, value:Complex, fftSize:Int){
     lazy val power:Double = value.normSq / fftSize
     lazy val abs:Double = value.abs
     lazy val phase:Angle = value.phase
   }
}

class FFT(val in: Iterable[Complex], val dt:Time) extends Iterable[FFT.Bin]{
   //inspired by Rosetta code fft for scala
   //TODO base internal storage of fft on Array's instead of scala.Seq
   private def _fft(in: Vector[Complex], direction: Complex, scalar: Int): IndexedSeq[Complex] = {
       val N = in.length
       if (N == 1) return in.toVector

       assume(N % 2 == 0, "The Cooley-Tukey FFT algorithm only works when the length of the input is even. (recursion implies starting with 2^n)")

       val evens = _fft(in.iterator.evens.toVector, direction, scalar)
       val odds  = _fft(in.iterator.odds.toVector,  direction, scalar)

       def leftRightPair(k: Int):(Complex, Complex) = {
           val base = evens(k) / scalar
           val offset = (direction * (pi * k / N)).exp * odds(k) / scalar
           (base + offset, base - offset)
       }

       val pairs = (0 until N/2) map leftRightPair
       val left  = pairs map (_._1)
       val right = pairs map (_._2)
       left ++ right
   }

   lazy val fft:Seq[Complex]  = _fft(in.toVector, Complex(0,  2), 1)
   lazy val ifft:Seq[Complex] = _fft(in.toVector, Complex(0, -2), 2)

   lazy val freq:Seq[Frequency] = 0 to powerSize map {_ / dt / fftSize}
   private lazy val fftSize = fft.size
   private lazy val powerSize = fftSize/2
   lazy val power:Seq[Double] = fft take powerSize map {_.normSq / fftSize} //TODO ?? N/2 - 1
   lazy val phase:Seq[Angle] = fft take powerSize map {_.phase}
   lazy val abs:Seq[Double] = fft take powerSize map {_.abs} //norm of a complex number
   // lazy val amplitude:Seq[Double] = power map {_.norm / fftSize}  //how to scale this one properly?

   lazy val freqBound = Bound(freq.head, freq.last)
   lazy val powerBound = Bound(peaks.last.power, peaks.head.power)
   lazy val absBound = Bound(peaks.last.abs, peaks.head.abs)

   def iterator = bins.iterator
   lazy val bins = fft.zipWithIndex take powerSize map { case (c,i) =>
      FFT.Bin(
         freq  = i / dt / fftSize,
         value = c,
         fftSize = fftSize
      )
   }
   lazy val peaks = bins.toSeq.sortBy{-_.power}
}