#!/usr/bin/env python
#
# Copyright (C) 2012  Drew Keppel
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 2 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

import sys
import matplotlib
matplotlib.use("TkAgg")
goldenratio = 2 / (1 + 5**.5)
matplotlib.rcParams.update({
        "font.size": 8.0,
        "axes.titlesize": 8.0,
        "axes.labelsize": 8.0,
        "xtick.labelsize": 8.0,
        "ytick.labelsize": 8.0,
        "legend.fontsize": 8.0,
	"figure.figsize": (3.3,3.3*goldenratio),
        "figure.dpi": 200,
	"subplots.left": 0.2, 
	"subplots.right": 0.75, 
	"subplots.bottom": 0.15,
	"subplots.top": 0.75,
        "savefig.dpi": 200,
        "text.usetex": True     # render all text with TeX
})
from matplotlib import pyplot
from glue.ligolw import ligolw
from glue.ligolw import array
from glue.ligolw import param
array.use_in(ligolw.LIGOLWContentHandler)
param.use_in(ligolw.LIGOLWContentHandler)
from glue.ligolw import utils
import pylal.coherent_inspiral_metric as metric
import pylal.coherent_inspiral_metric_detector_details as detector_details
import pylal.coherent_inspiral_metric_plots as metric_plots
from pylal import spawaveform
import scipy
import numpy
from scipy import pi,sin,cos,exp,log
import math
from optparse import OptionParser

__author__ = "Drew Keppel <drew.keppel@ligo.org>"

parser = OptionParser(description = __doc__)
parser.add_option("--flow-min", metavar = "int", type = "int", help="Minimum lower frequency cutoff")
parser.add_option("--flow-max", metavar = "int", type = "int", help="Maximum lower frequency cutoff")
parser.add_option("--flow-ref", metavar = "int", type = "int", help="Reference lower frequency cutoff")
parser.add_option("--mismatch-ref", metavar = "float", type = "float", help="Reference mismatch = 1. - minimal match")
parser.add_option("--filter-cost", help="Single filter cost. (fixed|cycles)")
parser.add_option("--output", metavar = "filename", help = "Save the generated figure to file to filename")
parser.add_option("--psd-file", metavar = "psdfilename", help = "Load the psd from psdfilename")
parser.add_option("--psd-name", metavar = "psdname", help = "Extract the psd named psdname from psdfilename")
parser.add_option("--verbose", action = "store_true", help = "Be verbose.")
options, filenames = parser.parse_args()

if options.filter_cost not in ['fixed', 'cycles']:
	print >> sys.stderr, "--filter-cost must be one of (fixed|cycles)"
	sys.exit(1)

# delayed import from gstlal
from gstlal import reference_psd


# set up the parameters of the signal
m1 = 1.
m2 = 1.
M = m1+m2
mu = m1*m2/M
eta = mu/M
mchirp = M*eta**.6

# set frequency parameters
flow_bank = options.flow_ref
fNyq = 2048.

# load and condition psd
new_psds = reference_psd.read_psd_xmldoc(utils.load_filename(options.psd_file, verbose=options.verbose, contenthandler=ligolw.LIGOLWContentHandler))
psd_REAL8FrequencySeries = new_psds[options.psd_name]
deltaF = psd_REAL8FrequencySeries.deltaF
f0 = psd_REAL8FrequencySeries.f0
f = scipy.arange(fNyq/deltaF+1)*deltaF + f0
psd_REAL8FrequencySeries.data = psd_REAL8FrequencySeries.data[:len(f)]
PSD = scipy.zeros(2*(fNyq/deltaF), dtype='float')+scipy.inf
PSD[f0/deltaF:fNyq/deltaF+1] = psd_REAL8FrequencySeries.data
if -scipy.floor(f0/deltaF) == 0:
	PSD[fNyq/deltaF+1:] = psd_REAL8FrequencySeries.data[-2:0:-1]
else:
	PSD[fNyq/deltaF+1:-scipy.floor(f0/deltaF)] = psd_REAL8FrequencySeries.data[-2:0:-1]
psd_REAL8FrequencySeries.data = PSD
f = detector_details.f_for_fft(flow_bank, fNyq, psd_REAL8FrequencySeries.deltaF)
psd = psd_REAL8FrequencySeries.data
psd[abs(f) < flow_bank] = scipy.inf
LHO = detector_details.make_LHO(fLow=flow_bank, fNyq=fNyq)
LHO.set_psd(f, psd)
LHO.set_required_moments()

