#!/usr/bin/env python
#
# Copyright (C) 2011 Chad Hanna
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 2 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

from optparse import OptionParser, Option
from scipy import random
import numpy
from glue.ligolw import lsctables
from glue.ligolw import utils
from glue.ligolw import ligolw
from glue.lal import LIGOTimeGPS as GPS
from glue.ligolw.utils import process
from pylal.antenna import response
import sys

def uniform_dec(num):
	return (numpy.pi / 2.) - numpy.arccos(2 * random.random_sample(num) - 1)

def uniform_theta(num):
	return numpy.arccos(2 * random.random_sample(num) - 1)

def uniform_phi(num):
	return random.random_sample(num) * 2 * numpy.pi

def uniform_interval(interval, num):
	# http://docs.scipy.org/doc/numpy/reference/generated/numpy.random.random_sample.html#numpy.random.random_sample
	return (interval[1] - interval[0]) * random.random_sample(num) + interval[0]

class spin(object):
	def __init__(self, r, theta, phi):
		self.x = r * numpy.cos(phi) * numpy.sin(theta)
		self.y = r * numpy.sin(phi) * numpy.sin(theta)
		self.z = r * numpy.cos(theta)


class source(object):
	def __init__(self, waveform, m1range, m2range, s1range, s2range, rate, distance, tstart, tstop):
		self.waveform = waveform
		self.m1range = m1range
		self.m2range = m2range
		self.s1range = s1range
		self.s2range = s2range
		self.rate = rate
		self.distance = distance
		self.tstart = tstart
		self.tstop = tstop
		self.Myrs = (tstop - tstart) / 365.25 / 86400 / 1.e6 # Julian years
		maxdist = self.distance
		volume = 4./3. * numpy.pi * maxdist**3
		rate = self.rate
		expnum = numpy.ceil(volume * self.Myrs * self.rate)
		#redo the max distance based on an integer number of injections
		maxdist = (expnum / self.Myrs / self.rate  / (4./3. * numpy.pi))**(1./3.)
		self.distance = maxdist
		self.expnum = expnum

	def uniform_sky(self):
		expnum = self.expnum
		ra = uniform_phi(expnum)
		dec = uniform_dec(expnum)
		pol = uniform_phi(expnum)
		inc = uniform_theta(expnum)
		phase = uniform_phi(expnum)
		return ra, dec, pol, inc, phase

	def uniform_time(self):
		return random.randint(self.tstart, self.tstop, self.expnum) + random.rand(self.expnum)
	
	def uniform_masses(self):
		mass1 = uniform_interval(self.m1range, self.expnum)
		mass2 = uniform_interval(self.m2range, self.expnum)
		return mass1, mass2
	
	def volume_distributed_distances(self):
		return random.power(3, self.expnum) * self.distance # 3 because it is this number minus 1 according to scipy doc

	def uniform_spins(self):
		# FIXME uniform in r not in the sphere, what is desired?
		r1 = uniform_interval(self.s1range, self.expnum)
		# FIXME, lal actually expects standard spherical coordinates in radians?
		theta1 = uniform_theta(self.expnum)
		phi1 = uniform_phi(self.expnum)
		r2 = uniform_interval(self.s2range, self.expnum)
		theta2 = uniform_theta(self.expnum)
		phi2 = uniform_phi(self.expnum)
		return spin(r1, theta1, phi1), spin(r2, theta2, phi2)

	def aligned_spins(self):
		# FIXME uniform in r not in the sphere, what is desired?
		r1 = uniform_interval(self.s1range, self.expnum)
		# FIXME, lal actually expects standard spherical coordinates in radians?
		r2 = uniform_interval(self.s2range, self.expnum)
		return spin(r1, 0, 0), spin(r2, 0, 0)


