//======================================  Omega/matlab headers.
#include "wmeasure.hh"
#include "wtransform.hh"
#include "wtile.hh"
#include "matlab_fcs.hh"

//======================================  DMT headers
#include "DVecType.hh"

//======================================  System / C++ headers
#include <cmath>
#include <iostream>

using namespace wpipe;
using namespace std;

const double normalizedEnergyThreshold = 4.5;

//
//  Wmeasure class constructor
//
//  This includes an OMEGA_MATLAB_BUG fix to calculate the averaged bandwidth 
//  correctly. The matlab code subtracts the squared average time rather than
//  the squared average frequency.
//
wmeasure::wmeasure(const wtransform& transforms, const wtile& tiling, 
		   const Time& startTime, const Time& refTime, 
		   const dble_vect& tRange, const dble_vect& fRange, 
		   const dble_vect& qRnge, int debugLevel) {

  //////////////////////////////////////////////////////////////////////////////
  //                        process command line arguments                    //
  //////////////////////////////////////////////////////////////////////////////

  const double dInf(1.0/0.0);

  // apply default arguments
  Time referenceTime = refTime;
  if (!referenceTime) {
    referenceTime = startTime + tiling.duration() / 2;
  }

  dble_vect timeRange = tRange;
  if (timeRange.empty()) {
    timeRange.resize(2);
    timeRange[0] = -0.5 * (tiling.duration() - 2 * tiling.transientDuration());
    timeRange[1] = -timeRange[0];
  }

  dble_vect frequencyRange = fRange;
  if (frequencyRange.empty()) {
    frequencyRange.resize(2);
    frequencyRange[0] = -dInf;
    frequencyRange[1] = dInf;
  }
    
  dble_vect qRange = qRnge;
  if (qRange.empty()) {
    qRange.resize(2);
    qRange[0] = -dInf;
    qRange[1] = dInf;
  }

  // determine number of channels
  int numberOfChannels = transforms.numberOfChannels();

  // force ranges to be monotonically increasing column vectors

  // if (only a single Q is requested, find nearest Q plane
  if (qRange.size() == 1) {
    qRange[0] = tiling.nearest_q(qRange[0]);
    qRange[1] = qRange[0];
  }

  /////////////////////////////////////////////////////////////////////////////
  //                       validate command line arguments                   //
  /////////////////////////////////////////////////////////////////////////////

  // Check for two component range vectors
  if (timeRange.size() != 2) {
    error("Time range must be two component vector [tmin tmax].");
  }

  if (frequencyRange.size() != 2) {
    error("Frequency range must be two component vector [fmin fmax].");
  }

  if (qRange.size() > 2) {
    error("Q range must be scalar or two component vector [Qmin Qmax].");
  }

  /////////////////////////////////////////////////////////////////////////////
  //                      initialize measurement structures                  //
  /////////////////////////////////////////////////////////////////////////////
  measurements.resize(numberOfChannels);

  Time tRange0 = referenceTime + timeRange[0];
  Interval tRangeDelta(timeRange[1]-timeRange[0]);
  Time tRangeEnd = tRange0 + tRangeDelta;
  // cout << "tRange0 = " << tRange0 << " tDelta " << tRangeDelta << endl;
  int numberOfPlanes = tiling.numberOfPlanes();

  /////////////////////////////////////////////////////////////////////////////
  //                         begin loop over channels                        //
  /////////////////////////////////////////////////////////////////////////////

  for (int channelNumber=0; channelNumber<numberOfChannels; channelNumber++) {
    int peakPlane = 0;
    wprops peakTile;
    std::vector<wprops> plane_props(numberOfPlanes);

    ///////////////////////////////////////////////////////////////////////////
    //                           begin loop over Q planes                    //
    ///////////////////////////////////////////////////////////////////////////
    for (int plane=0; plane < numberOfPlanes; plane++) {
      wprops sumPlane;

      /////////////////////////////////////////////////////////////////////////
      //                              threshold on Q                         //
      /////////////////////////////////////////////////////////////////////////

      // skip Q planes outside of requested Q range
      double pQ = tiling.planes(plane).q;
      if ((pQ < qRange[0]) || (pQ > qRange[1])) {
	continue;
      }

      /////////////////////////////////////////////////////////////////////////
      //                      begin loop over frequency rows                 //
      /////////////////////////////////////////////////////////////////////////
      for (int row=0; row < tiling.planes(plane).numberOfRows; row++) {
	const qrow& tilePlaneRow = tiling.planes(plane).row(row);
	double frequency = tilePlaneRow.frequency;

	///////////////////////////////////////////////////////////////////////
	//                    threshold on central frequency                 //
	///////////////////////////////////////////////////////////////////////

	// skip frequency rows outside of requested frequency range
	if (frequencyRange[1] !=dInf &&
	    (frequency < frequencyRange[0] || 
	     frequency > frequencyRange[1])) {
	  continue;
	}

	///////////////////////////////////////////////////////////////////////
	//           differential time-frequency area for integration        //
	///////////////////////////////////////////////////////////////////////
    
	// differential time-frequency area for integration
	double differentialArea = tilePlaneRow.timeStep 
	                        * tilePlaneRow.frequencyStep;

	///////////////////////////////////////////////////////////////////////
	//                   update peak tile properties                     //
	///////////////////////////////////////////////////////////////////////

	// check the range
	const TSeries& normE = transforms.normE(channelNumber, plane, row);
	size_t fBin = normE.getBin(tRange0);
	size_t nBin = normE.getNSample();
	while (normE.getBinT(fBin) < tRange0 && fBin < nBin) fBin++;
	if (fBin >= nBin) continue;
	size_t lBin = normE.getBin(tRangeEnd);
	while (normE.getBinT(lBin) < tRangeEnd && lBin < nBin) lBin++;
	if (lBin <= fBin) continue;
	Time tMin = normE.getBinT(fBin);
	Interval tDelta = normE.getBinT(lBin) - tMin;

	// vector of row tile normalized energies
	TSeries normalizedEnergies = normE.extract(tMin, tDelta);

	// find most significant tile in row
	const DVectD& dvd = dynamic_cast<const DVectD&>(*normalizedEnergies.refDVect());
	int nBins = dvd.size();
	if (!nBins) continue;
	int peakIndex = 0;
	double peakNormalizedEnergy = dvd[0];
	for (int iBin=1; iBin<nBins; iBin++) {
	  if (dvd[iBin] > peakNormalizedEnergy) {
	    peakNormalizedEnergy = dvd[iBin];
	    peakIndex = iBin;
	  }
	}

	double meanEnergy = 
	  transforms[channelNumber].planes(plane).rows(row).meanEnergy;

	// if peak tile is in this row
	if (peakNormalizedEnergy > peakTile.NormalizedEnergy) {

	  // update plane index of peak tile
	  peakPlane = plane;
      
	  // update center time of peak tile
	  peakTile = wprops(refTime, normalizedEnergies.getBinT(peakIndex),
			    frequency, pQ, peakNormalizedEnergy, meanEnergy);
	  peakTile.Average();

	  // update duration of peak tile
	  peakTile.Duration = tilePlaneRow.duration;

	  // update bandwidth of peak tile
	  peakTile.Bandwidth = tilePlaneRow.bandwidth;

	  // end test for peak tile in this row
	}
      
	///////////////////////////////////////////////////////////////////////
	//                update weighted signal properties                  //
	///////////////////////////////////////////////////////////////////////

	// threshold on significance
	int numberAboveThreshold = 0;
	for (int iBin=0; iBin<nBins; iBin++) {
	  if (dvd[iBin] > normalizedEnergyThreshold) {
	    Time tBin = normalizedEnergies.getBinT(iBin);
	    sumPlane += wprops(refTime, tBin, frequency, pQ, dvd[iBin],
			       meanEnergy * tiling.planes(plane).normalization,
			       differentialArea);
	    numberAboveThreshold++;
	  }
	}

      }	   //      end loop over frequency rows

      /////////////////////////////////////////////////////////////////////////
      //                       normalize signal properties                   //
      /////////////////////////////////////////////////////////////////////////

      sumPlane.Average();
      plane_props[plane] = sumPlane;

      //cout << "sumPlane dump " << endl;
      //sumPlane.dump(cout);
    }      //---------------------------  end loop over Q planes

    ///////////////////////////////////////////////////////////////////////////
    //        report signal properties from plane with peak significance     //
    ///////////////////////////////////////////////////////////////////////////

    measurements[channelNumber].signal = plane_props[peakPlane];
    measurements[channelNumber].peak   = peakTile;

    // report peak tile properties for very weak signals
    if ( measurements[channelNumber].signal.Area < 1) {
      measurements[channelNumber].signal = peakTile;
    }

    // end loop over channels
  }

  /////////////////////////////////////////////////////////////////////////////
  //                 return most significant tile properties                 //
  /////////////////////////////////////////////////////////////////////////////
}

