#! /usr/bin/env python
#
# Copyright (C) 2012 Stephen Privitera
# Copyright (C) 2011-2014 Chad Hanna
# Copyright (C) 2010 Melissa Frei
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 2 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#   
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.


import itertools
import math
import os
import sys
from optparse import OptionParser
from pylal import spawaveform
from glue.ligolw import ligolw
from glue.ligolw import lsctables
from glue.ligolw import utils as ligolw_utils
from glue.ligolw.utils import process as ligolw_process
from lal import MSUN_SI
from glue import lal as gluelal
from pylal.datatypes import LIGOTimeGPS
from gstlal import templates
from gstlal import inspiral_pipe
from gstlal import chirptime

## @file gstlal_bank_splitter
#
# This program splits template banks into sub banks suitable for singular value decomposition; see gstlal_bank_splitter for more information
# 
# ### Usage examples
#
# - split up bank file for H1; sort by mchirp; add final frequency and specify a maximum frequency
#
#		$ gstlal_bank_splitter --overlap 10 --instrument H1 --n 100 --sort-by mchirp --add-f-final --max-f-final 2048 H1-TMPLTBANK-871147516-2048.xml
#
# - Please add more!
#
# ### Command line interface
#
#	+ `--output-path` [path]: Set the path to the directory where output files will be written.  Default is "."
#	+ `--output-cache` [file]: Set the file name for the output cache.
#	+ `--n` [count] (int): Set the number of templates per output file (required).  It will be rounded to make all sub banks approximately the same size.
#	+ `--overlap` [count] (int): Overlap the templates in each file by this amount, must be even.
#	+ `--sort-by` [column]: Select the template sort order column (required).
#	+ `--add-f-final`: Select whether to add f_final to the bank.
#	+ `--max-f-final` [max final freq] (float): Max f_final to populate table with; if f_final > max, use max.
#	+ `--instrument` [ifo]: Override the instrument, required
#	+ `--bank-program` [name]: Select name of the program used to generate the template bank (default: tmpltbank).
#	+ `--verbose`: Be verbose.
#	+ `--approximant` [string]: Must specify an approximant
#	+ `--f-low` [frequency] (floate): Lower frequency cutoff
#	+ `--group-by-chi`: group templates into chi bins 0.1 in chi - helps with SVD.
#
# ### Review status
#
# Compared original bank with the split banks.  Verified that they are the same, e.g., add sub bank files into test.xml.gz and run (except that lalapps_tmpltbank adds redundant templates):
#
#		ligolw_print -t sngl_inspiral -c mass1 -c mass2 ../H1-TMPLTBANK-871147516-2048.xml | sort -u | wc
#		ligolw_print -t sngl_inspiral -c mass1 -c mass2 test.xml.gz | sort -u | wc
#
# | Names 	                                        | Hash 					                   | Date       | Diff to Head of Master      |
# | ----------------------------------------------- | ---------------------------------------- | ---------- | --------------------------- |
# | Florent, Sathya, Duncan Me., Jolien, Kipp, Chad | 7536db9d496be9a014559f4e273e1e856047bf71 | 2014-04-28 | <a href="@gstlal_inspiral_cgit_diff/bin/gstlal_bank_splitter?id=HEAD&id2=7536db9d496be9a014559f4e273e1e856047bf71">gstlal_bank_splitter</a> |
#
# #### Action
#
# - Consider cleanup once additional bank programs are used and perhaps have additional metadata
#

class LIGOLWContentHandler(ligolw.LIGOLWContentHandler):
	pass
lsctables.use_in(LIGOLWContentHandler)


def group_templates(templates, n, overlap = 0):
	"""
	break up the template table into sub tables of length n with overlap
	overlap.  n must be less than the number of templates and overlap must be less
	than n
	"""
	if n >= len(templates):
		yield templates
	else:
		n = len(templates) / round(len(templates) / float(n))
		assert n >= 1
		for i in itertools.count():
			start = int(round(i * n)) - overlap // 2
			end = int(round((i + 1) * n)) + overlap // 2
			yield templates[max(start, 0):end]
			if end >= len(templates):
				break


