#!/usr/bin/env python
#
# Copyright (C) 2013,2014  Kipp Cannon, Chad Hanna, Jacob Peoples
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 2 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

## @file
# A program to comput the signal and noise rate posteriors of a gstlal inspiral analysis


#
# =============================================================================
#
#                                   Preamble
#
# =============================================================================
#


import bisect
try:
	from fpconst import NegInf
except ImportError:
	NegInf = float("-inf")
import h5py
import math
import matplotlib
matplotlib.rcParams.update({
	"font.size": 8.0,
	"axes.titlesize": 10.0,
	"axes.labelsize": 10.0,
	"xtick.labelsize": 8.0,
	"ytick.labelsize": 8.0,
	"legend.fontsize": 8.0,
	"figure.dpi": 600,
	"savefig.dpi": 600,
	"text.usetex": True
})
from matplotlib import figure
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
import numpy
from optparse import OptionParser
import sqlite3
import sys


from glue import segments
from glue.lal import CacheEntry
from glue.ligolw import ligolw
from glue.ligolw import dbtables
from glue.ligolw import lsctables
from glue.ligolw import utils as ligolw_utils
from glue.ligolw.utils import process as ligolw_process
from glue.text_progress_bar import ProgressBar
from pylal import ligolw_thinca


from gstlal import far
from gstlal import rate_estimation


golden_ratio = (1. + math.sqrt(5.)) / 2.


process_name = u"gstlal_inspiral_rate_posterior"


__author__ = "Kipp Cannon <kipp.cannon@ligo.org>"
__version__ = "git id %s" % ""	# FIXME
__date__ = ""	# FIXME


#
# =============================================================================
#
#                                 Command Line
#
# =============================================================================
#


def parse_command_line():
	parser = OptionParser(
		version = "Name: %%prog\n%s" % "" # FIXME
	)
	parser.add_option("-c", "--confidence-intervals", metavar = "confidence[,...]", default = "0.68,0.95,0.999999", help = "Compute and report confidence intervals in the signal count for these confidences (default = \".68,.95,.999999\", clear to disable).")
	parser.add_option("-i", "--input-cache", metavar = "filename", help = "Also process the files named in this LAL cache.  See lalapps_path2cache for information on how to produce a LAL cache file.")
	parser.add_option("--chain-file", metavar = "filename", help = "Read chain from this file, save chain to this file.")
	parser.add_option("--likelihood-file", metavar = "filename", action = "append", help = "Load ranking statistic PDFs from this file.  Can be given multiple times.")
	parser.add_option("--likelihood-cache", metavar = "filename", help = "Load ranking statistic PDFs from the files in this LAL cache.")
	parser.add_option("-t", "--tmp-space", metavar = "path", help = "Path to a directory suitable for use as a work area while manipulating the database file.  The database file will be worked on in this directory, and then moved to the final location when complete.  This option is intended to improve performance when running in a networked environment, where there might be a local disk with higher bandwidth than is available to the filesystem on which the final output will reside.")
	parser.add_option("--with-background", action = "store_true", help = "Show background posterior on plot.")
	parser.add_option("--samples", metavar = "count", type = "int", help = "Run this many samples.  Set to 0 to load and plot the contents of a previously-recorded chain file without doing any additional samples.")
	parser.add_option("--threshold", metavar = "log likelihood ratio", type = "float", help = "Derive the rate posterior by considering only events ranked at or above this value of the log likelihood ratio ranking statistic (default = use all events).")
	parser.add_option("-v", "--verbose", action = "store_true", help = "Be verbose.")
	options, filenames = parser.parse_args()

	paramdict = options.__dict__

	options.confidence_intervals = map(float, options.confidence_intervals.split(","))

	options.likelihood_filenames = []
	if options.likelihood_cache is not None:
		options.likelihood_filenames += [CacheEntry(line).path for line in open(options.likelihood_cache)]
	if options.likelihood_file is not None:
		options.likelihood_filenames += options.likelihood_file
	if not options.likelihood_filenames and options.samples > 0:
		raise ValueError("must provide --likelihood-cache and/or one or more --likelihood-file options")

	if options.input_cache:
		filenames += [CacheEntry(line).path for line in open(options.input_cache)]

	return options, paramdict, filenames


