#! /usr/bin/env python
#
# Copyright (C) 2012 Stephen Privitera
# Copyright (C) 2011 Chad Hanna
# Copyright (C) 2010 Melissa Frei
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 2 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#   
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.


import os
import sys
import numpy
import copy
from optparse import OptionParser
from pylal import spawaveform
from glue.ligolw import ligolw
from glue.ligolw import lsctables
from glue.ligolw import utils
from glue.ligolw.utils import process as ligolw_process
from pylal.datatypes import LIGOTimeGPS

## @file gstlal_bank_splitter
#
# This program splits template banks into sub banks suitable for singular value decomposition; see gstlal_bank_splitter for more information

## @package gstlal_bank_splitter
# 
# ### Usage examples
#
# - split up bank file for H1; sort by mchirp; add final frequency and specify a maximum frequency
#
#		$ gstlal_bank_splitter --overlap 10 --instrument H1 --n 100 --sort-by mchirp --add-f-final --max-f-final 2048 H1-TMPLTBANK-871147516-2048.xml
#
# - Please add more!
#
# ### Command line interface
#
#	+ `--output-path` [path]: Set the path to the directory where output files will be written.  Default is "."
#	+ `--n` [count] (int): Set the number of templates per output file (required).
#	+ `--overlap` [count] (int): Overlap the templates in each file by this amount, must be even.
#	+ `--sort-by` [mchirp|mtotal|ffinal|chirptime]: Select the template sort order (required).
#	+ `--add-f-final`: Select whether to add f_final to the bank.
#	+ `--max-f-final` [max final freq] (float): Max f_final to populate table with; if f_final > max, use max.
#	+ `--instrument` [ifo]: Override the instrument, required
#	+ `--bank-program` [name]: Select name of the program used to generate the template bank (default: tmpltbank).
#	+ `--verbose`: Be verbose.
#


def group_templates(templates, n, overlap = 0):
	"""
	break up the template table into sub tables of length n with overlap
	overlap.  n must be less than the number of templates and overlap must be less
	than n
	"""
	if n >= len(templates):
		yield templates
	else:
		end = 0
		start = 0
		assert overlap < n
		while end < len(templates):
			end = start + n + overlap
			yield templates[start:end]
			start += n


def parse_command_line():
        parser = OptionParser()
        parser.add_option("-o", "--output-path", metavar = "path", default = ".", help = "Set the path to the directory where output files will be written.  Default is \".\".")
        parser.add_option("-n", "--n", metavar = "count", type = "int", help = "Set the number of templates per output file (required).")
        parser.add_option("-O", "--overlap", default = 0, metavar = "count", type = "int", help = "overlap the templates in each file by this amount, must be even")
        parser.add_option("-s", "--sort-by", metavar = "{mchirp|mtotal|ffinal|chirptime}", help = "Select the template sort order (required).")
	parser.add_option("-F", "--add-f-final", action = "store_true", help = "Select whether to add f_final to the bank.")
	parser.add_option("-M", "--max-f-final", metavar = "float", type="float", help = "Max f_final to populate table with; if f_final over mx, use max.")
	parser.add_option("-i", "--instrument", metavar = "ifo", type="string", help = "override the instrument, required")
	parser.add_option("--bank-program", metavar = "name", default = "tmpltbank", type="string", help = "Select name of the program used to generate the template bank (default: tmpltbank).")
        parser.add_option("-v", "--verbose", action = "store_true", help = "Be verbose.")
        options, filenames = parser.parse_args()

        required_options = ("n", "sort_by", "instrument")
        missing_options = [option for option in required_options if getattr(options, option) is None]
        if missing_options:
                raise ValueError, "missing required option(s) %s" % ", ".join("--%s" % option.replace("_", "-") for option in missing_options)

        if options.sort_by not in ("mchirp", "ffinal", "chirptime", "mtotal"):
                raise ValueError, "unrecognized --sort-by \"%s\"" % options.sort_by

        if len(filenames) != 1:
                raise ValueError, "must provide exactly one filename"
	
	if options.overlap % 2:
		raise ValueError("overlap must be even")

        return options, filenames[0]