//======================================  wmeasure structure destructor.
wmeasure::~wmeasure(void) {
}

//======================================  dump wmeasure
void
wmeasure::dump(std::ostream& out) const {
  size_t N = measurements.size();
  for (size_t i=0; i<N; ++i) {
    out << "Measurements for channel " << i << endl;
    measurements[i].dump(out);
  }
}

//======================================  dump wmeasure
void
chan_props::dump(std::ostream& out) const {
  out << "    Peak value properties: " << endl; 
  peak.dump(out);
  out << "    Signal average properties: " << endl; 
  signal.dump(out);
}


//======================================  wprops default constructor.
wprops::wprops(void) 
  : refTime(0), TimeOff(0), Frequency(0), Q(0), Duration(0), Bandwidth(0),
    NormalizedEnergy(0), Amplitude(0), Area(0), sumWeight(0), nAverageTerms(0)
{}

//======================================  wprops single tile constructor.
wprops::wprops(const Time& ref, const Time& t, double f, double q0, 
	       double normE, double calFac, double dArea) 
  : refTime(ref), nAverageTerms(1)
{
  double calibratedEnergy = (normE - 1.0) * calFac;
  Area      = dArea;
  double caldA = calibratedEnergy * dArea;
  double dT = double(t - ref);
  TimeOff   = dT * caldA;
  Frequency = f * caldA;
  Q = q0;
  Duration  = dT * dT * caldA;
  Bandwidth = f * f * caldA;
  NormalizedEnergy = normE * dArea;
  Amplitude = caldA;
  sumWeight = caldA;
}