#
# =============================================================================
#
#                              Support Functions
#
# =============================================================================
#


def load_ranking_data(filenames, ln_likelihood_ratio_threshold = None, verbose = False):
	if not filenames:
		raise ValueError("no likelihood files!")
	ranking_data = None
	for n, filename in enumerate(filenames, start = 1):
		if verbose:
			print >>sys.stderr, "%d/%d:" % (n, len(filenames)),
		xmldoc = ligolw_utils.load_filename(filename, contenthandler = far.ThincaCoincParamsDistributions.LIGOLWContentHandler, verbose = verbose)
		ignored, this_ranking_data, ignored = far.parse_likelihood_control_doc(xmldoc)
		xmldoc.unlink()
		if this_ranking_data is None:
			raise ValueError("'%s' does not contain ranking statistic PDFs" % filename)
		if ranking_data is None:
			ranking_data = this_ranking_data
		else:
			ranking_data += this_ranking_data
	# affect the zeroing of the PDFs below threshold by hacking the
	# histograms before running .finish().  do the indexing ourselves
	# to not 0 the bin @ threshold
	if ln_likelihood_ratio_threshold is not None:
		for binnedarray in ranking_data.background_likelihood_rates.values():
			binnedarray.array[:binnedarray.bins[0][ln_likelihood_ratio_threshold]] = 0.
		for binnedarray in ranking_data.signal_likelihood_rates.values():
			binnedarray.array[:binnedarray.bins[0][ln_likelihood_ratio_threshold]] = 0.
		for binnedarray in ranking_data.zero_lag_likelihood_rates.values():
			binnedarray.array[:binnedarray.bins[0][ln_likelihood_ratio_threshold]] = 0.

	ranking_data.finish(verbose = verbose)

	return ranking_data


def load_search_results(filenames, ln_likelihood_ratio_threshold = None, tmp_path = None, verbose = False):
	background_ln_likelihood_ratios = []
	zerolag_ln_likelihood_ratios = []

	for n, filename in enumerate(filenames, 1):
		if verbose:
			print >>sys.stderr, "%d/%d: %s" % (n, len(filenames), filename)
		working_filename = dbtables.get_connection_filename(filename, tmp_path = tmp_path, verbose = verbose)
		connection = sqlite3.connect(working_filename)

		xmldoc = dbtables.get_xml(connection)
		definer_id = lsctables.CoincDefTable.get_table(xmldoc).get_coinc_def_id(ligolw_thinca.InspiralCoincDef.search, ligolw_thinca.InspiralCoincDef.search_coinc_type, create_new = False)

		for ln_likelihood_ratio, is_background in connection.cursor().execute("""
SELECT
	coinc_event.likelihood,
	EXISTS (
		SELECT
			*
		FROM
			time_slide
		WHERE
			time_slide.time_slide_id == coinc_event.time_slide_id
			AND time_slide.offset != 0
	)
FROM
	coinc_event
	JOIN coinc_inspiral ON (
		coinc_inspiral.coinc_event_id == coinc_event.coinc_event_id
	)
WHERE
	coinc_event.coinc_def_id == ?
	AND coinc_event.likelihood >= ?"""
	, (definer_id, (ln_likelihood_ratio_threshold if ln_likelihood_ratio_threshold is not None else NegInf))):
			if is_background:
				background_ln_likelihood_ratios.append(ln_likelihood_ratio)
			else:
				zerolag_ln_likelihood_ratios.append(ln_likelihood_ratio)

		connection.close()
		dbtables.discard_connection_filename(filename, working_filename, verbose = verbose)

	return background_ln_likelihood_ratios, zerolag_ln_likelihood_ratios


