#!/usr/bin/python
#
# Copyright (C) 2008  Kipp Cannon
#               2010  Drew Keppel
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 2 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.


#
# =============================================================================
#
#                                   Preamble
#
# =============================================================================
#


import bisect
import numpy
from new import instancemethod
from optparse import OptionParser
try:
	import sqlite3
except ImportError:
	# pre 2.5.x
	from pysqlite2 import dbapi2 as sqlite3
import sys
import copy
import time

from glue import segments
from glue.ligolw import ligolw
from glue.ligolw import lsctables
from glue.ligolw import dbtables
from glue.ligolw.utils import process as ligolw_process
from pylal import git_version
from pylal import rate
from pylal import db_thinca_rings
from pylal.xlal.datatypes.ligotimegps import LIGOTimeGPS
from glue import lal

lsctables.LIGOTimeGPS = LIGOTimeGPS


__author__ = "Kipp Cannon <kipp.cannon@ligo.org>"
__version__ = "git id %s" % git_version.id
__date__ = git_version.date


#
# =============================================================================
#
#                                 Command Line
#
# =============================================================================
#


def parse_command_line():
	parser = OptionParser(
		version = "Name: %%prog\n%s" % git_version.verbose_msg,
		usage = "%prog [options] [file ...]",
		description = "%prog does blah blah blah."
	)
	parser.add_option("-i", "--input-cache", metavar = "filename", action = "append", help = "Retrieve database files from this LAL cache.  Can be given multiple times.")
	parser.add_option("-p", "--live-time-program", metavar = "name", help = "Set the name of the program whose entries in the search_summary table will set the search live time.  Required.")
	parser.add_option("--veto-segments-name", help = "Set the name of the segments to extract from the segment tables and use as the veto list.")
	parser.add_option("--categories", metavar = "{\"mchirp-ifos-oninstruments\",\"mtotal-ifos-oninstruments\",\"ifos-oninstruments\",\"oninstruments\",\"none\"}", help = "Select the event categorization algorithm.  Required.")
	parser.add_option("-b", "--mass-bins", metavar = "mass,mass[,mass,...]", help = "Set the boundaries of the mass bins in solar masses.  The lowest and highest bounds must be explicitly listed.  Example \"0,5,inf\".  Required if mass-based categorization algorithm has been selected.")
	parser.add_option("--rank-by", metavar = "{\"snr\",\"uncombined-ifar\",\"combined-ifar\",\"likelihood\"}", help = "Select the event ranking method.  Required.")
	parser.add_option("-t", "--tmp-space", metavar = "path", help = "Path to a directory suitable for use as a work area while manipulating the database file.  The database file will be worked on in this directory, and then moved to the final location when complete.  This option is intended to improve performance when running in a networked environment, where there might be a local disk with higher bandwidth than is available to the filesystem on which the final output will reside.")
	parser.add_option("-n", "--extrapolation-num", action = "store", type = "int", metavar = "num", help = "Set the number of background points to use in FAR extrapolation" )
	parser.add_option("-v", "--verbose", action = "store_true", help = "Be verbose.")
	options, filenames = parser.parse_args()

	#
	# categories and ranking
	#

	if options.categories not in ("mchirp-ifos-oninstruments", "mtotal-ifos-oninstruments", "ifos-oninstruments", "oninstruments", "none"):
		raise ValueError("missing or unrecognized --categories option")
	if options.rank_by not in ("snr", "uncombined-ifar", "combined-ifar", "likelihood"):
		raise ValueError("missing or unrecognized --rank-by option")

	if options.categories in ("mchirp-ifos-oninstruments", "mtotal-ifos-oninstruments"):
		if options.rank_by != "snr":
			raise ValueError("--rank-by must be \"snr\" if --categories is \"%s\"" % options.categories)
		options.populate_column = "false_alarm_rate"
	else:
		options.populate_column = "combined_far"

	if options.categories in ("none",) or options.rank_by in ("combined-ifar",):
		options.false_alarm_rate_column = "combined_far"
	else:
		options.false_alarm_rate_column = "false_alarm_rate"

	if options.extrapolation_num is not None:
		if options.extrapolation_num <= 0:
			raise ValueError("if specifying --extrapolation-num it must be > 0: %d given" % extrapolation_num)
		if options.rank_by in ("snr", "uncombined-ifar"):
			if options.categories in ("none",):
				raise ValueError("extrapolation for --categories=%s not implemented" % options.categories)
		else:
			raise ValueError("extrapolation not implemented for --rank-by %s" % options.rank_by)

	#
	# save a copy of options before mucking with mass bins
	#

	unaltered_options = copy.deepcopy(options)

	#
	# parse mass bins
	#

	if options.categories in ("mchirp-ifos-oninstruments", "mtotal-ifos-oninstruments"):
		if options.mass_bins is None:
			raise ValueError("--mass-bins required with category algorithm \"%s\"" % options.categories)
		options.mass_bins = sorted(map(float, options.mass_bins.split(",")))
		if len(options.mass_bins) < 2:
			raise ValueError("must set at least two mass bin boundaries (i.e., define at least one mass bin)")
		options.mass_bins = rate.IrregularBins(options.mass_bins)

	#
	# other
	#

	if options.live_time_program is None:
		raise ValueError("missing required option -p or --live-time-program")

	#
	# caches
	#

	if options.input_cache is not None:
		filenames += [lal.CacheEntry(line).path for filename in options.input_cache for line in file(filename)]

	#
	# done
	#

	return options, (filenames or [None]), unaltered_options


