#!/usr/bin/env python
#
# Copyright (C) 2012 Chris Pankow
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 2 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
"""Visualizer for SnglBurst events"""

import os
import sys

from collections import defaultdict

import matplotlib
#import matplotlib.gridspec
matplotlib.use("Agg")
from matplotlib import pyplot
from matplotlib.collections import PatchCollection

# FIXME: This doesn't seem to work with python-2.7 on macs
#matplotlib.rc('font',**{'family':'sans-serif','sans-serif':['Helvetica']})

from mpl_toolkits.axes_grid.inset_locator import inset_axes
from mpl_toolkits.axes_grid import make_axes_locatable

from scipy.stats import chi2, chi

import numpy
import math

import itertools

from glue.ligolw import utils
from glue.ligolw import lsctables
from glue.ligolw import table
from glue.ligolw import ligolw
lsctables.use_in(ligolw.LIGOLWContentHandler)
from glue import segments
from glue.lal import LIGOTimeGPS
from glue.lal import Cache

from optparse import OptionParser

parser = OptionParser()
parser.add_option("-c", "--channels", action="append", help="Comma separated list of channels. Will plot triggers only from the channels indicated. Otherwise, all channels will be plotted.")
parser.add_option("-o", "--output-file", action="store", help="File to output image to. Default: trigmap.png")
parser.add_option("-i", "--instruments", action="store", help="Comma separates list of instruments. Will plot channels only from the instruments indicated. Otherwise, all no instrument selection is done.")
parser.add_option("-I", "--input-cache", action="store", help="Use this argument as the file list rather than specifying on the command line.")
parser.add_option("-s", "--gps-start", action="store", type=float, help="Plot triggers only after indicated gps time. Default is to infer from triggers.")
parser.add_option("-e", "--gps-end", action="store", type=float, help="Plot triggers only before indicated gps time. Default is to infer from triggers.")
parser.add_option("-f", "--low-frequency", action="store", type=float, default=None, help="Plot only trigggers with central frequency above this frequency. Default is 0.")
parser.add_option("-F", "--high-frequency", action="store", type=float, default=None, help="Plot only trigggers with central frequency below this frequency.")
parser.add_option("-S", "--snr-thresh", action="store", type=float, default=0.0, help="Plot trigggers with snr greater than that indicated.")
parser.add_option("-p", "--plot-value", action="store", default="snr", help="What value to plot; choices are 'snr' (default), 'confidence', and 'amplitude'.")
parser.add_option("-m", "--mark-time", action="store", type=float, help="Mark a particular time on the map. Useful for identifying triggers in a file spanning a large time.")
parser.add_option("-M", "--mark-loudest", action="store", default=None, help="Mark a specific event in the title. Two choices are available 'snr' and 'conf'. 'snr' will mark the loudest event by SNR, and 'conf' will mark the most confident event.")
parser.add_option("-r", "--enable-event-hist", action="store_true", help="Enable the subplot with event rate versus strength for the channel.")
parser.add_option("-x", "--overlay-segments", action="store", help="The entries for this argument are parsed and if segments with those definitions exist in the viewsegment, they will be overlayed on the plot.")
parser.add_option("-C", "--coalesce-segments", action="store_true", help="If segments are found to be plotted, this will cause them to be coalesced before adding them to the plot.")
parser.add_option("-t", "--plotting-type", action="store", default="tile", help="Plot tiles, or markers for time-frequency events. Valid choices are 'tile', 'marker', 'both', or most_significant. Default is 'tile'.")
parser.add_option("-l", "--logscale", action="append", help="Set logscale on an axis. Valid options are x, y, and c (for colorbar). Several can be given.")
parser.add_option("-O", "--tile-opacity", type=float, action="store", default=1.0, help="Set the opacity of the plotted tiles. This is useful in the case of non-clustered output to see tiles which may obscured by others. Set to a value between 0 and 1. Default is 1.")
parser.add_option("-E", "--exclude-outside-segments", action="store_true", help="Only plot triggers within segments indicated by the -x option.")
parser.add_option("-j", "--plot-injections-from-file", action="store", help="Plot all injections from the file indicated. File must contain a sim_burst or sim_inspiral table. If a sim_inspiral table is found, will plot the 3.5 PN time-frequency track. If a sim_burst table is found, will plot the approximate extent of the time frequency area on the map.")
parser.add_option("-d", "--restrict-dof", action="append", help="Only plot tiles with these DOF. This option can be given several times.")
parser.add_option("-D", "--restrict-duration", action="append", help="Only plot tiles with these durations. This option can be given several times.")
parser.add_option("-b", "--restrict-bandwidth", action="append", help="Only plot tiles with these bandwidths. This option can be given several times.")

