/* -*- mode: c++; c-basic-offset: 3; -*- */
#include "wtransform.hh"
#include "wparameters.hh"
#include "wtile.hh"
#include "matlab_fcs.hh"
//#include "quick_select.hh"

#include "fSeries/DFT.hh"
#include "dv_function.hh"
#include "gds_veclib.hh"
#include "constant.hh"

#include <algorithm>
#include <sstream>
#include <iostream>
#include <iomanip>

//  This macro enables an alternate evaluation of the quartile levels 
//  by using quick-select twice to separate upper-quarter from lower 3/4.
//  In theory this should be faster because the quick-select can be almost
//  linear in time. It does appear to run somewhat faster, so...
#define USE_QUICK_SELECT 1

// port to c++ by jgz, results checked 2011.06.28
// reproduces omega/matlab results to within ~few parts per thousand.

using namespace wpipe;
using namespace containers;
using namespace std;

//====================================== Multi-channel constructor
wtransform::wtransform(const DFT& indata, const wtile& tiling, 
		       const DFT& coeffs, double outlierFactor, 
		       const string& channelName) 
{
   qTransform t(indata, tiling, coeffs, outlierFactor, channelName);
   transforms.push_back(t);
}

//====================================== Multi-channel constructor
wtransform::wtransform(const dft_vect& data, const wtile& tiling, 
		       double outlierFactor, const string& analysisMode, 
		       const str_vect& channelNames, 
		       const dft_vect& coeffs, 
		       const dble_vect& coordinate) {
   init(data, tiling, outlierFactor, analysisMode, channelNames, coeffs, 
	coordinate);
}

