#!/usr/bin/env python
#
# Copyright (C) 2013 Chad Hanna, Kipp Cannon
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 2 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

## @file gstlal_inspiral_plot_background
# A program to plot the likelihood distributions in noise of a gstlal inspiral analysis
#
# ### Command line interface
#
#	+ `--database` [filename]: Retrieve search results from this database (optional).  Can be given multiple times.
#	+ `--database-cache` [filename]: Retrieve search results from all databases in this LAL cache (optional).  See lalapps_path2cache.
#	+ `--max-snr` [value] (float): Plot SNR PDFs up to this value of SNR (default = 200).
#	+ `--max-log-lambda` [value] (float): Plot ranking statistic CDFs, etc., up to this value of the natural logarithm of the likelihood ratio (default = 100).
#	+ `--min-log-lambda` [value] (float): Plot ranking statistic CDFs, etc., down to this value of the natural logarithm of the likelihood ratio (default = -5).
#	+ `--output-dir` [path]: Write output to this directory (default = ".").
#	+ `--tmp-space` [path]: Path to a directory suitable for use as a work area while manipulating the database file.  The database file will be worked on in this directory, and then moved to the final location when complete.  This option is intended to improve performance when running in a networked environment, where there might be a local disk with higher bandwidth than is available to the filesystem on which the final output will reside.
#	+ `--user-tag` [tag]: Set the adjustable component of the description fields in the output filenames (default = "ALL").
#	+ `--verbose`: Be verbose.

import math
import matplotlib
matplotlib.rcParams.update({
	"font.size": 10.0,
	"axes.titlesize": 10.0,
	"axes.labelsize": 10.0,
	"xtick.labelsize": 8.0,
	"ytick.labelsize": 8.0,
	"legend.fontsize": 8.0,
	"figure.dpi": 600,
	"savefig.dpi": 600,
	"text.usetex": True
})
from matplotlib import figure
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
import numpy
from optparse import OptionParser
import sqlite3
import sys


from glue.lal import CacheEntry
from glue.ligolw import dbtables
from glue.ligolw import lsctables
from glue.ligolw import utils as ligolw_utils
from pylal import ligolw_thinca


from gstlal import far
from gstlal import plotfar
from gstlal import inspiral_pipe


def parse_command_line():
	parser = OptionParser()
	parser.add_option("-d", "--database", metavar = "filename", action = "append", help = "Retrieve search results from this database (optional).  Can be given multiple times.")
	parser.add_option("-c", "--database-cache", metavar = "filename", help = "Retrieve search results from all databases in this LAL cache (optional).  See lalapps_path2cache.")
	parser.add_option("--max-snr", metavar = "SNR", default = 200., type = "float", help = "Plot SNR PDFs up to this value of SNR (default = 200).")
	parser.add_option("--max-log-lambda", metavar = "value", default = 100., type = "float", help = "Plot ranking statistic CCDFs, etc., up to this value of the natural logarithm of the likelihood ratio (default = 100).")
	parser.add_option("--min-log-lambda", metavar = "value", default = -5., type = "float", help = "Plot ranking statistic CCDFs, etc., down to this value of the natural logarithm of the likelihood ratio (default = -10).")
	parser.add_option("--output-dir", metavar = "output-dir", default = ".", help = "Write output to this directory (default = \".\").")
	parser.add_option("--output-format", metavar = "extension", default = ".png", help = "Select output format by choosen the filename extension (default = \".png\").")
	parser.add_option("-t", "--tmp-space", metavar = "path", help = "Path to a directory suitable for use as a work area while manipulating the database file.  The database file will be worked on in this directory, and then moved to the final location when complete.  This option is intended to improve performance when running in a networked environment, where there might be a local disk with higher bandwidth than is available to the filesystem on which the final output will reside.")
	parser.add_option("--user-tag", metavar = "user-tag", default = "ALL", help = "Set the adjustable component of the description fields in the output filenames (default = \"ALL\").")
	parser.add_option("-v", "--verbose", action = "store_true", help = "Be verbose.")
	options, filenames = parser.parse_args()

	if options.database_cache is not None:
		if options.database is None:
			options.database = []
		options.database += [CacheEntry(line).path for line in open(options.database_cache)]

	if options.output_format not in (".png", ".pdf", ".svg"):
		raise ValueError("invalid --output-format \"%s\"" % options.output_format)

	options.user_tag = options.user_tag.upper()

	return options, filenames