opts, args = parser.parse_args()

def get_sim_inspirals_from_file(filename):
	"""
	Get a sim inspiral table from the file name."
	"""
	xmldoc = utils.load_filename(filename)
	return table.get_table( xmldoc, lsctables.SimInspiralTable.tableName )

def get_sim_burst_from_file( filename ):
	"""
	Get a sim burst table from the file name."
	"""
	xmldoc = utils.load_filename(filename)
	return table.get_table( xmldoc, lsctables.SimBurstTable.tableName )
	
def plot_wnb( sim_burst ):
	"""
	Return a rectangular outline of a band / time limited white noise burst.
	"""
	return matplotlib.patches.Rectangle( 
		# lower left point
		( sim_burst.get_time_geocent()-sim_burst.duration/2.0, (sim_burst.frequency) ),
		# tile extent
		sim_burst.duration,
		sim_burst.bandwidth,
		linewidth=1.2,
		linestyle='dashdot',
		facecolor='none'
	)

def plot_csg( sim_burst ):
	"""
	Return an elliptical outline of a cosine / sine gaussian burst.
	"""
	# Gaussian in time domain with width w has width 1/w in frequency
	# domain
	# tau = Q / sqrt(2) pi f_0, so 1/tau = sqrt(2) pi f_0 / Q
	sim_burst.duration = sim_burst.q / math.sqrt(2) / math.pi / sim_burst.frequency 
	sim_burst.bandwidth = 1.0 / sim_burst.duration
	center = (float(sim_burst.get_time_geocent()),
	 		  (sim_burst.frequency) )
	return matplotlib.patches.Ellipse( xy=center, width=sim_burst.duration/2.0, height=sim_burst.bandwidth/2.0, facecolor = 'none', linewidth=1.2, linestyle='dashed' )

def plot_chirp( mtot, mchirp, end_time ):
		"""
		Return the time-frequency track of a non-spinning 3.5 PN TaylorT2 expression for t(v). Reference BIOPS (arXiv:0907.0700) Eq. 3.8b
		Courtesy of Andy Lundgren (andrew.lundgren@ligo.org)
		"""
		mtsun = 4.9254923218988638e-06 
		eta = pow(mchirp/mtot,5./3.)
		freqs = numpy.arange(10.0, 2048.0)
		v = pow(numpy.pi*mtsun*mtot*freqs,1./3.)

		newt_chirp = (-5.*mtsun*mtot)/(256.*eta)*pow(v,-8.)
		pn_series = (1. \
						+ (2.94841 + 3.66667*eta) * pow(v,2) \
						+ (-20.1062) * pow(v,3) \
						+ (6.02063 + 10.7718*eta + 8.56944*eta*eta) * pow(v,4) \
						+ (-96.3546 + 13.6136*eta) * pow(v,5) \
						+ (120.87 + 661.664*eta - 8.80266*eta*eta + 19.7261*eta*eta*eta \
							+ 65.219*numpy.log(v)) * pow(v,6) \
						+ (-381.403 - 314.587*eta + 123.079*eta*eta) * pow(v, 7))
		times = newt_chirp * pn_series
		times = end_time + times - times[-1]

		return times, freqs

logaxes = opts.logscale or []
for ax in logaxes:
	if ax not in ['x', 'y', 'c']:
		print >>sys.stderr, "Invalid log axis choice %s." % ax
		exit()

plotting_types = ["tile", "marker", "both", "most_significant"]
if( opts.plotting_type not in plotting_types ):
	print >>sys.stderr, "Invalid plot type choice %s, valid choices are %s" % (opts.plotting_type, ", ".join(plotting_types))
	exit()

if opts.low_frequency is None:
	opts.low_frequency = 0

if opts.high_frequency is None:
	opts.high_frequency = float("inf")

