#!/usr/bin/python
#
# Copyright (C) 2009  Kipp Cannon, Chad Hanna
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 2 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.


#
# =============================================================================
#
#                                   Preamble
#
# =============================================================================
#


import bisect
from new import instancemethod
from optparse import OptionParser
try:
	import sqlite3
except ImportError:
	# pre 2.5.x
	from pysqlite2 import dbapi2 as sqlite3
import sys


from glue import segments
from glue.ligolw import lsctables
from glue.ligolw import dbtables
from glue.ligolw.utils import search_summary as ligolw_search_summary
from glue.ligolw.utils import segments as ligolw_segments
from pylal import git_version
from pylal import db_thinca_rings
from pylal.xlal.datatypes.ligotimegps import LIGOTimeGPS
from pylal import SnglBurstUtils

lsctables.LIGOTimeGPS = LIGOTimeGPS


__author__ = "Kipp Cannon <kipp.cannon@ligo.org>"
__version__ = "git id %s" % git_version.id
__date__ = git_version.date


#
# =============================================================================
#
#                                 Command Line
#
# =============================================================================
#


def parse_command_line():
	parser = OptionParser(
		version = "Name: %%prog\n%s" % git_version.verbose_msg,
		usage = "%prog [options] [file ...]",
		description = "%prog does blah blah blah."
	)
	parser.add_option("-p", "--live-time-program", default="omega_to_coinc", metavar = "name", help = "Set the name of the program whose entries in the search_summary table will set the search live time.  Required.")
	parser.add_option("--rank-by", metavar = "{\"snr\",\"likelihood\", \"snr_and_likelihood\"}", help = "Select the event ranking method.  Required.")
	parser.add_option("--veto-segments-name", help = "Set the name of the segments to extract from the segment tables and use as the veto list.")
	parser.add_option("-t", "--tmp-space", metavar = "path", help = "Path to a directory suitable for use as a work area while manipulating the database file.  The database file will be worked on in this directory, and then moved to the final location when complete.  This option is intended to improve performance when running in a networked environment, where there might be a local disk with higher bandwidth than is available to the filesystem on which the final output will reside.")
	parser.add_option("-v", "--verbose", action = "store_true", help = "Be verbose.")
	options, filenames = parser.parse_args()

	if options.rank_by not in ("snr", "likelihood", "snr_and_likelihood"):
		raise ValueError, "missing or unrecognized --rank-by option"

	# FIXME this probably doesn't need to be an option since there is only one column for burst codes
	options.false_alarm_rate_column = "false_alarm_rate"

	if options.live_time_program is None:
		raise ValueError, "missing required option -p or --live-time-program"

	#
	# done
	#

	return options, (filenames or [None])


#
# =============================================================================
#
#                               Helper Functions
#
# =============================================================================
#


#
# =============================================================================
#
#                                 Book-Keeping
#
# =============================================================================
#