def parse_command_line():
	parser = OptionParser()
	parser.add_option("--output-path", metavar = "path", default = ".", help = "Set the path to the directory where output files will be written.  Default is \".\".")
	parser.add_option("--output-cache", metavar = "file", help = "Set the file name for the output cache.")
	parser.add_option("--n", metavar = "count", type = "int", help = "Set the number of templates per output file (required). It will be rounded to make all sub banks approximately the same size.")
	parser.add_option("--overlap", default = 0, metavar = "count", type = "int", help = "overlap the templates in each file by this amount, must be even")
	parser.add_option("--sort-by", metavar = "column", default="mchirp", help = "Select the template sort column, default mchirp")
	parser.add_option("--add-f-final", action = "store_true", help = "Select whether to add f_final to the bank.")
	parser.add_option("--max-f-final", metavar = "float", type="float", help = "Max f_final to populate table with; if f_final over mx, use max.")
	parser.add_option("--instrument", metavar = "ifo", type="string", help = "override the instrument, required")
	parser.add_option("--bank-program", metavar = "name", default = "tmpltbank", type="string", help = "Select name of the program used to generate the template bank (default: tmpltbank).")
	parser.add_option("-v", "--verbose", action = "store_true", help = "Be verbose.")
	parser.add_option("--approximant", type = "string", help = "Must specify an approximant")
	parser.add_option("--f-low", type = "float", metavar = "frequency", help = "Lower frequency cutoff")
	parser.add_option("--group-by-chi", action = "store_true", help = "group templates into chi bins 0.1 in chi - helps with SVD.")
	options, filenames = parser.parse_args()

	required_options = ("n", "instrument", "sort_by", "output_cache", "approximant")
	missing_options = [option for option in required_options if getattr(options, option) is None]
	if missing_options:
		raise ValueError, "missing required option(s) %s" % ", ".join("--%s" % option.replace("_", "-") for option in missing_options)

	if options.overlap % 2:
		raise ValueError("overlap must be even")

	return options, filenames

options, filenames = parse_command_line()
output_cache_file = open(options.output_cache, "w")
bank_count = 0

# Hacky way to bin by chi
def chikey(chi, group_by_chi = options.group_by_chi):
	if group_by_chi:
		return math.floor(chi * 10) / 10.
	else:
		return None

outputrows = []

for filename in filenames:
	xmldoc = ligolw_utils.load_filename(filename, verbose = options.verbose, contenthandler = LIGOLWContentHandler)
	sngl_inspiral_table = lsctables.SnglInspiralTable.get_table(xmldoc)

	if options.add_f_final:
		if options.f_low is None:
			flow, = ligolw_process.get_process_params(xmldoc, options.bank_program, "--flow") + ligolw_process.get_process_params(xmldoc, options.bank_program, "--low-frequency-cutoff") + ligolw_process.get_process_params(xmldoc, options.bank_program, "--f-low")
		else:
			flow = options.f_low
		for row in sngl_inspiral_table:
			# Find the total spin magnitudes
			spin1, spin2 = (row.spin1x**2 + row.spin1y**2 + row.spin1z**2)**.5, (row.spin2x**2 + row.spin2y**2 + row.spin2z**2)**.5
			# Chirptime uses SI
			m1_SI, m2_SI = MSUN_SI * row.mass1, MSUN_SI * row.mass2

			if options.approximant in templates.gstlal_IMR_approximants:
				# make sure to go a factor of 2 above the ringdown frequency for safety
				row.f_final = 2 * chirptime.ringf(m1_SI + m2_SI, chirptime.overestimate_j_from_chi(max(spin1, spin2)))
			else:
				# otherwise choose a suitable high frequency
				# NOTE not SI
				row.f_final = spawaveform.ffinal(row.mass1, row.mass2, 'bkl_isco')

			# Override the high frequency with the max if appropriate
			if options.max_f_final and (row.f_final > options.max_f_final):
				row.f_final = options.max_f_final

			# Record the conservative template duration
			row.template_duration = chirptime.imr_time(flow, m1_SI, m2_SI, spin1, spin2, f_max = row.f_final)


	for row in sngl_inspiral_table:
		row.ifo = options.instrument

	# just to make sure it is set
	for row in sngl_inspiral_table:
		row.mtotal = row.mass1 + row.mass2
	
	# Bin by Chi, has no effect if option is not specified, i.e. there is only one bin.
	chidict = {}
	[chidict.setdefault(chikey(spawaveform.computechi(row.mass1, row.mass2, row.spin1z, row.spin2z)), []).append(row) for row in sngl_inspiral_table]

	for chi in chidict:
		chirows = chidict[chi]

		# store the process params
		process = ligolw_process.register_to_xmldoc(xmldoc, program = "gstlal_bank_splitter", paramdict = options.__dict__, comment = "Split bank into smaller banks for SVD")

		def sort_func(row, column = options.sort_by):
			return getattr(row, column)

		chirows.sort(key=sort_func)


		for rows in group_templates(chirows, options.n, options.overlap):
			assert len(rows) >= options.n/2, "There are too few templates in this chi interval.  Requested %d: have %d" % (options.n, len(rows))
			outputrows.append((rows[0], rows))


# One last sort now that the templates have been grouped
def sort_func((row, rows), column = options.sort_by):
	return getattr(row, column)

outputrows.sort(key=sort_func)

for bank_count, (row, rows) in enumerate(outputrows):
	sngl_inspiral_table[:] = rows
	output = inspiral_pipe.T050017_filename(options.instrument, "GSTLAL_SPLIT_BANK_%04d" % bank_count, 0, 0, ".xml.gz", path = options.output_path)
	output_cache_file.write("%s\n" % gluelal.CacheEntry.from_T050017("file://localhost%s" % os.path.abspath(output)))
	ligolw_utils.write_filename(xmldoc, output, gz = output.endswith('gz'), verbose = options.verbose)

output_cache_file.close()