if opts.mark_loudest not in [None, 'snr', 'conf']:
	exit( "Invalid argument to '--mark-loudest', please choose one of 'snr' or 'conf'" )

if opts.channels is None:
	channels = None
else:
	channels = opts.channels

if opts.instruments is not None:
    instruments = opts.instruments
else: instruments = None

if opts.overlay_segments is not None: 
	segment_defs = opts.overlay_segments.split(",")
else:
	segment_defs = None

if opts.input_cache is not None:
	args += Cache.fromfile( open(opts.input_cache) ).pfnlist()

events = lsctables.SnglBurstTable()

analyzed_segments = segments.segmentlistdict()

search_map = defaultdict(list)
event_map = defaultdict(list)
for arg in args:
	xmldoc = utils.load_filename( arg )

	for tbl in lsctables.getTablesByType( xmldoc, lsctables.SnglBurstTable ):
		for sb in tbl:
			event_map[sb.process_id].append( sb )

	for tbl in lsctables.getTablesByType( xmldoc, lsctables.SearchSummaryTable ):
		# search segment data and event filtering
		analyzed_segments.extend( tbl.get_out_segmentlistdict() )
		for ss in tbl:
			search_map[ss.process_id].append( ss )

# Map events to their searches, and ensure they are in the proper "out" segment
for pid, eventlist in event_map.iteritems():
	# FIXME: Clustering will sometimes take an event past the out segment and this will discard it
	try:
		search_map[pid]
	except ValueError:
		continue # not mapped to a search
	#for ss in search_map[pid]:
		#all_events += filter(lambda sb: sb.peak_time in ss.get_out(), eventlist )
	#all_events += eventlist
	for e in eventlist:
		if e.chisq_dof == float('nan'):
			e.chisq_dof = e.duration*e.bandwidth*2
		# FIXME: Don't do this, this is the easy way out
		if opts.plot_value == 'confidence':
			e.snr = e.confidence
		elif opts.plot_value == 'amplitude':
			e.snr = e.amplitude
		elif opts.plot_value != 'snr':
			raise ArgumentError("Do not understand %s as a parameter to --plot-value" % opts.plot_value)

	events += eventlist

if len(events) == 0:
	print "No events remain."
	exit()

# Get instrument and channel information
if instruments is None and channels is None:
	instruments = {}
	for event in events:
		instruments[event.ifo] = None
	instruments = instruments.keys()
if instruments is None:
	instruments = {}
	for channel in channels:
		instruments[channel.split(":")[0]] = None
	instruments = instruments.keys()

print "Examining instruments: %s" % str(instruments)

# Segment information
segs = segments.segmentlistdict()
searches = lsctables.SearchSummaryTable()
for arg in args:

	xmldoc = utils.load_filename( arg )
	def_id_pair = []
	if segment_defs is not None:
		segtables = lsctables.SegmentDefTable()
		for tbl in lsctables.getTablesByType( xmldoc, lsctables.SegmentDefTable ):
			segtables += tbl 

		# TODO: Reject if ifo isn't in acceptable list
		def_id_pair = [ (r.segment_def_id, r.name, unicode(r.ifos)) for r in segtables ]
		def_id_pair = filter( lambda (i,n,j): n in segment_defs and j in instruments, def_id_pair )
		def_id_pair = dict(map( lambda (i,n,j): (i,(n,j)), def_id_pair ))

	if len(def_id_pair) > 0:
		for i, ids in def_id_pair.iteritems(): 
			segs[ids] = segments.segmentlist()

		segtables = lsctables.SegmentTable()
		for tbl in lsctables.getTablesByType( xmldoc, lsctables.SegmentTable ):
			segtables += tbl 

		for seg in segtables:
			try:
				segm = seg.get()
			except AttributeError:
				segm = segments.segment(LIGOTimeGPS(seg.start_time), LIGOTimeGPS(seg.end_time)) 
			if seg.segment_def_id in def_id_pair.keys():
				segs[def_id_pair[seg.segment_def_id]].append( segm )

viewseg = segments.segment( segments.NegInfinity, segments.PosInfinity )
if opts.gps_start is None:
	viewseg = segments.segment( min([e.peak_time for e in events]), viewseg[1] )
else: viewseg = segments.segment( opts.gps_start, viewseg[1] )
if opts.gps_end is None:
	viewseg = segments.segment( viewseg[0], max([e.peak_time for e in events]) )