class Summaries(object):
	def __init__(self, rank_by, category_algorithm="oninstruments"):

		if rank_by == "snr":
			self.reputation_func = lambda self, snr, likelihood: snr
		elif rank_by == "likelihood":
			self.reputation_func = lambda self, snr, likelihood: likelihood
		else:
			raise ValueError, rank_by
		self.rank_by = rank_by

		self.reputation_func = instancemethod(self.reputation_func, self, self.__class__)

		#FIXME make categories cooler if needed, for now there is only "on_instruments" as a category
		self.category_func = lambda self, on_instruments: on_instruments
		self.category_func = instancemethod(self.category_func, self, self.__class__)		
		self.time_func = lambda self, on_instruments: on_instruments
		self.time_func = instancemethod(self.time_func, self, self.__class__)
		self.category_to_time = lambda self, category: category[0]
		self.category_to_time = instancemethod(self.category_to_time, self, self.__class__)

		#self.mass_bins = mass_bins
		self.reputations = {}
		self.zero_lag_live_time = {}
		self.background_live_time = {}
		self.categories = []
		self.times = []

	def add_livetime(self, connection, live_time_program, veto_segments_name = None, verbose = False):

		if live_time_program == "omega_to_coinc": return self.add_livetime_nonring(connection, live_time_program, veto_segments_name, verbose = verbose)
		if live_time_program == "waveburst": return self.add_livetime_ring(connection, live_time_program, veto_segments_name, verbose = verbose)

	def add_livetime_nonring(self, connection, live_time_program, veto_segments_name, verbose = False):

		# get the segment lists and live time
		xmldoc = dbtables.get_xml(connection)
		zero_lag_time_slides, background_time_slides = SnglBurstUtils.get_time_slides(connection)
		seglists = ligolw_search_summary.segmentlistdict_fromsearchsummary(xmldoc, live_time_program).coalesce()
		if veto_segments_name is not None:
			veto_segs = ligolw_segments.segmenttable_get_by_name(xmldoc, veto_segments_name).coalesce()
			seglists -= veto_segs
		instruments = frozenset(seglists.keys())
		self.zero_lag_live_time.setdefault(self.category_func(instruments), 0)
		self.background_live_time.setdefault(self.category_func(instruments), 0)
		
		self.zero_lag_live_time[self.category_func(instruments)] += SnglBurstUtils.time_slides_livetime(seglists, zero_lag_time_slides.values(), verbose = verbose)
                self.background_live_time[self.category_func(instruments)] += SnglBurstUtils.time_slides_livetime(seglists, background_time_slides.values(), verbose = verbose)

	def add_livetime_ring(self, connection, live_time_program, veto_segments_name=None, verbose = False):

		if veto_segments_name is not None:
			if verbose:
				print >>sys.stderr, "\tretrieving veto segments \"%s\" ..." % veto_segments_name
			veto_segments = db_thinca_rings.get_veto_segments(connection, veto_segments_name)
		else:
			veto_segments = segments.segmentlistdict()

		if verbose:
			print >>sys.stderr, "\tcomputing livetimes:",

		for on_instruments, livetimes in db_thinca_rings.get_thinca_livetimes(db_thinca_rings.get_thinca_rings_by_available_instruments(connection, program_name = live_time_program), veto_segments, db_thinca_rings.get_background_offset_vectors(connection), verbose = verbose).items():

			#on_instruments = lsctables.ifos_from_instrument_set(on_instruments)
			self.zero_lag_live_time.setdefault(self.category_func(on_instruments), 0)
			self.background_live_time.setdefault(self.category_func(on_instruments), 0)
			self.background_live_time[self.category_func(on_instruments)] += sum(livetimes)

			if verbose:
				print >>sys.stderr

	def add_coinc(self, on_instruments, snr, likelihood):

		on_instruments = frozenset(lsctables.instrument_set_from_ifos(on_instruments))
		self.reputations.setdefault(self.category_func(on_instruments),[]).append(self.reputation_func(snr, likelihood))
		if self.category_func(on_instruments) not in self.categories:
			self.categories.append(self.category_func(on_instruments))
		if self.time_func(on_instruments) not in self.times:
			self.times.append(self.time_func(on_instruments))

	def index(self):
		for reputations in self.reputations.values():
			reputations.sort()

	def rate(self, on_instruments, participating_instruments, snr, likelihood):
		#
		# retrieve the appropriate reputation list (create an empty
		# one if there are no reputations for this category)
		#

		on_instruments = frozenset(lsctables.instrument_set_from_ifos(on_instruments))


		category = self.category_func(on_instruments)

		reputations = self.reputations.setdefault(category, [])
		

		#
		# len(x) - bisect.bisect_left(x, reputation) = number of
		# elements in list >= reputation
		#

		n = len(reputations) - bisect.bisect_left(reputations, self.reputation_func(snr, likelihood))

		#
		# retrieve the livetime
		#


		t = self.background_live_time.setdefault(category, 0.0)

		
		#
		# return the rate of events above the given reputation
		#

		try:
			return n / t
		except ZeroDivisionError, e:
			print >>sys.stderr, "found an event in category %s that has a livetime of 0 s" % repr(self.category_func(on_instruments))
			raise e