#
# =============================================================================
#
#                               Helper Functions
#
# =============================================================================
#


class ExtrapolationCoeffs(object):
	def __init__(self, rank_by, N, reputations, Ncat):
		"""
		calculate the FAR extrapolation coefficients
		@rank_by: the statistic we are using to rank triggers
		@N: the number of triggers to compute the extrapolation coefficients
		@reputations: a list of the triggers' rank_by statistic
		@Ncat: the number
		"""

		self.rank_by = rank_by

		#
		# check we have enough triggers to extrapolate
		#

		if len(reputations) < N and len(reputations) >= 2:
			print >>sys.stderr,"Warning: found a category with less than N=%i background triggers, using the number available: %i"%(N,len(reputations))
		elif len(reputations) < 2:
			print >>sys.stderr,"Warning: found a category with less than 2 background triggers, cannot perform extrapolation for this category"
			self.npoints = 0
			self.A = 0
			self.B = 0
			return

		#
		# get just the loudest N in reputations
		#

		x = reputations[-N:]

		if rank_by == "snr":
			#
			# compute the extrapolation coefficients
			# extrapolation equation is FAN = A*e^{B*snr^2}
			#

			y = range(1, len(x) + 1, 1)
			y.reverse()

			Y0 = 0.
			X0 = 0.

			for idx in range(len(x)):
				sigmasq = 1. / y[idx]
				Y0 += numpy.log(1. * y[idx] / y[0]) * (x[idx]**2. - x[0]**2.) / sigmasq
				X0 += (x[idx]**2. - x[0]**2.)**2. / sigmasq

			self.npoints = len(x)
			self.B = Y0 / X0
			self.A = y[0] * numpy.exp(-self.B * x[0]**2.)

		elif rank_by == "uncombined-ifar":
			#
			# extrapolation equation is cFAN = A*FAN
			# number of points to extrapolate is N times Ncat
			#

			self.npoints = N*Ncat
			self.A = Ncat
			self.B = 0

		else:
			raise ValueError(rank_by)


	def extrapolate_reputation(self, reputation, t):
		"""
		calculate the FAN from the reputation and the extrapolation coefficients
		@reputation: the reputation of the trigger to extrapolate
		@t: the livetime that this trigger is from
		"""
		if self.rank_by == "snr":
			#
			# extrapolation equation is FAN = A*e^{B*snr^2}
			#

			return self.A * numpy.exp(self.B * reputation**2.)
		if self.rank_by == "uncombined-ifar":
			#
			# extrapolation equation is cFAN = A*FAN
			# remember FAN = t*FAR
			# and reputation = IFAR
			# so cFAN = A*t/reputation
			#

			return self.A * t / reputation
		raise ValueError(rank_by)


#
# =============================================================================
#
#                                 Book-Keeping
#
# =============================================================================
#