else: viewseg = segments.segment( viewseg[0], opts.gps_end )
viewseg = segments.segment( viewseg[0], viewseg[1] )

for key, segl in segs.iteritems():
	try:
		segs[key] = segs[key] & segments.segmentlist([viewseg])
	except ValueError:
		segs[key] = segments.segmentlistdict()

print "Summary of segments included:"
for k, v in segs.iteritems():
	print "%s: %f" % (k, abs(v))
if opts.coalesce_segments:
	segs.coalesce()
	print "Summary of segments included, after coalesce:"
	for k, v in segs.iteritems():
		print "%s: %f" % (k, abs(v))

# FIXME: Return to this when the ifos situation is sorted out
analyzed_segments.coalesce()
# Invert to get non-analyzed periods
analyzed_segments = analyzed_segments.__invert__()

# Channel information
if channels is None:
	channels = {}
	for event in events:
		if event.ifo in instruments:
			channels[(event.ifo, event.channel)] = None
	channels = channels.keys()
	channels = map( lambda cl: ":".join(cl), channels )

channels.sort()
print "Looking at channel list:\n", "\n".join( channels )

# construct tf map extent
fig = pyplot.figure(figsize=(15,len(channels)*7), dpi=160)
#gs = matplotlib.gridspec.GridSpec( nrows = len(channels), ncols = 2, width_ratios=[30,1])#, hspace=0.3 )
gs = matplotlib.gridspec.GridSpec( nrows = len(channels), ncols = 1 )

pn=0
events.sort(key=lambda e: e.ifo + ":" + e.channel) # Required for groupby
event_iter = itertools.groupby( events, lambda e: e.ifo + ":" + e.channel )

only_dof = segments.segment(0, segments.PosInfinity)
if opts.restrict_dof:
	only_dof = map(float, opts.restrict_dof)
	print "Restricting to tiles with DOF = %s" % ",".join(map(str,only_dof))

only_dur = segments.segment(0, segments.PosInfinity)
if opts.restrict_duration:
	only_dur = map(float, opts.restrict_duration)
	print "Restricting to tiles with duration = %s" % ",".join(map(str,only_dur))

only_bnd = segments.segment(0, segments.PosInfinity)
if opts.restrict_bandwidth:
	only_bnd = map(float, opts.restrict_bandwidth)
	print "Restricting to tiles with bandwidth = %s" % ",".join(map(str,only_bnd))