//======================================  Initialize the transform
void
wtransform::init(const dft_vect& data, const wtile& tiling, 
		 double outlierFactor, const string& analysisMode, 
		 const str_vect& chNames, const dft_vect& coeffs, 
		 const dble_vect& coord) {

   // infer analysis type from missing arguments and
   // construct default arguments
   if (coeffs.empty() && analysisMode != "independent") {
      error("further inputs required for coherent analysis modes");
   }

   dble_vect coordinate = coord;
   if (coordinate.empty()) {
      double def_thet_phi[] = {pi/2, 0.0};
      coordinate.assign(def_thet_phi, def_thet_phi+2);
   }

#ifdef NO_COHERENT
   if (analysisMode != "independent") {
      error(string("unknown analysis mode ") + analysisMode);
   }
#else
   if (analysisMode != "independent" && analysisMode != "coherent") {
      error(string("unknown analysis mode ") + analysisMode);
   }
#endif

   // determine number of channels
   size_t _numberOfChannels = data.size();

   // check channel names exist
   str_vect channelNames = chNames;
   if (channelNames.empty()) {
      if (analysisMode == "independent") {
	 // provide default channel names
	 channelNames.resize(_numberOfChannels);
	 for (size_t chanNumber=0; chanNumber < _numberOfChannels; 
	      chanNumber++) {
	    ostringstream ostr;
	    ostr << "X" << chanNumber;
	    channelNames[chanNumber] = ostr.str();
	 }
      } else {
	 // must supply channel names for coherent analyses that need them for
	 // antenna patterns
	 error("must provide channelNames for coherent analysis");
      }
   }

   // check coefficients exist
   dft_vect coefficients;
   if (!coeffs.empty()) {
      coefficients = coeffs;
   } else {
      if (analysisMode == "independent") {
	 // provide default coefficients
	 coefficients.resize(_numberOfChannels);
	 for (size_t chanNumber=0; chanNumber<_numberOfChannels; chanNumber++){
	    coefficients[chanNumber] = data[chanNumber];
	    int n =coefficients[chanNumber].size();
	    coefficients[chanNumber].refDVect().replace_with_zeros(0, n, n);
	    coefficients[chanNumber] += 1.0;
	 }
      } else {
	 // must supply coefficients for coherent analyses that need them for
	 // response matrix
	 error("must provide coefficients for coherent analysis");
      }
   }

   // determine required data lengths
   int dataLength = int(tiling.sampleFrequency() * tiling.duration());
   size_t halfDataLength = dataLength / 2 + 1;

   // validate data length and force row vectors
   for (size_t chanNumber=0; chanNumber < _numberOfChannels; chanNumber++) {
      if (data[chanNumber].size() != halfDataLength) {
	 error("data length not consistent with tiling");
      }
   }

   // validate number of coefficients vectors
   if (coefficients.size() != _numberOfChannels) {
      error("coefficients are inconsistent with number of channels");
   }

   // validate coefficients length and force row vectors
   for (size_t chanNumber=0; chanNumber < _numberOfChannels; chanNumber++) {
      if (coefficients[chanNumber].size() != halfDataLength) {
	 error("coefficients length not consistent with tiling");
      }
   }
  
   // validate channel names
   if (!channelNames.empty() && channelNames.size() != _numberOfChannels) {
      error("channel names inconsistent with number of transform channels");
   }

   //-----------------------------------   Ensure collocated network if it was 
   //                                      implied by omitting coordinate.
   string sites = wparameters::buildNetworkString(channelNames);
   int numberOfSites = sites.size();
   if (numberOfSites > 1) {
      error("coordinate must be provided for non-collocated networks");
   }

#ifndef NO_COHERENT
   if (analysisMode == "coherent") {
      if (_numberOfChannels < 2) {
	 error("not enough channels for a coherent analysis (>2 required)");
      }
   }

   // validate coordinate vector
   if (length(coordinate) != 2) {
      error("coordinates must be a two component vector [theta phi]");
   }

#if 0
   //-----------------------------------  Marked as unused
   // extract spherical coordinates
   theta = coordinate[0];
   phi =   coordinate[1];

   // validate spherical coordinates
   if (theta < 0 || theta > pi) {
      error("theta outside of range [0, pi]");
   }
   if (phi < 0 || phi >= 2 * pi) {
      error("phi outside of range [0, 2 pi)");
   } 
#endif  // ***** end currently unused *****

#endif

   ////////////////////////////////////////////////////////////////////////////
   //                          independent analysis                          //
   ////////////////////////////////////////////////////////////////////////////
   if (analysisMode == "independent") {
      transforms.resize(_numberOfChannels);
      for (size_t chanNum=0; chanNum < _numberOfChannels; chanNum++) {
	 transforms[chanNum].transform(data[chanNum], tiling, 
					  coefficients[chanNum], outlierFactor, 
					  channelNames[chanNum]);
      }
   }

#ifndef NO_COHERENT
   ////////////////////////////////////////////////////////////////////////////
   //                        setup coherent analysis                         //
   ////////////////////////////////////////////////////////////////////////////
   else if (analysisMode =="coherent") {
      str_vect outputChannelNames;
      dft_vect intermediateData(_numberOfChannels);

      // determine detector antenna functions and time delays
      [fplus, fcross, deltat] = wresponse(coordinate, channelNames);
      
      /////////////////////////////////////////////////////////////////////////
      //                    time shift detector data                         //
      /////////////////////////////////////////////////////////////////////////

      // use first-listed detector as time reference (this has the advantage of
      // making collocated work naturally)
      deltat = deltat - deltat(1);

      // time shift data by frequency domain phase shift
      for (int chanNumber=0; chanNumber < _numberOfChannels; chanNumber++) {
	 data[chanNumber].evolve(deltat(chanNumber));
      }

      // concatenated list of detector identifiers
      str_vect detectors;
      for (int chanNumber=0; chanNumber < _numberOfChannels; chanNumber++) {
	 detectors.push_back(channelNames[chanNumber].substr(0,2));
      }

      /////////////////////////////////////////////////////////////////////////
      //                       construct new basis                           //
      /////////////////////////////////////////////////////////////////////////

      // form the response matrix
      responseMatrix = [fplus; fcross];
    
      // Simple basis (not taking into account power spectrum) is useful tool
      // to understand structure of the SVD
      //
      // [u,s,v] = svd(responseMatrix);
      //
      // If s(2,2) does not exist or is zero we are insensitive to the 
      // second polarization and we can compute only the primary signal
      // component and N - 1 null streams
      //
      // If s(2,2) exists and is nonzero, we can compute the primary 
      // and secondary signal components and N - 2 null streams

      // preallocate the coefficient structure
      basis = cell(_numberOfChannels);
      for (int i = 1; i<=_numberOfChannels; i++) {
	 for (int j = 1; j<=_numberOfChannels; j++) {
	    basis{i,j} = zeros(size(coefficients{1}));
	 }
      }

      // preallocate the responseMatrix for a given frequency
      f = zeros(size(responseMatrix));

      //for each frequency bin
      for (int frequencyNumber=0; frequencyNumber < halfDataLength;
	   frequencyNumber++) {
	 // for each channel form the response matrix including the noise
	 // coefficients
	 for (int chanNumber=0; chanNumber < _numberOfChannels; chanNumber++) {
	    f(chanNumber,:) = responseMatrix(chanNumber,:) .* 
	       coefficients{chanNumber}(frequencyNumber);
	 }

	 // compute the singular value decomposition
	 [u, s, v] = svd(f);

	 // repack the orthonormal basis coefficients into the output 
	 // structure
	 for (int i = 1; i<=_numberOfChannels; i++) {
	    for (int j = 1; j<=_numberOfChannels; j++) {
	       basis(i,j)(frequencyNumber) = u(i,j);
	    }
	 }
      }
      
      /////////////////////////////////////////////////////////////////////////
      //                           setup coherent outputs                    //
      /////////////////////////////////////////////////////////////////////////
      int inxij = 0;
      for (int i=0; i < _numberOfChannels; i++) {
	 for (int j=0; j < _numberOfChannels; j++) {
	    intermediateData[inxij]  = data[i];
	    intermediateData[inxij] *= basis[inxij];
	    inxij++;
	 }
      }
    
      //setup output metadata
      size_t numberOfIntermediateChannels=_numberOfChannels*_numberOfChannels;
      size_t numberOfOutputChannels = 2;
      outputChannelNames[1] = detectors + ":SIGNAL-COHERENT";
      outputChannelNames[2] = detectors + ":SIGNAL-INCOHERENT";
    
      // output null stream if (network allows
      if ((numberOfSites >= 3) || (_numberOfChannels > numberOfSites)) {
	 numberOfOutputChannels = 4;
	 outputChannelNames[3] = detectors + ":NULL-COHERENT";
	 outputChannelNames[4] = detectors + ":NULL-INCOHERENT";
      }

      /////////////////////////////////////////////////////////////////////////
      //                      initialize Q transform structures              //
      /////////////////////////////////////////////////////////////////////////
  
      // validate tiling structure
      size_t numberOfPlanes =  tiling.numberOfPlanes();

      // create empty cell array of Q transform structures
      transforms.resize(numberOfOutputChannels);

      // begin loop over channels
      for (size_t outputChanNumber=0;
	   outputChanNumber < numberOfOutputChannels;
	   outputChanNumber++) {

	 // create empty vector of Q plane structures
	 transforms[outputChanNumber]._planes.reserve(numberOfPlanes);

	 // begin loop over Q planes
	 for (size_t plane=0; plane < numberOfPlanes; plane++) {

	    // create empty vector of frequency row structures
	    int nRows = tiling.planes(plane).numberOfRows;
	    transforms[outputChanNumber].addPlane(nRows);
	    // end loop over Q planes
	 }

	 // end loop over channels
      }

      /////////////////////////////////////////////////////////////////////////
      //                           begin loop over Q planes                  //
      /////////////////////////////////////////////////////////////////////////

      // begin loop over Q planes
      for (int plane=0; plane < numberOfPlanes; plane++) {

	 //////////////////////////////////////////////////////////////////////
	 //                      begin loop over frequency rows              //
	 //////////////////////////////////////////////////////////////////////

	 // begin loop over frequency rows
	 for (int row=0; row < tiling.planes(plane).numberOfRows; row++) {
	    const qrow& tilePlaneRow(tiling.planes(plane).row(row));

	    ///////////////////////////////////////////////////////////////////
	    //             extract and window frequency domain data          //
	    ///////////////////////////////////////////////////////////////////

	    // begin loop over channels
	    tser_vect energies(numberOfIntermediateChannels);
	    tser_vect tileCoefficients(numberOfIntermediateChannels);
	    for (size_t intermediateChanNumber=0; 
		 intermediateChanNumber <numberOfIntermediateChannels; 
		 intermediateChanNumber++) {
	       const DFT& intData = intermediateData[intermediateChanNumber];
	       tileCoefficients[intermediateChanNumber] = 
		  tilePlaneRow.tileCoeffs(intData);
	    }
      
	    // compute coherent and incoherent energies indirectly from
	    // intermediate data
	    for (int outerChanNumber=0; outerChanNumber < _numberOfChannels;
		 outerChanNumber++) {
	       int outerIndex = outerChanNumber * _numberOfChannels;

	       // coherent stream energy    
	       TSeries accumTileCoeffs = tileCoefficients[outerIndex];
	       for (int chanNumber=1; chanNumber < _numberOfChannels;
		    chanNumber++) {
		  accumTileCoeffs += tileCoefficients[chanNumber + outerIndex];
	       }
	       energies[outerChanNumber * 2] = 
		  TSeries(accumTileCoeffs.getStartTime(),
			  accumTileCoeffs.getTStep(),
			  dv_modsq(*accumTileCoeffs.refDVect()));

	       // incoherent stream energy
	       energies[1 + outerChanNumber * 2] =
		  TSeries (tileCoefficients[outerIndex].getStartTime(),
			   tileCoefficients[outerIndex].getTStep(),
			   dv_modsq(*tileCoefficients[chanid].refDVect()));
	       for (int chanNumber=1; chanNumber < _numberOfChannels;
		    chanNumber++) {
		  int chanid = chanNumber + outerIndex;
		  energies[1 + outerChanNumber * 2] +=
		     TSeries (tileCoefficients[chanid].getStartTime(),
			      tileCoefficients[chanid].getTStep(),
			      dv_modsq(*tileCoefficients[chanid].refDVect()));
	       }
	    }

	    // accumulate in corresponding channels
     
	    // sum all the null energies into a single channel
	    energies[2] = energies[4];
	    energies[3] = energies[5];
	    for (int chanNumber=3; chanNumber<_numberOfChannels; chanNumber++){
	       energies[2] += energies[    chanNumber * 2];
	       energies[3] += energies[1 + chanNumber * 2];
	    }

	    ///////////////////////////////////////////////////////////////////
	    //      exclude outliers and filter transients from statistics   //
	    ///////////////////////////////////////////////////////////////////
 
	    // begin loop over channels
	    for (size_t outputChanNumber=0; 
		 outputChanNumber < numberOfOutputChannels;
		 outputChanNumber++) {
	    
	       trow& t_pr=transforms[outputChanNumber]._planes[plane]._rows[row];
	    
	       t_pr.normalizeEnergies(energies[outputChanNumber], 
				      tiling.transientDuration(),
				      outlierFactor, tilePlaneRow);
	    
	    }  // end loop over channels

	 }	    // end loop over frequency rows

      }        // end loop over Q planes

      /////////////////////////////////////////////////////////////////////////
      //                    return discrete Q transform structure            //
      /////////////////////////////////////////////////////////////////////////
      for (size_t chanNumber=0; chanNumber < numberOfOutputChannels; 
	   chanNumber++) {
	 transforms[chanNumber]._channelName = outputChannelNames[chanNumber];
      }
   }  // end if (analysisMode == "coherent")
#endif

   ///////////////////////////////////////////////////////////////////////////
   //                              otherwise error                          //
   /////////////////////////////////////////////////////////////////////////// 
   else {
      error(string("unknown analysis mode: ") + analysisMode); 
   }

   // return to calling function
}

