#!/usr/bin/python
#
# Copyright (C) 2012  Kipp Cannon, Chad Hanna, Drew Keppel
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 2 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

import scipy
import scipy.stats
from scipy import interpolate
import numpy
try:
	import sqlite3
except ImportError:
	# pre 2.5.x
	from pysqlite2 import dbapi2 as sqlite3

sqlite3.enable_callback_tracebacks(True)

import sys
import glob
import copy
from optparse import OptionParser
import traceback

from glue import segments
from glue.ligolw import ligolw
from glue.ligolw import lsctables
from glue.ligolw import dbtables
from glue.ligolw import utils
from glue.ligolw import table
from glue import segmentsUtils
from glue.ligolw.utils import process
from glue.ligolw.utils import segments as ligolw_segments

from pylal import db_thinca_rings
from pylal import rate
from pylal import SimInspiralUtils
from pylal.xlal.datatypes.ligotimegps import LIGOTimeGPS
from pylal import imr_utils
goldenratio = 2 / (1 + 5**.5)
import matplotlib
matplotlib.use('Agg')
import pylab
matplotlib.rcParams.update({
        "font.size": 8.0,
        "axes.titlesize": 8.0,
        "axes.labelsize": 8.0,
        "xtick.labelsize": 8.0,
        "ytick.labelsize": 8.0,
        "legend.fontsize": 8.0,
        "figure.figsize": (3.3,3.3*goldenratio),
        "figure.dpi": 200,
        "subplots.left": 0.2,
        "subplots.right": 0.75,
        "subplots.bottom": 0.15,
        "subplots.top": 0.75,
        "savefig.dpi": 600,
        "text.usetex": True     # render all text with TeX
})

from pylal import git_version
__author__ = "Chad Hanna chad.hanna@ligo.org"
__version__ = "git id %s" % git_version.id
__date__ = git_version.date

def parse_command_line():
	parser = OptionParser(version = git_version.verbose_msg, usage = "%prog [options] [file ...]", description = "%prog computes mass/mass upperlimit")
	parser.add_option("--output-name-tag", default = "", metavar = "name", help = "Set the file output name tag, real name is 2Dsearchvolume-<tag>-<ifos>.xml")
	parser.add_option("--live-time-program", default = "thinca", help = "Set the name of the live time program to use to get segments from the search summary table")
	parser.add_option("--instruments", help = "Exclude instruments to this comma separated list, e.g. H1,H2,L1")
	parser.add_option("--veto-segments-name", help = "Set the name of the veto segments to use from the XML document. Default: don't use vetoes.")
	parser.add_option("-t", "--tmp-space", metavar = "path", help = "Path to a directory suitable for use as a work area while manipulating the database file.  The database file will be worked on in this directory, and then moved to the final location when complete.  This option is intended to improve performance when running in a networked environment, where there might be a local disk with higher bandwidth than is available to the filesystem on which the final output will reside.")
	parser.add_option("--verbose", action = "store_true", help = "Be verbose.")

	opts, filenames = parser.parse_args()
	if opts.instruments is not None:
		opts.instruments = lsctables.instrument_set_from_ifos(opts.instruments) 

	if not filenames:
		raise ValueError("must provide at least one database")

	return opts, filenames


def poisson_sf_with_trials_factor(idxs, Ns, offset = 0):
	trials_factor = idxs[-1]

	sf = scipy.stats.poisson.sf(idxs - 1, Ns)

	output = scipy.zeros(scipy.shape(sf))

	sflt = sf * trials_factor < 1e-3

	output[sflt] = (sf * trials_factor)[sflt]
	output[~sflt] = (1. - (1. - sf)**trials_factor)[~sflt]

	return output


