#!/usr/bin/python
#
# Copyright (C) 2009  Drew Keppel, Sarah Caudill
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 2 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.


#
# =============================================================================
#
#                                   Preamble
#
# =============================================================================
#

description = \
"Creates an ROC plot for either single databases or caches of databases. To specify a cache, use the --add-cache argument. To specify single files, give them as command line arguments. One curve will be generated for each cache file and for each single file. If you would like a label to appear on the plot for single files, add a colon and the label after each filename, e.g., lalapps_cbc_plotroc file1.sqlite:week1. Otherwise, if no label is given, files will be labeled 'fileN' where N is order in which the file was given. Similar rules apply to cache files."

import math
import matplotlib
matplotlib.rcParams.update({
	"font.size": 8.0,
	"axes.titlesize": 10.0,
	"axes.labelsize": 10.0,
	"xtick.labelsize": 8.0,
	"ytick.labelsize": 8.0,
	"legend.fontsize": 8.0,
	"figure.dpi": 200,
	"savefig.dpi": 200,
	"text.usetex": True,
	"path.simplify": True
})
from matplotlib import figure
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
try:
	from matplotlib.transforms import offset_copy
except:
	# FIXME: wrong matplotlib version, disable this feature;  figure
	# out how to do this portably later.
	pass
import numpy
from optparse import OptionParser
try:
	import sqlite3
except ImportError:
	# pre 2.5.x
	from pysqlite2 import dbapi2 as sqlite3
import sys, re

from glue import lal
from glue.ligolw import table
from glue.ligolw import dbtables
from pylal import git_version
from pylal import ligolw_sqlutils as sqlutils
from pylal import printutils

__author__ = "Drew Keppel <drew.keppel@ligo.org>"
__version__ = "git id %s" % git_version.id
__date__ = git_version.date


#
# =============================================================================
#
#                                 Command Line
#
# =============================================================================
#


