#!/usr/bin/env python
#
# Copyright (C) 2013  Kipp Cannon
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 2 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

## @file gstlal_inspiral_lvalert_psd_plotter
# A program to listen to lvalerts, download the psd from gstlal gracedb events, plot it, and upload the results
#
# ### Command line interface
#
#	+ `--no-upload`: Write plots to disk.
#	+ `--skip-404`: Skip events that give 404 (file not found) errors (default is to abort).
#	+ `--verbose`: Be verbose.
#
# =============================================================================
#
#                                   Preamble
#
# =============================================================================
#


import httplib
import logging
import math
import matplotlib
matplotlib.rcParams.update({
	"font.size": 8.0,
	"axes.titlesize": 10.0,
	"axes.labelsize": 10.0,
	"xtick.labelsize": 8.0,
	"ytick.labelsize": 8.0,
	"legend.fontsize": 8.0,
	"figure.dpi": 100,
	"savefig.dpi": 100,
	"text.usetex": True,
	"path.simplify": True
})
from matplotlib import figure
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
import numpy
from optparse import OptionParser
import os.path
import StringIO
import sys
import time
import urlparse


from glue.ligolw import ligolw
from glue.ligolw import array as ligolw_array
from glue.ligolw import param as ligolw_param
from glue.ligolw import lsctables
from glue.ligolw import utils as ligolw_utils
from gstlal.reference_psd import horizon_distance
try:
	from ligo.gracedb import cli as gracedb
except ImportError:
	from ligo import gracedb
from ligo.lvalert.utils import get_LVAdata_from_stdin
from pylal import series as lal_series


golden_ratio = (1 + math.sqrt(5)) / 2


class LIGOLWContentHandler(ligolw.LIGOLWContentHandler):
	pass
ligolw_array.use_in(LIGOLWContentHandler)
ligolw_param.use_in(LIGOLWContentHandler)
lsctables.use_in(LIGOLWContentHandler)


#
# =============================================================================
#
#                                   Library
#
# =============================================================================
#


def get_filename(gracedb_client, graceid, filename, retries = 3, retry_delay = 10.0, ignore_404 = False):
	for i in range(retries):
		logging.info("retrieving \"%s\" for %s" % (filename, graceid))
		response = gracedb_client.files(graceid, filename)
		if response.status == httplib.OK:
			return response
		if response.status == httplib.NOT_FOUND and ignore_404:
			logging.warning("retrieving \"%s\" for %s: (%d) %s.  skipping ..." % (filename, graceid, response.status, response.reason))
			return None
		logging.warning("retrieving \"%s\" for %s: (%d) %s.  pausing ..." % (filename, graceid, response.status, response.reason))
		time.sleep(retry_delay)
	raise Exception("retrieving \"%s\" for %s: (%d) %s" % (filename, graceid, response.status, response.reason))


def get_psds(gracedb_client, graceid, filename = "psd.xml.gz", ignore_404 = False):
	response = get_filename(gracedb_client, graceid, filename = filename, ignore_404 = ignore_404)
	if response is None:
		return None
	return lal_series.read_psd_xmldoc(ligolw_utils.load_fileobj(response, contenthandler = LIGOLWContentHandler)[0])


def get_coinc_xmldoc(gracedb_client, graceid, filename = "coinc.xml"):
	return ligolw_utils.load_fileobj(get_filename(gracedb_client, graceid, filename = filename), contenthandler = LIGOLWContentHandler)[0]


