#!/usr/bin/env python

# Copyright (C) 2012 Duncan M. Macleod
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

"""Veto triggers from a MultiInspiralTable based on signal parameters.
"""

from __future__ import division

import sys
import optparse
import numpy
import warnings
from os.path import basename
from time import time as time_now
warnings.filterwarnings("ignore", "column name", UserWarning)

from pylal import (git_version, coh_PTF_pyutils)
from lal import GPSTimeNow

from glue.ligolw import (ligolw, table, lsctables, utils as ligolw_utils)
from glue.ligolw.utils import (process as ligolw_process,
                               search_summary as ligolw_search_summary)

# set up metadata
__author__ = "Duncan Macleod <duncan.macleod@ligo.org>"
__version__ = git_version.id
__date__ = git_version.date

PROFILE = False
VERBOSE = False


start = time_now()
elapsed_time = lambda: time_now()-start
# print verbose messages
def print_verbose(message, verbose=True, stream=sys.stdout, profile=True):
    """Print verbose messages to a file stream.

    @param message
        text to print
    @param verbose
        flag to print or not, default: False (don"t print)
    @param stream
        file object stream in which to print
    @param profile
        flag to print timestamp, default: False
    """
    if stream != sys.stderr:
        profile &= PROFILE
        verbose &= VERBOSE
    if profile and message.endswith("\n"):
        message = "%s (%.2f)\n" % (message.rstrip("\n"), elapsed_time())
    if verbose:
        stream.write(message)
        stream.flush()


# verify command lien thresholds
def parse_threshold(option, opt_str, value, parser):
    """Verify threshold argument is valid on callback.
    """
    if value <= 0:
        parser.error("%s must be positive" % opt_str)
    if opt_str == "--sngl-snr":
        setattr(parser.values, option.dest, list(map(float, value.split(","))))
    else:
        setattr(parser.values, option.dest, float(value))