class Summaries(object):
	def __init__(self, category_algorithm, rank_by, mass_bins = None):
		if category_algorithm == "mchirp-ifos-oninstruments":
			self.category_func = lambda self, on_instruments, participating_instruments, mchirp, mtotal: (on_instruments, participating_instruments, self.mass_bins[mchirp])
			self.time_func = lambda self, on_instruments: on_instruments
			self.category_to_time = lambda self, category: category[0]
		elif category_algorithm == "mtotal-ifos-oninstruments":
			self.category_func = lambda self, on_instruments, participating_instruments, mchirp, mtotal: (on_instruments, participating_instruments, self.mass_bins[mtotal])
			self.time_func = lambda self, on_instruments: on_instruments
			self.category_to_time = lambda self, category: category[0]
		elif category_algorithm == "ifos-oninstruments":
			self.category_func = lambda self, on_instruments, participating_instruments, mchirp, mtotal: (on_instruments, participating_instruments)
			self.time_func = lambda self, on_instruments: on_instruments
			self.category_to_time = lambda self, category: category[0]
		elif category_algorithm == "oninstruments":
			self.category_func = lambda self, on_instruments, participating_instruments, mchirp, mtotal: on_instruments
			self.time_func = lambda self, on_instruments: on_instruments
			self.category_to_time = lambda self, category: category
		elif category_algorithm == "none":
			self.category_func = lambda self, on_instruments, participating_instruments, mchirp, mtotal: None
			self.time_func = lambda self, on_instruments: None
			self.category_to_time = lambda self, category: category
		else:
			raise ValueError(category_algorithm)
		self.category_func = instancemethod(self.category_func, self, self.__class__)
		self.time_func = instancemethod(self.time_func, self, self.__class__)
		self.category_to_time = instancemethod(self.category_to_time, self, self.__class__)

		if rank_by == "snr":
			self.reputation_func = lambda self, snr, uncombined_ifar, likelihood: snr
		elif rank_by == "uncombined-ifar":
			# cast to float so this does the right thing when
			# we get "inf" as a string
			self.reputation_func = lambda self, snr, uncombined_ifar, likelihood: float(uncombined_ifar)
		elif rank_by == "combined-ifar":
			# cast to float so this does the right thing when
			# we get "inf" as a string
			self.reputation_func = lambda self, snr, uncombined_ifar, likelihood: float(uncombined_ifar)
		elif rank_by == "likelihood":
			self.reputation_func = lambda self, snr, uncombined_ifar, likelihood: likelihood
		else:
			raise ValueError(rank_by)
		self.rank_by = rank_by
		self.reputation_func = instancemethod(self.reputation_func, self, self.__class__)

		self.mass_bins = mass_bins
		self.reputations = {}
		self.cached_livetime = {}
		self.participating_instruments = {}
		self.extrapolate = False
		self.extrapolation_coeffs = {}
		self.prev_num_mass_bins = None
		self.prev_newcorse_time = None

	def add_livetime(self, connection, live_time_program, veto_segments_name = None, verbose = False):
		if veto_segments_name is not None:
			if verbose:
				print >>sys.stderr, "\tretrieving veto segments \"%s\" ..." % veto_segments_name
			veto_segments = db_thinca_rings.get_veto_segments(connection, veto_segments_name)
		else:
			veto_segments = segments.segmentlistdict()
		if verbose:
			print >>sys.stderr, "\tcomputing livetimes:",
		for on_instruments, livetimes in db_thinca_rings.get_thinca_livetimes(db_thinca_rings.get_thinca_rings_by_available_instruments(connection, program_name = live_time_program), veto_segments, db_thinca_rings.get_background_offset_vectors(connection), verbose = verbose).items():
			on_instruments = lsctables.ifos_from_instrument_set(on_instruments)
			try:
				self.cached_livetime[self.time_func(on_instruments)] += sum(livetimes)
			except KeyError:
				self.cached_livetime[self.time_func(on_instruments)] = sum(livetimes)
		if verbose:
			print >>sys.stderr

	def add_coinc(self, on_instruments, participating_instruments, mchirp, mtotal, snr, uncombined_ifar, likelihood):
		self.reputations.setdefault(self.category_func(on_instruments, participating_instruments, mchirp, mtotal), []).append(self.reputation_func(snr, uncombined_ifar, likelihood))
		if participating_instruments not in self.participating_instruments.setdefault(on_instruments, set()):
			self.participating_instruments[on_instruments].add(participating_instruments)

	def index(self):
		for reputations in self.reputations.values():
			reputations.sort()

	def rate_extrapolation(self, N, Nbins):
		"""
		calculate the FAR extrapolation coefficients for the different categories
		@N: the number of triggers included in the extrapolation fit
		@Ncat: the number of mass bins used in previous newcorse run (0 is default from command line)
		"""

		self.extrapolate = True
		for category, reputations in self.reputations.items():
			self.extrapolation_coeffs[category] = ExtrapolationCoeffs(self.rank_by, N, reputations, Nbins*len(self.participating_instruments[self.category_to_time(category)]))

	def rate(self, on_instruments, participating_instruments, mchirp, mtotal, snr, uncombined_ifar, likelihood):
		#
		# retrieve the appropriate reputation list (create an empty
		# one if there are no reputations for this category)
		#

		category = self.category_func(on_instruments, participating_instruments, mchirp, mtotal)
		reputations = self.reputations.setdefault(category, [])

		#
		# retrieve the livetime
		#

		t = self.cached_livetime.setdefault(self.category_to_time(category), 0.0)

		#
		# len(x) - bisect.bisect_left(x, reputation) = number of
		# elements in list >= reputation
		#

		n = len(reputations) - bisect.bisect_left(reputations, self.reputation_func(snr, uncombined_ifar, likelihood))

		#
		# extrapolate the event count if needed
		#

		if self.extrapolate and category in self.extrapolation_coeffs:
			extrapolation_coeffs = self.extrapolation_coeffs[category]
			if n < extrapolation_coeffs.npoints:
				n = extrapolation_coeffs.extrapolate_reputation(self.reputation_func(snr, uncombined_ifar, likelihood), t)

		#
		# return the rate of events above the given reputation
		#

		try:
			return n / t
		except ZeroDivisionError:
			print >>sys.stderr, "found an event in category %s that has a livetime of 0 s" % repr(category)
			raise

	def get_previous_newcorse_info(self, connection):
		# this should only return at most one thing because of max(), but if it returns nothing it is a no-op
		for procid, starttime in connection.cursor().execute('SELECT process.process_id, start_time FROM process WHERE process.program == "lalapps_newcorse" and start_time == (SELECT max(start_time) FROM process WHERE process.program == "lalapps_newcorse")'):
			if self.prev_newcorse_time is None: self.prev_newcorse_time = starttime
			if self.prev_newcorse_time != starttime:
				raise ValueError("found file that has wrong newcorse process time, exiting")
			for bins, in connection.cursor().execute('SELECT value from process_params WHERE process_id == ? AND param = "--mass-bins"', (procid,)):
				if self.prev_num_mass_bins is None: self.prev_num_mass_bins = bins
				if self.prev_num_mass_bins != bins:
					raise ValueError("previous mass bins do not agree between files, exiting")

	def get_prev_num_mass_bins(self):
		out = 0
		if self.prev_num_mass_bins is not None: out = len(self.prev_num_mass_bins.split(',')) - 1
		if out < 0: out = 0
		return out