def load_distributions(filenames, verbose = False):
	coinc_params_distributions, ranking_data, seglists = None, None, None
	for n, filename in enumerate(filenames, 1):
		if verbose:
			print >>sys.stderr, "%d/%d:" % (n, len(filenames)),
		this_coinc_params_distributions, this_ranking_data, this_seglists = far.parse_likelihood_control_doc(ligolw_utils.load_filename(filenames[0], contenthandler = far.ThincaCoincParamsDistributions.LIGOLWContentHandler, verbose = verbose))
		if this_coinc_params_distributions is None and this_ranking_data is None:
			raise ValueError("%s contains no parameter distribution data" % filename)
		if coinc_params_distributions is None:
			coinc_params_distributions = this_coinc_params_distributions
		else:
			coinc_params_distributions += this_coinc_params_distributions
		if ranking_data is None:
			ranking_data = this_ranking_data
		else:
			ranking_data += this_ranking_data
		if seglists is None:
			seglists = this_seglists
		else:
			seglists |= this_seglists
	if coinc_params_distributions is None and ranking_data is None:
		raise ValueError("no parameter distribution data loaded")
	# FIXME:  hack to fake some segments so we can generate file names
	# if the input files didn't provide segment information
	try:
		seglists.extent_all()
	except ValueError:
		seglists[None] = lsctables.segments.segmentlist([lsctables.segments.segment(0, 0)])
	return coinc_params_distributions, ranking_data, seglists


def load_search_results(filenames, tmp_path = None, verbose = False):
	if not filenames:
		return None, None

	background_ln_likelihood_ratios = {}
	zerolag_ln_likelihood_ratios = {}

	for n, filename in enumerate(filenames, 1):
		if verbose:
			print >>sys.stderr, "%d/%d: %s" % (n, len(filenames), filename)
		working_filename = dbtables.get_connection_filename(filename, tmp_path = tmp_path, verbose = verbose)
		connection = sqlite3.connect(working_filename)

		xmldoc = dbtables.get_xml(connection)
		definer_id = lsctables.CoincDefTable.get_table(xmldoc).get_coinc_def_id(ligolw_thinca.InspiralCoincDef.search, ligolw_thinca.InspiralCoincDef.search_coinc_type, create_new = False)

		for participating_instruments, ln_likelihood_ratio, is_background in connection.cursor().execute("""
SELECT
	coinc_inspiral.ifos,
	coinc_event.likelihood,
	EXISTS (
		SELECT
			*
		FROM
			time_slide
		WHERE
			time_slide.time_slide_id == coinc_event.time_slide_id
			AND time_slide.offset != 0
	)
FROM
	coinc_event
	JOIN coinc_inspiral ON (
		coinc_inspiral.coinc_event_id == coinc_event.coinc_event_id
	)
WHERE
	coinc_event.coinc_def_id == ?
		""", (definer_id,)):
			participating_instruments = frozenset(lsctables.instrument_set_from_ifos(participating_instruments))
			if is_background:
				background_ln_likelihood_ratios.setdefault(participating_instruments, []).append(ln_likelihood_ratio)
			else:
				zerolag_ln_likelihood_ratios.setdefault(participating_instruments, []).append(ln_likelihood_ratio)

		connection.close()
		dbtables.discard_connection_filename(filename, working_filename, verbose = verbose)

	return background_ln_likelihood_ratios, zerolag_ln_likelihood_ratios


