#!/usr/bin/env python
#
# Copyright (C) 2009-2013  Kipp Cannon, Chad Hanna, Drew Keppel
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 2 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

## @file
# A program to produce a variety of plots from a gstlal inspiral analysis, e.g. IFAR plots, missed found, etc.

#
# =============================================================================
#
#                                   Preamble
#
# =============================================================================
#


import math
import matplotlib
matplotlib.rcParams.update({
	"font.size": 16.0,
	"axes.titlesize": 14.0,
	"axes.labelsize": 14.0,
	"xtick.labelsize": 13.0,
	"ytick.labelsize": 13.0,
	"legend.fontsize": 10.0,
	"figure.dpi": 300,
	"savefig.dpi": 300,
	"text.usetex": True,
	"path.simplify": True
})
from matplotlib import figure
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
import scipy
import numpy
from optparse import OptionParser
import sqlite3
import sys

import os

from glue import segments
from glue.ligolw import dbtables
from glue.ligolw import lsctables
from glue.ligolw.utils import segments as ligolw_segments
from glue import lal
from pylal import db_thinca_rings
from pylal import git_version
from pylal import SimBurstUtils
from pylal.xlal.datatypes.ligotimegps import LIGOTimeGPS
from gstlal import far
from gstlal import inspiral_pipe


class SimInspiral(lsctables.SimInspiral):
	@property
	def mtotal(self):
		return self.mass1 + self.mass2

	@property
	def chi(self):
		return (self.mass1 * self.spin1z + self.mass2 * self.spin2z) / self.mtotal


class SnglInspiral(lsctables.SnglInspiral):
	@property
	def mtotal(self):
		return self.mass1 + self.mass2

	@property
	def eta(self):
		return self.mass1 * self.mass2 / self.mtotal**2.

	@property
	def mchirp(self):
		return self.mtotal * self.eta**0.6

	@property
	def chi(self):
		return (self.mass1 * self.spin1z + self.mass2 * self.spin2z) / self.mtotal

	def get_effective_snr(self, fac):
		return self.snr / (self.chisq / self.chisq_dof)**.5

lsctables.LIGOTimeGPS = LIGOTimeGPS
lsctables.SimInspiralTable.RowType = SimInspiral
lsctables.SnglInspiralTable.RowType = SnglInspiral


__author__ = "Kipp Cannon <kipp.cannon@ligo.org>, Chad Hanna <channa@ligo.caltech.edu>"
__version__ = "git id %s" % git_version.id
__date__ = git_version.date


#
# =============================================================================
#
#                                 Command Line
#
# =============================================================================
#


def parse_command_line():
	parser = OptionParser(
		version = "Name: %%prog\n%s" % git_version.verbose_msg
	)
	parser.add_option("","--input-cache", help = "Also get the list of databases to process from this LAL cache.")
	parser.add_option("--user-tag", metavar = "user-tag", default = "ALL", help = "Set the prefix for output filenames (default = \"ALL\").")
	parser.add_option("--output-dir", metavar = "output-dir", default = ".", help = "Provide an output directory")
	parser.add_option("-f", "--format", metavar = "{\"png\",\"pdf\",\"svg\",\"eps\",...}", action = "append", default = [], help = "Set the output image format.  Can be given multiple times (default = \"png\").")
	parser.add_option("--segments-name", metavar = "name", default = "statevectorsegments", help = "Set the name of the segments that were analyzed (default = \"statevectorsegments\").")
	parser.add_option("--vetoes-name", metavar = "name", default = "vetoes", help = "Set the name of the veto segments (default = \"vetoes\").")
	parser.add_option("--plot-group", metavar = "number", action = "append", default = None, help = """Generate the given plot group.  Can be given multiple times (default = make all plot groups)
 0. Summary Table (top 10 loudest events globally across all zero lag triggers read in)
 1. Missed Found (Scatter plots of missed and found injections on several axes)
 2. Injection Parameter Accuracy Plots
 3. Background Vs Injection Plots (sngl detector triggers from coincs of snr, chisq, bank chisq,...)
 4. Background Vs Injection Plots pairwise (effective snr DET1 Vs. DET2...),
 5. Rate Vs Threshold (SNR histograms, IFAR histograms, ...)
 6. Injection Parameter Distribution Plots (The input parameters that went into inspinj, like mass1 vs mass2...)
""")
	parser.add_option("--far-threshold", metavar = "Hz", default = 1. / (30 * 86400), type = "float", help = "Set the FAR threshold for found injections (default = 1 / 30 days).")
	parser.add_option("-t", "--tmp-space", metavar = "path", help = "Path to a directory suitable for use as a work area while manipulating the database file.  The database file will be worked on in this directory, and then moved to the final location when complete.  This option is intended to improve performance when running in a networked environment, where there might be a local disk with higher bandwidth than is available to the filesystem on which the final output will reside.")
	parser.add_option("-v", "--verbose", action = "store_true", help = "Be verbose.")
	options, filenames = parser.parse_args()

	if options.plot_group is not None:
		options.plot_group = sorted(map(int, options.plot_group))
	if not options.format:
		options.format = ["png"]

	if not filenames:
		filenames = []
	if options.input_cache:
		filenames.extend(c.path for c in map(lal.CacheEntry, open(options.input_cache)))

	return options, filenames


#
# =============================================================================
#
#                                   Database
#
# =============================================================================
#