options, filename = parse_command_line()
# FIXME use all the options
options_params=[("--filename","string",filename),("--output_path","string", options.output_path), ("--n","int",options.n),("--sort-by","string",options.sort_by),("--add-f-final","string",options.add_f_final),("--max-f-final","float",options.max_f_final)]


xmldoc=utils.load_filename(filename, verbose = options.verbose)
sngl_inspiral_table=lsctables.table.get_table(xmldoc, lsctables.SnglInspiralTable.tableName)
process_params_table = lsctables.table.get_table(xmldoc, lsctables.ProcessParamsTable.tableName)
tmpltbank_process_ids = lsctables.table.get_table(xmldoc, lsctables.ProcessTable.tableName).get_ids_by_program(options.bank_program)

if options.add_f_final:
	for row in process_params_table:
		if row.process_id in tmpltbank_process_ids and row.param=='--approximant':
			approximant=row.value
		#FIXME there should be a check on process ids, but other programs may modify the template bank after the low frequency cutoff is determined
		if row.param in ("--low-frequency-cutoff", "--flow"):
			flow = float(row.value)
	for row in sngl_inspiral_table:
		chi = (row.mass1*row.spin1z + row.mass2*row.spin2z)/(row.mass1+row.mass2)
		if approximant in ['IMRPhenomB', 'EOBNRv2']:
			row.f_final= 2 * spawaveform.imrffinal(row.mass1,row.mass2,chi) # over sample these waveforms
			if options.max_f_final and row.f_final>options.max_f_final:
				row.f_final=options.max_f_final
			row.template_duration = spawaveform.chirptime(row.mass1,row.mass2,7,flow,row.f_final,chi) + 100 * (row.mass1 + row.mass2) * 5e-6 # 100 M in seconds for plenty of ringdown padding
		else:
			row.f_final=spawaveform.ffinal(row.mass1,row.mass2,'bkl_isco')
			if options.max_f_final and (row.f_final > options.max_f_final):
				row.set_end(LIGOTimeGPS(spawaveform.chirptime(row.mass1,row.mass2,7,options.max_f_final, row.f_final,chi)))
				row.f_final=options.max_f_final
			row.template_duration = spawaveform.chirptime(row.mass1,row.mass2,7,flow,row.f_final,chi)

for row in sngl_inspiral_table:
	row.ifo = options.instrument

# just to make sure it is set
for row in sngl_inspiral_table:
	row.mtotal = row.mass1 + row.mass2

if options.sort_by=='mchirp':
	sngl_inspiral_table.sort(lambda a, b: cmp(a.mchirp, b.mchirp))
if options.sort_by=='ffinal':
       	sngl_inspiral_table.sort(lambda a, b: cmp(a.f_final, b.f_final))
if options.sort_by=='chirptime':
       	sngl_inspiral_table.sort(lambda a, b: cmp(a.template_duration, b.template_duration))
if options.sort_by=='mtotal':
       	sngl_inspiral_table.sort(lambda a, b: cmp(a.mtotal, b.mtotal))

# prepare the replacement sngl inspiral table
sngl_inspiral_table_split = lsctables.table.new_from_template(sngl_inspiral_table)
sngl_inspiral_table.parentNode.replaceChild(sngl_inspiral_table_split, sngl_inspiral_table)

# store the process params
process = ligolw_process.append_process(xmldoc, program = "bank_splitter", comment = "split bank into smaller banks after sorting", ifos = None)
ligolw_process.append_process_params(xmldoc, process, options_params)	

for i, rows in enumerate(group_templates(sngl_inspiral_table, options.n, options.overlap)):
	sngl_inspiral_table_split[:] = rows
	output = os.path.join(options.output_path, "%04d-%s_split_bank-%s" % (i, options.instrument, os.path.basename(filename)))
	utils.write_filename(xmldoc, output, gz = output.endswith('gz'), verbose = options.verbose)