class SnrChiColourNorm(matplotlib.colors.Normalize):
	def __call__(self, value, clip = None):
		value, is_scalar = self.process_value(value)
		numpy.clip(value, 1e-200, 1e+200, value)

		self.autoscale_None(value)

		vmin = math.log(self.vmin)
		vmax = math.log(self.vmax)
		xbar = (vmax + vmin) / 2.
		delta = (vmax - vmin) / 2.
		pi_2 = math.pi / 2.

		value = (numpy.arctan((numpy.log(value) - xbar) * pi_2 / delta) + pi_2) / math.pi
		return value[0] if is_scalar else value

	def inverse(self, value):
		vmin = math.log(self.vmin)
		vmax = math.log(self.vmax)
		xbar = (vmax + vmin) / 2.
		delta = (vmax - vmin) / 2.
		pi_2 = math.pi / 2.

		value = numpy.exp(numpy.tan(value * math.pi - pi_2) * delta / pi_2 + xbar)
		numpy.clip(value, 1e-200, 1e+200, value)
		return value

	def _autoscale(self, vmin, vmax):
		vmin = math.log(vmin)
		vmax = math.log(vmax)
		xbar = (vmax + vmin) / 2.
		delta = (vmax - vmin) / 2.
		xbar += 2. * delta / 3.
		delta /= 4.
		vmin = math.exp(xbar - delta)
		vmax = math.exp(xbar + delta)
		return vmin, vmax

	def autoscale(self, A):
		self.vmin, self.vmax = self._autoscale(numpy.ma.min(A), numpy.ma.max(A))

	def autoscale_None(self, A):
		vmin, vmax = self._autoscale(numpy.ma.min(A), numpy.ma.max(A))
		if self.vmin is None:
			self.vmin = vmin
		if self.vmax is None:
			self.vmax = vmax


#
# command line
#


options, filenames = parse_command_line()


#
# load input
#


coinc_params_distributions, ranking_data, seglists = load_distributions(filenames, verbose = options.verbose)
if coinc_params_distributions is not None:
	coinc_params_distributions.finish(verbose = options.verbose)
if ranking_data is not None:
	ranking_data.finish(verbose = options.verbose)
	fapfar = far.FAPFAR(ranking_data, livetime = far.get_live_time(seglists))
else:
	fapfar = None


background_ln_likelihood_ratios, zerolag_ln_likelihood_ratios = load_search_results(options.database, tmp_path = options.tmp_space, verbose = options.verbose)


#
# plots
#

# SNR and Chisquared
for instrument in ("H1","L1","V1"):
	for snr_chi_type in ("background_pdf", "injection_pdf", "zero_lag_pdf", "LR"):
		fig = plotfar.plot_snr_chi_pdf(coinc_params_distributions, instrument, snr_chi_type, options.max_snr)
		if fig is None:
			continue
		plotname = inspiral_pipe.T050017_filename(instrument, "GSTLAL_INSPIRAL_PLOT_BACKGROUND_%s_%s_SNRCHI2" % (options.user_tag, snr_chi_type.upper()), int(seglists.extent_all()[0]), int(seglists.extent_all()[1]), options.output_format, path = options.output_dir)
		if options.verbose:
			print >>sys.stderr, "writing %s" % plotname
		fig.savefig(plotname)


# Trigger and event rates
fig = plotfar.plot_rates(coinc_params_distributions, ranking_data)
plotname = inspiral_pipe.T050017_filename("H1L1V1", "GSTLAL_INSPIRAL_PLOT_BACKGROUND_%s_RATES" % options.user_tag, int(seglists.extent_all()[0]), int(seglists.extent_all()[1]), options.output_format, path = options.output_dir)
if options.verbose:
	print >>sys.stderr, "writing %s" % plotname
fig.savefig(plotname)


# SNR PDFs
for (instruments, horizon_distances) in sorted(coinc_params_distributions.snr_joint_pdf_cache.keys(), key = lambda (a, horizon_distances): sorted(horizon_distances)):
	fig = plotfar.plot_snr_joint_pdf(coinc_params_distributions, instruments, horizon_distances, options.max_snr)
	if fig is not None:
		hd_dict = dict(horizon_distances)
		plotname = inspiral_pipe.T050017_filename(instruments, "GSTLAL_INSPIRAL_PLOT_BACKGROUND_%s_SNR_PDF_%s" % (options.user_tag, "_".join(["%s_%s" % (k, hd_dict[k]) for k in sorted(hd_dict)]) ), int(seglists.extent_all()[0]), int(seglists.extent_all()[1]), options.output_format, path = options.output_dir)
		if options.verbose:
			print >>sys.stderr, "writing %s" % plotname
		fig.savefig(plotname)