if __name__ == "__main__":

    usage = "%prog [options] [file ...]"
    parser = optparse.OptionParser(usage=usage, description=__doc__[1:],
                                   formatter=optparse.IndentedHelpFormatter(4))
    parser.add_option("-p", "--profile", action="store_true", default=False,
                      help="timestamp output, default: %default")
    parser.add_option("-v", "--verbose", action="store_true", default=False,
                      help="show verbose output, default: %default")
    parser.add_option("-V", "--version", action="version",
                      help="show program's version number and exit")
    parser.version = git_version.verbose_msg

    vetoopts = optparse.OptionGroup(parser, "Veto options")
    vetoopts.add_option("-s", "--snr", action="callback",
                        callback=parse_threshold, type="float",
                        default=None, help=("minimal (coherent) SNR "
                                           "threshold, default: %default"))
    vetoopts.add_option("-r", "--new-snr", action="callback",
                        callback=parse_threshold, type="float",
                        default=None, help=("minimal chisq weighted SNR "
                                            "threshold, default: %default"))
    vetoopts.add_option("-b", "--bank-new-snr", action="callback",
                        callback=parse_threshold, type="float",
                        default=None, help=("minimal bank-chisq weighted SNR "
                                            "threshold, default: %default"))
    vetoopts.add_option("-c", "--cont-new-snr", action="callback",
                        callback=parse_threshold, type="float",
                        default=None, help=("minimal cont-chisq weighted SNR "
                                            "threshold, default: %default"))
    vetoopts.add_option("-n", "--null-snr", action="callback",
                        callback=parse_threshold, type="float",
                        default=None, help=("maximal null SNR threshold, "
                                            "default: %default"))
    vetoopts.add_option("-x", "--sngl-snr", action="callback", type="string",
                        default=None, callback=parse_threshold,
                        help=("comma-separated list of minimal "
                              "single-detector SNR threshold. Give "
                              "N values to apply to the N most sensitive "
                              "instruments"))
    parser.add_option_group(vetoopts)

    outopts = optparse.OptionGroup(parser, "Input options")
    outopts.add_option("-o", "--output-file", action="store",
                      default=False, help="output xml file, default: stdout")
    parser.add_option_group(outopts)

    # parse args
    opts, args = parser.parse_args()

    # verify xmlfiles
    if len(args) > 1:
        raise parser.error("Please give only one input XML file (you gave %d)"
                           % len(args))
    xmlfile = args[0]

    VERBOSE = opts.verbose
    PROFILE = opts.profile

    # generate output document
    outxml = ligolw.Document()
    outxml.appendChild(ligolw.LIGO_LW())

    # append our process
    process = ligolw_process.append_process(outxml, program=__file__,
                                            version=__version__)
    ligolw_process.append_process_params(outxml, process,
                                         [("", "lstring", xmlfile)])
    for key,val in vars(opts).iteritems():
        key = "--%s" % key.replace("_", "-")
        if key == "--sngl-snr":
            for snr in val:
                ligolw_process.append_process_params(outxml, process,
                                                     [(key, "lstring", snr)])
        elif isinstance(val, bool):
            ligolw_process.append_process_params(outxml, process,
                                                 [(key, "lstring", "")])
        elif val is not None:
            ligolw_process.append_process_params(outxml, process,
                                                 [(key, "lstring", val)])


    # open file
    print_verbose("Reading xml file...\n", profile=False)
    xmldoc = ligolw_utils.load_filename(xmlfile, gz=xmlfile.endswith(".gz"), contenthandler = lsctables.use_in(ligolw.LIGOLWContentHandler))

    # read ProcesTable
    process_table = table.get_table(xmldoc, lsctables.ProcessTable.tableName)
    ifos = process_table[-1].get_ifos()

    # read SearchSummaryTable
    search_sum_table = table.get_table(xmldoc,
                                       lsctables.SearchSummaryTable.tableName)
    inseg = search_sum_table[0].get_in()
    outseg = search_sum_table[0].get_out()

    # read SimInspiralTable
    try:
        sim_table = table.get_table(xmldoc,
                                    lsctables.SimInspiralTable.tableName)
    except ValueError:
        pass
    else:
        outxml.childNodes[-1].appendChild(sim_table)
        print_verbose("SimInspiralTable read and copied.\n")

    # read TimeSlideTable
    slide_table = table.get_table(xmldoc, lsctables.TimeSlideTable.tableName)
    slides = slide_table.as_dict()
    outxml.childNodes[-1].appendChild(slide_table)
    print_verbose("TimeSlideTable read and copied.\n")

    # read MultiInspiralTable
    mi_table = table.get_table(xmldoc, lsctables.MultiInspiralTable.tableName)
    print_verbose("MultiInspiralTable read, %d events found.\n" % len(mi_table))

    # read experiment tables
    expr_table = table.get_table(xmldoc, lsctables.ExperimentTable.tableName)
    expr_summ_table = table.get_table(
                          xmldoc, lsctables.ExperimentSummaryTable.tableName)

    # apply vetoes separately for each experiment
    pass_mi_table = table.new_from_template(mi_table)
    for row in expr_summ_table:
        # extract time slide information
        slide_id = row.time_slide_id
        slide = filter(lambda s: s.time_slide_id == slide_id,
                       slide_table)[0]
        slide_vector = slides[slide_id]

        # extract the triggers for this slide
        print_verbose("Experiment (slide) %d:\n" % slide_id, profile=False)
        mi_expr_table = table.new_from_template(mi_table)
        mi_expr_table.extend(t for t in mi_table if
                             t.time_slide_id == slide_id)
        if len(mi_expr_table) == 0:
            print_verbose("    No events found!\n")
            continue
        keep = numpy.ones(len(mi_expr_table)).astype(bool)

        # apply SNR cut
        if opts.snr and row.nevents:
            veto = coh_PTF_pyutils.apply_snr_veto(mi_expr_table, snr=opts.snr,
                                                  return_index=True)
            keep &= veto
            print_verbose("    Coherent SNR cut kills %d events.\n"
                          % (row.nevents - veto.sum()))
        # apply new snr cut
        if opts.new_snr and row.nevents:
            veto = coh_PTF_pyutils.apply_chisq_veto(mi_expr_table,
                                                    snr=opts.new_snr,
                                                    return_index=True)
            keep &= veto
            print_verbose("    Chisq-weighted SNR cut kills %d events.\n"
                          % (row.nevents - veto.sum()))

        # apply bank new snr cut
        if opts.bank_new_snr and row.nevents:
            veto = coh_PTF_pyutils.apply_bank_veto(mi_expr_table,
                                                    snr=opts.bank_new_snr,
                                                    return_index=True)
            keep &= veto
            print_verbose("    Bank chisq-weighted SNR cut kills %d events.\n"
                          % (row.nevents - veto.sum()))
        # apply cont new snr cut
        if opts.cont_new_snr and row.nevents:
            veto = coh_PTF_pyutils.apply_auto_veto(mi_expr_table,
                                                    snr=opts.cont_new_snr,
                                                    return_index=True)
            keep &= veto
            print_verbose("    Auto chisq-weighted SNR cut kills %d events.\n"
                          % (row.nevents - veto.sum()))
        # apply sngl SNR cut
        if opts.sngl_snr and row.nevents:
            veto = coh_PTF_pyutils.apply_sngl_snr_veto(mi_expr_table,
                                                       snrs=opts.sngl_snr,
                                                       return_index=True)
            keep &= veto
            print_verbose("    Single-detector SNR cut kills %d events.\n"
                          % (row.nevents - veto.sum()))
        # apply null SNR cut
        if opts.null_snr and row.nevents:
            veto = coh_PTF_pyutils.apply_null_snr_veto(mi_expr_table,
                                                       null_snr=opts.null_snr,
                                                       return_index=True)
            keep &= veto
            print_verbose("    Null SNR cut kills %d events.\n"
                          % (row.nevents - veto.sum()))
        print_verbose("    Signal-based vetoes kill %d of %d events.\n"
                      % (row.nevents - keep.sum(), row.nevents), profile=False)
        row.nevents = keep.sum()
        pass_mi_table.extend(numpy.asarray(mi_expr_table)[keep])

    # finalize and close
    outxml.childNodes[-1].appendChild(expr_summ_table)
    outxml.childNodes[-1].appendChild(expr_table)
    outxml.childNodes[-1].appendChild(
        lsctables.New(lsctables.SearchSummaryTable))
    ligolw_search_summary.append_search_summary(outxml, process, inseg=inseg,
                                                outseg=outseg,
                                                nevents=len(pass_mi_table),
                                                comment=basename(__file__))

    # construct new table and write
    outxml.childNodes[-1].appendChild(pass_mi_table)

    # finish up
    process.end_time = int(GPSTimeNow())
    ligolw_utils.write_filename(outxml, opts.output_file,
                                gz=opts.output_file.endswith(".gz"),
                                verbose=opts.verbose)
    print_verbose("Done.\n")