freq_band = segments.segment(opts.low_frequency, opts.high_frequency)
for channel, selected_events in event_iter:
	if channel not in channels: continue

	sys.stdout.write( "Filtering %d events... " % len(events) )
	selected_events = filter(lambda event: \
        event.get_band().intersects(freq_band) \
        and event.get_period().intersects(viewseg) \
        and event.snr >= opts.snr_thresh \
        and event.snr != float("nan") \
        and event.snr != float("Inf"), selected_events)

	selected_events = filter( lambda event: event.chisq_dof in only_dof and event.bandwidth in only_bnd and event.duration in only_dur, selected_events )

	if opts.exclude_outside_segments:
		for k in segs.keys():
			try:
				segs[k[1]].extend( segs[k] )
			except KeyError:
				segs[k[1]] = segs[k]
		selected_events = filter( lambda event: event.peak_time in segs[event.ifo], selected_events )
	print "%d events remain." % len(selected_events)
	if len(selected_events) == 0:
		print "No events remain, skipping."
		continue

	high_frequency = opts.high_frequency
	if opts.high_frequency == float("inf"):
		high_frequency = max( [e.central_freq for e in selected_events] )
	low_frequency = opts.low_frequency or 0

	try:
		inst, channel = channel.split(":")
	except ValueError:
		_, inst, channel = channel.split(":")

	ax = pyplot.subplot(gs[pn])
	# Extent in time
	ax.set_xlim(viewseg[0], viewseg[1])
	# Extent in frequency
	# FIXME: Different channels won't have the same frequency settings
	ax.set_ylim(low_frequency, high_frequency)

	selected_events.sort(key=lambda e: e.snr)
	selected_events = numpy.array( selected_events )

	# what's our snr scale
	if len(selected_events) > 0:
		if 'c' in logaxes:
			snr_norm = matplotlib.colors.Normalize( 
				numpy.log10(selected_events[0].snr), 
				numpy.log10(selected_events[-1].snr)
			)
		else:
			snr_norm = matplotlib.colors.Normalize( 
				selected_events[0].snr, selected_events[-1].snr
			)
	else:
		snr_norm = matplotlib.colors.Normalize( 0, 1 )

	# Make me purty
	ax.grid()
	if opts.mark_time in viewseg:
		pyplot.axvline(opts.mark_time, color="k")

	if pn == len(channels)-1:
		pyplot.xlabel( "Time (s)" )
	pyplot.ylabel( "Frequency (Hz)" )

	# construct tiles
	statstr = "0% complete"
	sys.stdout.write( statstr )
	statlen = len(statstr)

	tiles, snrs, markers = [], [], []
	loudest = [0,-1]
	for i, event in enumerate(selected_events):
		if event.snr == float('Inf') or event.snr == float('NaN'):
   			print >>sys.stderr, "Warning, invalid SNR value found for event id %s, skipping" % event.event_id
   			continue

		sys.stdout.write( "\b"*statlen )
		statstr = "%d%% complete" % numpy.round(float(i)/len(selected_events)*100)
		sys.stdout.write( statstr )
		statlen = len(statstr)

		# 1. plot tile 

		if opts.plotting_type == "most_significant":
			start = float(event.ms_start)
			band = event.ms_band
			duration = event.ms_duration
			# TODO: Is this correct?
			freq = event.peak_frequency
		else:
			start = float(event.get_start())
			band = event.bandwidth
			duration = event.duration
			freq = event.central_freq

		tile = matplotlib.patches.Rectangle(
			# lower left point
			( start, (freq - band/2.0) ),
			# tile extent
			duration,
			band
		)
		#ax.add_patch( tile )
		if opts.plotting_type in ["tile", "both", "most_significant"]:
			tiles.append( tile )
		if opts.plotting_type in ["marker", "both"]:
			try:
				freq = event.peak_frequency
			except AttributeError:
				freq = event.central_freq
			if 'c' in logaxes:
				markers.append( (float(event.get_peak()), freq, numpy.log10(event.ms_snr if opts.plotting_type == "most_significant" else event.snr)) )
			else:
				markers.append( (float(event.get_peak()), freq, event.ms_snr if opts.plotting_type == "most_significant" else event.snr) )

		if 'c' in logaxes:
			snrs.append(numpy.log10(event.ms_snr if opts.plotting_type == "most_significant" else event.snr))
		else:
			snrs.append(event.ms_snr if opts.plotting_type == "most_significant" else event.snr)

        #
        # Keep track of the loudest / most significant
        #
		if loudest[0] < (event.ms_snr if opts.plotting_type == "most_significant" else event.snr) and opts.mark_loudest == 'snr':
			loudest = [(event.ms_snr if opts.plotting_type == "most_significant" else event.snr), float(event.get_peak()), event.central_freq, event.duration, event.bandwidth]
		elif loudest[0] < event.confidence and opts.mark_loudest == 'conf':
			loudest = [event.confidence, event.get_peak(), event.central_freq, event.duration, event.bandwidth]

	markers = numpy.array(markers).T
	print ""

	title = "channel: %s, inst %s" % (channel, inst)
	if pn == 0:
		title = "Tile energy time frequency map\n%s" % title
	if opts.mark_loudest is not None:
		print "Loudest event at (%s)" % ", ".join(map(lambda a: str(a), loudest[1:]))
		if opts.plotting_type != "marker":
			ax.add_patch( matplotlib.patches.Ellipse( xy=loudest[1:3], width=loudest[3]*1.2, height=loudest[4]*1.2, facecolor = 'none', linewidth=1.2 ) )
		else:
			pyplot.plot( [loudest[1]], [loudest[2]], marker='*', markerfacecolor='yellow', markersize=15.0 )
		title += " loudest event (%s=%.3g): %10.1f" % (opts.mark_loudest, math.sqrt(loudest[0]), loudest[1])
	pyplot.title( title )


	# TODO: See if replacing data speeds up the process
	if opts.plotting_type in ["tile", "both"]:
		patches = PatchCollection( tiles, match_original = False, cmap = matplotlib.cm.jet, edgecolor='none', antialiased = False, alpha=opts.tile_opacity )
		patches.set_array( numpy.array(snrs) )
		ax.add_collection( patches )
	if opts.plotting_type == "marker":
		markers = numpy.array(markers)
		marker_thresh = numpy.array([6, 10])**2
		if 'c' in logaxes:
			marker_thresh = numpy.log10(marker_thresh)
		def sizef( m ):
			if m <= marker_thresh[0]:
				return 1
			elif m > marker_thresh[0] and m < marker_thresh[1]:
				return 5
			else:
				return 15
		size = map(sizef, markers[2])
		ax.scatter( markers[0], markers[1], c=markers[2], s=size, lw=0, cmap=matplotlib.cm.copper_r ) # with circles
	elif opts.plotting_type == "both":
		ax.scatter( markers[0], markers[1], marker='+' ) # with crosses

	# Plot simulations
	if opts.plot_injections_from_file is not None:
		try:
			sim_insp = get_sim_inspirals_from_file( opts.plot_injections_from_file )
			for si in sim_insp:
				# half hour buffer for BNS
				print si.geocent_end_time
				if si.geocent_end_time < viewseg[0] - 1800 or si.geocent_end_time > viewseg[1] + 1800:
					continue # outside, skip
				print "plotting chirp of masses (%f, %f) with end time %d.%d" % (si.mass1, si.mass2, si.geocent_end_time, si.geocent_end_time_ns) 
				time, freq = plot_chirp( si.mass1+si.mass2, si.mchirp, si.geocent_end_time + si.geocent_end_time_ns*1e-9 )
				ax.plot( time, freq, "k-" )
		except ValueError:
			print >>sys.stderr, "No sim_inspiral table found, skipping..."


		try:
			sim_burst = get_sim_burst_from_file( opts.plot_injections_from_file )

			valid_waveforms = ["BTLWNB", "SineGaussian"]
			
			for sb in sim_burst:
				if sb.waveform not in valid_waveforms: 
					continue # Dunno how to do this one.
			
				if not (sb.time_geocent_gps > viewseg[0] and sb.time_geocent_gps < viewseg[1]):
					continue # outside, skip
				print "plotting %s with params" % sb.waveform
				if sb.waveform == "BTLWNB":
					print "\tt:%10.2f +/- %4.2f f: %4.2f +/- %4.2f" % (sb.get_time_geocent(), sb.duration/2.0, sb.frequency+sb.bandwidth/2.0, sb.bandwidth/2.0)
					ax.add_patch( plot_wnb( sb ) )
				elif sb.waveform == "SineGaussian":
					print "\tt:%10.2f f: %4.2f q: %4.2f" % (sb.get_time_geocent(), sb.frequency, sb.q)
					ax.add_patch( plot_csg( sb ) )
					
		except ValueError:
			print >>sys.stderr, "No sim_burst table found, skipping..."


	lastwrite = viewseg[0]
	# FIXME: This is empircal at best.
	tsize = 0.02*abs(viewseg)
	t = 0
	for (segdef, ifo), slist in segs.iteritems():
		if ifo != inst: continue
		for seg in slist:
			# Indicate the name and if the segment extent continues to left or right
			name = str(ifo+":"+segdef)
			if seg[0] < viewseg[0]:
				name = unichr(0x2191) + name 
			if seg[1] > viewseg[1]:
				name = name + unichr(0x2192) 

			try:
				seg = viewseg & seg 
			except ValueError: # outside the segment
				continue
			# Shouldn't happen, but just in case
			if seg is None: continue
			start, end = seg[0], seg[1]
			tile = matplotlib.patches.Rectangle(
				# lower left point
				( start, 1e-10 ),
				# tile extent
				(end-start),
				high_frequency,
				color = 'k',
				alpha = 0.1, hatch = "/"
			)

			# If multiple segment names would be written to the same area
			# move the text over. FIXME: This is not the definitive way to do 
			# this it's a hack at best
			if lastwrite == start:
				lastwrite += tsize*t
				t+=1
			else:
				lastwrite = start
				t = 0
			ax.text( lastwrite, 0.95*high_frequency, name, rotation = "vertical" )
			ax.add_patch( tile )

	for (ifo, aseg) in analyzed_segments.iteritems():
		# TODO: what to do with this?
		continue
		if ifo != inst: continue
		for seg in aseg:
			# Indicate the name and if the segment extent continues to left or right
			name = str(ifo+"_analyzed")
			if( seg[0] < viewseg[0] ):
				name = unichr(0x2191) + name 
			if( seg[1] > viewseg[1] ):
				name = name + unichr(0x2192) 

			try:
				seg = viewseg & seg 
			except ValueError:
				continue # not in segment
			# Shouldn't happen, but just in case
			if seg is None: continue
			start, end = seg[0], seg[1]
			tile = matplotlib.patches.Rectangle(
				# lower left point
				( start, 1e-10 ),
				# tile extent
				(end-start),
				high_frequency,
				color = 'k',
				alpha = 0.1, hatch = "/"
			)

			# If multiple segment names would be written to the same area
			# move the text over. FIXME: This is not the definitive way to do 
			# this it's a hack at best
			if lastwrite == start:
				lastwrite += tsize*t
				t+=1
			else:
				lastwrite = start
				t = 0
			ax.text( lastwrite, 0.95*high_frequency, name, rotation = "vertical" )
			ax.add_patch( tile )

	if 'x' in logaxes and 'y' in logaxes:
		ax.loglog()
	elif 'x' in logaxes:
		ax.semilogx()
	elif 'y' in logaxes:
		ax.semilogy()

	# Add colorbar
	divider = make_axes_locatable(ax)
	cax = divider.append_axes( "right", size="5%", pad=0.1 )
	"""
	if 'c' in logaxes:
		formatter = matplotlib.ticker.LogFormatter(10, labelOnlyBase=False)
		mins, maxs = selected_events[0].snr, selected_events[-1].snr
		ticks = numpy.logspace(numpy.log10(mins), numpy.log10(maxs), 5, base=10) 
	else:
		formatter = None
		ticks = None
	"""
	#cbl = matplotlib.colorbar.ColorbarBase( cax, cmap=matplotlib.cm.jet, norm=snr_norm, orientation="vertical" ) #, format=formatter, ticks=ticks )
	if opts.plotting_type == "marker":
		cbl = matplotlib.colorbar.ColorbarBase( cax, cmap=matplotlib.cm.copper_r, norm=snr_norm, orientation="vertical" ) #, format=formatter, ticks=ticks )
	else:
		cbl = matplotlib.colorbar.ColorbarBase( cax, cmap=matplotlib.cm.jet, norm=snr_norm, orientation="vertical" ) #, format=formatter, ticks=ticks )
	if 'c' in logaxes:
		cbl.set_label( "log10 Tile Energy" )
	else:
		cbl.set_label( "Tile Energy" )


	# If requested, make a histogram of the tile energies and compare to
	# expectation
	snrs = numpy.array([e.snr for e in selected_events])
	if opts.enable_event_hist:
		sax = divider.append_axes( "right", size="50%", pad=0.7 )
		#pyplot.hist( selected_events**2, log=True )
		pyplot.hist( snrs*0.62, log=True )
		dof = events[0].chisq_dof
		tfmap_area = abs(viewseg)*(high_frequency-low_frequency)
		#dist = chi2(dof).pdf( selected_events**2 )*tfmap_area/dof
		# magic number to get the effective DOF = 0.62
		dist = chi2(dof*0.62).pdf( snrs*0.62 )*tfmap_area/dof
		pyplot.semilogy()
		pyplot.plot( snrs, dist, "r-", label="$\chi_{%d}^2$" % dof )
		sax.yaxis.set_ticks_position( "right" )
		pyplot.locator_params(nbins=3, axis="x")
		pyplot.xlabel( "Tile Energies" )
		pyplot.legend()

	#inset_ax = inset_axes( ax, width="20%", height="33%", loc=4 )
	#pyplot.locator_params(nbins=3, axis="x")

	pn += 1

# Drop the file
filename = opts.output_file or "trigmap"
filename, filext = os.path.splitext(filename)
if len(filext) == 0: filext = "png"
else: filext = filext[1:]
if opts.mark_time: filename += "_%d" % int(opts.mark_time)
filename = "%s.%s" % (filename, filext)
print "Saving " + filename
fig.savefig( filename )
