#!/usr/bin/env python

# Copyright (C) 2012 Duncan M. Macleod
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

"""Veto triggers from a MultiInspiralTable based on data-quality vetoes
"""

from __future__ import division

import sys
import optparse
import warnings
import itertools

from time import time as time_now
from os.path import basename

from pylal import (git_version, coh_PTF_pyutils)
from lal import GPSTimeNow

from glue.segments import (segment as Segment, segmentlist as SegmentList,
                           segmentlistdict as SegmentListDict)
from glue.ligolw import (ligolw, table, lsctables, utils as ligolw_utils)
from glue.ligolw.utils import (process as ligolw_process,
                               search_summary as ligolw_search_summary)
warnings.filterwarnings("ignore", "column name (.*) is not lower case",
                        UserWarning)

lsctables.MultiInspiralTable.veto = coh_PTF_pyutils.veto
lsctables.MultiInspiralTable.vetoed = coh_PTF_pyutils.vetoed

__author__ = "Duncan Macleod <duncan.macleod@ligo.org>"
__version__ = git_version.id
__date__ = git_version.date

PROFILE = False
VERBOSE = False


job_start = time_now()
elapsed_time = lambda: time_now()-job_start
def print_verbose(message, verbose=True, stream=sys.stdout, profile=True):
    """Print verbose messages to a file stream.

    @param message
        text to print
    @param verbose
        flag to print or not, default: False (don"t print)
    @param stream
        file object stream in which to print
    @param profile
        flag to print timestamp, default: False
    """
    if stream != sys.stderr:
        profile &= PROFILE
        verbose &= VERBOSE
    if profile and message.endswith("\n"):
        message = "%s (%.2f)\n" % (message.rstrip("\n"), elapsed_time())
    if verbose:
        stream.write(message)
        stream.flush()


def parse_threshold(option, opt_str, value, parser):
    """Verify threshold argument is valid on callback.
    """
    if value <= 0:
        parser.error("%s must be positive" % opt_str)
    setattr(parser.values, option.dest, float(value))


