#!/usr/bin/python

# Copyright (C) 2012 Ian W. Harry, Duncan M. Macleod
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

"""Cluster MultiInspiral events generated by the CBC coherent (PTF) analysis.
"""

from __future__ import division
import re
import os
import sys
import time
import optparse
import numpy
import warnings
from os.path import basename

from pylal import (git_version, MultiInspiralUtils, ligolw_tisi)
from lal import GPSTimeNow

from glue import (lal as cache, iterutils)
from glue.segments import (segment as Segment, segmentlist as SegmentList,
                           segmentlistdict as SegmentListDict)
from glue.ligolw import (ligolw, lsctables, table, ilwd, utils as ligolw_utils)
from glue.ligolw.utils import (ligolw_add, process as ligolw_process,
                               search_summary as ligolw_search_summary,
                               segments as ligolw_segments)
warnings.filterwarnings("ignore", "column name (.*) is not lower case",
                        UserWarning)

__author__ = "Ian W. Harry <ian.harry@astro.cf.ac.uk>, Duncan M. Macleod <duncan.macleod@ligo.org>"
__version__ = git_version.id
__date__ = git_version.date

# set up timer
process_start = int(GPSTimeNow())
start = time.time()
elapsed_time = lambda: time.time()-start

# global print options
VERBOSE = False
PROFLE = False


def print_verbose(message, verbose=True, stream=sys.stdout, profile=True):
    """Print verbose messages to a file stream.

    @param message
        text to print
    @param verbose
        flag to print or not, default: False (don"t print)
    @param stream
        file object stream in which to print
    @param profile
        flag to print timestamp, default: False
    """
    if stream != sys.stderr:
        profile &= PROFILE
        verbose &= VERBOSE
    if profile and message.endswith("\n"):
        message = "%s (%.2f)\n" % (message.rstrip("\n"), elapsed_time())
    elif profile and message.endswith("\r"):
        message = "%s (%.2f)\r" % (message.rstrip("\r"), elapsed_time())
    if verbose:
        stream.write(message)
        stream.flush()