def add_process(connection, opts, process, comment=u""):
	xmldoc = dbtables.get_xml(connection)
	params = ligolw_process.process_params_from_dict(opts.__dict__)
	ligolw_process.append_process_params(xmldoc, process, params)
	xmldoc.unlink()


#
# =============================================================================
#
#                                     Main
#
# =============================================================================
#


#
# Get the process table rolling, the XMLDoc here is a dummy.
# newcorse_process will be inserted into the DB correctly in add_process()
#

xmldoc = ligolw.Document()
xmldoc.appendChild(ligolw.LIGO_LW())
newcorse_process = ligolw_process.append_process(xmldoc, program = u"lalapps_newcorse", version = __version__, cvs_repository = u"lscsoft", cvs_entry_time = __date__, comment = u"")


#
# command line
#


options, filenames, unaltered_options = parse_command_line()


#
# initialize book-keeping
#

background = Summaries(options.categories, options.rank_by, mass_bins = options.mass_bins)


#
# iterate over database files accumulating background statistics
#


if options.verbose:
	print >>sys.stderr, "collecting background statistics ..."


for n, filename in enumerate(filenames):
	#
	# open the database
	#

	if options.verbose:
		print >>sys.stderr, "%d/%d: %s" % (n + 1, len(filenames), filename)
	working_filename = dbtables.get_connection_filename(filename, tmp_path = options.tmp_space, verbose = options.verbose)
	connection = sqlite3.connect(working_filename)

	#
	# First pull out process info, then add some to it
	#
	background.get_previous_newcorse_info(connection)
	add_process(connection, unaltered_options, newcorse_process)

	#
	# if the database contains a sim_inspiral table then it is assumed
	# to represent an injection run.  its rings must not added to the
	# livetime, and it cannot provide background coincs, so it is just
	# skipped altogether in this first pass.
	#

	if "sim_inspiral" in dbtables.get_table_names(connection):
		if options.verbose:
			print >>sys.stderr, "\tdatabase contains sim_inspiral table, skipping ..."

		#
		# close the database
		#

		connection.close()
		dbtables.discard_connection_filename(filename, working_filename, verbose = options.verbose)
		continue

	#
	# compute and record background livetime
	#

	background.add_livetime(connection, options.live_time_program, veto_segments_name = options.veto_segments_name, verbose = options.verbose)

	#
	# count background coincs
	#

	if options.verbose:
		print >>sys.stderr, "\tcollecting background statistics ..."
	for on_instruments, participating_instruments, mchirp, mtotal, snr, uncombined_ifar, likelihood in connection.cursor().execute("""
SELECT
	coinc_event.instruments,
	coinc_inspiral.ifos,
	coinc_inspiral.mchirp,
	coinc_inspiral.mass,
	coinc_inspiral.snr,
	CASE coinc_inspiral.%s WHEN 0 THEN "inf" ELSE 1.0 / coinc_inspiral.%s END,
	coinc_event.likelihood
FROM
	coinc_event
	JOIN coinc_inspiral ON (
		coinc_inspiral.coinc_event_id == coinc_event.coinc_event_id
	)
WHERE
	-- require coinc to be background (= at least one of its time slide offsets is non-zero)
	EXISTS (
		SELECT
			*
		FROM
			time_slide
		WHERE
			time_slide.time_slide_id == coinc_event.time_slide_id
			AND time_slide.offset != 0
	)
	""" % (options.false_alarm_rate_column, options.false_alarm_rate_column)):
		background.add_coinc(on_instruments, participating_instruments, mchirp, mtotal, snr, uncombined_ifar, likelihood)

	#
	# close the database
	#

	connection.close()
	dbtables.discard_connection_filename(filename, working_filename, verbose = options.verbose)