class CoincDatabase(object):
	def __init__(self, connection, data_segments_name, veto_segments_name = None, verbose = False, wiki = None, base = None, program_name = "gstlal_inspiral"):
		"""
		Compute and record some summary information about the
		database.
		"""

		self.base = base
		self.connection = connection
		xmldoc = dbtables.get_xml(connection)

		cursor = connection.cursor()

		# find the tables
		try:
			self.sngl_inspiral_table = lsctables.SnglInspiralTable.get_table(xmldoc)
		except ValueError:
			self.sngl_inspiral_table = None
		try:
			self.sim_inspiral_table = lsctables.SimInspiralTable.get_table(xmldoc)
		except ValueError:
			self.sim_inspiral_table = None
		try:
			self.coinc_def_table = lsctables.CoincDefTable.get_table(xmldoc)
			self.coinc_table = lsctables.CoincTable.get_table(xmldoc)
			self.time_slide_table = lsctables.TimeSlideTable.get_table(xmldoc)
		except ValueError:
			self.coinc_def_table = None
			self.coinc_table = None
			self.time_slide_table = None
		try:
			self.coinc_inspiral_table = lsctables.CoincInspiralTable.get_table(xmldoc)
		except ValueError:
			self.coinc_inspiral_table = None

		# determine a few coinc_definer IDs
		# FIXME:  don't hard-code the numbers
		if self.coinc_def_table is not None:
			try:
				self.ii_definer_id = self.coinc_def_table.get_coinc_def_id("inspiral", 0, create_new = False)
			except KeyError:
				self.ii_definer_id = None
			try:
				self.si_definer_id = self.coinc_def_table.get_coinc_def_id("inspiral", 1, create_new = False)
			except KeyError:
				self.si_definer_id = None
			try:
				self.sc_definer_id = self.coinc_def_table.get_coinc_def_id("inspiral", 2, create_new = False)
			except KeyError:
				self.sc_definer_id = None
		else:
			self.ii_definer_id = None
			self.si_definer_id = None
			self.sc_definer_id = None

		# retrieve the distinct on and participating instruments
		self.on_instruments_combos = [frozenset(lsctables.instrument_set_from_ifos(x)) for x, in cursor.execute("SELECT DISTINCT(instruments) FROM coinc_event WHERE coinc_def_id == ?", (self.ii_definer_id,))]
		self.participating_instruments_combos = [frozenset(lsctables.instrument_set_from_ifos(x)) for x, in cursor.execute("SELECT DISTINCT(ifos) FROM coinc_inspiral")]

		# get the segment lists
		self.seglists = ligolw_segments.segmenttable_get_by_name(xmldoc, data_segments_name).coalesce()
		self.instruments = set(self.seglists)
		if veto_segments_name is not None:
			self.veto_segments = ligolw_segments.segmenttable_get_by_name(xmldoc, veto_segments_name).coalesce()
		else:
			self.veto_segments = segments.segmentlistdict()
		self.seglists -= self.veto_segments

		# Get the live time used for the far calculation.  By convention this is simply the entire interval of the analysis with no regard for segments
		self.farsegs = far.get_live_time_segs_from_search_summary_table(connection)

		# get the live time
		if verbose:
			print >>sys.stderr, "calculating background livetimes: ",
		self.offset_vectors = db_thinca_rings.get_background_offset_vectors(connection)

		if verbose:
			print >>sys.stderr
		self.zerolag_livetime = {}
		self.background_livetime = {}
		for on_instruments in self.on_instruments_combos:
			self.zerolag_livetime[on_instruments] = float(abs(self.seglists.intersection(on_instruments) - self.seglists.union(self.instruments - on_instruments)))
		# FIXME:  background livetime hard-coded to be same
		# as zero-lag livetime.  figure out what to do
		self.background_livetime.update(self.zerolag_livetime)

		# verbosity
		if verbose:
			print >>sys.stderr, "database overview:"
			for on_instruments in self.on_instruments_combos:
				print >>sys.stderr, "\tzero-lag livetime for %s: %f s" % ("+".join(sorted(on_instruments)), self.zerolag_livetime[on_instruments])
				print >>sys.stderr, "\tbackground livetime for %s: %f s" % ("+".join(sorted(on_instruments)), self.background_livetime[on_instruments])
			if self.sngl_inspiral_table is not None:
				print >>sys.stderr, "\tinspiral events: %d" % len(self.sngl_inspiral_table)
			if self.sim_inspiral_table is not None:
				print >>sys.stderr, "\tinjections: %d" % len(self.sim_inspiral_table)
			if self.time_slide_table is not None:
				print >>sys.stderr, "\ttime slides: %d" % cursor.execute("SELECT COUNT(DISTINCT(time_slide_id)) FROM time_slide").fetchone()[0]
			if self.coinc_def_table is not None:
				for description, n in cursor.execute("SELECT description, COUNT(*) FROM coinc_definer NATURAL JOIN coinc_event GROUP BY coinc_def_id"):
					print >>sys.stderr, "\t%s: %d" % (description, n)

		if wiki:
			wiki.write("database overview:\n\n")
			for on_instruments in self.on_instruments_combos:
				wiki.write("||zero-lag livetime for %s||%f s||\n" % ("+".join(sorted(on_instruments)), self.zerolag_livetime[on_instruments]))
				wiki.write("||background livetime for %s ||%f s||\n" % ("+".join(sorted(on_instruments)), self.background_livetime[on_instruments]))
			if self.sngl_inspiral_table is not None:
				wiki.write("||inspiral events|| %d||\n" % len(self.sngl_inspiral_table))
			if self.sim_inspiral_table is not None:
				wiki.write("||injections|| %d||\n" % len(self.sim_inspiral_table))
			if self.time_slide_table is not None:
				wiki.write("||time slides|| %d||\n" % cursor.execute("SELECT COUNT(DISTINCT(time_slide_id)) FROM time_slide").fetchone()[0])
			if self.coinc_def_table is not None:
				for description, n in cursor.execute("SELECT description, COUNT(*) FROM coinc_definer NATURAL JOIN coinc_event GROUP BY coinc_def_id"):
					wiki.write("||%s||%d||\n" % (description, n) )


#
# =============================================================================
#
#                                  Utilities
#
# =============================================================================
#


def sim_end_time(sim, instrument):
	# this function requires .get_time_geocent() and .get_ra_dec()
	# methods, and so can be used for both burst and inspiral
	# injections.  FIXME:  update function call when inspiral
	# injections carry offset vector information
	return SimBurstUtils.time_at_instrument(sim, instrument, {instrument: 0.0})


def roman(i, arabics = (1000,900,500,400,100,90,50,40,10,9,5,4,1), romans = ("m","cm","d","cd","c","xc","l","xl","x","ix","v","iv","i")):
	if not arabics:
		return ""
	if i < arabics[0]:
		return roman(i, arabics[1:], romans[1:])
	return romans[0] + roman(i - arabics[0], arabics, romans)


#
# width is in mm, default aspect ratio is the golden ratio
#


def create_plot(x_label = None, y_label = None, width = 165.0, aspect = None):
	if aspect is None:
		aspect = (1 + math.sqrt(5)) / 2
	fig = figure.Figure()
	FigureCanvas(fig)
	fig.set_size_inches(width / 25.4, width / 25.4 / aspect)
	axes = fig.gca()
	axes.grid(True)
	if x_label is not None:
		axes.set_xlabel(x_label)
	if y_label is not None:
		axes.set_ylabel(y_label)
	return fig, axes


def create_sim_coinc_view(connection):
	"""
	Construct a sim_inspiral --> best matching coinc_event mapping.
	Only injections that match at least one coinc get an entry in this
	table.
	"""
	#
	# the log likelihood ratio stored in the likelihood column of the
	# coinc_event table is the ranking statistic.  the "best match" is
	# the coinc with the highest value in this column.  although it has
	# not been true in the past, there is now a one-to-one relationship
	# between the value of this ranking statistic and false-alarm rate,
	# therefore it is OK to order by log likelihood ratio and then,
	# later, impose a "detection" threshold based on false-alarm rate.
	#

	connection.cursor().execute("""
CREATE TEMPORARY TABLE
	sim_coinc_map
AS
	SELECT
		sim_inspiral.simulation_id AS simulation_id,
		(
			SELECT
				coinc_event.coinc_event_id
			FROM
				coinc_event_map AS a
				JOIN coinc_event_map AS b ON (
					b.coinc_event_id == a.coinc_event_id
				)
				JOIN coinc_event ON (
					b.table_name == 'coinc_event'
					AND b.event_id == coinc_event.coinc_event_id
				)
			WHERE
				a.table_name == 'sim_inspiral'
				AND a.event_id == sim_inspiral.simulation_id
				AND NOT EXISTS (SELECT * FROM time_slide WHERE time_slide.time_slide_id == coinc_event.time_slide_id AND time_slide.offset != 0)
			ORDER BY
				coinc_event.likelihood DESC
			LIMIT 1
		) AS coinc_event_id
	FROM
		sim_inspiral
	WHERE
		coinc_event_id IS NOT NULL
	""")


#
# =============================================================================
#
#                      Summary Table
#
# =============================================================================
#