if ranking_data is not None:
	for instruments, binnedarray in ranking_data.background_likelihood_pdfs.items():
		fig = plotfar.plot_likelihood_ratio_pdf(ranking_data, instruments, (options.min_log_lambda, options.max_log_lambda), "Noise", binnedarray_string = "background_likelihood_pdfs")
		plotname = inspiral_pipe.T050017_filename(instruments or "COMBINED", "GSTLAL_INSPIRAL_PLOT_BACKGROUND_%s_NOISE_LIKELIHOOD_RATIO_PDF" % options.user_tag, int(seglists.extent_all()[0]), int(seglists.extent_all()[1]), options.output_format, path = options.output_dir)
		if options.verbose:
			print >>sys.stderr, "writing %s" % plotname
		fig.savefig(plotname)

	ranking_stats = sum(zerolag_ln_likelihood_ratios.values(), []) if zerolag_ln_likelihood_ratios is not None else None
	fig = plotfar.plot_likelihood_ratio_ccdf(fapfar, (options.min_log_lambda, options.max_log_lambda), "Noise")
	plotname = inspiral_pipe.T050017_filename("NONE", "GSTLAL_INSPIRAL_PLOT_BACKGROUND_%s_NOISE_LIKELIHOOD_RATIO_CCDF" % options.user_tag, int(seglists.extent_all()[0]), int(seglists.extent_all()[1]), options.output_format, path = options.output_dir)
	if options.verbose:
		print >>sys.stderr, "writing %s" % plotname
	fig.savefig(plotname)


# FIXME
# I am not sure what to do here.
sys.exit(0)

for instruments, binnedarray in ranking_data.signal_likelihood_pdfs.items() if ranking_data is not None else ():
	fig = plot_likelihood_ratio_pdf(instruments, binnedarray, (options.min_log_lambda, options.max_log_lambda), "Signal")
	plotname = inspiral_pipe.T050017_filename(instruments or "NONE", "GSTLAL_INSPIRAL_PLOT_BACKGROUND_%s_signal_likelihood_ratio_pdf" % options.user_tag, int(seglists.extent_all()[0]), int(seglists.extent_all()[1]), options.output_format, path = options.output_dir)
	if options.verbose:
		print >>sys.stderr, "writing %s" % plotname
	fig.savefig(plotname)

	# the interpolators won't accept infinite co-ordinates, so
	# have to discard the last bin
	ranks = ranking_data.background_likelihood_rates[instruments].bins[0].upper()[:-1]
	counts = ranking_data.background_likelihood_rates[instruments].array[:-1]

	# cumulative distribution function and its complement.
	# it's numerically better to recompute the ccdf by
	# reversing the array of counts than trying to subtract the
	# cdf from 1.
	ccdf = counts[::-1].cumsum()[::-1]
	ccdf /= ccdf[0]

	# cdf boundary condition:  cdf = 1/e at the ranking
	# statistic threshold so that
	# self.far_from_rank(threshold) * livetime =
	# observed count of events above threshold.
	ccdf *= 1. - 1. / math.e

	from scipy import interpolate
	fig = plot_likelihood_ratio_ccdf(instruments, interpolate.interp1d(ranks, ccdf), None, (options.min_log_lambda, options.max_log_lambda), "Signal")
	plotname = inspiral_pipe.T050017_filename(instruments or "NONE", "GSTLAL_INSPIRAL_PLOT_BACKGROUND_%s_signal_likelihood_ratio_ccdf" % options.user_tag, int(seglists.extent_all()[0]), int(seglists.extent_all()[1]), options.output_format, path = options.output_dir)
	if options.verbose:
		print >>sys.stderr, "writing %s" % plotname
	fig.savefig(plotname)
