#!/usr/bin/env python
#
# Copyright (C) 2012 Ian Harry
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 2 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
"""Estimate median of multiple PSDs for the recolouring pipeline."""


#
# parse command line
#

# FIXME: Move some common options to a module.
from optparse import OptionParser
from glue import segments
from glue.ligolw import ligolw
from glue.ligolw import array
from glue.ligolw import param
array.use_in(ligolw.LIGOLWContentHandler)
param.use_in(ligolw.LIGOLWContentHandler)
from glue.ligolw import utils
from pylal.datatypes import LIGOTimeGPS
from gstlal.pipeutil import gst
from gstlal import reference_psd

def smooth_psd(psd,smooth_psd = False, rescale_factor= 1):
# This routine is designed to smooth a PSD for use in recolouring for NINJA2

  psdData = psd.data

  lowerBowlF = 20
  upperBowlF = 1900
  length = len(psdData)
  deltaF = psd.deltaF

  for i in xrange(length):
    psdData[i] = psdData[i] / rescale_factor

  if not smooth_psd:
    psd.data = psdData
    return psd

  # First scan for the relevant points
  bowlMin = 1E20
  bowlF = 0
  for i in xrange(length):
    f = deltaF * i
    if (f > lowerBowlF) and (f < upperBowlF):
      if psdData[i] < bowlMin:
        bowlMin = psdData[i]
        bowlF = f

  localMax = 0
  localMin = 0
  localPos = 0

  # Start by smoothing down into the bowl from low F
  for i in xrange(length):
    f = deltaF * i
    if (f > lowerBowlF) and (f < bowlF):
      if (localMin == 0):
        localMin = psdData[i]
        localPos = i
      elif (psdData[i] < localMin):
        upperVal = localMin
        lowerVal = psdData[i]
        for j in range(localPos,i+1):
          psdData[j] = upperVal - (upperVal-lowerVal)*(j - localPos)/(i-localPos)
        localMin = psdData[i]
        localPos = i

  localMin = 0

  # And smooth down into the bowl from high F
  for i in xrange(length-1,-1,-1):
    f = deltaF * i
    if (f < upperBowlF) and (f > bowlF):
      if (localMin == 0):
        localMin = psdData[i]
        localPos = i
      elif (psdData[i] < localMin):
        upperVal = localMin
        lowerVal = psdData[i]
        for j in range(localPos,i-1,-1):
          psdData[j] = upperVal - (upperVal-lowerVal)*(j - localPos)/(i-localPos)
        localMin = psdData[i]
        localPos = i

  # And then do the rest
  joinPointsLower = []
  joinPointsUpper = []
  for i in xrange(length):
    f = deltaF * i
    if (f < lowerBowlF):
      if (f==0):
        joinPointsLower.append(i)
      else:
        addPoint=True
        for j in range(max(0,i-3),min(int(lowerBowlF/deltaF),i+4)):
          if psdData[j]<psdData[i]:
            addPoint=False
        if addPoint:
          joinPointsLower.append(i)
    if (f > upperBowlF):
      if (i == length - 1):
        joinPointsUpper.append(i)
      else:
        addPoint=True
        for j in range(max(int(upperBowlF/deltaF),i-3),min(length,i+4)):
          if psdData[j]<psdData[i]:
            addPoint=False
        if addPoint:
          joinPointsUpper.append(i)

  localMin = 0
  for i in joinPointsLower:
    if localMin == 0:
      localMin = psdData[i]
      localPos = i
    else:
      upperVal = localMin
      lowerVal = psdData[i]
      for j in range(localPos,i+1):
        newVal = upperVal - (upperVal-lowerVal)*(j - localPos)/(i-localPos)
        if newVal < psdData[j]:
          psdData[j] = newVal
      localMin = psdData[i]
      localPos = i

  localMin = 0
  for i in joinPointsUpper:
    if localMin == 0:
      localMin = psdData[i]
      localPos = i
    else:
      upperVal = localMin
      lowerVal = psdData[i]
      for j in range(localPos,i+1):
        psdData[j] = upperVal - (upperVal-lowerVal)*(j - localPos)/(i-localPos)
      localMin = psdData[i]
      localPos = i

  psd.data = psdData
  return psd



parser = OptionParser(description = __doc__)
parser.add_option("--input-psd", metavar = "filename", help = "Input PSD to be smoothed, should be a LIGO light-weight XML file (required).")
parser.add_option("--output-psd", metavar = "filename", help = "Write smoothed noise spectrum to this LIGO light-weight XML file (required).")
parser.add_option("--instrument", metavar = "filename", help = "Which ifo's PSD are we looking at (required).")
parser.add_option("-v", "--verbose", action = "store_true", help = "Be verbose (optional).")

options, filenames = parser.parse_args()

required_options = ["input_psd", "output_psd","instrument"]

missing_options = [option for option in required_options if getattr(options, option) is None]
if missing_options:
	raise ValueError, "missing required option(s) %s" % ", ".join("--%s" % option.replace("_", "-") for option in sorted(missing_options))

#
# =============================================================================
#
#                                     Main
#
# =============================================================================
#

psd = reference_psd.read_psd_xmldoc(utils.load_filename(options.input_psd, verbose =options.verbose, contenthandler = ligolw.LIGOLWContentHandler))[options.instrument]

psd = smooth_psd(psd,smooth_psd = True, rescale_factor= 1)

reference_psd.write_psd(
	options.output_psd,
        { options.instrument: psd },
	verbose = options.verbose
)