# create the metric for this point of parameter space
snr2 = LHO.I_n['-7']
g = metric.single_detector_mass_metric(LHO, mchirp*metric.M_sun, eta)
rootdetg0 = (scipy.linalg.det(g))**.5

# loop over other flows
flows = scipy.linspace(options.flow_min, options.flow_max, options.flow_max - options.flow_min + 1)
rootdetgs = []
snr2s = []
for flow in flows:
	psd = psd_REAL8FrequencySeries.data
	psd[abs(f) < flow] = scipy.inf
	LHO.set_psd(f, psd)
	LHO.set_required_moments()
	g = metric.single_detector_mass_metric(LHO, mchirp*metric.M_sun, eta)
	rootdetgs.append((scipy.linalg.det(g))**.5)
	snr2s.append(LHO.I_n['-7'])
rootdetgs = scipy.array(rootdetgs)
snr2s = scipy.array(snr2s)

# cost of a single filter given by number of cycles from flow to fisco
filter_cost_dict = {}
filter_cost_dict[10.] = 28459.942721
filter_cost_dict[11.] = 24292.990888
filter_cost_dict[12.] = 21023.699209
filter_cost_dict[13.] = 18405.878657
filter_cost_dict[14.] = 16274.683630
filter_cost_dict[15.] = 14511.672718
filter_cost_dict[16.] = 13037.004411
filter_cost_dict[17.] = 11787.844083
filter_cost_dict[18.] = 10720.194326
filter_cost_dict[19.] = 9798.932266
filter_cost_dict[20.] = 8998.386507
filter_cost_dict[21.] = 8297.711844
filter_cost_dict[22.] = 7680.818917
filter_cost_dict[23.] = 7133.961077
filter_cost_dict[24.] = 6646.451169
filter_cost_dict[25.] = 6210.901952
filter_cost_dict[26.] = 5818.352048
filter_cost_dict[27.] = 5464.524287
filter_cost_dict[28.] = 5144.209545
filter_cost_dict[29.] = 4852.570085
filter_cost_dict[30.] = 4586.442188
filter_cost_dict[31.] = 4342.950681
filter_cost_dict[32.] = 4119.775802
filter_cost_dict[33.] = 3914.006141
filter_cost_dict[34.] = 3724.632136
filter_cost_dict[35.] = 3549.001642
filter_cost_dict[36.] = 3386.457720
filter_cost_dict[37.] = 3235.402843
filter_cost_dict[38.] = 3094.872198
filter_cost_dict[39.] = 2963.957300
filter_cost_dict[40.] = 2841.828222

# choose type of computational cost to use
if options.filter_cost == 'cycles':
	filter_cost = []
	for flow in flows:
		filter_cost.append(filter_cost_dict[flow])
	filter_cost = scipy.array(filter_cost)
	filter_cost /= filter_cost[flows == flow_bank]
elif options.filter_cost == 'fixed':
	filter_cost = 1.

# number of dimensions of the parameter space
d = 2.

# compute mismatch at fixed computational cost
m_bank = options.mismatch_ref * (filter_cost * rootdetgs / rootdetg0)**(2./d)

# compute recovered snr from bank
rho_bank = 1. - m_bank

# compute possible snr from integral starting at flow
rho_int = (snr2s/snr2)**.5

# combine snrs
rho = rho_int * rho_bank

fig = pyplot.figure()
ax = fig.add_axes((.15,.16,.8,.6))
ax.plot(flows, rho_bank, label=r'$\rho_{\rm bank}(f_{\rm low}) / \rho(%i\rm Hz)$'%(flow_bank))
ax.plot(flows, rho_int, label=r'$\rho_{\rm int}(f_{\rm low}) / \rho(%i\rm Hz)$'%(flow_bank))
ax.plot(flows, rho, label=r'$\rho_{\rm total}(f_{\rm low}) / \rho(%i\rm Hz)$'%(flow_bank))
ax.set_ylabel(r'$\rho(f_{\rm low}) / \rho(%i\rm Hz)$'%(flow_bank))
ax.set_xlabel(r'$f_{\rm low}$')
pyplot.title(r'$\textrm{Maximum sensitivity at }$\\$f_{\rm low} = %i{\rm Hz},\,m = %s,\textrm{SNR gain} = %.1f\%%$'%(flows[rho == max(rho)], ('%.1e'%(m_bank[rho == max(rho)])).replace('-0','-').replace('+0','').replace('+','').replace('e',r'\times 10^{')+'}',100*(rho[rho == max(rho)]-rho[flows == flow_bank])))
pyplot.legend(loc='best')
pyplot.savefig(options.output)
