#!/usr/bin/env python
"""
Compute the expected signal-to-noise ratios of precessing inspiral signals
described in an injection file (as generated by lalapps_inspinj).  It overrides
the input waveforms from the injection file and always uses SpinTaylorT4 from
lalsimulation.  Also, full spin interactions are always turned on.

Perfect match between a template and the actual waveform is assumed.

The noise model can be optionally specified with the --noise-model argument.
The detector (default: H1) can be changed with the --instrument argument.

Results are printed as newline-delimited ASCII in the same order as the
injections.
"""
__author__ = "Leo Singer <leo.singer@ligo.org>"

import re
import sys
import os.path
import numpy as np
import swiglal as lal
import swiglalsimulation as lalsimulation
from glue.ligolw import utils, lsctables
from optparse import Option, OptionParser

# Find list of all noise models by inspection list of functions in lalsimulation
expr = re.compile("^XLALSimNoisePSD(.LIGO.*)$")
noisemodels = sorted(match.group(1) for match in
	(expr.match(func) for func in dir(lalsimulation)) if match)

# Find list of available instruments
instruments = sorted([det.frDetector.prefix for det in lal.lalCachedDetectors])

# Command line interface
opts, args = OptionParser(
	description = __doc__,
	usage = "%prog [options] [INPUT] [-o OUTPUT]",
	option_list = [
		Option("--instrument", default="H1", choices=instruments,
			metavar="|".join(instruments),
			help="Instrument (default=H1)"),
		Option("--noise-model", default="aLIGOZeroDetHighPower",
			choices=noisemodels, metavar="|".join(noisemodels),
			help="Noise model (default=aLIGOZeroDetHighPower)"),
		Option("--sample-rate", type=float, default=4096.,
			help="Sample rate (default=4096)"),
		Option("--output", "-o",
			help="Output file (default=stdout)")
	]
).parse_args()

if len(args) == 0:
	input_filename = None
elif len(args) == 1:
	input_filename = args[0]
else:
	raise ValueError("Too many command line arguments!")

if opts.output is None:
	output = sys.stdout
else:
	output = open(opts.output, 'w')

# Read siminspiral table
siminspiral = lsctables.table.get_table(utils.load_filename(input_filename,
	gz = (input_filename and os.path.splitext(input_filename)[1] == '.gz'),
	), lsctables.SimInspiralTable.tableName)

# Choose detector
detector = lalsimulation.XLALDetectorPrefixToLALDetector(opts.instrument)

# Choose noise model function
psdfunc = getattr(lalsimulation, "XLALSimNoisePSD" + opts.noise_model)

# Create unit for strain amplitude spectral density
rootHertz = lal.LALUnit()
inverseRootHertz = lal.LALUnit()
lal.XLALUnitSqrt(rootHertz, lal.lalHertzUnit)
lal.XLALUnitInvert(inverseRootHertz, rootHertz)

Mpc = 1e6 * lal.LAL_PC_SI

def abs2(a):
	"""Compute |z|**2 for a complex number z"""
	return a.real * a.real + a.imag * a.imag

for row in siminspiral:

	t0 = lal.LIGOTimeGPS(row.geocent_end_time, row.geocent_end_time_ns)

	# Generate waveform in time domain
	hplus, hcross = lalsimulation.XLALSimInspiralSpinTaylorT4(
		row.coa_phase,              # GW phase at coalescence (rad)
		1,                          # tail gauge term (default = 1)
		1./opts.sample_rate,        # sampling interval (s)
		row.mass1 * lal.LAL_MSUN_SI,# mass of companion 1 (kg)
		row.mass2 * lal.LAL_MSUN_SI,# mass of companion 2 (kg)
		row.f_lower,                # start GW frequency (Hz)
		row.distance * Mpc,         # distance of source (m)
		row.spin1x,                 # initial value of S1x
		row.spin1y,                 # initial value of S1y
		row.spin1z,                 # initial value of S1z
		row.spin2x,                 # initial value of S2x
		row.spin2y,                 # initial value of S2y
		row.spin2z,                 # initial value of S2z
		np.sin(row.inclination),    # initial value of LNhatx
		0,                          # initial value of LNhaty
		np.cos(row.inclination),    # initial value of LNhatz
		np.cos(row.inclination),    # initial value of E1x
		0,                          # initial value of E1y
		-np.sin(row.inclination),   # initial value of E1z
		0,                          # tidal deformability of mass 1
		0,                          # tidal deformability of mass 2
		lalsimulation.LAL_SIM_INSPIRAL_SPIN_ORDER_ALL, # flags to control spin effects
		lalsimulation.LAL_SIM_INSPIRAL_TIDAL_ORDER_ALL, # flags to control tidal effects
		7,                          # twice PN phase order
		3
	)

	hplus.epoch = t0 - len(hplus.data.data) * hplus.deltaT
	hcross.epoch = t0 - len(hplus.data.data) * hplus.deltaT

	# Project hplus and hcross onto detector's antenna pattern
	h = lalsimulation.XLALSimDetectorStrainREAL8TimeSeries(hplus, hcross,
		row.longitude, row.latitude, row.polarization, detector)

	# Throw away hplus and hcross.
	# Note: no need for these explicit dels, except to limit
	# maximum memory footprint inside the for loop.
	del hplus, hcross

	# Generate FFT plan
	plan = lal.XLALCreateREAL8FFTPlan(len(h.data.data), True, 0)

	# Create frequency series to store result of FFT
	hf = lal.XLALCreateCOMPLEX16FrequencySeries("", lal.LIGOTimeGPS(0), 0., 0.,
		lal.lalDimensionlessUnit, len(h.data.data) // 2 + 1)

	# Execute FFT
	lal.XLALREAL8TimeFreqFFT(hf, h, plan)

	# Throw away h
	del h

	# Throw away FFT plan (FIXME: make SWIG do this automatically)
	lal.XLALDestroyREAL8FFTPlan(plan)
	del plan

	# Create amplitude spectral density of detector's noise
	psd = lal.XLALCreateREAL8FrequencySeries("", hf.epoch, hf.f0, hf.deltaF,
		inverseRootHertz, len(hf.data.data))
	freqs = hf.f0 + np.arange(len(psd.data.data)) * hf.deltaF
	for i, f in enumerate(freqs):
		if f < row.f_lower:
			psd.data.data[i] = 0
		else:
			psd.data.data[i] = psdfunc(f)


	# Whiten h(f)
	lal.XLALWhitenCOMPLEX16FrequencySeries(hf, psd)

	# Throw away PSD
	del psd

	# Take quadrature sum of frequency components from row.f_lower and above
	print >>output, np.sqrt(2 * np.sum(abs2(hf.data.data[freqs >= row.f_lower])))