def parse_command_line():
	parser = OptionParser(
		version = "Name: %%prog\n%s" % git_version.verbose_msg,
		usage   = "%prog [options] file1.sqlite[:label1] file2.sqlite[:label2] ...",
		description = description
	)
	parser.add_option("-b", "--base", metavar = "base", default = "cbc_plotroc", help = "Set the prefix for output filenames (default = \"cbc_plotroc\")")
	parser.add_option("-f", "--format", metavar = "{\"png\",\"pdf\",\"svg\",\"eps\",...}", action = "append", default = ["png"], help = "Set the output image format.  Can be given multiple times (default = \"png\")")
	parser.add_option("-O", "--output-path", action = "store", default = ".", help = "Redirect output to specified directory. Default is to save output " +
			"to current directory."	)
	parser.add_option("-c", "--add-cache", metavar = "file.cache[:label]", action = "append", default = [], help = "Cache file containing injection sqlite files for plotting ROC curves. All files in each cache file will be grouped together into a single curve. To specify a label for this curve, add a colon and the label name after cache. For example, file_name1.cache:week1. If no label given, the file will be labeled cacheN where N is the order that the file was specified on the command line.")
	parser.add_option("-l", "--map-label", default = None, help = "Required. Name of the mapping that was used for finding injections.")
	parser.add_option("", "--far-column", action = "append", default = [], help = "Column to get false alarm rates from. Can be any column in the coinc table. If specified multiple times, will create a separate curve for each column. Default is to plot combined_far only.")
	parser.add_option("", "--sngl-table", action = "store", type="string", default = None, help = "Required. Should be set to sngl_inspiral or sngl_ringdown.")
	parser.add_option("", "--sim_table", action = "store", type="string", default = None, help = "Required. Should be set to the injection table that has mappings to the coinc_event_map table; coinc_inspiral or coinc_ringdown.")
	parser.add_option("", "--coinc-table", action = "store", type="string", default = None, help = "Required. Should be set to coinc_inspiral or coinc_ringdown.")
	parser.add_option( "", "--param-name", metavar = "PARAMETER", action = "store", default = None,	help = "Can be any parameter in the sim_inspiral table. Specifying this and param-ranges will only select triggers that fall within the given parameter ranges. You can also combine parameters using simple math. For example, to select triggers bassed on total mass, type 'sim_inspiral.mass1+sim_inspiral.mass2', then give total mass ranges in the param-ranges argument. Note that when combining columns you must type the table name dot each column name. Allowed math operations are +,-,*,/.")
	parser.add_option( "", "--param-ranges", action = "store", default = None, metavar = " [ LOW1, HIGH1 ); ( LOW2, HIGH2]; !VAL3; etc.", help = "Requires --param-name. Specify the parameter ranges to select triggers in. A '(' or ')' implies an open boundary, a '[' or ']' a closed boundary. To specify multiple ranges, separate each range by a ';'. To specify a single value, just type that value with no parentheses or brackets. To specify not equal to a single value, put a '!' before the value. If multiple ranges are specified, a separate curve will be created for each range.")
	parser.add_option( "", "--exclude-coincs", action = "store", type = "string", default = None, metavar = " [COINC_INSTRUMENTS1 + COINC_INSTRUMENTS2 in INSTRUMENTS_ON1];[ALL in INSTRUMENTS_ON2]; etc.",	help = "Exclude coincident types in specified detector times, e.g., '[H2,L1 in H1,H2,L1]'. Some rules: * Coinc-types and detector time must be separated by an ' in '. When specifying a coinc_type or detector time, detectors and/or ifos must be separated by commas, e.g. 'H1,L1' not 'H1L1'. * To specify multiple coinc-types in one type of time, separate each coinc-type by a '+', e.g., '[H1,H2 + H2,L1 in H1,H2,L1]'. * To exclude all coincs in a specified detector time or specific coinc-type in all times, use 'ALL'. E.g., to exclude all H1,H2 triggers, use '[H1,H2 in ALL]' or to exclude all H2,L1 time use '[ALL in H2,L1]'. * To specify multiple exclusions, separate each bracket by a ';'. * Order of the instruments nor case of the letters matter. So if your pinky is broken and you're dyslexic you can type '[h2,h1 in all]' without a problem.")
	parser.add_option( "", "--include-only-coincs", action = "store", type = "string", default = None, metavar = " [COINC_INSTRUMENTS1 + COINC_INSTRUMENTS2 in INSTRUMENTS_ON1];[ALL in INSTRUMENTS_ON2]; etc.", help = "Opposite of --exclude-coincs: only create plots for the specified coinc types. ")
	parser.add_option( "", "--sim-tag", action = "store", type = "string", default = 'ALLINJ', help = "Create ROC curves for a specific simulation tag, e.g., 'BNSINJ'. To plot multiple injection types in a single curve, separate the types by a '+'. To plot multiple injection types as separate curves, separate by a ';'. For example, specifying --sim-tag 'BNSLOGINJ_A+BNSLOGINJ_B;BNSLININJ' will cause BNSLOGINJA_A and BNSLOGINJ_B to be one curve, and BNSLININJ to be a separate curve (both curves will be on the same plot).")
	parser.add_option("-T", "--title", action = "store", default = None, help = "Add a title to the plot.")
	parser.add_option("-v", "--verbose", action = "store_true", help = "Be verbose.")
	parser.add_option("-D", "--debug", action = "store_true", default = False, help = "Turns on sqlite3.enable_callback_traceback and prints main sqlquery used.")

	options, filenames = parser.parse_args()

	if options.map_label is None:
		raise ValueError, "map-label is a required argument"
	if options.far_column == []:
		options.far_column = ['combined_far']

	# check for required options
	if not options.sngl_table:
		raise ValueError, "--sngl-table is a required argument."
	if not options.sim_table:
		raise ValueError, "--sim-table is a required argument."
	if not options.coinc_table:
		raise ValueError, "--coinc-table is a required argument."

	return options, filenames


#
# =============================================================================
#
#                                   Database
#
# =============================================================================
#


