#!/usr/bin/env python
#
# Copyright (C) 2011, 2012 Kipp Cannon, Chad Hanna, Drew Keppel
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 2 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

import sys
import numpy
from scipy import interpolate, random
from scipy.stats import poisson
from glue import segments
from glue.ligolw import ligolw
from glue.ligolw import lsctables, param, array
array.use_in(ligolw.LIGOLWContentHandler)
param.use_in(ligolw.LIGOLWContentHandler)
lsctables.use_in(ligolw.LIGOLWContentHandler)
from glue.ligolw import utils
from glue.ligolw.utils import process as ligolw_process
from glue.ligolw.utils import segments as ligolw_segments
from glue.segmentsUtils import vote
from glue import iterutils
from glue.ligolw.utils import search_summary as ligolw_search_summary
from pylal import rate
from optparse import OptionParser
from gstlal import far
from gstlal.svd_bank import read_bank
from gstlal import far

try:
	import sqlite3
except ImportError:
	# pre 2.5.x
	from pysqlite2 import dbapi2 as sqlite3

sqlite3.enable_callback_tracebacks(True)


def parse_command_line():
	parser = OptionParser()
	parser.add_option("--background-bins-file", metavar = "filename", action = "append", help = "Set the name of the xml file containing the marginalized likelihood")
	parser.add_option("--tmp-space", metavar = "dir", help = "Set the name of the tmp space if working with sqlite")
	parser.add_option("--verbose", "-v", action = "store_true", help = "Be verbose.")
	parser.add_option("--non-injection-db", metavar = "filename", action = "append", help = "single file for non injections run")
	parser.add_option("--injection-dbs", action = "append", default=[], help = "append to the list of possible injection files, may be empty if no injections were done. Databases are assumed to be over the same time period as the non injection databases using the same templates.  If not the results will be nonsense.")
	options, filenames = parser.parse_args()
	return options, filenames

#
# Parse command line
#

options, filenames = parse_command_line()

#
# Pull out background and injections distribution and set up the FAR class
#

newtrials = far.TrialsTable()
newtrials.initialize_from_sngl_ifos(("H1", "H2", "L1", "V1"), count = 0, count_below_thresh = 0)
global_ranking, procid = far.RankingData.from_xml(utils.load_filename(options.background_bins_file[0], contenthandler = ligolw.LIGOLWContentHandler, verbose = options.verbose))
global_ranking.trials_table += newtrials

# late import for DB manipulations
from glue.ligolw import dbtables

global_ranking.compute_joint_cdfs()


#
# Set the FAR.  iterate over non injection files to rank as a first pass at
# computing fars
#

for bkdb in options.non_injection_db:
	far.set_far(global_ranking, bkdb, tmp_path = options.tmp_space, verbose = options.verbose)


#
# Calibrate to the high event limit (100 events) to get the number of independent trials
#

top100 = {}

for bkdb in options.non_injection_db:
	working_filename = dbtables.get_connection_filename(bkdb, tmp_path = options.tmp_space, verbose = options.verbose)
	connection = sqlite3.connect(working_filename)
	for ifos in [ifos for ifos in global_ranking.trials_table if global_ranking.trials_table[ifos].count > 0]:
		top100.setdefault(ifos, []).extend(connection.cursor().execute('SELECT combined_far FROM coinc_inspiral WHERE ifos == ? ORDER BY combined_FAR LIMIT 100', (lsctables.ifos_from_instrument_set(ifos),)).fetchall())
	connection.close()
	dbtables.discard_connection_filename(bkdb, working_filename, verbose = options.verbose)

for ifos in top100:
	top100[ifos].sort()
	if len(top100[ifos]) < 100:
		global_ranking.trials_table[ifos].count_below_thresh = len(top100[ifos])
		global_ranking.trials_table[ifos].thresh = top100[ifos][-1][0]
	else:
		global_ranking.trials_table[ifos].count_below_thresh = 100
		global_ranking.trials_table[ifos].thresh = top100[ifos][99][0]


#
# Adjust the fap to the high event limit to include the dependence of trials
#
# The fap must be adjusted by the high event limit calibration, the number and
# the numer of ifo combinations (since we make a single IFAR plot)
#

# get number of slides
working_filename = dbtables.get_connection_filename(options.non_injection_db[0], tmp_path = options.tmp_space, verbose = options.verbose)
connection = sqlite3.connect(working_filename)
num_slides, = connection.cursor().execute("SELECT COUNT(DISTINCT(time_slide_id)) FROM time_slide").fetchone()
connection.close()
dbtables.discard_connection_filename(options.non_injection_db[0], working_filename, verbose = options.verbose)


for ifos in global_ranking.trials_table:
	try:
		global_ranking.scale[ifos] = global_ranking.trials_table[ifos].count_below_thresh / global_ranking.trials_table[ifos].thresh / float(abs(global_ranking.livetime_seg)) * global_ranking.trials_table.num_nonzero_count() / num_slides
	except TypeError:
		global_ranking.scale[ifos] = 1


#
# Scale the rate Set the FAP and FAR
#


for bkdb in options.non_injection_db:
	far.set_fap(global_ranking, bkdb, tmp_path = options.tmp_space, verbose = options.verbose)

#
# increment the trials factor by 1 before assigning faps for
# injections.  Injections imply that there is always one more event
#

global_ranking.trials_table.increment_count(1)
for injdb in options.injection_dbs:
	far.set_fap(global_ranking, injdb, tmp_path = options.tmp_space, verbose = options.verbose)

for bkdb in options.non_injection_db:
	far.set_far(global_ranking, bkdb, tmp_path = options.tmp_space, scale = True, verbose = options.verbose)

for injdb in options.injection_dbs:
	far.set_far(global_ranking, injdb, tmp_path = options.tmp_space, verbose = options.verbose)

#
# Write out marginalized likelihood file after the scale factor has been computed
#

for k in tuple(global_ranking.trials_table):
	if global_ranking.trials_table[k].count == 1:
		del global_ranking.trials_table[k]
xmldoc = ligolw.Document()
node = xmldoc.appendChild(ligolw.LIGO_LW())
node.appendChild(lsctables.New(lsctables.ProcessTable))
node.appendChild(lsctables.New(lsctables.ProcessParamsTable))
node.appendChild(lsctables.New(lsctables.SearchSummaryTable))
process = ligolw_process.register_to_xmldoc(xmldoc, u"gstlal_inspiral_marginalize_likelihood", options.__dict__)
search_summary = ligolw_search_summary.append_search_summary(xmldoc, process)
search_summary.set_out(global_ranking.livetime_seg)
xmldoc.childNodes[-1].appendChild(global_ranking.to_xml(process, search_summary))
ligolw_process.set_process_end_time(process)
outname = "post_%s" % options.background_bins_file[0]
utils.write_filename(xmldoc, outname, gz = outname.endswith(".gz"), verbose = options.verbose)