if __name__=='__main__':

    # parse command line
    epilog = "For help, just ask.ligo.org"

    parser = optparse.OptionParser(description=__doc__, epilog=epilog,
                                   formatter=optparse.IndentedHelpFormatter(4))
    parser.add_option("-p", "--profile", action="store_true", default=False,
                      help="timestamp output, default: %default")
    parser.add_option("-v", "--verbose", action="store_true", default=False,
                      help="verbose output, default: %default")
    parser.add_option("-V", "--version", action="version",
                      help="show program's version number and exit")
    parser.version = git_version.verbose_msg

    # input options
    inputopts = parser.add_option_group("Input options")
    inputopts.add_option("-t", "--trig-file", action="append", type="string",
                         default=[],
                         help="Path to xml file containing MultiInspiralTable")
    inputopts.add_option("-c", "--cache-file", action="store", type="string",
                         default=None,
                         help=("Path to LAL-format cache file containing "
                               "paths to xml files containing "
                               "MultiInspiralTables"))

    # output options
    outputopts = parser.add_option_group("Output options")
    outputopts.add_option("-o", "--output-file", action="store",
                          type="string", help="File path for output xml.")
    outputopts.add_option("-P", "--preserve-processes", action="store_true",
                          help=("preserve all process information from the "
                                "input files, default: %default"))

    # clustering options
    clusteropts = parser.add_option_group("Clustering options")
    clusteropts.add_option("-W", "--time-window", action="store",
                           type="float", default=None,
                           help="The cluster time window")
    clusteropts.add_option("-l", "--loudest-by", type="string", default="snr",
                           metavar="COLUMN",
                           help=("return only the loudest MultiInspiral event "
                                 "for each injection, as ranked by COLUMN."))

    (opts,args) = parser.parse_args()

    if not opts.trig_file and not opts.cache_file:
        parser.error("Must provide either --trig-file or --cache-file.")

    if not opts.time_window or opts.time_window <= 0:
        parser.error("A positive --time-window must be given.")

    VERBOSE = opts.verbose
    PROFILE = opts.profile
    outfile = opts.output_file

    # generate LIGOLw document
    outxml = ligolw.Document()
    outxml.appendChild(ligolw.LIGO_LW())

    # load trigger files
    trig_cache = cache.Cache()
    if opts.trig_file:
        trig_cache.extend(map(cache.CacheEntry.from_T050017, opts.trig_file))
    if opts.cache_file:
        trig_cache.extend(cache.Cache.fromfile(open(opts.cache_file, "r")))
    trig_cache.sort(key=lambda e: e.segment[0])
    trig_cache.checkfilesexist(on_missing="error")

    for fp in trig_cache.pfnlist():
        if re.search("Jan", time.ctime(os.path.getmtime(fp))):
            print_verbose("$ /home/duncan.macleod/bin/"
                          "ligolw_fix_time_slide_id %s\n" % fp)
            os.system("/home/duncan.macleod/bin/ligolw_fix_time_slide_id %s"
                      % fp)

    N = len(trig_cache)
    print_verbose("Loading triggers from %d files... " % N)
    ligolw_add.ligolw_add(outxml, trig_cache.pfnlist())
    time_slide_table = table.get_table(outxml,
                                       lsctables.TimeSlideTable.tableName)
    table.reset_next_ids([type(time_slide_table)])
    time_slide_mapping = ligolw_tisi.time_slides_vacuum(
                             time_slide_table.as_dict())
    iterutils.inplace_filter(lambda row: row.time_slide_id not in
                             time_slide_mapping.keys(), time_slide_table)
    for tbl in outxml.getElementsByTagName(ligolw.Table.tagName):
        tbl.applyKeyMapping(time_slide_mapping)
    print_verbose("Done.\n")

    # extract the tables
    process_table = table.get_table(outxml,
                                    lsctables.ProcessTable.tableName)
    search_summary_table = table.get_table(outxml,
                                        lsctables.SearchSummaryTable.tableName)
    multi_inspiral_table = table.get_table(
                               outxml, lsctables.MultiInspiralTable.tableName)
    time_slide_table = table.get_table(outxml,
                                       lsctables.TimeSlideTable.tableName)
    slides = time_slide_table.as_dict()

    # uniquify the search summary
    unique_search_summary = table.new_from_template(search_summary_table)
    search_summary_inlist = SegmentList()
    search_summary_outlist = SegmentList()
    for search_sum in search_summary_table:
        inseg = search_sum.get_in()
        outseg = search_sum.get_out()
        if (inseg not in search_summary_inlist and
            outseg not in search_summary_outlist):
            unique_search_summary.append(search_sum)
            search_summary_inlist.append(inseg)
            search_summary_outlist.append(outseg)
    outxml.childNodes[-1].removeChild(search_summary_table)
    search_summary_table = unique_search_summary
    search_summary_inlist = search_summary_table.get_inlist()
    search_summary_outlist = search_summary_table.get_outlist()

    # work out the livetime for each slide
    segments = ligolw_segments.LigolwSegments(outxml)
    expr_table = lsctables.New(lsctables.ExperimentTable)
    expr_summ_table = lsctables.New(lsctables.ExperimentSummaryTable)
    expr_map = dict()
    for slide_id,vector in slides.iteritems():
        ifos = set(vector.keys())
        # build slide segment
        expr_segments = SegmentList()
        reverse_slide = dict((ifo, -offset) for ifo,offset in
                             vector.iteritems())
        for i,search_summary in enumerate(search_summary_table):
            overlap = 128 # FIXME: HARDCODED SEGMENT OVERLAP
            expr_span = SegmentList([search_summary.get_in().contract(overlap/2.)])
            ss_segments = SegmentListDict((ifo, expr_span) for ifo in
                                             vector.keys()).copy()
            ss_segments.offsets.update(reverse_slide)
            try:
                out_slide_segment = (
                    ss_segments.intersection(vector.keys())[0] &
                    search_summary.get_out())
            except ValueError:
                pass
            else:
                expr_segments.append(out_slide_segment)
        segments.add(
            ligolw_segments.LigolwSegmentList(
                             active=expr_segments, valid=expr_segments,
                             instruments=ifos, name=slide_id,
                             comment=",".join([str(vector[ifo]) for
                                               ifo in ifos])))
        expr_segments.coalesce()
        expr_start, expr_end = expr_segments.extent()
        expr_id = expr_table.write_new_expr_id("cbc", "coh_PTF_inspiral",
                                               None, vector.keys(),
                                               expr_start, expr_end)
        datatype = any(vector.values()) and "full_data_slide" or "full_data"
        expr_summ_table.write_experiment_summ(expr_id, slide_id, None, datatype)
        expr_map[slide_id] = expr_summ_table[-1]
        expr_summ_table[-1].duration = abs(expr_segments)

    # cluster the triggers separately for each time slide
    slides = time_slide_table.as_dict()
    cluster_table = table.new_from_template(multi_inspiral_table)
    ifos = None
    for slide_id,vector in sorted(slides.iteritems(), key=lambda (a,b): int(a)):
        print_verbose("Clustering time slide %d... " % int(slide_id))
        ifos = vector.keys()
        mi_slide_table = table.new_from_template(cluster_table)
        mi_slide_table.extend(filter(lambda row: row.time_slide_id == slide_id,
                              multi_inspiral_table))
        clusters = MultiInspiralUtils.cluster_multi_inspirals(
                       mi_slide_table, opts.time_window,
                       loudest_by=opts.loudest_by)
        cluster_table.extend(clusters)
        expr_map[slide_id].nevents = len(clusters) 
        print_verbose("%d events selected.\n" % len(clusters))

    print_verbose("Clustering complete.\n")

    # append our process
    if not opts.preserve_processes:
        process_params = table.get_table(outxml,
                                         lsctables.ProcessParamsTable.tableName)
        outxml.childNodes[-1].removeChild(process_table)
        outxml.childNodes[-1].removeChild(process_params)
    process = ligolw_process.append_process(outxml, program=__file__,
                                            version=__version__)
    process.start_time = process_start
    options = [("--time-window", opts.time_window), ("--output-file", outfile),
               ("--loudest-by", opts.loudest_by)]
    if opts.cache_file:
        options.append(("--cache-file", opts.cache_file))
    for fp in opts.trig_file:
        options.append(("--trig-file", opts.trig_file))
    for key,val in options:
        ligolw_process.append_process_params(outxml, process,
                                             [(key, "lstring", val)])

    # append our search summary
    outxml.childNodes[-1].removeChild(search_summary_table)
    outxml.childNodes[-1].appendChild(
        lsctables.New(lsctables.SearchSummaryTable))
    for inseg,outseg in zip(search_summary_inlist, search_summary_outlist):
        nevents = len(cluster_table.vetoed([outseg]))
        ligolw_search_summary.append_search_summary(outxml, process,
                                                    nevents=nevents,
                                                    comment=os.path.basename(
                                                                __file__),
                                                    inseg=inseg,
                                                    outseg=outseg)

    # uniquify the simulations
    try:
        sim_inspiral_table = table.get_table(
                                 outxml, lsctables.SimInspiralTable.tableName)
    except ValueError:
        pass
    else:
        unique_sims = table.new_from_template(sim_inspiral_table)
        unique_sim_times = []
        for sim in sim_inspiral_table:
            if sim.get_time_geocent() in unique_sim_times:
                continue
            if not opts.preserve_processes:
                sim.process_id = process.process_id
            unique_sims.append(sim)
            unique_sim_times.append(sim.get_time_geocent())
        outxml.childNodes[-1].removeChild(sim_inspiral_table)
        outxml.childNodes[-1].appendChild(unique_sims)

    # set the process_ids
    if not opts.preserve_processes:
        for row in cluster_table:
            row.process_id = process.process_id
        for row in time_slide_table:
            row.process_id = process.process_id

    # move the time slide table
    outxml.childNodes[-1].removeChild(time_slide_table)
    outxml.childNodes[-1].appendChild(time_slide_table)

    # finalize the segments
    segments.optimize()
    segments.coalesce()
    segments.finalize(process)

    # write the output file
    outxml.childNodes[-1].appendChild(expr_table)
    outxml.childNodes[-1].appendChild(expr_summ_table)
    outxml.childNodes[-1].removeChild(multi_inspiral_table)
    outxml.childNodes[-1].appendChild(cluster_table)
    table.reset_next_ids(lsctables.TableByName.values())
    table.reassign_ids(outxml)
    segment_def_table = table.get_table(outxml,
                                        lsctables.SegmentDefTable.tableName)
    for time_slide_id,vector in time_slide_table.as_dict().iteritems():
        comment = ",".join([str(vector[ifo]) for ifo in set(vector.keys())])
        for i,row in enumerate(segment_def_table):
            if row.comment == comment:
                row.name = time_slide_id
    process.set_ifos(ifos)
    process.end_time = int(GPSTimeNow())
    ligolw_utils.write_filename(outxml, outfile, gz=outfile.endswith(".gz"),
                                verbose=opts.verbose)
    print_verbose("Done.\n")