wprops&
wprops::operator+=(const wprops& p) {
  if (!p.nAverageTerms) return *this;
  if (!refTime) refTime = p.refTime;
  TimeOff   += p.TimeOff;
  Frequency += p.Frequency;
  Duration  += p.Duration;
  Q          = p.Q;
  Bandwidth += p.Bandwidth;
  NormalizedEnergy += p.NormalizedEnergy;
  Amplitude += p.Amplitude;
  Area      += p.Area;
  sumWeight += p.sumWeight;
  nAverageTerms += p.nAverageTerms;
  return *this;
}

void
wprops::Average(void) {
  if (! nAverageTerms) return;
  TimeOff   /= sumWeight;
  Frequency /= sumWeight;
  Duration  /= sumWeight;
  Duration  =  sqrt(Duration - TimeOff*TimeOff);
  Bandwidth /= sumWeight;
  Bandwidth = sqrt(Bandwidth - Frequency * Frequency);
  Amplitude = sqrt(Amplitude);
  nAverageTerms = 0;
}

//======================================  dump wmeasure
void
wprops::dump(std::ostream& out) const {
  out << "        Time offset: " << TimeOff << endl; 
  out << "        Frequency:   " << Frequency << endl; 
  out << "        Duration:    " << Duration << endl; 
  out << "        Q:           " << Q << endl; 
  out << "        Bandwidth:   " << Bandwidth << endl; 
  out << "        Amplitude:   " << Amplitude << endl; 
  out << "        SumWeight:   " << sumWeight << endl; 
  out << "        NormEnergy:  " << NormalizedEnergy << endl;   
  out << "        Area:        " << Area << endl;   
}