class SummaryTable(object):
	def __init__(self):
		self.candidates = []
		self.bgcandidates = []
		self.livetime = {}
		self.num_trigs = {}

	def add_contents(self, contents):
		self.base = contents.base
		if contents.sim_inspiral_table:
			#For now we only return summary information on non injections
			return
		self.candidates += contents.connection.cursor().execute("""
SELECT
	coinc_inspiral.combined_far,
	coinc_inspiral.false_alarm_rate,
	coinc_event.likelihood,
	coinc_inspiral.snr,
	coinc_inspiral.end_time + coinc_inspiral.end_time_ns * 1e-9,
	coinc_inspiral.mass,
	coinc_inspiral.mchirp,
	coinc_inspiral.ifos,
	coinc_event.instruments,
	(SELECT
		group_concat(sngl_inspiral.ifo || ":" || sngl_inspiral.snr || ":" || sngl_inspiral.chisq || ":" || sngl_inspiral.mass1 || ":" || sngl_inspiral.mass2, " ")
	FROM
		sngl_inspiral
		JOIN coinc_event_map ON (
			sngl_inspiral.event_id == coinc_event_map.event_id AND coinc_event_map.table_name == "sngl_inspiral"
		)
	WHERE
		coinc_event_map.coinc_event_id == coinc_inspiral.coinc_event_id
	)
FROM
	coinc_inspiral
	JOIN coinc_event ON (
		coinc_event.coinc_event_id == coinc_inspiral.coinc_event_id
	)
WHERE
	NOT EXISTS(
		SELECT
			*
		FROM
			time_slide
		WHERE
			time_slide.time_slide_id == coinc_event.time_slide_id AND time_slide.offset != 0
	)
ORDER BY
	combined_far
LIMIT 10
		""").fetchall()

		self.bgcandidates += contents.connection.cursor().execute("""
SELECT
	coinc_inspiral.combined_far,
	coinc_inspiral.false_alarm_rate,
	coinc_event.likelihood,
	coinc_inspiral.snr,
	coinc_inspiral.end_time + coinc_inspiral.end_time_ns * 1e-9,
	coinc_inspiral.mass,
	coinc_inspiral.mchirp,
	coinc_inspiral.ifos,
	coinc_event.instruments,
	(SELECT
		group_concat(sngl_inspiral.ifo || ":" || sngl_inspiral.snr || ":" || sngl_inspiral.chisq || ":" || sngl_inspiral.mass1 || ":" || sngl_inspiral.mass2, " ")
	FROM
		sngl_inspiral
		JOIN coinc_event_map ON (
			sngl_inspiral.event_id == coinc_event_map.event_id AND coinc_event_map.table_name == "sngl_inspiral"
		)
	WHERE
		coinc_event_map.coinc_event_id == coinc_inspiral.coinc_event_id
	)
FROM
	coinc_inspiral
	JOIN coinc_event ON (
		coinc_event.coinc_event_id == coinc_inspiral.coinc_event_id
	)
WHERE
	EXISTS(
		SELECT
			*
		FROM
			time_slide
		WHERE
			time_slide.time_slide_id == coinc_event.time_slide_id AND time_slide.offset != 0
	)
ORDER BY
	combined_far
LIMIT 10
		""").fetchall()


		contents.connection.cursor().execute("CREATE TEMPORARY TABLE distinct_ifos AS SELECT DISTINCT(ifos) AS ifos FROM coinc_inspiral")
		for instruments, num in contents.connection.cursor().execute("""
SELECT distinct_ifos.ifos, count(*) FROM coinc_inspiral JOIN distinct_ifos ON (distinct_ifos.ifos==coinc_inspiral.ifos) JOIN coinc_event ON (coinc_event.coinc_event_id == coinc_inspiral.coinc_event_id) WHERE coinc_inspiral.ifos==distinct_ifos.ifos AND NOT EXISTS(SELECT * FROM time_slide WHERE time_slide.time_slide_id == coinc_event.time_slide_id AND time_slide.offset != 0) GROUP BY distinct_ifos.ifos;
"""):
			key = frozenset(lsctables.instrument_set_from_ifos(instruments))
			self.num_trigs.setdefault(key,0)
			self.num_trigs[key] += num

		contents.connection.cursor().execute("DROP TABLE distinct_ifos")

		for on_instruments in set(contents.background_livetime) | set(contents.zerolag_livetime):
			self.livetime.setdefault(on_instruments, 0.0)

		for on_instruments, livetime in contents.zerolag_livetime.items():
			self.livetime[on_instruments] += livetime

	def write_wiki_string(self, l, f, lt):
		f.write("|| Rank || FAR (Hz) || FAP || ln &Lambda; || Combined SNR || GPS End Time || <i>M</i><sub>total</sub> / M<sub>&#x2299;</sub> || <i>M</i><sub>chirp</sub> / M<sub>&#x2299;</sub> || Participating Instruments                   || On Instruments                              ||\n")
		f.write("||      ||          ||     ||             ||              || Instrument   || SNR                                             || &chi;<sup>2</sup>/DOF                           || <i>m</i><sub>1</sub> / M<sub>&#x2299;</sub> || <i>m</i><sub>2</sub> / M<sub>&#x2299;</sub> ||\n")
		for rank, values in enumerate(l, 1):
			values = tuple(values)
			f.write('|| %d || %.2e || %.3g || %.3g || %.2f || %.4f || %.2f || %.2f || %s || %s ||\n' % ((rank,) + values[:9]))
			for ifo_row in values[9].split():
				ifo_row = ifo_row.split(":")
				ifo_row[1:] = map(float, ifo_row[1:])
				f.write('|| || || || || || %s || %.2f || %.2f || %.2f || %.2f ||\n' % tuple(ifo_row) )

	def finish(self):
		self.candidates.sort()
		f = open(self.base+'summary_table.txt','w')
		f.write("=== Open box loudest 10 summary table ===\n")
		self.write_wiki_string(self.candidates[:11], f, self.livetime)
		f.close()

		f = open(self.base+'num_trigs_table.txt','w')
		f.write("||<b>DETECTORS</b>||<b># COINC EVENTS</b>||\n")
		for inst in self.livetime.keys(): 
			f.write("||%s||" % ("".join(sorted(inst)),))
			try:
				num = self.num_trigs[inst]
			except:
				num = 0
			f.write("%d||\n" % (num,))
		f.close()

		f = open(self.base+'live_time_table.txt','w')
		f.write("||<b>DETECTORS ON</b>||<b>LIVETIME (s) (d) (yr)</b>||\n")
		for inst in self.livetime.keys(): 
			f.write("||%s||%.2f %.2f %.2f||\n" % ("".join(sorted(inst)), self.livetime[inst],self.livetime[inst]/86400.0,self.livetime[inst]/31556926.0))
		f.close()

		self.bgcandidates.sort()
		f = open(self.base+'bgsummary_table.txt','w')
		f.write("=== Closed box loudest 10 summary table ===\n")
		self.write_wiki_string(self.bgcandidates[:11], f, self.livetime)
		f.close()
		yield None, None, None

#
# =============================================================================
#
#                      Injection Parameter Distributions
#
# =============================================================================
#


class InjectionParameterDistributionPlots(object):
	def __init__(self):
		self.injections = {}

	def add_contents(self, contents):
		if contents.sim_inspiral_table is None:
			# no injections
			return
		for values in contents.connection.cursor().execute("""
SELECT
	*
FROM
	sim_inspiral
			"""):
			sim = contents.sim_inspiral_table.row_from_cols(values)
			del sim.process_id, sim.source, sim.simulation_id
			instruments = frozenset(instrument for instrument, segments in contents.seglists.items() if sim.get_time_geocent() in segments)
			self.injections.setdefault(sim.waveform, []).append(sim)

	def finish(self):
		for waveform, sims in self.injections.items():
			for col1,col2,ax1,ax2,name,aspect in [
							([sim.mass1 for sim in sims], [sim.mass2 for sim in sims], r"$M_{1}$ ($\mathrm{M}_{\odot}$)", r"$M_{2}$ ($\mathrm{M}_{\odot}$)", "sim_dist_m1_m2_%s", 1),
							([sim.geocent_end_time for sim in sims], [math.log10(sim.distance) for sim in sims], r"Time (s)", r"$\log_{10} (\mathrm{distance} / 1\,\mathrm{Mpc})$", "sim_dist_time_distance_%s",None),
							([sim.longitude * 12 / math.pi for sim in sims], [math.sin(sim.latitude) for sim in sims], r"RA (h)", r"$\sin \mathrm{dec}$", "sim_dist_ra_dec_%s",None),
							([math.cos(sim.inclination) for sim in sims], [sim.polarization for sim in sims], r"$\cos $Inclination (rad)", r"Polarization (rad)", "sim_dist_inc_pol_%s",None),
							([sim.spin1z for sim in sims], [sim.spin2z for sim in sims], r"Spin 1 z", r"Spin 2 z", "sim_dist_spin1z_spin2z_%s",None)]:
				fig, axes = create_plot(ax1,ax2, aspect = aspect)
				axes.set_title(r"Injection Parameter Distribution (%s Injections)" % waveform)
				if len(col1) > 16383:
					axes.plot(col1,col2, "k,")
				else:
					axes.plot(col1,col2, "k.")
				minx, maxx = axes.get_xlim()
				miny, maxy = axes.get_ylim()
				if aspect == 1:
					axes.set_xlim((min(minx, miny), max(maxx, maxy)))
					axes.set_ylim((min(minx, miny), max(maxx, maxy)))
				yield fig, name % (waveform), False


#
# =============================================================================
#
#                              Missed/Found Plot
#
# =============================================================================
#


