#!/usr/bin/env python
#
# Copyright (C) 2015  Chad Hanna
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 2 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

## @file

import sys
import os
import subprocess
import re
import time
import logging
import urlparse
import json
from optparse import OptionParser
import numpy
import matplotlib
matplotlib.use('Agg')
from matplotlib import pyplot
import multiprocessing


matplotlib.rcParams.update({
        "font.size": 10.0,
        "axes.titlesize": 10.0,
        "axes.labelsize": 10.0,
        "xtick.labelsize": 8.0,
        "ytick.labelsize": 8.0,
        "figure.dpi": 600,
        "savefig.dpi": 600,
        "legend.fontsize": 8.0,
        "text.usetex": True
})


import lal
from glue.ligolw import ligolw
from glue.ligolw import table
from glue.ligolw import lsctables
from glue.ligolw import utils as ligolw_utils
from glue import lal as gluelal
from gstlal import lvalert_helper
from ligo.gracedb import rest as gracedb
from pylal.datatypes import LIGOTimeGPS
from pylal import rate


# define a content handler
class LIGOLWContentHandler(ligolw.LIGOLWContentHandler):
	pass
lsctables.use_in(LIGOLWContentHandler)


def now():
	return LIGOTimeGPS(lal.UTCToGPS(time.gmtime()), 0)


def parse_command_line():
	parser = OptionParser()
	parser.add_option("--gracedb-service-url", default="%s" % gracedb.DEFAULT_SERVICE_URL, help = "GraceDb service url to upload to (default: %s)" % gracedb.DEFAULT_SERVICE_URL)

	options, gid_list = parser.parse_args()
	
	if len(gid_list) > 1:
		raise ValueError("%d graceids specified, no more than one allowed" % len(gid_list))

	if len(gid_list) == 0:
		lvalert_data = json.loads(sys.stdin.read())
		logging.info("%(alert_type)s-type alert for event %(uid)s" % lvalert_data)
		logging.info("lvalert data: %s" % repr(lvalert_data))
		filename = os.path.split(urlparse.urlparse(lvalert_data["file"]).path)[-1]
		if filename not in (u"psd.xml.gz",):
			logging.info("filename is not 'psd.xml.gz'.  skipping")
			sys.exit()
		gid = str(lvalert_data["uid"])
	else:
		gid = gid_list[0]

	return options, gid
	

options, graceid = parse_command_line()


# Connect to gracedb and extract the coinc document
gracedb_client = gracedb.GraceDb(service_url = options.gracedb_service_url)
coinc_xmldoc = lvalert_helper.get_coinc_xmldoc(gracedb_client, graceid)
coinc_inspiral, = lsctables.CoincInspiralTable.get_table(coinc_xmldoc)
gps = coinc_inspiral.get_end()


# FIXME there must be a better way to do this
# If the event is not older than five hours, wait till it is.
sleep = 3600 * 5 -  abs(now() - gps)
if sleep > 0:
	print >> sys.stderr, "sleeping for %d seconds" % sleep
	time.sleep(sleep)


def parse(c, gps = gps):
	xmldoc = ligolw_utils.load_filename(c.path, contenthandler = LIGOLWContentHandler, verbose = True)
	sngl_inspiral_table = lsctables.SnglInspiralTable.get_table(xmldoc)
	return [row for row in sngl_inspiral_table if row.get_end() < gps + 600 and row.get_end() >= gps - 600]


def mchirp(m1,m2):
	return (m1*m2)**.6 / (m1+m2)**.2


# FIXME assumes we are in the analysis directory
try:
	os.mkdir("followups")
except OSError:
	pass
try:
	wd = "followups/%d" % gps
	os.mkdir(wd)
except OSError:
	print "The directory for this event already exists, continuing anyway."


# FIXME hardcodes gps digits to 5 to match current
subprocess.call(["ls %s/* | lalapps_path2cache > %s/%s.cache" % (str(gps)[:5], wd, str(gps))], shell=True)
subprocess.call(["ls %s/* | lalapps_path2cache >> %s/%s.cache" % (str(gps - 100000)[:5], wd, str(gps))], shell=True)
subprocess.call(["ls %s/* | lalapps_path2cache >> %s/%s.cache" % (str(gps + 100000)[:5], wd, str(gps))], shell=True)