def process_fars(fars1, T, minfap = None, maxsigma = 5):
	if minfap is not None:
		minfar = minfap / T
		fars1[fars1 < minfar] = minfar
	idxs = numpy.arange(len(fars1))+1
	Ns1 = fars1*T
	full_fars1 = fars1
	try:
		idx_end = scipy.arange(len(Ns1))[Ns1 >= 1][0] + 1
	except IndexError:
		idx_end = len(Ns1) - 1
	idxs = idxs[:idx_end]
	stop_far = fars1[idx_end]
	Ns1 = Ns1[:idx_end]
	faps1 = 1. - numpy.exp(-Ns1)
	P_sigmas = []
	for idx in range(1, maxsigma + 1):
		P_sigmas.append(scipy.special.erfc(idx/2**.5))
	P_sigmas.reverse()

	cfaps1 = poisson_sf_with_trials_factor(idxs, Ns1)

	return stop_far, idxs, Ns1, full_fars1, idx_end, faps1, P_sigmas, cfaps1


#
# Main
#

options, filenames = parse_command_line()

IMR = imr_utils.DataBaseSummary(filenames, tmp_path = options.tmp_space, veto_segments_name = options.veto_segments_name, live_time_program = options.live_time_program, verbose = options.verbose)

for instruments in IMR.instruments:
	print instruments, options.instruments
	if instruments != options.instruments:
		continue 
	# Live Time FIXME only approximate for time slides
	segs = IMR.segments.intersection(instruments)
	
	possibilities = (('CLOSED', numpy.array(sorted(IMR.ts_fars_by_instrument_set[instruments])), float(abs(segs) * IMR.numslides)),
			 ('OPEN', numpy.array(sorted(IMR.zerolag_fars_by_instrument_set[instruments])), float(abs(segs)))
			 )

	for (label, fars1, T) in possibilities:
		NumEvents = min(100, len(fars1))
		stop_far, idxs, Ns1, full_fars1, stop_N, faps1, P_sigmas, cfaps1 = process_fars(fars1, T)

		#
		# Standard IFAR plot
		# 
	
		fig = pylab.figure()
		ax = fig.add_axes((.15,.165,.8,.8))
		ax.loglog((1. / full_fars1)[:NumEvents], scipy.arange(NumEvents) + 1, color = 'k', label = r'\rm Observed')

		bg_ifars = scipy.logspace(numpy.log10(T / NumEvents), 16, 1000, base = 10.)
		ax.loglog(bg_ifars, (T/bg_ifars), color = 'k', label = r'\rm\begin{center}Expected Background\end{center}', ls = '-.')

		for P_sigma in P_sigmas:
			up = scipy.stats.poisson.isf(P_sigma, T/bg_ifars)
			up[up == 0] = 1e-1
			down = scipy.stats.poisson.ppf(P_sigma, T/bg_ifars)
			down[down == 0] = 0.1
			ax.fill_between(bg_ifars, down, up, color = 'k', alpha = .1)

		#ax.scatter(1./stop_far, [stop_N], color = 'k', marker = 'o', s = 49, facecolors = 'none', label = r'$N_{\max}$')

		ax.set_xlabel(r'\rm IFAR (Hz$^{-1}$)')
		ax.set_ylabel(r'\rm Count')
		ax.set_ylim(0.9, NumEvents)
		ax.set_xlim(min(bg_ifars), max(bg_ifars))
		leg = pylab.legend(loc = 'upper right', scatterpoints = 1, numpoints = 4)
		ax.grid()
		pylab.savefig('%s-%s_IFAR-%d-%d.png' % ("".join(sorted(list(instruments))), options.output_name_tag+"_"+label, min(segs)[0], T))

		#
		# Population statement plot
		#

		fig = pylab.figure()
		ax = fig.add_axes((.2,.15,.75,.8))
		ax.semilogy(idxs, cfaps1, color = 'k', label = r'\rm Population')
		ax.semilogy(idxs, faps1, color = 'k', ls='--', label = r'\rm Individual')
		for P_sigma in P_sigmas:
			ax.fill_between([1,30], P_sigma, 1, color = 'k', alpha = .1)
		ax.set_xlabel(r'\rm Count')
		ax.set_ylabel(r'\rm FAP')
		leg = pylab.legend(loc = 'lower right', numpoints = 4)
		ax.grid()
		ax.set_xlim(1, stop_N)
		ax.set_ylim(min(min(faps1), min(cfaps1)), 1)
		pylab.savefig('%s-%s_POP_STATEMENT-%d-%d.png' % ("".join(sorted(list(instruments))), options.output_name_tag+"_"+label, min(segs)[0], T))