if __name__ == "__main__":

    usage = "%prog [options] file"
    parser = optparse.OptionParser(usage=usage, description=__doc__[1:],\
                                   formatter=optparse.IndentedHelpFormatter(4))
    parser.add_option("-p", "--profile", action="store_true", default=False,
                      help="timestamp output, default: %default")
    parser.add_option("-v", "--verbose", action="store_true", default=False,\
                      help="show verbose output, default: %default")
    parser.add_option("-V", "--version", action="version",\
                      help="show program's version number and exit")
    parser.version = git_version.verbose_msg

    vetoopts = parser.add_option_group("Veto options")
    vetoopts.add_option("-s", "--segment-file", action="append",
                        type="string", default=[], metavar="FILE",
                        help=("path to LIGO_LW sml file containing"
                              "data quality segments to apply. "
                              "Can be given multiple times."))

    outopts = parser.add_option_group("Input options")
    outopts.add_option("-o", "--output-file", action="store",\
                      default=False, help="output xml file, default: stdout")

    # parse args
    opts, args = parser.parse_args()
    if len(args) > 1:
        raise parser.error("Please give only one input XML file (you gave %d)"
                           % len(args))
    xmlfile = args[0]

    if opts.segment_file:
        segment_files = itertools.chain.from_iterable(
                            f.split(",") for f in opts.segment_file)

    VERBOSE = opts.verbose
    PROFILE = opts.profile

    # generate output document
    outxml = ligolw.Document()
    outxml.appendChild(ligolw.LIGO_LW())

    # append our process
    process = ligolw_process.append_process(outxml, program=basename(__file__),
                                            version=__version__)
    ligolw_process.append_process_params(outxml, process,
                                         [("", "lstring", xmlfile)])
    for key,val in vars(opts).iteritems():
        key = "--%s" % key.replace("_", "-")
        if key == "--segment-file":
            for fp in val:
                ligolw_process.append_process_params(outxml, process,
                                                     [(key, "lstring", fp)])
        elif val is not None:
            ligolw_process.append_process_params(outxml, process,
                                                 [(key, "lstring", val)])

    # open file
    print_verbose("Reading xml file...\n", profile=False)
    xmldoc = ligolw_utils.load_filename(xmlfile, gz=xmlfile.endswith(".gz"), contenthandler = lsctables.use_in(ligolw.LIGOLWContentHandler))

    # read SearchSummaryTable
    search_summ_table = table.get_table(xmldoc,
                                        lsctables.SearchSummaryTable.tableName)
    ifos = search_summ_table[-1].get_ifos()
    if not ifos:
        process_table = table.get_table(xmldoc,
                                        lsctables.ProcessTable.tableName)
        ifos = process_table[0].get_ifos()
    process.set_ifos(ifos)
    outxml.childNodes[-1].appendChild(search_summ_table)

    # read Experiment and ExperimentSummary tables
    expr_table = table.get_table(xmldoc, lsctables.ExperimentTable.tableName)
    expr_summ_table = table.get_table(
                          xmldoc, lsctables.ExperimentSummaryTable.tableName)

    # read MultiInspiralTable
    mi_table = table.get_table(xmldoc, lsctables.MultiInspiralTable.tableName)
    nevents = len(mi_table)
    print_verbose("MultiInspiralTable read, %d events found.\n" % nevents)

    # read SimInspiralTable (if it has one)
    try:
        sim_table = table.get_table(xmldoc,
                                    lsctables.SimInspiralTable.tableName)
    except ValueError:
        sim_table = None
    else:
        print_verbose("SimInspiralTable read.\n")

    # read TimeSlideTable
    slide_table = table.get_table(xmldoc, lsctables.TimeSlideTable.tableName)
    slides = slide_table.as_dict()
    outxml.childNodes[-1].appendChild(slide_table)
    print_verbose("TimeSlideTable read and copied.\n")

    # read segments into a dict by instrument
    veto_segments = SegmentListDict((ifo, SegmentList()) for ifo in ifos)
    for fp in segment_files:
        xmldoc = ligolw_utils.load_filename(fp, gz=fp.endswith(".gz"), contenthandler = lsctables.use_in(ligolw.LIGOLWContentHandler))
        seg_def_table = table.get_table(xmldoc,
                                        lsctables.SegmentDefTable.tableName)
        seg_table = table.get_table(xmldoc, lsctables.SegmentTable.tableName)
        ifo_index = dict((int(row.segment_def_id), row.get_ifos()) for
                         row in seg_def_table)
        for ifo in ifo_index:
            if ifo not in veto_segments.keys():
                veto_segments[ifo] = SegmentList()
        for row in seg_table:
            seg = Segment(row.start_time, row.end_time)
            for ifo in ifo_index[int(row.segment_def_id)]:
                veto_segments[ifo].append(seg)
    veto_segments.coalesce()
    print_verbose("Veto segments loaded.\n")

    # veto injections, FIXME: needs to be modified for slid injections
    if sim_table:
        nsim = len(sim_table)
        sim_table = sim_table.veto(veto_segments.union(ifos))
        print_verbose("Vetoes killed %d software injections.\n"
                      % (nsim - len(sim_table)))
        outxml.childNodes[-1].appendChild(sim_table)

    # apply veto individually for each experiment (slide)
    out_table = table.new_from_template(mi_table)
    i = 0
    for row in expr_summ_table:
        # extract experiment information
        expr = filter(lambda e: e.experiment_id ==
                                row.experiment_id, expr_table)[0]
        start = expr.gps_start_time
        end = expr.gps_end_time
        expr_span = SegmentList([Segment(start, end)])
        ifos = expr.get_instruments()
        # extract time slide information
        slide_id = row.time_slide_id
        slide = filter(lambda s: s.time_slide_id == slide_id, slide_table)[0]
        slide_vector = slides[slide_id]
        # slide vetoes instead of triggers
        inverse_vector = dict((ifo, -offset) for ifo,offset in
                              slide_vector.iteritems())
        slide_veto_dict = veto_segments.copy(ifos)
        slide_veto_dict.offsets.update(inverse_vector)
        slide_veto_list = slide_veto_dict.union(ifos)# & expr_span
        print_verbose("Experiment (slide) %d:\n" % slide_id, profile=False)
        print_verbose("    Livetime: %.2f seconds\n" % row.duration,
                      profile=False)
        for ifo in sorted(ifos):
            print_verbose("    Deadtime %s: %.2f seconds\n"
                          % (ifo, abs(slide_veto_dict[ifo]&expr_span)),
                          profile=False)
        print_verbose("    Total deadtime: %.2f seconds\n"
                      % abs(slide_veto_list), profile=False)
        # extract the triggers for this slide
        mi_slide_events = table.new_from_template(mi_table)
        mi_slide_events.extend(t for t in mi_table if
                               (t.time_slide_id == slide_id and
                                float(t.get_end()) not in slide_veto_list))
        print_verbose("    %d out of %d events remain.\n"
                      % (len(mi_slide_events), row.nevents))
        row.nevents = len(mi_slide_events)
        row.duration = abs(expr_span-slide_veto_list)
        out_table.extend(mi_slide_events)
    print_verbose("Vetoes killed %d of %d events.\n"
                  % (nevents - len(out_table), nevents))

    # finalize and close
    outxml.childNodes[-1].appendChild(expr_table)
    outxml.childNodes[-1].appendChild(expr_summ_table)
    outxml.childNodes[-1].appendChild(out_table)
    process.end_time = int(GPSTimeNow())
    ligolw_utils.write_filename(outxml, opts.output_file,
                                gz=opts.output_file.endswith(".gz"),
                                verbose=opts.verbose)
    print_verbose("Done.\n")