class CoincDatabase(object):
	def __init__(self, connection, options, verbose = False):
		"""
		Compute and record some summary information about the
		database.
		"""

		self.connection = connection
		xmldoc = dbtables.get_xml(connection)

		cursor = connection.cursor()

		# find the tables
		self.sngl_table = sqlutils.validate_option( options.sngl_table )
		self.sim_table = sqlutils.validate_option( options.sim_table )
		self.coinc_def_table = table.get_table(xmldoc, dbtables.lsctables.CoincDefTable.tableName)
		self.coinc_table = table.get_table(xmldoc, dbtables.lsctables.CoincTable.tableName)
		self.time_slide_table = table.get_table(xmldoc, dbtables.lsctables.TimeSlideTable.tableName)
		self.coinc_table = sqlutils.validate_option( options.coinc_table )
		

		# determine a few coinc_definer IDs
		# FIXME:  don't hard-code the numbers
		if self.coinc_def_table is not None and cmp(self.sngl_table, "sngl_inspiral"):
			try:
				self.ii_definer_id = self.coinc_def_table.get_coinc_def_id("inspiral", 0, create_new = False)
			except KeyError:
				self.ii_definer_id = None
			try:
				self.si_definer_id = self.coinc_def_table.get_coinc_def_id("inspiral", 1, create_new = False)
			except KeyError:
				self.si_definer_id = None
			try:
				self.sc_definer_id = self.coinc_def_table.get_coinc_def_id("inspiral", 2, create_new = False)
			except KeyError:
				self.sc_definer_id = None
		elif self.coinc_def_table is not None and cmp(self.sngl_table, "sngl_ringdown"):
			try:
				self.ii_definer_id = self.coinc_def_table.get_coinc_def_id("ring", 0, create_new = False)
			except KeyError:
				self.ii_definer_id = None
			try:
				self.si_definer_id = self.coinc_def_table.get_coinc_def_id("ring", 1, create_new = False)
			except KeyError:
				self.si_definer_id = None
			try:
				self.sc_definer_id = self.coinc_def_table.get_coinc_def_id("ring", 2, create_new = False)
			except KeyError:
				self.sc_definer_id = None
		else:
			self.ii_definer_id = None
			self.si_definer_id = None
			self.sc_definer_id = None
		print >>sys.stderr, "\tii_definer_id: %s" % (self.ii_definer_id)
		print >>sys.stderr, "\tii_definer_id: %s" % (self.si_definer_id)
		print >>sys.stderr, "\tii_definer_id: %s" % (self.sc_definer_id)

		# verbosity
		print >>sys.stderr, "\tsingle_table events: %d" % len(self.sngl_table)
		print >>sys.stderr, "\tinjections: %d" % len(self.sim_table)
		print >>sys.stderr, "\ttime slides: %d" % cursor.execute("SELECT COUNT(DISTINCT(time_slide_id)) FROM time_slide").fetchone()[0]
		if verbose:
			print >>sys.stderr, "database overview:"
			if self.sngl_table is not None:
				print >>sys.stderr, "\tsingle_table events: %d" % len(self.sngl_table)
			if self.sim_table is not None:
				print >>sys.stderr, "\tinjections: %d" % len(self.sim_table)
			if self.time_slide_table is not None:
				print >>sys.stderr, "\ttime slides: %d" % cursor.execute("SELECT COUNT(DISTINCT(time_slide_id)) FROM time_slide").fetchone()[0]
			if self.coinc_def_table is not None:
				for description, n in cursor.execute("SELECT description, COUNT(*) FROM coinc_definer NATURAL JOIN coinc_event GROUP BY coinc_def_id"):
					print >>sys.stderr, "\t%s: %d" % (description, n)

#
# =============================================================================
#
#                                  Utilities
#
# =============================================================================
#


def create_plot(x_label = None, y_label = None, title = None, width = 165.0, aspect = (1 + math.sqrt(5)) / 2):
	fig = figure.Figure()
	FigureCanvas(fig)
	fig.set_size_inches(width / 25.4, width / 25.4 / aspect)
	axes = fig.gca()
	axes.grid(True)
	if x_label is not None:
		axes.set_xlabel(x_label)
	if y_label is not None:
		axes.set_ylabel(y_label)
	if title is not None:
		axes.set_title(re.sub(r'_', r'\_', title))
	return fig, axes



#
# =============================================================================
#
#                              ROC Plots
#
# =============================================================================
#