parser = OptionParser(description = __doc__)
parser.add_option("--gps-start-time", metavar = "seconds", type = "int", help = "Set the start time of the segment to analyze in GPS seconds (required).  Can be specified to nanosecond precision.")
parser.add_option("--gps-end-time", metavar = "seconds", type = "int", help = "Set the end time of the segment to analyze in GPS seconds (required).  Can be specified to nanosecond precision.")
parser.add_option("--bns-waveform", metavar = "str", default = "SpinTaylorT4threePointFivePN", help = "Set the name of the bns waveform, default = TaylorT4threePointFivePN")
parser.add_option("--nsbh-waveform", metavar = "str", default = "SpinTaylorT4threePointFivePN", help = "Set the name of the nsbh waveform, default = SpinTaylorT4threePointFivePN")
parser.add_option("--bbh-waveform", metavar = "str", default = "EOBNRv2HMpseudoFourPN", help = "Set the name of the bbh waveform, default = EOBNRv2HMpseudoFourPN")
parser.add_option("--bns-local-rate", metavar = "mergers / Mpc^3 / Myr", default = 1.0, type = "float", help = "set the local merger rate in mergers / Mpc^3 / Myr. Setting to 0.0 or less will disable the BNS distribution - default = 1.0")
parser.add_option("--nsbh-local-rate", metavar = "mergers / Mpc^3 / Myr", default = .03, type = "float", help = "set the local merger rate in mergers / Mpc^3 / Myr. Setting to 0.0 or less will disable the NSBH distribution - default = 0.03")
parser.add_option("--bbh-local-rate", metavar = "mergers / Mpc^3 / Myr", default = .005, type = "float", help = "set the local merger rate in mergers / Mpc^3 / Myr. Setting to 0.0 or less will disable the BBH distribution - default = 0.005")
parser.add_option("--output", metavar = "filename", help = "Set the name of the LIGO light-weight XML file to output")
parser.add_option("--bns-max-distance", metavar = "Mpc", type = "float", default = 1000., help = "set the bns maximum distance (respected approximately) default = 1000")
parser.add_option("--nsbh-max-distance", metavar = "Mpc", type = "float", default = 5000., help = "set the nsbh maximum distance (respected approximately) default = 5000")
parser.add_option("--bbh-max-distance", metavar = "Mpc", type = "float", default = 5000., help = "set the bbh maximum distance (respected approximately) default = 5000")
parser.add_option("--ns-min-mass", metavar = "Msun", type = "float", default = 1.0, help = "set the minimum mass of the neutron star, default = 1.0")
parser.add_option("--ns-max-mass", metavar = "Msun", type = "float", default = 2.0, help = "set the maximum mass of the neutron star, default = 2.0")
parser.add_option("--ns-min-spin", metavar = "dimensionless", type = "float", default = 0., help = "set the min neutron star spin, default = 0")
parser.add_option("--ns-max-spin", metavar = "dimensionless", type = "float", default = 0., help = "set the max neutron star spin, default = 0")
parser.add_option("--bh-min-mass", metavar = "Msun", type = "float", default = 5.0, help = "set the minimum mass of the black hole, default = 5.0")
parser.add_option("--bh-max-mass", metavar = "Msun", type = "float", default = 20.0, help = "set the maximum mass of the black hole, default = 20.0")
parser.add_option("--bh-min-spin", metavar = "dimensionless", type = "float", default = 0., help = "set the min black hole spin, default = 0")
parser.add_option("--bh-max-spin", metavar = "dimensionless", type = "float", default = 1., help = "set the max black hole spin, default = 1")
parser.add_option("--aligned-spin", action="store_true", default = False, help = "Align the bh spin direction with the z-direction.")
parser.add_option("--flower", metavar = "float", type = "float", help = "set the minumum frequency")
parser.add_option("--seed", metavar = "int", type = "int", default = 1, help = "set the random seed default = 1")

#
# Parse options
#

options, filenames = parser.parse_args()

#
# Setup the output document
#

xmldoc = ligolw.Document()
lw = xmldoc.appendChild(ligolw.LIGO_LW())
sim = lsctables.New(lsctables.SimInspiralTable)
lw.appendChild(sim)
procrow = process.register_to_xmldoc(xmldoc, "gstlal_injections_by_local_rate", options.__dict__)