background.index()
if options.extrapolation_num is not None:
	if options.verbose:
		print >>sys.stderr, "\tcalculating FAR extrapolation coefficients ..."
	background.rate_extrapolation(options.extrapolation_num, background.get_prev_num_mass_bins())

#
# iterate over database files assigning false-alarm rates to coincs
#


for n, filename in enumerate(filenames):
	#
	# open the database
	#

	if options.verbose:
		print >>sys.stderr, "%d/%d: %s" % (n + 1, len(filenames), filename)
	working_filename = dbtables.get_connection_filename(filename, tmp_path = options.tmp_space, verbose = options.verbose)
	connection = sqlite3.connect(working_filename)

	#
	# prepare the database
	#

	connection.create_function("background_rate", 7, background.rate)

	#
	# count background coincs by type and mass bin
	#

	if options.verbose:
		print >>sys.stderr, "\tcalculating and recording false alarm rates ..."
	connection.cursor().execute("""
UPDATE
	coinc_inspiral
SET
	%s = (
		SELECT
			background_rate(
				coinc_event.instruments,
				coinc_inspiral.ifos,
				coinc_inspiral.mchirp,
				coinc_inspiral.mass,
				coinc_inspiral.snr,
				CASE coinc_inspiral.%s WHEN 0 THEN "inf" ELSE 1.0 / coinc_inspiral.%s END,
				coinc_event.likelihood
			)
		FROM
			coinc_event
		WHERE
			coinc_event.coinc_event_id == coinc_inspiral.coinc_event_id
	)
	""" % (options.populate_column, options.false_alarm_rate_column, options.false_alarm_rate_column) )
	connection.commit()

	#
	# close the database
	#

	connection.close()
	dbtables.put_connection_filename(filename, working_filename, verbose = options.verbose)