class ROCPlots(object):
	def __init__(self):
		self.stats = {}

	def add_contents(self, label, contents, far_column = 'combined_far', filter = '', debug = False):
		if contents.sim_table:
			all_coincs = {}
			if contents.sim_table=="sim_ringdown":
				sqlquery = ''.join([ """
SELECT
	MIN(""", contents.coinc_table, """.""", far_column, """),
	-- Work out the correction factor for injection population distances
	CASE (SELECT value FROM process_params WHERE program =="rinj" AND param =="--d-distr" AND process_id == sim_ringdown.process_id)
		WHEN "log10" THEN  sim_ringdown.distance * sim_ringdown.distance * sim_ringdown.distance
		WHEN "uniform" THEN  sim_ringdown.distance * sim_ringdown.distance
		ELSE 1.0 END
FROM
	""", contents.coinc_table, """
""", sqlutils.join_experiment_tables_to_coinc_table( contents.coinc_table ), """
	JOIN coinc_event_map AS mapA ON (mapA.event_id == """, contents.coinc_table, """.coinc_event_id AND mapA.table_name == "coinc_event")
	JOIN coinc_event_map AS mapB ON (mapB.coinc_event_id == mapA.coinc_event_id AND mapB.table_name == "sim_ringdown")
	JOIN sim_ringdown ON (mapB.event_id == sim_ringdown.simulation_id)
    JOIN coinc_event, coinc_definer ON (coinc_event.coinc_event_id == mapB.coinc_event_id AND coinc_definer.coinc_def_id == coinc_event.coinc_def_id)
WHERE
	experiment_summary.datatype != "slide" AND
	coinc_definer.description == ?
	""", filter is not '' and 'AND ' + filter or '', """
GROUP BY
	mapB.event_id """])
				if debug:
					print >> sys.stderr, "SQLite query used to retrieve triggers:"
					print >> sys.stderr, sqlquery
				for far, distance in contents.connection.cursor().execute( sqlquery  (map_label,) ):
					self.stats.setdefault(label,[]).append((far, distance))

			if contents.sim_table=="sim_inspiral":
				sqlquery = ''.join([ """
SELECT
	MIN(""", contents.coinc_table, """.""", far_column, """),
	-- Work out the correction factor for injection population distances
	CASE (SELECT value FROM process_params WHERE program =="inspinj" AND param =="--d-distr" AND process_id == sim_inspiral.process_id)
		WHEN "log10" THEN  sim_inspiral.distance * sim_inspiral.distance * sim_inspiral.distance
		WHEN "uniform" THEN  sim_inspiral.distance * sim_inspiral.distance
		ELSE 1.0 END
FROM
	""", contents.coinc_table, """
""", sqlutils.join_experiment_tables_to_coinc_table( contents.coinc_table ), """
	JOIN coinc_event_map AS mapA ON (mapA.event_id == """, contents.coinc_table, """.coinc_event_id AND mapA.table_name == "coinc_event")
	JOIN coinc_event_map AS mapB ON (mapB.coinc_event_id == mapA.coinc_event_id AND mapB.table_name == "sim_inspiral")
	JOIN sim_inspiral ON (mapB.event_id == sim_inspiral.simulation_id)
WHERE
	experiment_summary.datatype != "slide"
	""", filter is not '' and 'AND ' + filter or '', """
GROUP BY
	mapB.event_id """])
				if debug:
					print >> sys.stderr, "SQLite query used to retrieve triggers:"
					print >> sys.stderr, sqlquery
				for far, distance in contents.connection.cursor().execute( sqlquery ):
					self.stats.setdefault(label,[]).append((far, distance))

	def finish(self, title):
		norm = 0
		for label in self.stats.keys():
			norm = max(sum([x[1] for x in self.stats[label]]), norm)
		for label in self.stats.keys():
			#norm = sum([x[1] for x in self.stats[label]])
			self.stats[label] = [(x[0],x[1]/norm) for x in self.stats[label]]
			self.stats[label].sort()
		fig, axes = create_plot(r"False Alarm Rate", r"Relative $V \times T$", title = title, aspect = 1.0, width=100)
		axes.hold(1)
		for label in sorted(self.stats.keys()):
			axes.semilogx([x[0] for x in self.stats[label]], numpy.cumsum([x[1] for x in self.stats[label]]), label=label, alpha = .9)
		axes.set_ylim(0,1)
		axes.legend(loc='lower right')
		axes.hold(0)
		return fig, "ROC"

#
# =============================================================================
#
#                                     Main
#
# =============================================================================
#


#
# Parse command line
#


options, filenames = parse_command_line()

sqlite3.enable_callback_tracebacks(options.debug)

#
# Process files
#

plot = ROCPlots()

labels = {}
# get cache-files
if options.verbose and options.add_cache != []:
	print >> sys.stderr, "Gathering files from cache files..."
for n, cachename in enumerate(options.add_cache):
	label = len(cachename.split(':')) > 1 and cachename.split(':')[-1] or "%s%d" % ("cache", n)
	cn = re.sub( label[::-1]+':', '', cachename[::-1], 1 )[::-1]
	for fn in [lal.CacheEntry(line).path for line in file(cn) if lal.CacheEntry(line).path.endswith('sqlite')]:
		labels[fn] = label