def plot_psds(psds, coinc_xmldoc, plot_width = 640, colours = {"H1": "r", "H2": "b", "L1": "g", "V1": "m"}):
	coinc_event, = lsctables.CoincTable.get_table(coinc_xmldoc)
	coinc_inspiral, = lsctables.CoincInspiralTable.get_table(coinc_xmldoc)
	offset_vector = lsctables.TimeSlideTable.get_table(coinc_xmldoc).as_dict()[coinc_event.time_slide_id] if coinc_event.time_slide_id is not None else None
	# FIXME:  MBTA uploads are missing process table
	#process, = lsctables.ProcessTable.get_table(coinc_xmldoc)
	sngl_inspirals = dict((row.ifo, row) for row in lsctables.SnglInspiralTable.get_table(coinc_xmldoc))

	mass1 = sngl_inspirals.values()[0].mass1
	mass2 = sngl_inspirals.values()[0].mass2
	end_time = coinc_inspiral.get_end()
	logging.info("%g Msun -- %g Msun event in %s at %.2f GPS" % (mass1, mass2, ", ".join(sorted(sngl_inspirals)), float(end_time)))

	fig = figure.Figure()
	FigureCanvas(fig)
	fig.set_size_inches(plot_width / float(fig.get_dpi()), int(round(plot_width / golden_ratio)) / float(fig.get_dpi()))
	axes = fig.gca()
	axes.grid(True)

	min_psds, max_psds = [], []
	for instrument, psd in sorted(psds.items()):
		if psd is None:
			continue
		psd_data = psd.data
		f = psd.f0 + numpy.arange(len(psd_data)) * psd.deltaF
		logging.info("found PSD for %s spanning [%g Hz, %g Hz]" % (instrument, f[0], f[-1]))
		axes.loglog(f, psd_data, color = colours[instrument], alpha = 0.8, label = "%s (%.4g Mpc)" % (instrument, horizon_distance(psd, mass1, mass2, 8, 10)))
		if instrument in sngl_inspirals:
			logging.info("found %s event with SNR %g" % (instrument, sngl_inspirals[instrument].snr))
			inspiral_spectrum = [None, None]
			horizon_distance(psd, mass1, mass2, sngl_inspirals[instrument].snr, 10, inspiral_spectrum = inspiral_spectrum)
			axes.loglog(inspiral_spectrum[0], inspiral_spectrum[1], color = colours[instrument], dashes = (5, 2), alpha = 0.8, label = "SNR = %.3g" % sngl_inspirals[instrument].snr)
		# record the minimum from within the rage 10 Hz -- 1 kHz
		min_psds.append(psd_data[int((10.0 - psd.f0) / psd.deltaF) : int((1000 - psd.f0) / psd.deltaF)].min())
		# record the maximum from within the rage 1 Hz -- 1 kHz
		max_psds.append(psd_data[int((1.0 - psd.f0) / psd.deltaF) : int((1000 - psd.f0) / psd.deltaF)].max())

	axes.set_xlim((1.0, 3000.0))
	if min_psds:
		axes.set_ylim((10**math.floor(math.log10(min(min_psds))), 10**math.ceil(math.log10(max(max_psds)))))
	axes.set_title(r"Strain Noise Spectral Density for $%.3g\,\mathrm{M}_{\odot}$--$%.3g\,\mathrm{M}_{\odot}$ Merger at %.2f GPS" % (mass1, mass2, float(end_time)))
	axes.set_xlabel(r"Frequency (Hz)")
	axes.set_ylabel(r"Spectral Density ($\mathrm{strain}^2 / \mathrm{Hz}$)")
	axes.legend(loc = "lower left")

	return fig


def upload_fig(fig, gracedb_client, graceid, filename = "psd.png"):
	plotfile = StringIO.StringIO()
	fig.savefig(plotfile, format = os.path.splitext(filename)[-1][1:])
	logging.info("uploading \"%s\" for %s" % (filename, graceid))
	response = gracedb_client.writeLog(graceid, "strain spectral densities", filename = filename, filecontents = plotfile.getvalue(), tagname = "psd")
	if response.status != httplib.CREATED:
		raise Exception("upload of \"%s\" for %s failed: %s" % (filename, graceid, response["error"]))


#
# =============================================================================
#
#                                 Command Line
#
# =============================================================================
#


def parse_command_line():
	parser = OptionParser()
	parser.add_option("--no-upload", action = "store_true", help = "Write plots to disk.")
	parser.add_option("--skip-404", action = "store_true", help = "Skip events that give 404 (file not found) errors (default is to abort).")
	parser.add_option("-v", "--verbose", action = "store_true", help = "Be verbose.")
	options, graceids = parser.parse_args()

	if not graceids:
		# FIXME:  lvalert_listen doesn't allow command-line options
		options.verbose = True

	# can only call basicConfig once (otherwise need to switch to more
	# complex logging configuration)
	if options.verbose:
		logging.basicConfig(format = "%(asctime)s:%(message)s", level = logging.INFO)
	else:
		logging.basicConfig(format = "%(asctime)s:%(message)s")

	return options, graceids


#
# =============================================================================
#
#                                     Main
#
# =============================================================================
#


options, graceids = parse_command_line()


if not graceids:
	lvalert_data = get_LVAdata_from_stdin(sys.stdin, as_dict = True)
	logging.info("%(alert_type)s-type alert for event %(uid)s" % lvalert_data)
	logging.info("lvalert data: %s" % repr(lvalert_data))
	filename = os.path.split(urlparse.urlparse(lvalert_data["file"]).path)[-1]
	if filename not in (u"coinc.xml",) and "_CBC_" not in filename:
		logging.info("filename is not 'coinc.xml'.  skipping")
		sys.exit()
	graceids = [str(lvalert_data["uid"])]
	# pause to give the psd file a chance to get uploaded.
	time.sleep(8)


gracedb_client = gracedb.Client()


for graceid in graceids:
	psds = get_psds(gracedb_client, graceid, ignore_404 = options.skip_404)
	if psds is None:
		continue
	fig = plot_psds(psds, get_coinc_xmldoc(gracedb_client, graceid))
	if options.no_upload:
		filename = "psd_%s.png" % graceid
		logging.info("writing %s ..." % filename)
		fig.savefig(filename)
	else:
		upload_fig(fig, gracedb_client, graceid)
	logging.info("finished processing %s" % graceid)