class MissedFoundPlots(object):
	class MissedFound(object):
		def __init__(self, on_instruments, far_thresh):
			self.on_instruments = on_instruments
			self.far_thresh = far_thresh
			self.found_in = {}

		def add_contents(self, contents):
			self.base = contents.base
			zerolag_segments = contents.seglists.intersection(self.on_instruments) - contents.seglists.union(contents.instruments - self.on_instruments)
			for values in contents.connection.cursor().execute("""
SELECT
	sim_inspiral.*,
	(
		SELECT
			coinc_inspiral.ifos
		FROM
			sim_coinc_map
			JOIN coinc_inspiral ON (
				coinc_inspiral.coinc_event_id == sim_coinc_map.coinc_event_id
			)
		WHERE
			sim_coinc_map.simulation_id == sim_inspiral.simulation_id
			AND coinc_inspiral.combined_far < ?
	)
FROM
	sim_inspiral
			""", (self.far_thresh if self.far_thresh is not None else float("+inf"),)):
				sim = contents.sim_inspiral_table.row_from_cols(values)
				del sim.process_id, sim.source, sim.simulation_id
				if sim.get_time_geocent() in zerolag_segments:
					participating_instruments = lsctables.instrument_set_from_ifos(values[-1])
					if participating_instruments is not None:
						participating_instruments = frozenset(participating_instruments)
					try:
						self.found_in[participating_instruments].append(sim)
					except KeyError:
						self.found_in[participating_instruments] = [sim]

		def finish(self):
			f = open(self.base + "injection_summary.txt", "a")
			missed = self.found_in.pop(None, [])
			for cnt, (title, x_label, x_func, y_label, y_func, filename_fragment) in enumerate((
				(r"Distance vs.\ Chirp Mass (With %s Operating)" % ", ".join(sorted(self.on_instruments)), r"$M_{\mathrm{chirp}}$ ($\mathrm{M}_{\odot}$)", lambda sim: sim.mchirp, r"$D$ ($\mathrm{Mpc}$)", lambda sim, instruments: sim.distance, "d_vs_mchirp"),
				(r"Decisive Distance vs.\ Chirp Mass (With %s Operating)" % ", ".join(sorted(self.on_instruments)), r"$M_{\mathrm{chirp}}$ ($\mathrm{M}_{\odot}$)", lambda sim: sim.mchirp, r"$\mathrm{Decisive} D_{\mathrm{eff}}$ ($\mathrm{Mpc}$)", lambda sim, instruments: sorted(sim.get_eff_dist(instrument) for instrument in instruments)[1], "deff_vs_mchirp"),
				(r"Chirp Decisive Distance vs.\ Chirp Mass (With %s Operating)" % ", ".join(sorted(self.on_instruments)), r"$M_{\mathrm{chirp}}$ ($\mathrm{M}_{\odot}$)", lambda sim: sim.mchirp, r"$\mathrm{Decisive} D_{\mathrm{chirp}}$ ($\mathrm{Mpc}$)", lambda sim, instruments: sorted(sim.get_chirp_eff_dist(instrument) for instrument in instruments)[1], "chirpdist_vs_mchirp"),
				(r"Decisive Distance vs.\ Total Mass (With %s Operating)" % ", ".join(sorted(self.on_instruments)), r"$M_{\mathrm{total}}$ ($\mathrm{M}_{\odot}$)", lambda sim: sim.mass1 + sim.mass2, r"$\mathrm{Decisive} D_{\mathrm{eff}}$ ($\mathrm{Mpc}$)", lambda sim, instruments: sorted(sim.get_eff_dist(instrument) for instrument in instruments)[1], "deff_vs_mtotal"),
				(r"Decisive Distance vs.\ Effective Spin (With %s Operating)" % ", ".join(sorted(self.on_instruments)), r"$\chi$", lambda sim: (sim.spin1z*sim.mass1 + sim.spin2z*sim.mass2)/(sim.mass1 + sim.mass2), r"$\mathrm{Decisive} D_{\mathrm{eff}}$ ($\mathrm{Mpc}$)", lambda sim, instruments: sorted(sim.get_eff_dist(instrument) for instrument in instruments)[1], "deff_vs_chi"),
				(r"Decisive Distance vs.\ Time (With %s Operating)" % ", ".join(sorted(self.on_instruments)), r"GPS Time (s)", lambda sim: sim.get_time_geocent(), r"$\mathrm{Decisive} D_{\mathrm{eff}}$ ($\mathrm{Mpc}$)", lambda sim, instruments: sorted(sim.get_eff_dist(instrument) for instrument in instruments)[1], "deff_vs_t")
			)):
				fig, axes = create_plot(x_label, y_label)
				legend = []
				for participating_instruments, sims in sorted(self.found_in.items(), key = (lambda x: lsctables.ifos_from_instrument_set(x[0]))):
					if not cnt: f.write("||%s||%s||FOUND: %d||\n" % ("".join(sorted(self.on_instruments)), "".join(sorted(participating_instruments)), len(sims)))
					legend.append("Found in %s" % ", ".join(sorted(participating_instruments)))
					axes.semilogy([x_func(sim) for sim in sims], [y_func(sim, participating_instruments) for sim in sims], ".")
				if missed:
					if not cnt: f.write("||%s||%s||MISSED: %d||\n" % ("".join(sorted(self.on_instruments)), "---", len(missed)))
					legend.append("Missed")
					axes.semilogy([x_func(sim) for sim in missed], [y_func(sim, self.on_instruments) for sim in missed], "k.")
				f.close()
				if legend:
					axes.legend(legend)
				axes.set_title(title)
				yield fig, filename_fragment, False

	def __init__(self, far_thresh):
		self.far_thresh = far_thresh
		self.plots = {}

	def add_contents(self, contents):
		self.base = contents.base
		if contents.sim_inspiral_table is None:
			# no injections
			return
		for on_instruments in contents.on_instruments_combos:
			if on_instruments not in self.plots:
				self.plots[on_instruments] = MissedFoundPlots.MissedFound(on_instruments, self.far_thresh)
			self.plots[on_instruments].add_contents(contents)

	def finish(self):
		f = open(self.base + "injection_summary.txt", "w")
		f.write("||<b>ON INSTRUMENTS</b>||<b> PARTICIPATING INSTRUMENTS</b>||<b>MISSED/FOUND</b||\n")
		f.close()
		for on_instruments, plot in self.plots.items():
			for fig, filename_fragment, is_open_box in plot.finish():
				yield fig, "%s_%s" % (filename_fragment, "".join(sorted(on_instruments))), is_open_box


#
# =============================================================================
#
#                              Parameter Accuracy
#
# =============================================================================
#