//========================================  Display transform data
void 
wtransform::dump(std::ostream& out) const {
   for (size_t i=0; i<transforms.size(); i++) {
      out << "Q transform for: " << transforms[i]._channelName << endl;
      transforms[i].dump(out);
   }
}

//========================================  qTransform default constructor
qTransform::qTransform(void) {
}

//========================================  qTransform data constructor
qTransform::qTransform(const containers::DFT& data, const wtile& tiling, 
		       const containers::DFT& coeffs, double outlierFactor, 
		       const std::string& chanName) {
   transform(data, tiling, coeffs, outlierFactor, chanName);
}

//========================================  qTransform default constructor
void
qTransform::addPlane(int nRows) {
   _planes.push_back(tplane(nRows));
}

void 
qTransform::dump(std::ostream& out) const {
   for (size_t i=0; i<_planes.size(); i++) {
      out << "  Plane: " << i << endl;
      _planes[i].dump(out);
   }
}

//======================================  Perform the q-transform.
void
qTransform::transform(const containers::DFT& data, const wtile& tiling, 
		      const containers::DFT& whitenCoef, double outlierFactor, 
		      const std::string& chanName) {

   //----------------------------------  Save the channel name.
   _channelName = chanName;
    
   //----------------------------------  Unfold the input DFT
   DFT intermediateData = data;
   intermediateData.unfold();
   double transientDt = tiling.transientDuration();

   //----------------------------------  Create plane and row structures.
   int nPlanes = tiling.numberOfPlanes();
   _planes.reserve(nPlanes);
   for (int plane=0; plane < nPlanes; plane++) {
      addPlane(tiling.planes(plane).numberOfRows);
   }

   //----------------------------------  begin loop over Q planes
   for (int plane=0; plane < nPlanes; plane++) {

      //------------------------------  begin loop over frequency rows
      for (int row=0; row < tiling.planes(plane).numberOfRows; row++) {
	 const qrow& tilePlaneRow(tiling.planes(plane).row(row));

	 // compute energies directly from intermediate data
	 TSeries coeffs = tilePlaneRow.tileCoeffs(intermediateData);
	 DVectD* dvd = new DVectD(dv_modsq(*coeffs.refDVect()));

	 trow& t_pr = _planes[plane]._rows[row];
	 if (whitenCoef.empty()) {
	    t_pr.conditioningCoef = 1.0;
	 } else {
	    t_pr.conditioningCoef = whitenCoef(tilePlaneRow.frequency).MagSq();
	 }
	 t_pr.getMeanEnergy(*dvd, transientDt, outlierFactor, tilePlaneRow);
	 *dvd *= 1./t_pr.meanEnergy;
	 t_pr.normalizedEnergies.setData(coeffs.getStartTime(),
					 coeffs.getTStep(), dvd);
      }      // end loop over frequency rows

   }      // end loop over Q planes
}