#
# =============================================================================
#
#                                     Main
#
# =============================================================================
#


#
# command line
#


options, filenames = parse_command_line()


#
# initialize book-keeping
#


background = Summaries(options.rank_by)


#
# iterate over database files accumulating background statistics
#


if options.verbose:
	print >>sys.stderr, "collecting background statistics ..."


for n, filename in enumerate(filenames):
	#
	# open the database
	#

	if options.verbose:
		print >>sys.stderr, "%d/%d: %s" % (n + 1, len(filenames), filename)
	working_filename = dbtables.get_connection_filename(filename, tmp_path = options.tmp_space, verbose = options.verbose)
	connection = sqlite3.connect(working_filename)

	#
	# if the database contains a sim_inspiral table then it is assumed
	# to represent an injection run.  its rings must not added to the
	# livetime, and it cannot provide background coincs, so it is just
	# skipped altogether in this first pass.
	#

	if "sim_inspiral" in dbtables.get_table_names(connection) or "sim_burst" in dbtables.get_table_names(connection):
		if options.verbose:
			print >>sys.stderr, "\tdatabase contains sim_inspiral table, skipping ..."

		#
		# close the database
		#

		connection.close()
		dbtables.discard_connection_filename(filename, working_filename, verbose = options.verbose)
		continue

	#
	# compute and record background livetime
	#

	background.add_livetime(connection, options.live_time_program, options.veto_segments_name, verbose = options.verbose)

	#
	# count background coincs
	#

	if options.verbose:
		print >>sys.stderr, "\tcollecting background statistics ..."
	for on_instruments, participating_instruments, snr, likelihood in connection.cursor().execute("""
SELECT
	coinc_event.instruments,
	multi_burst.ifos,
	multi_burst.snr,
	coinc_event.likelihood
FROM
	coinc_event
	JOIN multi_burst ON (
		multi_burst.coinc_event_id == coinc_event.coinc_event_id
	)
WHERE
	-- require coinc to be background (= at least one of its time slide offsets is non-zero)
	EXISTS (
		SELECT
			*
		FROM
			time_slide
		WHERE
			time_slide.time_slide_id == coinc_event.time_slide_id
			AND time_slide.offset != 0
	)
	"""):
		background.add_coinc(on_instruments, snr, likelihood)

	#
	# close the database
	#

	connection.close()
	dbtables.discard_connection_filename(filename, working_filename, verbose = options.verbose)


background.index()

#
# iterate over database files assigning false-alarm rates to coincs
#


for n, filename in enumerate(filenames):
	#
	# open the database
	#

	if options.verbose:
		print >>sys.stderr, "%d/%d: %s" % (n + 1, len(filenames), filename)
	working_filename = dbtables.get_connection_filename(filename, tmp_path = options.tmp_space, verbose = options.verbose)
	connection = sqlite3.connect(working_filename)

	#
	# prepare the database
	#

	connection.create_function("background_rate", 4, background.rate)

	#
	# count background coincs by type and mass bin
	#

	if options.verbose:
		print >>sys.stderr, "\tcalculating and recording false alarm rates ..."
	connection.cursor().execute("""
UPDATE
	multi_burst
SET
	false_alarm_rate = (
		SELECT
			background_rate(
				coinc_event.instruments,
				multi_burst.ifos,
				multi_burst.snr,
				coinc_event.likelihood
			)
		FROM
			coinc_event
		WHERE
			coinc_event.coinc_event_id == multi_burst.coinc_event_id
	)
	""")
	connection.commit()

	#
	# close the database
	#

	connection.close()
	dbtables.put_connection_filename(filename, working_filename, verbose = options.verbose)
