#!/usr/bin/env python
#
# Copyright (C) 2011  Chad Hanna
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 2 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

from glue.ligolw import lsctables
from glue.ligolw import utils
from glue.ligolw import ligolw
from glue.lal import LIGOTimeGPS as GPS
import sys
import matplotlib
matplotlib.use('Agg')
import pylab
import numpy

xmldoc = utils.load_filename(sys.argv[1], verbose = True)
sim_inspiral_table=lsctables.table.get_table(xmldoc, lsctables.SimInspiralTable.tableName)
proc_params = lsctables.table.get_table(xmldoc, lsctables.ProcessParamsTable.tableName)
start = int([proc.value for proc in proc_params if proc.param == "--gps-start-time"][0])
stop = int([proc.value for proc in proc_params if proc.param == "--gps-end-time"][0])
waveforms = set()
[waveforms.add(sim.waveform) for sim in sim_inspiral_table]

#
# Masses
#

pylab.figure(1)
for waveform in waveforms:
	mass1 = [sim.mass1 for sim in sim_inspiral_table if sim.waveform == waveform]
	mass2 = [sim.mass2 for sim in sim_inspiral_table if sim.waveform == waveform]
	pylab.plot(mass1, mass2, '.', label = waveform)

pylab.legend()
pylab.xlabel('mass1')
pylab.ylabel('mass2')
pylab.grid()
pylab.savefig("mass.png")

#
# sky position
#

ra = [sim.longitude for sim in sim_inspiral_table]
dec = [sim.latitude for sim in sim_inspiral_table]

pylab.figure(2)
pylab.plot(ra, numpy.cos(dec), '.k')
pylab.xlabel('ra')
pylab.ylabel('cos(dec)')
pylab.grid()
pylab.xlim([0, 2 * numpy.pi])
pylab.ylim([-1, 1])
pylab.savefig("sky.png")

#
# inc and pol
#

inc = [sim.inclination for sim in sim_inspiral_table]
pol = [sim.polarization for sim in sim_inspiral_table]

pylab.figure(3)
pylab.plot(pol, numpy.cos(inc), '.k')
pylab.xlabel('pol')
pylab.ylabel('cos(inc)')
pylab.grid()
pylab.xlim([0, 2 * numpy.pi])
pylab.ylim([-1, 1])
pylab.savefig("pol_inc.png")

#
# distance (cumulative)
#

def rate_to_poly(rate):
	return numpy.poly1d([4/3. * numpy.pi * rate, 0 , 0, 0])

pylab.figure(4)
for waveform in waveforms:
	dist = numpy.array([0.] + sorted([sim.distance for sim in sim_inspiral_table if sim.waveform == waveform]))
	num = numpy.arange(len(dist)) / float(stop - start) * 365.25 * 86400 #julian year
	pylab.plot(dist, num, label=waveform)

dist = numpy.array([0.] + sorted([sim.distance for sim in sim_inspiral_table]))

# BNS
bns_rate = [proc.value for proc in proc_params if proc.param == "--bns-local-rate"]
if len(bns_rate) > 0:
	bns_rate = start = float(bns_rate[0]) * 1e-6 # Myr
	poly = rate_to_poly(bns_rate)
	pylab.plot(dist, poly(dist), label="bns rate %.2e" % (bns_rate,) )

# NSBH
nsbh_rate = [proc.value for proc in proc_params if proc.param == "--nsbh-local-rate"]
if len(nsbh_rate) > 0:
	nsbh_rate = start = float(nsbh_rate[0]) * 1e-6 # Myr
	poly = rate_to_poly(nsbh_rate)
	pylab.plot(dist, poly(dist), label="nsbh rate %.2e" % (nsbh_rate,) )

pylab.legend()
pylab.grid()
pylab.ylim([0, 1000])
pylab.ylabel("mergers per year")
pylab.savefig("dist.png")