# Parse the files in many threads
# FIXME don't hard code this regex
pattern = re.compile("0[0-9]{3}_LLOID$")
pool = multiprocessing.Pool(8)
ce = [gluelal.CacheEntry(l) for l in open("%s/%s.cache" % (wd, str(gps)))]
sngl = []
for res in pool.map(parse, [c for c in ce if gps in c.segment and pattern.match(c.description)]):
	sngl.extend(res)
	

mchirps = numpy.array([mchirp(r.mass1, r.mass2) for r in sngl])
snrs = [r.snr for r in sngl]
chisqs = [r.chisq for r in sngl]
maxchisq, minchisq = max(chisqs), min(chisqs)
maxsnr, minsnr = max(snrs), min(snrs)
maxmchirp, minmchirp = max(mchirps), min(mchirps)


# FIXME don't hardocde
# Get various data by IFO
times, snrs, mchirps, chisqs, ba = {}, {}, {}, {}, {}
for n,ifo in enumerate(["H1", "L1"]):
	times[ifo] = numpy.array([float(r.get_end() - gps) for r in sngl if r.ifo == ifo])
	snrs[ifo] = numpy.array([r.snr for r in sngl if r.ifo == ifo])
	chisqs[ifo] = numpy.array([r.chisq for r in sngl if r.ifo == ifo])
	mchirps[ifo] = numpy.array([mchirp(r.mass1, r.mass2) for r in sngl if r.ifo == ifo])
	ba[ifo] = rate.BinnedArray(rate.NDBins((rate.LinearBins(-600,600,401), rate.LogarithmicBins(minmchirp, maxmchirp, 100))))
	ba[ifo].array[:] = 4.0
	for t,s,m in zip(times[ifo], snrs[ifo], mchirps[ifo]):
		if ba[ifo][t, m] < s:
			ba[ifo][t, m] = s


# SNR heatmap
fig = pyplot.figure(figsize=(9,3.0))
for n, ifo in enumerate(["H1", "L1"]):
	axes = pyplot.subplot(2,1,n+1)
	pyplot.pcolormesh(ba[ifo].centres()[0], ba[ifo].centres()[1], ba[ifo].array.T, cmap='coolwarm', norm=matplotlib.colors.LogNorm(vmin=4, vmax=maxsnr))
	cb = pyplot.colorbar()
	cb.set_label("$\mathrm{SNR}$")
	axes.set_yscale('log')
	pyplot.ylim([minmchirp, maxmchirp])
	if n==1:
		pyplot.xlabel("$\mathrm{Time \, since \, %d}$" % int(gps))
	pyplot.ylabel("$\mathrm{%s} \, \mathcal{M}$ $(M_{\odot})$" % ifo)
pyplot.minorticks_on()
lvalert_helper.upload_fig(fig, gracedb_client, graceid, filename = "snrheatmap.png", log_message = "SNR heat map: Computed from single detector triggers plus/minus 600 seconds around the event", tagname = "background")


# chisq vs SNR
fig = pyplot.figure(figsize=(9, 4.0))
for n, ifo in enumerate(["H1", "L1"]):
	axes = pyplot.subplot(1,2,n+1)
	pyplot.loglog(snrs[ifo], chisqs[ifo], 'k.')
	if n==0:
		pyplot.ylabel("$\chi^2$")
	pyplot.xlabel("$\mathrm{%s \, SNR}$" % ifo)
	pyplot.xlim(minsnr, maxsnr)
	pyplot.ylim(minchisq, maxchisq)
	pyplot.grid()
pyplot.minorticks_on()
lvalert_helper.upload_fig(fig, gracedb_client, graceid, filename = "chisqvsSNR.png", log_message = "Chi-squared vs. SNR: Computed from single detector triggers plus/minus 600 seconds around the event", tagname = "background")