class ParameterAccuracyPlots(object):
	def __init__(self):
		self.sim_sngl_pairs = {}

	def add_contents(self, contents):
		if contents.sim_inspiral_table is None:
			# not an injections file
			return
		n_simcolumns = len(contents.sim_inspiral_table.columnnames)
		for values in contents.connection.cursor().execute("""
SELECT
	sim_inspiral.*,
	sngl_inspiral.*
FROM
	sim_inspiral
	JOIN sim_coinc_map ON (
		sim_coinc_map.simulation_id == sim_inspiral.simulation_id
	)
	JOIN coinc_event_map ON (
		coinc_event_map.coinc_event_id == sim_coinc_map.coinc_event_id
	)
	JOIN sngl_inspiral ON (
		coinc_event_map.table_name == 'sngl_inspiral'
		AND coinc_event_map.event_id == sngl_inspiral.event_id
	)
	WHERE sngl_inspiral.snr > 8.0
		"""):
			sim = contents.sim_inspiral_table.row_from_cols(values)
			sngl = contents.sngl_inspiral_table.row_from_cols(values[n_simcolumns:])
			del sim.process_id, sim.source, sim.simulation_id
			del sngl.process_id, sngl.event_id
			self.sim_sngl_pairs.setdefault((sim.waveform, sngl.ifo), []).append((sim, sngl))

	def finish(self):

		def hist(arr, axes):
			start = scipy.stats.mstats.mquantiles(arr, 0.01)
			end = scipy.stats.mstats.mquantiles(arr, 0.99)

			axes.hist(arr, numpy.linspace(start, end, 100))

		for (waveform, instrument), pairs in self.sim_sngl_pairs.items():
			fig, axes = create_plot(r"Injected $M_{\mathrm{chirp}}$ ($\mathrm{M}_{\odot}$)", r"Recovered $M_{\mathrm{chirp}}$ - Injected $M_{\mathrm{chirp}}$ ($\mathrm{M}_{\odot}$)")
			axes.set_title(r"Absolute $M_{\mathrm{chirp}}$ Accuracy in %s (%s Injections)" % (instrument, waveform))
			axes.plot([sim.mchirp for sim, sngl in pairs], [sngl.mchirp - sim.mchirp for sim, sngl in pairs], "kx")
			yield fig, "mchirp_acc_abs_%s_%s" % (waveform, instrument), False

			fig, axes = create_plot(r"Recovered $M_{\mathrm{chirp}}$ - Injected $M_{\mathrm{chirp}}$ ($\mathrm{M}_{\odot}$)", "Number")
			axes.set_title(r"Absolute $M_{\mathrm{chirp}}$ Accuracy in %s (%s Injections)" % (instrument, waveform))
			hist(numpy.array([sngl.mchirp - sim.mchirp for sim, sngl in pairs]), axes)
			yield fig, "mchirp_acc_abs_hist_%s_%s" % (waveform, instrument), False

			fig, axes = create_plot(r"Injected $M_{\mathrm{chirp}}$ ($\mathrm{M}_{\odot}$)", r"(Recovered $M_{\mathrm{chirp}}$ - Injected $M_{\mathrm{chirp}}$) / Injected $M_{\mathrm{chirp}}$")
			axes.set_title(r"Fractional $M_{\mathrm{chirp}}$ Accuracy in %s (%s Injections)" % (instrument, waveform))
			axes.plot([sim.mchirp for sim, sngl in pairs], [(sngl.mchirp - sim.mchirp) / sim.mchirp for sim, sngl in pairs], "kx")
			yield fig, "mchirp_acc_frac_%s_%s" % (waveform, instrument), False

			fig, axes = create_plot(r"(Recovered $M_{\mathrm{chirp}}$ - Injected $M_{\mathrm{chirp}}$ ($\mathrm{M}_{\odot}$)) / Injected $M_{\mathrm{chirp}}$ ($\mathrm{M}_{\odot}$)", "Number")
			axes.set_title(r"Fractional $M_{\mathrm{chirp}}$ Accuracy in %s (%s Injections)" % (instrument, waveform))
			hist(numpy.array([(sngl.mchirp - sim.mchirp) / sim.mchirp for sim, sngl in pairs]), axes)
			yield fig, "mchirp_acc_frac_hist_%s_%s" % (waveform, instrument), False

			fig, axes = create_plot(r"Injected $\eta$", r"Recovered $\eta$ - Injected $\eta$")
			axes.set_title(r"Absolute $\eta$ Accuracy in %s (%s Injections)" % (instrument, waveform))
			axes.plot([sim.eta for sim, sngl in pairs], [sngl.eta - sim.eta for sim, sngl in pairs], "kx")
			yield fig, "eta_acc_abs_%s_%s" % (waveform, instrument), False

			fig, axes = create_plot(r"Recovered $\eta$ - Injected $\eta$", "Number")
			axes.set_title(r"Absolute $\eta$ Accuracy in %s (%s Injections)" % (instrument, waveform))
			hist(numpy.array([sngl.eta - sim.eta for sim, sngl in pairs]), axes)
			yield fig, "eta_acc_abs_hist_%s_%s" % (waveform, instrument), False

			fig, axes = create_plot(r"Injected $\eta$", r"(Recovered $\eta$ - Injected $\eta$) / Injected $\eta$")
			axes.set_title(r"Fractional $\eta$ Accuracy in %s (%s Injections)" % (instrument, waveform))
			axes.plot([sim.eta for sim, sngl in pairs], [(sngl.eta - sim.eta) / sim.eta for sim, sngl in pairs], "kx")
			yield fig, "eta_acc_frac_%s_%s" % (waveform, instrument), False

			fig, axes = create_plot(r"(Recovered $\eta$ - Injected $\eta$) / Injected $\eta$", "Number")
			axes.set_title(r"Fractional $\eta$ Accuracy in %s (%s Injections)" % (instrument, waveform))
			hist(numpy.array([(sngl.eta - sim.eta) / sim.eta for sim, sngl in pairs]), axes)
			yield fig, "eta_acc_frac_hist_%s_%s" % (waveform, instrument), False

			fig, axes = create_plot(r"Injection End Time (GPS s)", r"Recovered End Time - Injection End Time (s)")
			axes.set_title(r"End Time Accuracy in %s (%s Injections)" % (instrument, waveform))
			axes.plot([sim_end_time(sim, instrument) for sim, sngl in pairs], [sngl.get_end() - sim_end_time(sim, instrument) for sim, sngl in pairs], "kx")
			yield fig, "t_acc_%s_%s" % (waveform, instrument), False

			fig, axes = create_plot(r"Recovered End Time - Injection End Time (s)", "Number")
			axes.set_title(r"End Time Accuracy in %s (%s Injections)" % (instrument, waveform))
			hist(numpy.array([float(sngl.get_end()) - float(sim_end_time(sim, instrument)) for sim, sngl in pairs]), axes)
			yield fig, "t_acc_hist_%s_%s" % (waveform, instrument), False

			fig, axes = create_plot(r"Injection $D_{\mathrm{eff}}$ ($\mathrm{Mpc}$)", r"(Recovered $D_{\mathrm{eff}}$ - Injection $D_{\mathrm{eff}}$) / Injection $D_{\mathrm{eff}}$")
			axes.set_title(r"Fractional Effective Distance Accuracy in %s (%s Injections)" % (instrument, waveform))
			axes.semilogx([sim.get_eff_dist(instrument) for sim, sngl in pairs], [(sngl.eff_distance - sim.get_eff_dist(instrument)) / sim.get_eff_dist(instrument) for sim, sngl in pairs], "kx")
			yield fig, "deff_acc_frac_%s_%s" % (waveform, instrument), False

			fig, axes = create_plot(r"(Recovered $D_{\mathrm{eff}}$ - Injection $D_{\mathrm{eff}}$) / Injection $D_{\mathrm{eff}}$", "Number")
			axes.set_title(r"Fractional Effective Distance Accuracy in %s (%s Injections)" % (instrument, waveform))
			hist(numpy.array([(sngl.eff_distance - sim.get_eff_dist(instrument)) / sim.get_eff_dist(instrument) for sim, sngl in pairs]), axes)
			yield fig, "deff_acc_frac_hist_%s_%s" % (waveform, instrument), False

			fig, axes = create_plot(r"(Recovered $1/D_{\mathrm{eff}}$ - Injection $1/D_{\mathrm{eff}}$) / Injection $1/D_{\mathrm{eff}}$", "Number")
			axes.set_title(r"Fractional Effective Amplitude Accuracy in %s (%s Injections)" % (instrument, waveform))
			hist(numpy.array([(1. / sngl.eff_distance - 1. / sim.get_eff_dist(instrument)) / (1. / sim.get_eff_dist(instrument)) for sim, sngl in pairs]), axes)
			yield fig, "deff_acc_frac_inv_hist_%s_%s" % (waveform, instrument), False

			fig, axes = create_plot(r"Injected $\chi$", r"Recovered $\chi$")
			axes.set_title(r"Effective Spin Accuracy in %s (%s Injections)" % (instrument, waveform))
			axes.plot([sim.chi for sim, sngl in pairs], [sngl.chi for sim, sngl in pairs], "kx")
			yield fig, "chi_acc_%s_%s" % (waveform, instrument), False


#
# =============================================================================
#
#               Background vs. Injections --- Single Instrument
#
# =============================================================================
#