def plot_rates(signal_rate, confidence_intervals = None):
	fig = figure.Figure()
	FigureCanvas(fig)
	fig.set_size_inches((4., 4. / golden_ratio))
	signal_axes = fig.gca()

	x = signal_rate.bins[0].centres()
	y = signal_rate.array
	line1, = signal_axes.plot(x, y, color = "k", linestyle = "-", label = "Signal")
	signal_axes.set_title("Event Rate Posterior")
	signal_axes.set_xlabel("Event Rate ($\mathrm{signals} / \mathrm{experiment}$)")
	signal_axes.set_ylabel(r"$P(\mathrm{signals} / \mathrm{experiment})$")
	signal_axes.loglog()

	signal_axes.set_ylim((1e-8, 1.))

	if confidence_intervals is not None:
		alphas = dict(zip(sorted(confidence_intervals), [.6, .4, .2]))

		# convert lo and hi bounds to co-ordinate index segments
		confidence_intervals = sorted((conf, segments.segmentlist([segments.segment(bisect.bisect_left(x, lo), bisect.bisect_right(x, hi))])) for conf, (mode, lo, hi) in confidence_intervals.items())

		# remove from each the indexes spanned by lower confidence regions
		confidence_intervals = [(conf, indexes - sum((segs for conf, segs in confidence_intervals[:i]), segments.segmentlist())) for i, (conf, indexes) in enumerate(confidence_intervals)]

		for conf, segs in confidence_intervals:
			for lo, hi in segs:
				signal_axes.fill_between(x[lo:hi], y[lo:hi], 1e-8, color = "k", alpha = alphas[conf])

	try:
		fig.tight_layout()
	except AttributeError:
		# old matplotlib. auto-layout not available
		pass
	return fig


#
# =============================================================================
#
#                                     Main
#
# =============================================================================
#


#
# command line
#


options, paramdict, filenames = parse_command_line()


#
# load ranking statistic PDFs
#


if options.likelihood_filenames:
	ranking_data = load_ranking_data(options.likelihood_filenames, ln_likelihood_ratio_threshold = options.threshold, verbose = options.verbose)
else:
	ranking_data = None


#
# load search results
#


background_ln_likelihood_ratios, zerolag_ln_likelihood_ratios = load_search_results(filenames, ln_likelihood_ratio_threshold = options.threshold, tmp_path = options.tmp_space, verbose = options.verbose)


#
# calculate rate posteriors
#


if options.verbose:
	print >>sys.stderr, "calculating rate posteriors using %d likelihood ratios ..." % len(zerolag_ln_likelihood_ratios)
	progressbar = ProgressBar()
else:
	progressbar = None
kwargs = {}
if options.chain_file is not None:
	kwargs["chain_file"] = h5py.File(options.chain_file)
if options.samples is not None:
	kwargs["nsamples"] = options.samples
signal_rate_pdf, noise_rate_pdf = rate_estimation.calculate_rate_posteriors(ranking_data, zerolag_ln_likelihood_ratios, progressbar = progressbar, **kwargs)
del progressbar


#
# find confidence intervals
#


if options.confidence_intervals:
	if options.verbose:
		print >>sys.stderr, "determining confidence intervals ..."
	confidence_intervals = dict((conf, rate_estimation.confidence_interval_from_binnedarray(signal_rate_pdf, conf)) for conf in options.confidence_intervals)
else:
	confidence_intervals = None
if options.verbose and confidence_intervals is not None:
	print >>sys.stderr, "rate posterior mean = %g signals/experiment" % rate_estimation.mean_from_pdf(signal_rate_pdf)
	# all modes are the same, pick one and report it
	print >>sys.stderr, "maximum-likelihood rate = %g signals/experiment" % confidence_intervals.values()[0][0]
	for conf, (mode, lo, hi) in sorted(confidence_intervals.items()):
		print >>sys.stderr, "%g%% confidence interval = [%g, %g] signals/experiment" % (conf * 100., lo, hi)


#
# save results
#


filename = "rate_posteriors.xml.gz"
xmldoc = ligolw.Document()
xmldoc.appendChild(ligolw.LIGO_LW())
process = ligolw_process.register_to_xmldoc(xmldoc, process_name, paramdict)
xmldoc.childNodes[-1].appendChild(signal_rate_pdf.to_xml(u"%s:signal_pdf" % process_name))
xmldoc.childNodes[-1].appendChild(noise_rate_pdf.to_xml(u"%s:noise_pdf" % process_name))
ligolw_utils.write_filename(xmldoc, filename, gz = (filename or stdout).endswith(".gz"), verbose = options.verbose)


fig = plot_rates(signal_rate_pdf, confidence_intervals = confidence_intervals)
for filename in ("rate_posteriors.png", "rate_posteriors.pdf"):
	if options.verbose:
		print >>sys.stderr, "writing %s ..." % filename
	fig.savefig(filename)

if options.verbose:
	print >>sys.stderr, "done"