# Setup some global parameters
stop, start = options.gps_end_time, options.gps_start_time
random.seed(options.seed)

# set up the types of sources
srcs = []
if options.bns_local_rate > 0.0:
	srcs.append(source(options.bns_waveform, (options.ns_min_mass, options.ns_max_mass), (options.ns_min_mass, options.ns_max_mass), (options.ns_min_spin, options.ns_max_spin), (options.ns_min_spin, options.ns_max_spin), options.bns_local_rate, options.bns_max_distance, start, stop))
if options.nsbh_local_rate > 0.0:
	srcs.append(source(options.nsbh_waveform, (options.ns_min_mass,options.ns_max_mass), (options.bh_min_mass, options.bh_max_mass), (options.ns_min_spin, options.ns_max_spin), (options.bh_min_spin, options.bh_max_spin), options.nsbh_local_rate, options.nsbh_max_distance, start, stop))
if options.bbh_local_rate > 0.0:
	srcs.append(source(options.bbh_waveform, (options.bh_min_mass,options.bh_max_mass), (options.bh_min_mass, options.bh_max_mass), (options.bh_min_spin, options.bh_max_spin), (options.bh_min_spin, options.bh_max_spin), options.bbh_local_rate, options.bbh_max_distance, start, stop))

for src in srcs:

	# injection parameters
	time = src.uniform_time()
	mass1, mass2 = src.uniform_masses()
	dist = src.volume_distributed_distances()
	if options.aligned_spin:
		spin1, spin2 = src.aligned_spins()
	else:
		spin1, spin2 = src.uniform_spins()
	ra, dec, pol, inc, phase = src.uniform_sky()

	#FIXME only set the "needed" columns. Hopefully these are the needed ones
	for i,t in enumerate(time):
		row = sim.RowType()
	
		# string paramters
		row.waveform = src.waveform
		row.source = ""
		row.numrel_data = ""
		row.taper = "TAPER_START"

		# time parameters
		row.set_time_geocent(GPS(float(t)))
		# shouldn't be needed as injection code recomputes these
		row.h_end_time = row.h_end_time_ns = row.l_end_time = row.l_end_time_ns = row.v_end_time = row.v_end_time_ns = row.g_end_time = row.g_end_time_ns = row.t_end_time = row.t_end_time_ns = row.end_time_gmst = 0

		# masses
		row.mass2 = mass2[i]
		row.mass1 = mass1[i]
		row.eta = row.mass1 * row.mass2 / (row.mass1 + row.mass2)**2
		row.mchirp = (row.mass1 * row.mass2)**(3./5.) / (row.mass1 + row.mass2)**(1./5.)
		row.psi0 = row.psi3 = 0.0
	
		# location / orientation
		row.distance = dist[i]
		row.longitude = ra[i]
		row.latitude = dec[i]
		row.coa_phase = phase[i]
		row.inclination = inc[i]
		row.polarization = pol[i]
		# set effective distances
		for det,s in (("H1","h"), ("L1","l"), ("V1","v"), ("G1","g"), ("T1","t")):
			hp, hc, ha, q =  response(row.geocent_end_time, row.longitude, row.latitude, row.inclination, row.polarization, "radians", det)
			setattr(row, "eff_dist_%s" % s, row.distance / q)

		# spins
		row.spin1x = spin1.x[i]
		row.spin1y = spin1.y[i]
		row.spin1z = spin1.z[i]
		row.spin2x = spin2.x[i]
		row.spin2y = spin2.y[i]
		row.spin2z = spin2.z[i]

		# frequencies
		row.f_lower = options.flower
		row.f_final = 0

		# misc
		row.amp_order = -1
		row.alpha = row.alpha1 = row.alpha2 = row.alpha3 = row.alpha4 = row.alpha5 = row.alpha6 = row.beta = row.theta0 = row.phi0 = row.numrel_mode_min = row.numrel_mode_max = row.bandpass = 0
		row.simulation_id = sim.get_next_id()
		row.process_id = procrow.process_id	
		sim.append(row)

utils.write_filename(xmldoc, options.output)