class BackgroundVsInjectionPlots(object):
	class Points(object):
		def __init__(self):
			self.snr = []
			self.chi2 = []
			self.bankveto = []
			self.spin = []

	def __init__(self):
		self.injections = {}
		self.background = {}
		self.zerolag = {}

	def add_contents(self, contents):
		if contents.sim_inspiral_table is None:
			# non-injections file
			for instrument, snr, chi2, bankveto, is_background in contents.connection.cursor().execute("""
SELECT
	sngl_inspiral.ifo,
	sngl_inspiral.snr,
	sngl_inspiral.chisq,
	sngl_inspiral.bank_chisq / bank_chisq_dof,
	EXISTS (
		SELECT
			*
		FROM
			time_slide
		WHERE
			time_slide.time_slide_id == coinc_event.time_slide_id
			AND time_slide.offset != 0
	)
FROM
	coinc_event
	JOIN coinc_event_map ON (
		coinc_event_map.coinc_event_id == coinc_event.coinc_event_id
	)
	JOIN sngl_inspiral ON (
		coinc_event_map.table_name == 'sngl_inspiral'
		AND coinc_event_map.event_id == sngl_inspiral.event_id
	)
WHERE
	coinc_event.coinc_def_id == ?
			""", (contents.ii_definer_id,)):
				if is_background:
					if instrument not in self.background:
						self.background[instrument] = BackgroundVsInjectionPlots.Points()
					self.background[instrument].snr.append(snr)
					self.background[instrument].chi2.append(chi2)
					self.background[instrument].bankveto.append(bankveto)
				else:
					if instrument not in self.zerolag:
						self.zerolag[instrument] = BackgroundVsInjectionPlots.Points()
					self.zerolag[instrument].snr.append(snr)
					self.zerolag[instrument].chi2.append(chi2)
					self.zerolag[instrument].bankveto.append(bankveto)
		else:
			# injections file
			for instrument, snr, chi2, bankveto, end_time, spin in contents.connection.cursor().execute("""
SELECT
	sngl_inspiral.ifo,
	sngl_inspiral.snr,
	sngl_inspiral.chisq,
	sngl_inspiral.bank_chisq / bank_chisq_dof,
	sngl_inspiral.end_time + sngl_inspiral.end_time_ns * 1e-9,
	sim.spin1x * sim.spin1x + sim.spin1y * sim.spin1y + sim.spin1z * sim.spin1z + sim.spin2x * sim.spin2x + sim.spin2y * sim.spin2y + sim.spin2z * sim.spin2z
FROM
	sim_coinc_map
	JOIN coinc_event_map ON (
		coinc_event_map.coinc_event_id == sim_coinc_map.coinc_event_id
	)
	JOIN sngl_inspiral ON (
		coinc_event_map.table_name == 'sngl_inspiral'
		AND coinc_event_map.event_id == sngl_inspiral.event_id
	)
	JOIN sim_inspiral AS sim ON sim.simulation_id == sim_coinc_map.simulation_id
			"""):
				if end_time in contents.seglists[instrument]:
					if instrument not in self.injections:
						self.injections[instrument] = BackgroundVsInjectionPlots.Points()
					self.injections[instrument].snr.append(snr)
					self.injections[instrument].chi2.append(chi2)
					self.injections[instrument].bankveto.append(bankveto)
					self.injections[instrument].spin.append((spin / 2.)**.5)

	def finish(self):
		for instrument in set(self.injections) | set(self.background) | set(self.zerolag):
			self.injections.setdefault(instrument, BackgroundVsInjectionPlots.Points())
			self.background.setdefault(instrument, BackgroundVsInjectionPlots.Points())
			self.zerolag.setdefault(instrument, BackgroundVsInjectionPlots.Points())
		for instrument in self.background:
			fig, axes = create_plot(r"$\rho$", r"$\chi^{2}$")
			axes.set_title(r"$\chi^{2}$ vs.\ $\rho$ in %s (Closed Box)" % instrument)

			for (spinstart, spinstop) in [(0,0.1), (0.1,0.2), (0.2,0.3), (0.3,0.4), (0.4,0.5), (0.5, 0.6), (0.6, 1.0)][::-1]:
				injsnr = numpy.array([self.injections[instrument].snr[n] for n in range(len(self.injections[instrument].snr)) if self.injections[instrument].spin[n] >= spinstart and self.injections[instrument].spin[n] < spinstop])
				injchi2 = numpy.array([self.injections[instrument].chi2[n] for n in range(len(self.injections[instrument].snr)) if self.injections[instrument].spin[n] >= spinstart and self.injections[instrument].spin[n] < spinstop])
				axes.loglog(injsnr, injchi2, '.', label = "Inj $|s|$=%.1f" % spinstart)

			axes.loglog(self.background[instrument].snr, self.background[instrument].chi2, "kx", label = "Background")
			axes.legend(loc = "upper left")
			yield fig, "chi2_vs_rho_%s" % instrument, False

			fig, axes = create_plot(r"$\rho$", r"$\chi^{2}$")
			axes.set_title(r"$\chi^{2}$ vs.\ $\rho$ in %s" % instrument)
			for (spinstart, spinstop) in [(0,0.1), (0.1,0.2), (0.2,0.3), (0.3,0.4), (0.4,0.5), (0.5, 0.6), (0.6, 1.0)][::-1]:
				injsnr = numpy.array([self.injections[instrument].snr[n] for n in range(len(self.injections[instrument].snr)) if self.injections[instrument].spin[n] >= spinstart and self.injections[instrument].spin[n] < spinstop])
				injchi2 = numpy.array([self.injections[instrument].chi2[n] for n in range(len(self.injections[instrument].snr)) if self.injections[instrument].spin[n] >= spinstart and self.injections[instrument].spin[n] < spinstop])
				axes.loglog(injsnr, injchi2, '.', label = "Inj $|s|$=%.1f" % spinstart)
			axes.loglog(self.background[instrument].snr, self.background[instrument].chi2, "kx", label = "Background")
			axes.loglog(self.zerolag[instrument].snr, self.zerolag[instrument].chi2, "bx", label = "Zero-lag")
			axes.legend(loc = "upper left")
			yield fig, "chi2_vs_rho_%s" % instrument, True


#
# =============================================================================
#
#               Background vs. Injections --- Multi Instrument
#
# =============================================================================
#