# get files
for n, fnl in enumerate(filenames):
	label = len(fnl.split(':')) > 1 and fnl.split(':')[-1] or "%s%d" % (fnl in options.add_cache and "cache" or "file", n)
	fn = re.sub( label[::-1]+':', '', fnl[::-1], 1 )[::-1]
	labels[fn] = label

if options.verbose:
	print >> sys.stderr, "Processing files..."

for m, filename in enumerate(labels.keys()):
	if options.verbose:
		print >>sys.stderr, "%d/%d: %s" % (m + 1, len(filenames), filename)
	working_filename = dbtables.get_connection_filename(filename, verbose = options.verbose)
	connection = sqlite3.connect(working_filename)
	contents = CoincDatabase(connection, options, verbose = options.verbose)
	print >>sys.stderr, "Sim table: %s" % (contents.sim_table)
	print >>sys.stderr, "Coinc table: %s" % (contents.coinc_table)
	if not contents.sim_table:
		if options.verbose:
			"No sim table found in database %s; skipping..." % (filename)
		connection.close()
		dbtables.discard_connection_filename(filename, working_filename, verbose = options.verbose)
		continue
	print >>sys.stderr, "Options far_column: %s" % (options.far_column)
	for far_stat in options.far_column:
		for sim_tag in options.sim_tag.split(';'):
			if options.param_ranges is None:
				param_ranges = ['']
			else:
				param_ranges = options.param_ranges.split(';')
			for param_range in param_ranges:
				print >>sys.stderr, "Contents Coinc table: %s" % (contents.coinc_table)
				print >>sys.stderr, "Options Coinc table: %s" % (options.coinc_table)
				filter = printutils.create_filter( connection, contents.coinc_table, param_name = options.param_name, param_ranges = param_range,
					exclude_coincs = options.exclude_coincs, include_only_coincs = options.include_only_coincs,
					sim_tag = sim_tag, verbose = options.verbose)
				print >>sys.stderr, "Filter: %s" % (filter)
				print >>sys.stderr, "Contents sngl_table1: %s" % (contents.sngl_table)
				print >>sys.stderr, "Contents sim_table1: %s" % (contents.sim_table)
				if contents.sim_table=="sim_ringdown":
					print >>sys.stderr, "Contents sngl_table2: %s" % (contents.sngl_table)
					print >>sys.stderr, "labels[filename]: %s" % (labels[filename])
					print >>sys.stderr, "filename: %s" % (filename)
					label = "%s%s%s %s" %( labels[filename],
						len(options.sim_tag.split(';')) > 1 and sim_tag != 'ALLINJ' and ' ' + sim_tag  or '',
						len(param_ranges) > 1 and param_range != '' and ' ' + re.sub(r'sim_ringdown[.]','',options.param_name) + ' ' + param_range or '',
						len(options.far_column) > 1 and far_stat or '')
					print >>sys.stderr, "label1: %s" % (label)
				if contents.sim_table=="sim_inspiral":
					print >>sys.stderr, "Contents sngl_table3: %s" % (contents.sngl_table)
					label = "%s%s%s %s" %( labels[filename],
						len(options.sim_tag.split(';')) > 1 and sim_tag != 'ALLINJ' and ' ' + sim_tag  or '',
						len(param_ranges) > 1 and param_range != '' and ' ' + re.sub(r'sim_inspiral[.]','',options.param_name) + ' ' + param_range or '',
						len(options.far_column) > 1 and far_stat or '')
					print >>sys.stderr, "label2: %s" % (label)
				# add escape characers to underscores in label as they can cause LaTeX problems
				print >>sys.stderr, "label3: %s" % (label)
				label = re.sub(r'_', r'\_', label)
				plot.add_contents(label, contents, far_column = sqlutils.validate_option(far_stat), filter = filter, debug = options.debug)
	connection.close()
	dbtables.discard_connection_filename(filename, working_filename, verbose = options.verbose)


#
# Finish and write plots, deleting them as we go to save memory
#


filename_template = "%s/%s_%s.%s"
fig, filename_fragment = plot.finish( title = options.title )
for format in options.format:
	filename = filename_template % ( options.output_path, options.base, filename_fragment, format)
	if options.verbose:
		print >>sys.stderr, "writing %s ..." % filename
	fig.savefig(filename)

sys.exit(0)
