#!/usr/bin/env python
#
# Copyright (C) 2010--2014  Kipp Cannon, Chad Hanna
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 2 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
#
## @file gstlal_ll_inspiral_create_prior_diststats
# A program to create some prior likelihood data to seed an online analysis
#
# ### Command line interface
#
#		--verbose, action = "store_true", help = "Be verbose."
#		--num-templates, metavar = "N", default = 1000, type = "int", help = "Set the number of templates per bank. default 1000")
#		--num-banks, metavar = "N", default = 1, type = "int", help = "Set the number of banks. default 1")
#		--write-likelihood-basename, metavar = "string", default = "prior.xml.gz", help = "Write likelihood files to disk with this basename: default prior.xml.gz.")
#		--write-likelihood-cache, metavar = "filename", help = "Write likelihood files to disk and include the names in this cachefile file.")
#		--segment-and-horizon, action = "append", help = "Append to a list of segments and horizon distances for a given instrument.  Argument specified as IFO:start:end:horizon, e.g., H1:1000000000:1000000100:120 ")
#		--background-prefactors, metavar = "s,e", default = "0.5,20.", help = "Set the range of prefactors on the chi-squared distribution for the background model: default 0.5,20.")
#		--injection-prefactors, metavar = "s,e", default = "0.01,0.25", help = "Set the range of prefactors on the chi-squared distribution for the signal model: default 0.01,0.25")

#
# =============================================================================
#
#                                   Preamble
#
# =============================================================================
#


import sys
import shutil
from gstlal import far
from gstlal import inspiral_pipe
from glue.ligolw import utils as ligolw_utils
from glue.ligolw.utils import process as ligolw_process
from glue.ligolw import ligolw
from glue.text_progress_bar import ProgressBar
from pylal import rate, snglcoinc 
from glue import segments
from glue import iterutils
from glue.lal import LIGOTimeGPS, CacheEntry
import numpy
from optparse import OptionParser


def parse_command_line():
	parser = OptionParser(
		version = "Name: %%prog\n%s" % "" # FIXME
	)
	parser.add_option("-v", "--verbose", action = "store_true", help = "Be verbose.")
	parser.add_option("--num-templates", metavar = "N", default = 1000, type = "int", help = "Set the number of templates per bank. default 1000")
	parser.add_option("--num-banks", metavar = "N", default = 1, type = "int", help = "Set the number of banks. default 1")
	parser.add_option("--write-likelihood-basename", metavar = "string", default = "prior.xml.gz", help = "Write likelihood files to disk with this basename: default prior.xml.gz.")
	parser.add_option("--write-likelihood-cache", metavar = "filename", help = "Write likelihood files to disk and include the names in this cachefile file.")
	parser.add_option("--segment-and-horizon", action = "append", help = "Append to a list of segments and horizon distances for a given instrument.  Argument specified as IFO:start:end:horizon, e.g., H1:1000000000:1000000100:120 ")
	parser.add_option("--background-prefactors", metavar = "s,e", default = "0.5,20.", help = "Set the range of prefactors on the chi-squared distribution for the background model: default 0.5,20.")
	parser.add_option("--injection-prefactors", metavar = "s,e", default = "0.01,0.25", help = "Set the range of prefactors on the chi-squared distribution for the signal model: default 0.01,0.25")
	options, filenames = parser.parse_args()

	process_params = dict(options.__dict__)

	def parse_segment_and_horizon(options = options):
		seglistdict = segments.segmentlistdict()
		horizon_history = {}
		for x in options.segment_and_horizon:
			ifo, start, stop, horizon = x.split(":")
			seglistdict.setdefault(ifo, segments.segmentlist()).append(segments.segment([LIGOTimeGPS(float(start)), LIGOTimeGPS(float(stop))]))
			horizon_history.setdefault(ifo, far.NearestLeafTree())[LIGOTimeGPS(float(start))] = float(horizon)
			horizon_history.setdefault(ifo, far.NearestLeafTree())[LIGOTimeGPS(float(stop))] = float(horizon)
		return seglistdict, horizon_history

	seglistdict, horizon_history = parse_segment_and_horizon(options)

	instruments = frozenset(seglistdict)
	if len(instruments) < 2:
		raise ValueError("must specify at least two distinct instrument")

        bankcachedict = None #inspiral_pipe.parse_cache_str(options.bank_cache)

	return options, process_params, instruments, seglistdict, horizon_history, bankcachedict


options, process_params, instruments, segs, horizon_history, bankcachedict = parse_command_line()

if options.verbose:
	print "Livetime: ", abs(segs)
	print "Extent: ", segs.extent_all()

#
# quantities derived from input
#

#
# Number of background events in each detector
#
# This is calculated assuming the following
# 1) There are options.num_templates in the analysis
# 2) Each template produces exactly 1 trigger for every second that a given
# detector is on according to the user provided segments
#

n_background = dict(((ifo, float(abs(seg)) * options.num_templates) for ifo, seg in segs.items())) 