class BackgroundVsInjectionPlotsMulti(object):
	class Points(object):
		def __init__(self):
			self.background_snreff = []
			self.injections_snreff = []
			self.zerolag_snreff = []
			self.background_deff = []
			self.injections_deff = []
			self.zerolag_deff = []

	def __init__(self, snrfactor):
		self.snrfactor = snrfactor
		self.points = {}

	def add_contents(self, contents):
		if contents.sim_inspiral_table is None:
			# non-injections file
			for values in contents.connection.cursor().execute("""
SELECT
	sngl_inspiral_x.*,
	sngl_inspiral_y.*,
	EXISTS (
		SELECT
			*
		FROM
			time_slide
		WHERE
			time_slide.time_slide_id == coinc_event.time_slide_id
			AND time_slide.offset != 0
	)
FROM
	coinc_event
	JOIN coinc_event_map AS coinc_event_map_x ON (
		coinc_event_map_x.coinc_event_id == coinc_event.coinc_event_id
	)
	JOIN sngl_inspiral AS sngl_inspiral_x ON (
		coinc_event_map_x.table_name == 'sngl_inspiral'
		AND coinc_event_map_x.event_id == sngl_inspiral_x.event_id
	)
	JOIN coinc_event_map AS coinc_event_map_y ON (
		coinc_event_map_y.coinc_event_id == coinc_event.coinc_event_id
	)
	JOIN sngl_inspiral AS sngl_inspiral_y ON (
		coinc_event_map_y.table_name == 'sngl_inspiral'
		AND coinc_event_map_y.event_id == sngl_inspiral_y.event_id
	)
	JOIN coinc_inspiral ON (
		coinc_inspiral.coinc_event_id == coinc_event.coinc_event_id
	)
WHERE
	coinc_event.coinc_def_id == ?
	AND sngl_inspiral_x.ifo > sngl_inspiral_y.ifo
			""", (contents.ii_definer_id,)):
				x = contents.sngl_inspiral_table.row_from_cols(values)
				y = contents.sngl_inspiral_table.row_from_cols(values[len(contents.sngl_inspiral_table.columnnames):])
				is_background, = values[-1:]
				instrument_pair = (x.ifo, y.ifo)
				if instrument_pair not in self.points:
					self.points[instrument_pair] = BackgroundVsInjectionPlotsMulti.Points()
				if is_background:
					self.points[instrument_pair].background_snreff.append((x.get_effective_snr(fac = self.snrfactor), y.get_effective_snr(fac = self.snrfactor)))
					self.points[instrument_pair].background_deff.append((x.eff_distance, y.eff_distance))
				else:
					self.points[instrument_pair].zerolag_snreff.append((x.get_effective_snr(fac = self.snrfactor), y.get_effective_snr(fac = self.snrfactor)))
					self.points[instrument_pair].zerolag_deff.append((x.eff_distance, y.eff_distance))
		else:
			# injections file
			for values in contents.connection.cursor().execute("""
SELECT
	sngl_inspiral_x.*,
	sngl_inspiral_y.*
FROM
	sim_coinc_map
	JOIN coinc_event_map AS coinc_event_map_x ON (
		coinc_event_map_x.coinc_event_id == sim_coinc_map.coinc_event_id
	)
	JOIN sngl_inspiral AS sngl_inspiral_x ON (
		coinc_event_map_x.table_name == 'sngl_inspiral'
		AND coinc_event_map_x.event_id == sngl_inspiral_x.event_id
	)
	JOIN coinc_event_map AS coinc_event_map_y ON (
		coinc_event_map_y.coinc_event_id == sim_coinc_map.coinc_event_id
	)
	JOIN sngl_inspiral AS sngl_inspiral_y ON (
		coinc_event_map_y.table_name == 'sngl_inspiral'
		AND coinc_event_map_y.event_id == sngl_inspiral_y.event_id
	)
WHERE
	sngl_inspiral_x.ifo > sngl_inspiral_y.ifo
			"""):
				x = contents.sngl_inspiral_table.row_from_cols(values)
				y = contents.sngl_inspiral_table.row_from_cols(values[len(contents.sngl_inspiral_table.columnnames):])
				instrument_pair = (x.ifo, y.ifo)
				if instrument_pair not in self.points:
					self.points[instrument_pair] = BackgroundVsInjectionPlotsMulti.Points()
				self.points[instrument_pair].injections_snreff.append((x.get_effective_snr(fac = self.snrfactor), y.get_effective_snr(fac = self.snrfactor)))
				self.points[instrument_pair].injections_deff.append((x.eff_distance, y.eff_distance))

	def finish(self):
		for (x_instrument, y_instrument), points in self.points.items():
			fig, axes = create_plot(r"$\rho_{\mathrm{eff}}$ in %s" % x_instrument, r"$\rho_{\mathrm{eff}}$ in %s" % y_instrument, aspect = 1.0)
			axes.set_title(r"Effective SNR in %s vs.\ %s (SNR Factor = %g) (Closed Box)" % (y_instrument, x_instrument, self.snrfactor))
			axes.loglog([x for x, y in points.injections_snreff], [y for x, y in points.injections_snreff], "rx")
			axes.loglog([x for x, y in points.background_snreff], [y for x, y in points.background_snreff], "kx")
			axes.legend(("Injections", "Background"), loc = "lower right")
			yield fig, "rho_%s_vs_%s" % (y_instrument, x_instrument), False

			fig, axes = create_plot(r"$\rho_{\mathrm{eff}}$ in %s" % x_instrument, r"$\rho_{\mathrm{eff}}$ in %s" % y_instrument, aspect = 1.0)
			axes.set_title(r"Effective SNR in %s vs.\ %s (SNR Factor = %g)" % (y_instrument, x_instrument, self.snrfactor))
			axes.loglog([x for x, y in points.injections_snreff], [y for x, y in points.injections_snreff], "rx")
			axes.loglog([x for x, y in points.background_snreff], [y for x, y in points.background_snreff], "kx")
			axes.loglog([x for x, y in points.zerolag_snreff], [y for x, y in points.zerolag_snreff], "bx")
			axes.legend(("Injections", "Background", "Zero-lag"), loc = "lower right")
			yield fig, "rho_%s_vs_%s" % (y_instrument, x_instrument), True

			fig, axes = create_plot(r"$D_{\mathrm{eff}}$ in %s" % x_instrument, r"$D_{\mathrm{eff}}$ in %s" % y_instrument, aspect = 1.0)
			axes.set_title(r"Effective Distance in %s vs.\ %s (Closed Box)" % (y_instrument, x_instrument))
			axes.loglog([x for x, y in points.injections_deff], [y for x, y in points.injections_deff], "rx")
			axes.loglog([x for x, y in points.background_deff], [y for x, y in points.background_deff], "kx")
			axes.legend(("Injections", "Background"), loc = "lower right")
			yield fig, "deff_%s_vs_%s" % (y_instrument, x_instrument), False

			fig, axes = create_plot(r"$D_{\mathrm{eff}}$ in %s" % x_instrument, r"$D_{\mathrm{eff}}$ in %s" % y_instrument, aspect = 1.0)
			axes.set_title(r"Effective Distance in %s vs.\ %s" % (y_instrument, x_instrument))
			axes.loglog([x for x, y in points.injections_deff], [y for x, y in points.injections_deff], "rx")
			axes.loglog([x for x, y in points.background_deff], [y for x, y in points.background_deff], "kx")
			axes.loglog([x for x, y in points.zerolag_deff], [y for x, y in points.zerolag_deff], "bx")
			axes.legend(("Injections", "Background", "Zero-lag"), loc = "lower right")
			yield fig, "deff_%s_vs_%s" % (y_instrument, x_instrument), True


#
# =============================================================================
#
#                           Rate vs. Threshold Plots
#
# =============================================================================
#


def sigma_region(mean, nsigma):
	return numpy.concatenate((mean - nsigma * numpy.sqrt(mean), (mean + nsigma * numpy.sqrt(mean))[::-1]))


def create_farplot(axes, zerolag_stats, expected_count_x, expected_count_y, is_open_box, xlim = (None, None), max_events = 1000):
	#
	# isolate relevent data
	#

	zerolag_stats = zerolag_stats[:max_events]

	#
	# background.  uncomment the two lines to make the background
	# stair-step-style like the observed counts
	#

	#expected_count_x = expected_count_x.repeat(2)[1:]
	#expected_count_y = expected_count_y.repeat(2)[:-1]
	line1, = axes.plot(expected_count_x, expected_count_y, 'k--', linewidth = 1)

	#
	# error bands
	#

	expected_count_x = numpy.concatenate((expected_count_x, expected_count_x[::-1]))
	line2, = axes.fill(expected_count_x, sigma_region(expected_count_y, 3.0).clip(0.001, max_events), alpha = 0.25, facecolor = [0.75, 0.75, 0.75])
	line3, = axes.fill(expected_count_x, sigma_region(expected_count_y, 2.0).clip(0.001, max_events), alpha = 0.25, facecolor = [0.5, 0.5, 0.5])
	line4, = axes.fill(expected_count_x, sigma_region(expected_count_y, 1.0).clip(0.001, max_events), alpha = 0.25, facecolor = [0.25, 0.25, 0.25])

	#
	# zero-lag
	#

	N = numpy.arange(1., len(zerolag_stats) + 1., dtype = "double")
	line5, = axes.plot(zerolag_stats.repeat(2)[1:], N.repeat(2)[:-1], 'k', linewidth = 2)

	#
	# legend
	#

	if is_open_box:
		axes.legend((line5, line1, line4, line3, line2), ("Zero-lag", r"$\langle N \rangle$", r"$\pm\sqrt{\langle N \rangle}$", r"$\pm 2\sqrt{\langle N \rangle}$", r"$\pm 3\sqrt{\langle N \rangle}$"), loc = "upper right")
	else:
		axes.legend((line5, line1, line4, line3, line2), (r"$\pi$ shift", r"$\langle N \rangle$", r"$\pm\sqrt{\langle N \rangle}$", r"$\pm 2\sqrt{\langle N \rangle}$", r"$\pm 3\sqrt{\langle N \rangle}$"), loc = "upper right")

	#
	# adjust bounds of plot
	#

	xlim = max(zerolag_stats.min(), xlim[0]), (2.**math.ceil(math.log(zerolag_stats.max(), 2.)) if xlim[1] is None else xlim[1])
	axes.set_xlim(xlim)
	axes.set_ylim((0.001, 10.**math.ceil(math.log10(max_events))))