//======================================  Construct a q-plane with nRows rows
tplane::tplane(int nRows) 
   : _rows(nRows)
{}


//======================================  Display a plane
void 
tplane::dump(std::ostream& out) const {
   for (size_t i=0; i<_rows.size(); i++) {
      out << "    Row: " << i << endl;
      _rows[i].dump(out);
   }
}

//======================================  Construct a q-transform row structure
trow::trow(void) 
   : meanEnergy(0)
{
}

//===================================  Set normalized energy
void
trow::getMeanEnergy(const DVector& energies, double transientDt, 
		    double outlierFactor, const qrow& tilePlaneRow) 
{
   //--------------------------------  exclude outliers and filter 
   //                                  transients from statistics
   int valid_index = int(transientDt/tilePlaneRow.timeStep) + 1;
   size_t valid_len   = tilePlaneRow.numberOfTiles - 2 * valid_index;

   dble_vect sortedEnergies(valid_len);
   dble_iter itr25 = sortedEnergies.begin() + int(0.25*double(valid_len)-0.5);
   dble_iter itr75 = sortedEnergies.begin() + int(0.75*double(valid_len)-0.5);
   energies.getData(valid_index, valid_len, &sortedEnergies[0]);

#ifndef USE_QUICK_SELECT
   sort(&sortedEnergies[0], &sortedEnergies[valid_len]);

   // identify lower and upper quartile energies
   // Note that in the original matlab
   //double lowerQuartile = sortedEnergies[int(0.25*double(valid_len)-0.5)];
   //double upperQuartile = sortedEnergies[int(0.75*double(valid_len)-0.5)];

#else
   //quick_select(&sortedEnergies[0], valid_len);

   // identify lower and upper quartile energies
   // Note that in the original matlab
   //int half_len = valid_len/2;
   //double lowerQuartile = quick_select(&sortedEnergies[0], half_len);
   //double upperQuartile = quick_select(&sortedEnergies[(valid_len+1)/2], 
   //				       half_len);
   nth_element(sortedEnergies.begin(), itr25, sortedEnergies.end());
   nth_element(itr25+1, itr75, sortedEnergies.end());
#endif
   double lowerQuartile = *itr25;
   double upperQuartile = *itr75;
  
   // determine inter quartile range
   double interQuartileRange = upperQuartile - lowerQuartile;

   // energy threshold of outliers
   double outlierThreshold = upperQuartile + outlierFactor*interQuartileRange;

   //-----------------------------------  Mean of valid tile energies. Start 
   //                                     with a sum of lower 3 quartiles.
   int meanCount = 0;
   int upper_quartile_index = (valid_len * 3) / 4;
   if (sortedEnergies[upper_quartile_index] <= outlierThreshold) {
      meanCount = upper_quartile_index;
      meanEnergy = vsum(&sortedEnergies.front(), meanCount);
   } 
   else {
      upper_quartile_index = 0;
      meanEnergy = 0.0;
   }

   //-----------------------------------  Sum over remainder
   for (size_t vindex=upper_quartile_index; vindex < valid_len; vindex++) {
      if (sortedEnergies[vindex] < outlierThreshold) {
	 meanEnergy += sortedEnergies[vindex];
	 meanCount++;
      }
   }

   //-----------------------------------  default for large outlier factors
   double meanCorrectionFactor = 1.0;
   if (outlierFactor < 100) {
      double fac = 4.0 * pow(3, outlierFactor) - 1.0;
      meanCorrectionFactor = fac / (fac - (outlierFactor * log(3) + log(4)));
   }
   meanEnergy *= meanCorrectionFactor/double(meanCount);
}

//===================================  Construct & initialize a q-transform row
trow::trow(double mean, const TSeries& ts) 
   : meanEnergy(mean), normalizedEnergies(ts)
{}

//===================================  Display q-transform row status (
void 
trow::dump(std::ostream& out) const {
   out << "      Mean Energy: " << meanEnergy << endl;
}