#
# Number of zero lag events in each detector
#
# This is calculated assuming the following
# 1) Only coincident events go into the zero_lag histograms
# 2) The coincidence rate is 100 times lower than background rate for each detector
#

n_zerolag = dict(((ifo, float(abs(seg)) * options.num_templates / 100.) for ifo, seg in segs.items()))

#
# Initialize an empty ThincaCoincParamsDistributions class
#

diststats = far.ThincaCoincParamsDistributions()

#
# Add background, zero_lag and injection prior distributions in the SNR and chi-squared plane
# 

diststats.add_background_prior(n = n_background, ba = "background_rates", prefactors_range = tuple(float(x) for x in options.background_prefactors.split(",")), verbose = options.verbose)
diststats.add_background_prior(n = n_zerolag, ba = "zero_lag_rates", prefactors_range = tuple(float(x) for x in options.background_prefactors.split(",")), verbose = options.verbose)
diststats.add_foreground_snrchi_prior(n = dict(((ifo, 1e8) for ifo, seg in segs.items())), prefactors_range = tuple(float(x) for x in options.injection_prefactors.split(",")), verbose = options.verbose)

#
# Update the horizon distance history with our fake, user provided horizon history
#

diststats.horizon_history.update(horizon_history)

#
# Fake a set of coincident triggers for all possible combinations
# FIXME Derive these from the segment lists!!! Use Kipp's code??
#

for i in range(2, len(instruments) + 1):
	for ifos in [frozenset(x) for x in iterutils.choices(tuple(instruments), i)]:
		diststats.zero_lag_rates["instruments"][diststats.instrument_categories.category(ifos),] = min(n_background[ifo] for ifo in ifos) / 100.**(i-1)

diststats.add_instrument_combination_counts(segs = segs, verbose = options.verbose)

#
# joint SNR PDF
#

horizon_distances = dict(((ifo, numpy.mean(h.values())) for ifo, h in horizon_history.items()))
for n in range(2, len(horizon_distances) + 1):
	for instruments in iterutils.choices(horizon_distances.keys(), n):
		# Force the SNR pdf to be generated to be at the actual horizon distance values not the quantized ones
		key = frozenset(instruments), frozenset(diststats.quantize_horizon_distances(horizon_distances).items())
		if options.verbose:
			print >>sys.stderr, "For horizon distances %s" % ", ".join("%s = %.4g Mpc" % item for item in sorted(horizon_distances.items()))
			progressbar = ProgressBar(text = "%s SNR PDF" % ", ".join(sorted(key[0])))
		else:
			progressbar = None
		binnedarray = diststats.joint_pdf_of_snrs(key[0], horizon_distances, progressbar = progressbar)
		del progressbar

		lnbinnedarray = binnedarray.copy()
		with numpy.errstate(divide = "ignore"):
			lnbinnedarray.array = numpy.log(lnbinnedarray.array)
		pdf = rate.InterpBinnedArray(lnbinnedarray, fill_value = float("-inf"))
		diststats.snr_joint_pdf_cache[key] = pdf, binnedarray, 0

diststats.populate_prob_of_instruments_given_signal(segs)

#
# Finished with this class
#

diststats.finish()

#
# Prep an output XML file
#

xmldoc = ligolw.Document()
xmldoc.appendChild(ligolw.LIGO_LW())
process = ligolw_process.register_to_xmldoc(xmldoc, u"gstlal_ll_inspiral_create_prior_diststats", ifos = instruments, paramdict = process_params)

#
# Instantiate a new RankingData class from our distribution stats
#

ranking_data = far.RankingData(diststats, instruments = instruments, process_id = process.process_id, nsamples = 100000, verbose = options.verbose)

#
# Simulate a measured coincident trigger likelihood histogram by just using the background one with a lower normalized count
#

for instruments in ranking_data.background_likelihood_rates:
	if instruments is None:
		continue
	else:
		ranking_data.zero_lag_likelihood_rates[instruments].array[:] = ranking_data.background_likelihood_rates[instruments].array[:] / ranking_data.background_likelihood_rates[instruments].array.sum() * diststats.zero_lag_rates["instruments"][diststats.instrument_categories.category(ifos),]
		ranking_data.zero_lag_likelihood_rates[instruments].array[:] = ranking_data.zero_lag_likelihood_rates[instruments].array.round()

ranking_data._compute_combined_rates()
ranking_data.finish()

#
# record results in output file
#

far.gen_likelihood_control_doc(xmldoc, process, diststats, ranking_data, segs, comment = u"background and signal priors (no real data)")

#
# done
#

ligolw_process.set_process_end_time(process)

cachefile = open(options.write_likelihood_cache, "w")

for n in range(options.num_banks):
	this_name = "%04d_%s" % (n, options.write_likelihood_basename)
	c = CacheEntry("".join(sorted(segs.keys())), "PRIORS", segs.extent_all(), this_name)
	if n == 0:
		ligolw_utils.write_filename(xmldoc, this_name, gz = this_name.endswith(".gz"), verbose = options.verbose)
		original = this_name
	else:
		shutil.copyfile(original, this_name)
	print >> cachefile, str(c)

cachefile.close()

