/*
 * Simplex minimization for multidimensional curve fitting.
 *
 * Reference:
 *	Numerical Recipes in C
 *	Press, Flannery, Saul & Vetterling
 *	Cambridge University Press
 *	1988
 *	ISBN 0-521-35465-X
 *
 */
#include <math.h>		// use fabs()

#include "list.h"
#include "memalloc.h"		// use new()
#include "num.h"
#include "simplex.h"

#define FLIP_FACTOR -1.0
#define SHRINK_FACTOR 0.5
#define EXPAND_FACTOR 2.0
#define MAXDIM	1024

static bool minimize(double **p, double *y, double *pmin, double *pmax,
		     double *psum, double *ptry,
		     int ndim, double ftol, int maxiterations,
		     double (*func)(double *, void *fdata),
		     bool (*stoptrying)(void *fdata), void *fdata, int *best);
static void get_parameter_sum(double **p, int mpts, int ndim, double *psum);
static void get_high_low(double *y, int mpts,
			 int *low, int *high, int *nexthigh);
static bool bound(double *ptry, double *pmin, double *pmax, int ndim);
static double try_move(double *pt, double *psum, double *pmin, double *pmax,
		       double *ptry, double *yvalue, int ndim, double fac,
		       double (*func)(double *, void *), void *fdata);

// ----------------------------------------------------------------------------
//
bool simplex(
	double *start,
	double *variation,
	double *pmin,
	double *pmax,
	int ndim,
	double ftol,
	int maxiterations,
	double (*func)(double *, void *fdata),
	bool (*stoptrying)(void *fdata),
	void *fdata)
{
  int		mpts = ndim + 1;

  double **p = new double * [mpts];
  for (int i = 0 ; i < mpts ; ++i)
    p[i] = new double [ndim];
  double *y = new double [mpts];
  double *psum = new double [ndim];
  double *ptry = new double [ndim];

  bound(start, pmin, pmax, ndim);
  for (int i = 0 ; i < mpts ; ++i)
    {
      for (int j = 0 ; j < ndim ; ++j)
	p[i][j] = start[j];
      if (i > 0)
	{
	  double v = p[i][i-1] + variation[i-1];
	  if (v <= pmin[i-1] || v >= pmax[i-1])
	    v = (pmin[i-1] + pmax[i-1]) / 2;
	  p[i][i-1] = v;
	}
      y[i] = (*func)(p[i], fdata);
    }

  int best = 0;
  bool converged = minimize(p, y, pmin, pmax, psum, ptry, ndim,
			    ftol, maxiterations, func, stoptrying, fdata,
			    &best);
  if (converged)
    for (int i = 0 ; i < ndim ; ++i)
      start[i] = p[best][i];

  for (int i = 0 ; i < mpts ; ++i)
    delete [] p[i];
  delete [] p;
  delete [] y;
  delete [] psum;
  delete [] ptry;

  return converged;
}