class RateVsThreshold(object):
	def __init__(self):
		self.background_ln_likelihood_ratio = []
		self.zerolag_ln_likelihood_ratio = []
		self.background_far = []
		self.zerolag_far = []
		self.background_fap = []
		self.zerolag_fap = []
		self.background_snr = []
		self.zerolag_snr = []
		self.farsegs = segments.segmentlistdict()

	def add_contents(self, contents):
		if contents.sim_inspiral_table is not None:
			# skip injection documents
			return

		self.farsegs |= contents.farsegs

		for ln_likelihood_ratio, far, fap, snr, is_background in connection.cursor().execute("""
SELECT
	coinc_event.likelihood,
	coinc_inspiral.combined_far,
	coinc_inspiral.false_alarm_rate,
	coinc_inspiral.snr,
	EXISTS (
		SELECT
			*
		FROM
			time_slide
		WHERE
			time_slide.time_slide_id == coinc_event.time_slide_id
			AND time_slide.offset != 0
	)
FROM
	coinc_inspiral
	JOIN coinc_event ON (
		coinc_event.coinc_event_id == coinc_inspiral.coinc_event_id
	)
WHERE
	coinc_event.likelihood >= 0.
		"""):
			if is_background:
				self.background_ln_likelihood_ratio.append(ln_likelihood_ratio)
				self.background_far.append(far)
				self.background_fap.append(fap)
				self.background_snr.append(snr)
			else:
				self.zerolag_ln_likelihood_ratio.append(ln_likelihood_ratio)
				self.zerolag_far.append(far)
				self.zerolag_fap.append(fap)
				self.zerolag_snr.append(snr)

	def finish(self):
		livetime = far.get_live_time(self.farsegs)

		fig, axes = create_plot(x_label = r"SNR", y_label = r"$\ln \Lambda$")
		axes.loglog(self.background_snr, self.background_ln_likelihood_ratio, "kx", label = "Background")
		axes.legend(loc = "upper left")
		axes.set_title(r"$\ln \Lambda$ vs.\ SNR Scatter Plot (Closed Box)")
		yield fig, "lr_vs_snr", False
		fig, axes = create_plot(x_label = r"SNR", y_label = r"$\ln \Lambda$")
		axes.loglog(self.background_snr, self.background_ln_likelihood_ratio, "kx", label = "Background")
		axes.loglog(self.zerolag_snr, self.zerolag_ln_likelihood_ratio, "bx", label = "Zero-lag")
		axes.legend(loc = "upper left")
		axes.set_title(r"$\ln \Lambda$ vs.\ SNR Scatter Plot")
		yield fig, "lr_vs_snr", True

		for ln_likelihood_ratio, fars, is_open_box in [(self.zerolag_ln_likelihood_ratio, self.zerolag_far, True), (self.background_ln_likelihood_ratio, self.background_far, False)]:
			if fars:
				fig, axes = create_plot(None, r"Number of Events")
				axes.loglog()
				# fars in ascending order --> ifars in descending order
				zerolag_stats = 1. / numpy.array(sorted(fars))
				expected_count_y = numpy.logspace(-7, numpy.log10(len(zerolag_stats)), 1000)
				expected_count_x = livetime / expected_count_y
				create_farplot(axes, zerolag_stats, expected_count_x, expected_count_y, is_open_box, xlim = (None, 2000. * livetime))
				if is_open_box:
					axes.set_title(r"Event Count vs.\ Inverse False-Alarm Rate Threshold")
				else:
					axes.set_title(r"Event Count vs.\ Inverse False-Alarm Rate Threshold (Closed Box)")
				axes.set_xlabel(r"Inverse False-Alarm Rate (s)")
				yield fig, "count_vs_ifar", is_open_box

			if ln_likelihood_ratio:
				fig, axes = create_plot(None, r"Number of Events")
				axes.semilogy()

				zerolag_stats = numpy.array(sorted(ln_likelihood_ratio, reverse = True))

				# we want to plot FAR(ln L) * livetime vs.
				# ln L, but we don't have access to the
				# ranking statistic data file where that
				# function is encoded.  instead, we rely on
				# the FARs stored in each coinc, together
				# with the ln L assigned to each coinc, to
				# provide us with a collection of samples
				# of that function.  to get more points, we
				# combine data from the zero-lag and
				# background coincs.  in the future,
				# perhaps this program could be provided
				# with the marginalized ranking statistic
				# PDF data file

				expected_count_x = self.zerolag_ln_likelihood_ratio + self.background_ln_likelihood_ratio
				order = range(len(expected_count_x))
				order.sort(key = lambda i: expected_count_x[i], reverse = True)
				expected_count_x = numpy.array(expected_count_x)[order]
				expected_count_y = numpy.array(self.zerolag_far + self.background_far)[order] * livetime

				create_farplot(axes, zerolag_stats, expected_count_x, expected_count_y, is_open_box, xlim = (None, 23.), max_events = 10000)
				if is_open_box:
					axes.set_title(r"Event Count vs.\ Ranking Statistic Threshold")
				else:
					axes.set_title(r"Event Count vs.\ Ranking Statistic Threshold (Closed Box)")
				axes.set_xlabel(r"$\ln \Lambda$")
				yield fig, "count_vs_lr", is_open_box


#
# =============================================================================
#
#                                     Main
#
# =============================================================================
#


#
# Parse command line
#


options, filenames = parse_command_line()


#
# Initialize plots
#


# how many there could be, so we know how many digits for the filenames
max_plot_groups = None

def new_plots(plots = None, far_thresh = None):
	global max_plot_groups
	l = (
		SummaryTable(),
		MissedFoundPlots(far_thresh = far_thresh),
		ParameterAccuracyPlots(),
		BackgroundVsInjectionPlots(),
		BackgroundVsInjectionPlotsMulti(snrfactor = 50.0),
		RateVsThreshold(),
		InjectionParameterDistributionPlots(),
	)
	max_plot_groups = len(l)
	if plots is None:
		plots = range(len(l))
	return [l[i] for i in plots]

plots = new_plots(options.plot_group, options.far_threshold)
if options.plot_group is None:
	options.plot_group = range(len(plots))


#
# Process files
#


wiki = open(os.path.join(options.output_dir, "%s_%s" % (options.user_tag, "plotsummary.txt")),"w")

for n, filename in enumerate(filenames):
	if options.verbose:
		print >>sys.stderr, "%d/%d: %s" % (n + 1, len(filenames), filename)
	wiki.write("=== %d/%d: %s ===\n\n" % (n + 1, len(filenames), filename))
	working_filename = dbtables.get_connection_filename(filename, tmp_path = options.tmp_space, verbose = options.verbose)
	connection = sqlite3.connect(working_filename)
	contents = CoincDatabase(connection, options.segments_name, veto_segments_name = options.vetoes_name, verbose = options.verbose, wiki = wiki, base = os.path.join(options.output_dir, options.user_tag))
	if contents.sim_inspiral_table is not None:
		create_sim_coinc_view(connection)
	for n, plot in zip(options.plot_group, plots):
		if options.verbose:
			print >>sys.stderr, "adding to plot group %d ..." % n
		plot.add_contents(contents)
	connection.close()
	dbtables.discard_connection_filename(filename, working_filename, verbose = options.verbose)


#
# Finish and write plots, deleting them as we go to save memory
#


n = 0
filename_template = inspiral_pipe.T050017_filename("H1L1V1", "GSTLAL_INSPIRAL_PLOTSUMMARY_%s_%02d_%s_%s", contents.seglists.extent_all()[0], contents.seglists.extent_all()[1], "%s", path = options.output_dir)
while len(plots):
	for fig, filename_fragment, is_open_box in plots.pop(0).finish():
		for format in options.format:
			if filename_fragment and fig:
				filename = filename_template % (options.user_tag, options.plot_group[n], filename_fragment, ("openbox" if is_open_box else "closedbox"), format)
				if options.verbose:
					print >>sys.stderr, "writing %s ..." % filename
				fig.savefig(filename)
	n += 1
