#!/usr/bin/python
import os, sys, re
from glue import lal
from glue import segmentsUtils
from glue.ligolw import utils
from glue.ligolw import lsctables
from glue.ligolw.utils import segments as llwsegments
from glue import iterutils
from optparse import *
import numpy
import matplotlib
matplotlib.use('Agg')
import pylab

from pylal import git_version
__author__ = "Chad Hanna <channa@ligo.caltech.edu>"
__version__ = "git id %s" % git_version.id
__date__ = git_version.date

class IFOdata(object):
	def __init__(self):
		self.color_map = {'H1': 'r', 'H2': 'b', 'L1': 'g', 'V1': 'm'}
		self.line_map = {'H1': ['o', 'x', 'd'], 'H2': ['o', 'x', 'd'], 'L1': ['o', 'x', 'd'], 'V1': ['o', 'x', 'd']}
		self.counter = {'H1': 0, 'H2': 0, 'L1': 0, 'V1': 0}
	def color(self, ifo):
		return self.color_map[ifo]
	def line(self, ifo):
		self.counter[ifo] +=1
		return self.line_map[ifo][(self.counter[ifo]-1) % 3]

class InspiralRangeFromSummValue(object):
	def __init__(self):
		self.distance = {}
		self.start = {}

	def append_rows(self, rows):
		for row in rows:
			if not row.name == "inspiral_effective_distance": continue
			mass1, mass2, snr  = row.comment.split('_')
			self.distance.setdefault(row.ifo, {})
			self.start.setdefault(row.ifo, {})
			self.distance[row.ifo].setdefault((mass1, mass2), []).append(float(row.value))
			self.start[row.ifo].setdefault((mass1, mass2), []).append(float(row.start_time))

	def ns_by_time_ifo(self, ifo):
		try:
			#FIXME THIS IS TERRIBLE
			time = self.start[ifo][('1.40', '1.40')]
			distance = self.distance[ifo][('1.40', '1.40')]
			return time, distance
		except:
			return [], []

	def mean_std_by_mass_ifo(self,ifo):
		items = [(float(x)+float(y), z) for ((x,y),z) in self.distance[ifo].items()]
		items.sort()
		return [x for (x,y) in items], [numpy.mean(y) for (x,y) in items], [numpy.std(y) for (x,y) in items]

	def get_bins(self):
		dists = []
		num = []
		for ifo in self.start.keys():
			time, dist = self.ns_by_time_ifo(ifo)
			dists.extend(dist)
			num.append(len(dist))
		numbins = min(num) / 10
		if numbins < 5: numbins = 5
		if numbins > 50: numbins = 50
		bins = numpy.linspace(min(dists),max(dists),numbins)
		return bins
			
def parse_command_line():
	parser = OptionParser(
	version = "%prog",
		description = "Plot number of rows in a single inspiral table from many files"
        )
        parser.add_option("-p", "--pattern", metavar = "pat", action="append", default = [], help = "Set the pattern for parsing the cache file for patterns to find sngl inspiral tables")
	parser.add_option("-b", "--base", metavar = "str", help = "Set the figure basename, default derive from the cache")
	parser.add_option("-s", "--stride", metavar = "num", help = "Set the stride of files to read in (in case you want to debug, or just decimate the dataset)")

	options, filenames = parser.parse_args()
	if not options.pattern: options.pattern = ['TMPLTBANK_FULL_DATA']
	if not options.base: options.base = os.path.basename(filenames[0]).replace(".cache","")
	return options, (filenames or [])

print >>sys.stderr, "processing cache file for pattern..."
opts, cache = parse_command_line()
print >> sys.stderr, "\t%s" % "|".join(opts.pattern)
pats = re.compile("|".join(opts.pattern))
new_list = filter(pats.search,open(cache[0]).readlines())
pats = [re.compile(val) for val in opts.pattern]

print >>sys.stderr, "found %d files of interest" % len(new_list)

range = InspiralRangeFromSummValue()
for i, line in enumerate(new_list):
	if opts.stride and i % int(opts.stride): continue
	print >>sys.stderr, "\tprocessing %d/%d\r" % (i, len(new_list)),
	c = lal.CacheEntry(line)
	patIX = [j for j, val in enumerate(pats) if val.search(c.description)][0]
	XML = utils.load_filename(c.path, gz=c.path.endswith('.gz'))
	try: summ = lsctables.table.get_table(XML, "summ_value")
	except: continue
	if len(summ):
		range.append_rows(summ)		
print >>sys.stderr, "\n"
ifodata = IFOdata()
pylab.figure(figsize=(15,5))
start = []
for ifo in range.start.keys():
	time, dist = range.ns_by_time_ifo(ifo)
	start.append(min(time))
start = min(start)

outtab=open('%s_%s_%s.txt' % (opts.base, os.path.basename(sys.argv[0]), "-".join(opts.pattern).replace(".*","")), "w")
outtab.write("||DETECTOR||NS RANGE||STD. DEV.||\n")
for ifo in range.start.keys():
	pylab.subplot(131)
	time, dist = range.ns_by_time_ifo(ifo)
	outtab.write("||%s||%.2f||%.2f||\n" % (ifo,numpy.mean(dist), numpy.std(dist)))
	pylab.plot((numpy.array(time)-start) / 86400.0, dist, color=ifodata.color(ifo),marker=ifodata.line(ifo))
	pylab.xlabel("Time (days) relative to %d" % start)
	pylab.ylabel("NS range (Mpc)")
	pylab.subplot(132)
	pylab.hist(dist, bins=range.get_bins(), alpha=0.35, ec=ifodata.color(ifo), fc=ifodata.color(ifo), ls='solid', linewidth=2.0)#, color=ifodata.color(ifo))
	pylab.xlabel("NS range (Mpc)")
	pylab.ylabel("Number")
	pylab.subplot(133)
	mass, mean, std = range.mean_std_by_mass_ifo(ifo)
	pylab.errorbar(mass, mean, std, color=ifodata.color(ifo),marker=ifodata.line(ifo))
	pylab.xlabel("Total Mass (Msun)")
	pylab.ylabel("Range (Mpc)")
outtab.close()
figname = '%s_%s_%s.png' % (opts.base, os.path.basename(sys.argv[0]), "-".join(opts.pattern).replace(".*",""))
print >> sys.stderr, "Writing output to %s" % (figname,)
pylab.savefig(figname)