// ----------------------------------------------------------------------------
//
static bool minimize(
	double **p,
	double *y,
	double *pmin,
	double *pmax,
	double *psum,
	double *ptry,
	int ndim,
	double ftol,
	int maxiterations,
	double (*func)(double *, void *fdata),
	bool (*stoptrying)(void *fdata),
	void *fdata,
	int *best)
{
  int		mpts = ndim + 1;

	int try_count = 0;
	get_parameter_sum(p, mpts, ndim, psum);
	for (;;) {
	  int lo, hi, nexthi;
	  get_high_low(y, mpts, &lo, &hi, &nexthi);

	  double denom = fabs(y[hi]) + fabs(y[lo]);
	  double rtol = (denom == 0 ? 0 : 2.0 * fabs(y[hi] - y[lo]) / denom);
	  if (rtol < ftol)
	    {
	      *best = lo;
	      break;
	    }
		
	  if (try_count >= maxiterations)
	    return false;	// exceeded max number of iterations.

	  // Flip highest about center
	  double ytry = try_move(p[hi], psum, pmin, pmax, ptry, &y[hi],
				 ndim, FLIP_FACTOR, func, fdata);
	  try_count += 1;
	  if (stoptrying && stoptrying(fdata))
	    return false;

	  if (ytry <= y[lo])	// if flipped point is minimum step further
	    {
	      ytry = try_move(p[hi], psum, pmin, pmax, ptry, &y[hi],
			      ndim, EXPAND_FACTOR, func, fdata);
	      try_count += 1;
	      if (stoptrying && stoptrying(fdata))
		return false;
	    }
	  else if (ytry >= y[nexthi]) // if flipped point is maximum step less
	    {
	      double ysave = y[hi];
	      ytry = try_move(p[hi], psum, pmin, pmax, ptry, &y[hi],
			      ndim, SHRINK_FACTOR, func, fdata);
	      try_count += 1;
	      if (stoptrying && stoptrying(fdata))
		return false;

	      if (ytry >= ysave)
		{
		  //
		  // Moved flipped point toward center, still is maximum.
		  // Now move all points half way towards current minimum.
		  //
		  for (int i = 0; i < mpts; i++)
		    {
		      if (i == lo)
			continue;
		      for (int j = 0; j < ndim; j++)
			p[i][j] = 0.5*(p[i][j]+p[lo][j]);
		      y[i] = (*func)(p[i], fdata);
		      if (stoptrying && stoptrying(fdata))
			return false;
		      
		      try_count++;
		    }
		  get_parameter_sum(p, mpts, ndim, psum);
		}
	    }
	}

	return true;
}

// ----------------------------------------------------------------------------
//
static void get_parameter_sum(double **p, int mpts, int ndim, double *psum)
{
  for (int j = 0 ; j < ndim ; ++j)
    {
      double sum = 0;
      for (int i = 0 ; i < mpts ; ++i)
	sum += p[i][j];
      psum[j]=sum;
    }
}

// ----------------------------------------------------------------------------
//
static void get_high_low(double *y, int mpts,
			 int *low, int *high, int *nexthigh)
{
  int lo = 0;
  int hi = y[0] > y[1] ? 0 : 1;
  int nexthi = y[0] > y[1] ? 1 : 0;

  for (int i = 0; i < mpts; i++) {
    if (y[i] < y[lo])
      lo = i;
    if (y[i] > y[hi]) {
      nexthi = hi;
      hi = i;
    }
    else if (y[i] > y[nexthi] && i != hi)
      nexthi=i;
  }

  *low = lo;
  *high = hi;
  *nexthigh = nexthi;
}

// ----------------------------------------------------------------------------
//
static bool bound(double *ptry, double *pmin, double *pmax, int ndim)
{
  bool clipped = false;
  for (int k = 0 ; k < ndim ; ++k)
    if (ptry[k] > pmax[k])
      { ptry[k] = pmax[k]; clipped = true; }
    else if (ptry[k] < pmin[k])
      { ptry[k] = pmin[k]; clipped = true; }
  return clipped;
}

// ----------------------------------------------------------------------------
//
static double try_move(double *pt, double *psum, double *pmin, double *pmax,
		       double *ptry, double *yvalue, int ndim, double fac,
		       double (*func)(double *, void *), void *fdata)
{
	double fac1 = (1.0 - fac) / ndim;
	double fac2 = fac1 - fac;
	for (int j = 0; j < ndim; j++)
	  ptry[j] = psum[j] * fac1 - pt[j] * fac2;
	if (fac > 1 || fac < 0)
	  bound(ptry, pmin, pmax, ndim);

	double ytry = (*func)(ptry, fdata);
	if (ytry < *yvalue) {
		*yvalue = ytry;
		for (int j = 0; j < ndim; j++) {
			psum[j] = psum[j] + ptry[j] - pt[j];
			pt[j] = ptry[j];
		}
	}
	return ytry;
}
