/* -*- mode: c++; c-basic-offset: 4; -*- */
#include "Coherence.hh"
#include "Hamming.hh"
#include "DecimateBy2.hh"
#include "Bits.hh"
#include "fSeries/DFT.hh"
#include <stdexcept>
#include <iostream>

using namespace std;

//======================================  Default constructor.
Coherence::Coherence(void)
    : mStride(0), mOverlap(0), mSampleRate(0), mStartTime(0), mCurrent(0)
{}

//======================================  Data constructor.
Coherence::Coherence(Interval stride, double overlap, const window_api* w,
		     double sample_rate)
  : mStride(stride), mOverlap(overlap), mSampleRate(sample_rate), 
    mStartTime(0), mCurrent(0)
{
    if (w) mWindow.set(*w);
}

//======================================  Destructor.
Coherence::~Coherence(void) {
}

//======================================  Add one or more strides to the 
//                                        accumulated coherence
void
Coherence::add(const TSeries& x, const TSeries& y) {

    //------------------------------------  Check that the stride is calculated.
    if (!mStride) {
	set_stride(x, y, 1.0);
    }

    //------------------------------------  Set up sample rate
    if (mSampleRate == 0) {
	if (x.getTStep() < y.getTStep()) {
	    mSampleRate = 1.0 / y.getTStep();
	}
	else if (x.getTStep() == Interval(0.0)) {
	    throw runtime_error("Coherence: Invalid sample rate. ");
	}
	else {
	    mSampleRate = 1.0 / x.getTStep();
	}
    }

    //-------------------------------------  Resample data, append to history
    resample(mXDecim, x, mXHistory);
    resample(mYDecim, y, mYHistory);

    //------------------------------------  Make sure start times are equal.
    if (mXHistory.getStartTime() != mXHistory.getStartTime()) {
	if (!mXHistory.getStartTime() || !mXHistory.getStartTime()) return;
	if (mXHistory.getStartTime() > mYHistory.getStartTime()) {
	    Interval dt = mXHistory.getStartTime() - mYHistory.getStartTime();
	    mYHistory.eraseStart(dt);
	} else {
	    Interval dt = mYHistory.getStartTime() - mXHistory.getStartTime();
	    mXHistory.eraseStart(dt);
	}
    }

    //------------------------------------  Record first data start time.
    if (!mStartTime) {
	mStartTime = mXHistory.getStartTime();
	mCurrent   = mStartTime;
    }

    //------------------------------------  Loop over overlapping strides.
    while (mXHistory.getInterval() >= mStride &&
	   mYHistory.getInterval() >= mStride) {
	containers::DFT xDft(mWindow(mXHistory.extract(mCurrent, mStride)));
	containers::DFT yDft(mWindow(mYHistory.extract(mCurrent, mStride)));

	//-------------------------------- First time - set accumulators.
	if (mXYSum.empty()) {
	    mXYSum = containers::CSD(xDft, yDft);
	    mXXSum = containers::PSD(xDft);
	    mYYSum = containers::PSD(yDft);
	}

	//------------------------------  Subsequently - Add to accumulators.
	else {
	    mXYSum += containers::CSD(xDft, yDft);
	    mXXSum += containers::PSD(xDft);
	    mYYSum += containers::PSD(yDft);
	}

	//----------------------------------  Advance history and current time.
	Interval DtErase = mStride * (1.0 - mOverlap);
	mXHistory.eraseStart(DtErase);
	mYHistory.eraseStart(DtErase);    
	mCurrent += DtErase;
    }
}

//======================================  Calculate the coherence from the 
//                                        accumulated CSD and PSDs
containers::PSD
Coherence::get_coherence(void) const {
    //----------------------------------  fill a PSD with the CSD modsq.
    containers::PSD r;
    static_cast<containers::fSeries&>(r) = mXYSum.modsq();

    //----------------------------------  Single sided PSDs have a factor 
    //                                    of 2 for +/- freqs in bins (1...N-1)
    if (r.single_sided()) r.refDVect().scale(1, 4.0, mXYSum.size()-2);

    r /= mXXSum;
    r /= mYYSum;
    return r;
}

//======================================  All in one coherence calculation.
containers::PSD
Coherence::operator()(const TSeries& x, const TSeries& y) {
    reset_accumulators();
    if (!mStride) set_stride(x, y, 8.0);
    add(x, y);
    return get_coherence();
}

//======================================  Resample data and append it to the
//                                        input history series.
void
Coherence::resample(auto_pipe& decim, const TSeries& in, TSeries& hist) {

    //------------------------------------  No resampling necessary
    if (fabs(mSampleRate * double(in.getTStep()) - 1.0) < 1e-6) {
	if (hist.empty()) {
	    hist = in;
	} else {
	    int rc = hist.Append(in);
	    if (rc) throw runtime_error("Coherence: Invalid input data.");
	}
    }

    //------------------------------------  Set up resampling?
    else {
	if (!mStartTime) {
	    int resample = int(1.0/double(in.getTStep() * mSampleRate) + 0.5);
	    if (resample < 2 || !is_power_of_2(resample)) 
		throw runtime_error("Coherence: Invalid resample request");
	    int N = 0;
	    while (resample > 1) {
		resample /= 2;
		N++;
	    }
	    decim.set(new DecimateBy2(N, 1));
	}

	//------------------------------------  Resample
	if (decim.null()) 
	    throw runtime_error("Coherence: Resampling misconfigured.");
	if (hist.empty()) {
	    hist = decim(in);
	} else {
	    int rc = hist.Append(decim(in));
	    if (rc) throw runtime_error("Coherence: Invalid input data.");
	}    
    }
}

//======================================  Reset accumulators, history and 
//                                        resamplers
void 
Coherence::reset(void) {
    reset_accumulators();
    mStartTime = Time(0);
    mXHistory.Clear();
    mYHistory.Clear();
    mXDecim.set(0);
    mYDecim.set(0);
}
//======================================  Reset accumulators
void 
Coherence::reset_accumulators(void) {
    mXXSum.clear();
    mXYSum.clear();
    mYYSum.clear();
}

//======================================  Set the overlap value
void 
Coherence::set_overlap(double ovlp) {
    if (ovlp < 0 || ovlp >= 1.0) 
	throw std::invalid_argument("Coherence: Invalid overlap fraction");
    mOverlap = ovlp;
}

//======================================  Set the sample rate
void Coherence::set_rate(double rate) {
    mSampleRate = rate;
}

//======================================  Set the stride length
void Coherence::set_stride(Interval& dt) {
    mStride = dt;
}

//======================================  Set the stride length from 
void Coherence::set_stride(const TSeries& x, const TSeries& y, double nSeg) {
    Interval tSeg = x.getInterval();
    if (y.getInterval() < tSeg) tSeg = y.getInterval();
    if (nSeg > 1.0) tSeg /= nSeg;
    if (tSeg == Interval(0.0)) {
	throw runtime_error("Coherence: Invalid segment length");
    }
    set_stride(tSeg);
}

//======================================  Set-up for Welch method
void 
Coherence::set_welch(Interval stride) {
    set_stride(stride);
    set_overlap(0.5);
    Hamming ham;
    set_window(ham);
}

//======================================  Set the window
void Coherence::set_window(const window_api& w) {
    mWindow.set(w);
}
