#!/usr/bin/env python
#
# Copyright (C) 2015  Chad Hanna
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 2 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.


import sys
import os
import subprocess
import re
import time
from optparse import OptionParser
import numpy
import matplotlib
matplotlib.use('Agg')
from matplotlib import pyplot
import multiprocessing


matplotlib.rcParams.update({
        "font.size": 10.0,
        "axes.titlesize": 10.0,
        "axes.labelsize": 10.0,
        "xtick.labelsize": 8.0,
        "ytick.labelsize": 8.0,
        "figure.dpi": 600,
        "savefig.dpi": 600,
        "legend.fontsize": 8.0,
        "text.usetex": True
})


import lal
from glue.ligolw import ligolw
from glue.ligolw import table
from glue.ligolw import lsctables
from glue.ligolw import utils as ligolw_utils
from glue import lal as gluelal
from gstlal import lvalert_helper
from ligo.gracedb import rest as gracedb
from pylal.datatypes import LIGOTimeGPS
from pylal import rate


# define a content handler
class LIGOLWContentHandler(ligolw.LIGOLWContentHandler):
	pass
lsctables.use_in(LIGOLWContentHandler)


def now():
	return LIGOTimeGPS(lal.UTCToGPS(time.gmtime()), 0)


def parse_command_line():
	parser = OptionParser()
	parser.add_option("--gracedb-service-url", default="%s" % gracedb.DEFAULT_SERVICE_URL, help = "GraceDb service url to upload to (default: %s)" % gracedb.DEFAULT_SERVICE_URL)

	options, gid_list = parser.parse_args()
	
	if len(gid_list) > 1:
		raise ValueError("%d graceids specified, no more than one allowed" % len(gid_list))

	if len(gid_list) == 0:
		lvalert_data = json.loads(sys.stdin.read())
		logging.info("%(alert_type)s-type alert for event %(uid)s" % lvalert_data)
		logging.info("lvalert data: %s" % repr(lvalert_data))
		filename = os.path.split(urlparse.urlparse(lvalert_data["file"]).path)[-1]
		if filename not in (u"psd.xml.gz",):
			logging.info("filename is not 'psd.xml.gz'.  skipping")
			sys.exit()
		gid = str(lvalert_data["uid"])
	else:
		gid = gid_list[0]

	return options, gid
	

options, graceid = parse_command_line()


# Connect to gracedb and extract the coinc document
gracedb_client = gracedb.GraceDb(service_url = options.gracedb_service_url)
coinc_xmldoc = lvalert_helper.get_coinc_xmldoc(gracedb_client, graceid)
coinc_inspiral, = lsctables.CoincInspiralTable.get_table(coinc_xmldoc)
gps = coinc_inspiral.get_end()


# FIXME there must be a better way to do this
# If the event is not older than five hours, wait till it is.
sleep = 3600 * 5 -  abs(now() - gps)
if sleep > 0:
	print >> sys.stderr, "sleeping for %d seconds" % sleep
	time.sleep(sleep)


def parse(c, gps = gps):
	xmldoc = ligolw_utils.load_filename(c.path, contenthandler = LIGOLWContentHandler, verbose = True)
	sngl_inspiral_table = lsctables.SnglInspiralTable.get_table(xmldoc)
	return [row for row in sngl_inspiral_table if row.get_end() < gps + 600 and row.get_end() >= gps - 600]


def mchirp(m1,m2):
	return (m1*m2)**.6 / (m1+m2)**.2


# FIXME assumes we are in the analysis directory
try:
	os.mkdir("followups")
except OSError:
	pass
try:
	wd = "followups/%d" % gps
	os.mkdir(wd)
except OSError:
	print "The directory for this event already exists, continuing anyway."


# FIXME hardcodes gps digits
# FIXME add a check that the gps time isn't near a boundary, we might have to
# read the gps directory before or after this one
subprocess.call(["ls %s/* | lalapps_path2cache > %s/%s.cache" % (str(gps)[:5], wd, str(gps))], shell=True)


# Parse the files in many threads
# FIXME don't hard code this regex
pattern = re.compile("0[0-9]{3}_LLOID$")
pool = multiprocessing.Pool(8)
ce = [gluelal.CacheEntry(l) for l in open("%s/%s.cache" % (wd, str(gps)))]
sngl = []
for res in pool.map(parse, [c for c in ce if gps in c.segment and pattern.match(c.description)]):
	sngl.extend(res)
	

mchirps = numpy.array([mchirp(r.mass1, r.mass2) for r in sngl])
snrs = [r.snr for r in sngl]
chisqs = [r.chisq for r in sngl]
maxchisq, minchisq = max(chisqs), min(chisqs)
maxsnr, minsnr = max(snrs), min(snrs)
maxmchirp, minmchirp = max(mchirps), min(mchirps)


# FIXME don't hardocde
# Get various data by IFO
times, snrs, mchirps, chisqs, ba = {}, {}, {}, {}, {}
for n,ifo in enumerate(["H1", "L1"]):
	times[ifo] = numpy.array([float(r.get_end() - gps) for r in sngl if r.ifo == ifo])
	snrs[ifo] = numpy.array([r.snr for r in sngl if r.ifo == ifo])
	chisqs[ifo] = numpy.array([r.chisq for r in sngl if r.ifo == ifo])
	mchirps[ifo] = numpy.array([mchirp(r.mass1, r.mass2) for r in sngl if r.ifo == ifo])
	ba[ifo] = rate.BinnedArray(rate.NDBins((rate.LinearBins(-600,600,1201), rate.LogarithmicBins(minmchirp, maxmchirp, 200))))
	ba[ifo].array[:] = 4.0
	for t,s,m in zip(times[ifo], snrs[ifo], mchirps[ifo]):
		if ba[ifo][t, m] < s:
			ba[ifo][t, m] = s


# SNR heatmap
fig = pyplot.figure(figsize=(9,3.0))
for n, ifo in enumerate(["H1", "L1"]):
	axes = pyplot.subplot(2,1,n+1)
	pyplot.pcolormesh(ba[ifo].centres()[0], ba[ifo].centres()[1], ba[ifo].array.T, cmap='coolwarm', norm=matplotlib.colors.LogNorm(vmin=4, vmax=maxsnr))
	cb = pyplot.colorbar()
	cb.set_label("$\mathrm{SNR}$")
	axes.set_yscale('log')
	pyplot.ylim([minmchirp, maxmchirp])
	if n==1:
		pyplot.xlabel("$\mathrm{Time \, since \, %d}$" % int(gps))
	pyplot.ylabel("$\mathrm{%s} \, \mathcal{M}$ $(M_{\odot})$" % ifo)
pyplot.minorticks_on()
lvalert_helper.upload_fig(fig, gracedb_client, graceid, filename = "snrheatmap.png", log_message = "SNR heat map", tagname = "background")


# chisq vs SNR
fig = pyplot.figure(figsize=(9, 4.0))
for n, ifo in enumerate(["H1", "L1"]):
	axes = pyplot.subplot(1,2,n+1)
	pyplot.loglog(snrs[ifo], chisqs[ifo], 'k.')
	if n==0:
		pyplot.ylabel("$\chi^2$")
	pyplot.xlabel("$\mathrm{%s \, SNR}$" % ifo)
	pyplot.xlim(minsnr, maxsnr)
	pyplot.ylim(minchisq, maxchisq)
	pyplot.grid()
pyplot.minorticks_on()
lvalert_helper.upload_fig(fig, gracedb_client, graceid, filename = "chisqvsSNR.png", log_message = "Chi-squared vs. SNR", tagname = "background")
