#!/usr/bin/env python
#
# Copyright (C) 2007  Nickolas Fotopoulos
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 2 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
#
"""
Characterize a set of timeslide triggers.
"""

from __future__ import division

__author__ = "Nickolas Fotopoulos and Valeriu Predoi <nvf@gravity.phys.uwm.edu> <valeriu.predoi@astro.cf.ac.uk>"
__prog__ = "pylal_grbtimeslide_stats"
__title__ = "GRB timeslide and trial statistics"


import glob
import itertools
import optparse
import sys

import numpy
numpy.seterr(all="raise")

import matplotlib
matplotlib.use("Agg")
import pylab
pylab.rc("text", usetex=True)

from glue import iterutils
from glue import lal
from glue import segmentsUtils
from glue import segments
from glue.segments import segment, segmentlist, segmentlistdict
from glue.ligolw import ligolw
from glue.ligolw import table
from glue.ligolw import lsctables
from pylal import CoincInspiralUtils
from pylal import InspiralUtils
from pylal import SearchSummaryUtils
from pylal import SnglInspiralUtils
from pylal import git_version
from pylal import grbsummary
from pylal import plotutils
from pylal import rate

##############################################################################
# utility functions

def array_of_lists(shape):
    """
    Return an array of empty lists given shape, a tuple of dimensions.
    """
    size = numpy.prod(shape)
    arr = numpy.empty(size, dtype=object)
    for i in xrange(size):
        arr[i] = []
    return arr.reshape(shape)

def get_loudest_stat(coinc_list):
    """
    Return the loudest statistic in coinc_list.  If coinc_list is empty,
    return 0.
    """
    if len(coinc_list) > 0:
        return max(c.stat for c in coinc_list)
    return 0

# some vectorized functions that act on arrays
_vec_len = numpy.frompyfunc(len, 1, 1)
def vec_len(arr):
    """
    Map the builtin len() to an array and return an output array of the same
    shape of type numpy.int32.
    """
    return _vec_len(arr).astype(numpy.int32)

_vec_get_loudest_stat = numpy.frompyfunc(get_loudest_stat, 1, 1)
def vec_get_loudest_stat(arr):
    """
    Map get_loudest_stat to an array and return an output array of the same
    shape of type numpy.float32.
    """
    return _vec_get_loudest_stat(arr).astype(numpy.float32)



##############################################################################
# handle user input

def comma_delimited(option, opt_str, value, parser, elem_type=str):
    """
    This callback function stores a comma-delimited string as a list during
    option parsing.  Cast elements of list to elem_type, specified via
    callback_kwargs (default: str).  Must specify type="string"!

    Example:
    parser.add_option("-n", "--names", action="callback", type="string",
        callback=comma_delimited, callback_kwargs={"elem_type":str})

    This would allow a user to provide "--names Joe,Bob,Sarah" and the
    developer to have opts.names == ["Joe", "Bob", "Sarah"].
    """
    setattr(parser.values, option.dest, map(elem_type, value.split(",")))


parser = optparse.OptionParser(version=git_version.verbose_msg)
parser.add_option("-g", "--glob",
    help="glob of zero-lag thinca or thinca slide files to read")
parser.add_option("-o", "--onsource-pattern",
    help="on-source coinc file; only used for segment information without "
         "--hist-onsource")
parser.add_option("--mc-boundaries", action="callback", type="string",
    callback=comma_delimited, default=[-numpy.inf, numpy.inf],
    callback_kwargs={"elem_type":float},
    help="comma-separated list of mchirp bin boundaries")
parser.add_option("", "--coinc-statistic",
    help="coincidence statistic of interest (effective_snr or new_snr)")
parser.add_option("", "--hist-loudest", action="store_true", default=False,
    help="make a cumulative histogram that includes the loudest statistic "\
         "from each trial")
parser.add_option("--plot-num-coincs", action="store_true", default=False,
    help="plot the number of coincs vs trial number")
parser.add_option("--hist-num-coincs", action="store_true", default=False,
    help="histogram the number of coincs in each trial")
parser.add_option("--hist-onsource", action="store_true", default=False,
    help="make a cumulative histogram of the on-source trial (If you want "\
    "this, make sure not to veto the on-source segment)")
parser.add_option("--hist-background", action="store_true", default=False,
    help="make a cumulative histogram of the background +/- 1 sigma, "\
    "including both off-source trials and time-slides.")
parser.add_option("", "--hist-trials", action="callback", type="string",
    callback=comma_delimited, callback_kwargs={"elem_type":int}, default=[],
    help="make a cumulative histogram of statistics for each of the given "\
         "trials (comma-delimited list)")
parser.add_option("", "--plot-prefix", help="prefix for plot filenames")
parser.add_option("", "--plot-slide-loudest", action="store_true",
    default=False, help="plot the loudest events in each slide vs slide number")
parser.add_option("", "--plot-trials-loudest", action="store_true",
    default=False, help="plot the loudest events in each zero-lag trial vs " \
    "trial number")
parser.add_option("", "--plot-trial-stat-autocorrelation",
    action="store_true", default=False, help="plot the auto-correlation of "\
    "loudest statistic vs trial number")
parser.add_option("", "--plot-trial-number-autocorrelation",
    action="store_true", default=False, help="plot the auto-correlation of "\
    "number of coincidences vs trial number")
parser.add_option("", "--show-plot", action="store_true", default=False,
    help="display plots to screen")
parser.add_option("-v", "--verbose", action="store_true", default=False,
    help="print extra information to stdout")

# InspiralUtils compatibility
parser.add_option("", "--ifo-tag", help="IFO coincidence time analyzed")
parser.add_option("", "--output-path", help="root of the HTML output")
parser.add_option("", "--enable-output", action="store_true", default=False,
    help="enable plots and HTML output")
parser.add_option("", "--html-for-cbcweb", action="store_true", default=False,
    help="enable HTML output with the appropriate headers for the CBC website")
parser.add_option("--cache-file", help="LAL-formatted cache file "
    "containing entries for all XML filies of interest")
parser.add_option("--offsource-pattern", metavar="PAT", help="sieve the "
    "cache pattern for offsource coincs: THINCA_SLIDE_SECOND*OFFSOURCE for slides only; THINCA*SECOND*OFFSOURCE for slides and zero-lag; THINCA_SECOND*OFFSOURCE for zero-lag only")


    # output
parser.add_option("--user-tag", default="",
    help="identifying string for the analysed data")

    # segments
parser.add_option("--g1-veto-file", default="",
    help="segwizard-formatted segment files that "
    "contain segments to veto in G1")
parser.add_option("--h1-veto-file", default="",
    help="segwizard-formatted segment files that "
    "contain segments to veto in H1")
parser.add_option("--h2-veto-file", default="",
    help="segwizard-formatted segment files that "
    "contain segments to veto in H2")
parser.add_option("--l1-veto-file", default="",
    help="segwizard-formatted segment files that "
    "contain segments to veto in L1")
parser.add_option("--t1-veto-file", default="",
    help="segwizard-formatted segment files that "
    "contain segments to veto in T1")
parser.add_option("--v1-veto-file", default="",
    help="segwizard-formatted segment files that "
    "contain segments to veto in V1")

parser.add_option( "--onsource-seg", action="store", type="string",
    default=None,
    help="segwizard-formatted file containing onsource segment")
parser.add_option( "--buffer-seg", action="store", type="string",
    default=None,
    help="segwizard-formatted file containing buffer segment")
parser.add_option( "--offsource-seg", action="store", type="string",
    default=None,
    help="segwizard-formatted file containing offsource segments")
parser.add_option("--ifos", type="string",
    help="used IFsO written together, like H1V1")
parser.add_option("--numloudesttable", action="store", type="int", default="-1",
    help="number of loudest coincident triggers to be displayed in the Loudest triggers table, defaults at -1")
parser.add_option("", "--gps-start-time", type="int",
    help="GPS start time of data analyzed")
parser.add_option("", "--gps-end-time", type="int",
    help="GPS end time of data analyzed")


(opts, args) = parser.parse_args()

if not ((opts.glob is None) ^ (opts.cache_file is None)):
    print >>sys.stderr, "A glob or input file is required (but not both)"
    sys.exit(2)

if opts.onsource_seg is None:
    print >>sys.stderr, "An onsource segment is required."
    sys.exit(2)

if opts.coinc_statistic is None:
    print >>sys.stderr, "A coincidence statistic is required."
    sys.exit(2)

# turn a glob into a cache
if opts.glob is not None:
    tmp_files = glob.glob(opts.glob)
    if len(tmp_files) == 0:
        raise ValueError, "no files match glob"
    cache = Cache.fromfilenames(tmp_files, coltype=int)
else:
    cache = lal.Cache.fromfile(open(opts.cache_file))

##############################################################################
# HTML initialization
page = InspiralUtils.InspiralPage(opts)

##############################################################################
# read in document
offsource_doc = grbsummary.load_cache(ligolw.Document(), cache,
    opts.offsource_pattern, verbose=opts.verbose)
nominal_num_slides = grbsummary.get_num_slides(offsource_doc)
offsource_cache = cache.sieve(description = opts.offsource_pattern)
found = offsource_cache.checkfilesexist()
if len(found) == 0:
  print >>sys.stderr, "warning: no files found for pattern %s" \
      % opts.offsource_pattern
total_num_slides = 0
array_size = 1  # require the zero-lag to exist or else indexing breaks
if any("SLIDE" in entry.description for entry in offsource_cache):
  total_num_slides += 2 * nominal_num_slides
  array_size += 2 * nominal_num_slides
if any("SLIDE" not in entry.description for entry in offsource_cache):
  total_num_slides += 1


off_segs = grbsummary.get_segs_from_doc(offsource_doc)
on_segs = segmentsUtils.fromsegwizard(open(opts.onsource_seg),
                                      coltype=int)
buffer_segs = segmentsUtils.fromsegwizard(open(opts.buffer_seg),
                                      coltype=int)

ifo_list = set(opts.ifos[2*i:2*i+2].lower() for i in range(len(opts.ifos)/2))

#get slide amounts
shift_unit_vector = dict((ifo, 0) for ifo in ifo_list)  # default to 0
for row in table.get_table(offsource_doc, lsctables.ProcessParamsTable.tableName):
  for ifo in ifo_list:
    if row.param == "--%s-slide" % ifo:
        shift_unit_vector[ifo] = int(float(row.value))

##############################################################################
# make mchirp bins
mc_bins = rate.IrregularBins(opts.mc_boundaries)

##############################################################################
# make trial bins

rings = grbsummary.retrieve_ring_boundaries(offsource_doc)
trial_seg = (rings | on_segs | buffer_segs).extent()
trial_len =  int(abs(on_segs))
num_trials = int(abs(trial_seg)) // trial_len
trial_bins = rate.LinearBins(trial_seg[0], trial_seg[0] + trial_len * num_trials, num_trials)

##############################################################################
# veto trials

veto_segs = segmentlistdict()
for ifo in ifo_list:
  veto_segs[ifo] = segmentsUtils.fromsegwizard(open(getattr(opts, ifo + "_veto_file")), coltype=int)
  veto_segs[ifo] |= buffer_segs
  veto_segs[ifo] &= segmentlist([trial_seg])

trial_veto_mask_2d = numpy.zeros((array_size, num_trials), dtype=numpy.bool8)
for slide_num in range(-nominal_num_slides, nominal_num_slides+1):
  shift_dict = dict((key, slide * slide_num) for (key, slide) in shift_unit_vector.iteritems())
  for ring in rings:
    slid_seglistdict = SnglInspiralUtils.slideSegListDictOnRing(ring, veto_segs, shift_dict)
    slid_seglist = slid_seglistdict.union(slid_seglistdict.keys())
    trial_veto_mask_2d[slide_num, :] |= rate.bins_spanned(trial_bins, slid_seglist, dtype=numpy.bool8)
trial_veto_mask_2d[0, :] |= rate.bins_spanned(trial_bins, buffer_segs, dtype=numpy.bool8)
if total_num_slides != array_size:  # these are equal iff we are including zero-lag background
  trial_veto_mask_2d[0] |= True

##############################################################################
# make mchirp bins
mc_bins = rate.IrregularBins(opts.mc_boundaries)

##############################################################################
# read triggers, veto them, and reconstruct coincidences
triggers = table.get_table(offsource_doc, lsctables.SnglInspiralTable.tableName)

stat = CoincInspiralUtils.coincStatistic(opts.coinc_statistic)
coincTable = CoincInspiralUtils.coincInspiralTable(triggers, stat)
##############################################################################
# sort coincidences in each opts.fold_time second segment

# initialize an array of empty lists, into which we will sort our triggers
coincs = array_of_lists(shape=(len(mc_bins), array_size, num_trials))

# sort into (mchirp, slide, segment) bins
num_discarded = 0
for c in coincTable:
    try:
        mc_ind = mc_bins[grbsummary.get_mean_mchirp(c)]
    except IndexError:
        num_discarded += 1
        continue

    # NB: c.slide_num indexing OK with current inspiral slide implementation
    coincs[mc_ind, c.slide_num, trial_bins[c.get_time()]]\
        .append(c)
if num_discarded > 0:
    print >>sys.stderr, "warning: %d coincs outside mchirp binning" \
        % num_discarded

# apply a vectorized len function to count how many coincs are in each bin
counts_by_mc_slide_trial = vec_len(coincs)
eff_num_trials = numpy.sum(~trial_veto_mask_2d.flatten())

##############################################################################
# output statistics by mchirp

page.write("Total number of slides and zero lag: "\
    "(pos + neg + zero lag): %d" % (total_num_slides))
page.write("Total number of analyzed background trials: "\
    "%d of %d" % (numpy.sum(~trial_veto_mask_2d.flatten()), num_trials*(total_num_slides)))
page.write("Number of %d second "\
    "segments per slide: %d" % (trial_len, (numpy.sum(~trial_veto_mask_2d.flatten()))/(total_num_slides)))


#loudest coincs display using coinc_summ_table
coincSumm = InspiralUtils.write_coinc_summ_table(
  tableList = [coincTable], commentList = ['Loudest triggers'], stat = stat,
  statTag = 'Effective SNR', number=opts.numloudesttable, format='html',\
  followup=None,followupOpts=None)
page.write(coincSumm)

# by mchirp range
for counts_by_slide_trial, mc_range in \
    zip(counts_by_mc_slide_trial, zip(mc_bins.lower(), mc_bins.upper())):

    total_counts = counts_by_slide_trial[~trial_veto_mask_2d].sum()
    norm = 1 / eff_num_trials
    mean = norm * total_counts
    stdev = numpy.sqrt(norm * ((counts_by_slide_trial[~trial_veto_mask_2d] \
                                - mean)**2).sum())

    page.write("Mean mchirp in [%f, %f):" \
        % mc_range)
    page.write("  Total number of coincidences: "\
        "%d" % total_counts)
    page.write("  Mean coincidences per trial: "\
        "%f" % mean)
   

    page.write("  Stdev of trials: %f" % stdev)

    for i in range(min(10, counts_by_slide_trial[~trial_veto_mask_2d].max()+1)):
        prob = norm * (counts_by_slide_trial[~trial_veto_mask_2d] == i).sum()
        page.write("  p(%d|0): %f" % (i, prob))
    page.write("")    

# all together
if len(mc_bins) > 1:
    counts_by_slide_trial = counts_by_mc_slide_trial.sum(axis=0)
    total_counts = counts_by_slide_trial[~trial_veto_mask_2d].sum()
    norm = 1 / eff_num_trials
    mean = norm * total_counts
    stdev = numpy.sqrt(norm * ((counts_by_slide_trial[~trial_veto_mask_2d] \
                                - mean)**2).sum())

    page.write("Combined mchirp bins:")
    page.write("  Total number of coincidences: "\
        "%d" % total_counts)
    page.write("  Mean coincidences per trial: "\
        "%f" % mean)
    page.write("  Stdev of trials: %f" % stdev)

    for i in range(min(10, counts_by_slide_trial[~trial_veto_mask_2d].max()+1)):
        prob = norm * (counts_by_slide_trial[~trial_veto_mask_2d] == i).sum()
        page.write("  p(%d|0): %f" % (i, prob))

##############################################################################
# prepare loudest event statistics (squared) for plots
trial_numbers = numpy.arange(num_trials, dtype=int)
slide_numbers = numpy.arange(-nominal_num_slides, nominal_num_slides+1)
loudest_sq_by_mc_slide_trial = vec_get_loudest_stat(coincs)**2
##############################################################################
# plots
safe_stat = opts.coinc_statistic.replace("_", r"\_")
fnameList = []
tagList = []

# decide what to iterate over
mc_ranges = zip(mc_bins.lower(), mc_bins.upper())

# make plots for each mchirp range
for loudest_sq, counts, mc_range in \
    zip(loudest_sq_by_mc_slide_trial, counts_by_mc_slide_trial, mc_ranges):

    mc_str = r"\textrm{ for }\langle \hat{M}_\mathrm{chirp} \rangle" \
        r"\in [%4.2f, %4.2f)" % mc_range
    fname_suffix = "-%4.2f-%4.2f" % mc_range

    # plot cumulative histogram of zero-lag trials' loudest event statistics
    if opts.hist_loudest:
        text = "Cumulative histogram of slides-trials' loudest event" \
            "statistics"

        # initialize plot
        plot = plotutils.CumulativeHistogramPlot(
            r"$\textrm{%s}^2$" % safe_stat,
            r"$\textrm{\# trials with " + safe_stat \
                + r"}_\mathrm{loudest}^2 > \mathrm{" + safe_stat + "}^2$",
            r"$\textrm{Cumulative histogram of " + safe_stat \
                + r"}^2\textrm{ of loudest events}" + mc_str + "$")

        # finalize plot
        plot.add_content(loudest_sq[~trial_veto_mask_2d].flatten())
        plot.finalize(num_bins=50, normalization=1/eff_num_trials)

        # output it as appropriate
        if opts.enable_output:
            page.add_plot(plot.fig, "trial_loudest_%s_cum_hist_mchirp" \
                % opts.coinc_statistic + fname_suffix)
        if not opts.show_plot:
            plot.close()

    # plot cumulative histogram of specified trials' loudest statistics
    if opts.hist_onsource or opts.hist_background or len(opts.hist_trials) > 0:
        # sanity check
        for trial_num in opts.hist_trials:
            if trial_num < 0 or trial_num > num_trials*total_num_slides:
                raise ValueError, "--hist-trials must provide numbers "
                "between 1 and the number of trials"

        text = "Cumulative histogram of statistics"

        # initialize plot
        plot = plotutils.CumulativeHistogramPlot(
                r"$\textrm{%s}^2$" % safe_stat,
                r"$\textrm{\# loudest events louder than %s}^2$" % safe_stat,
                r"$\textrm{Cumulative histogram of loudest %s}^2" % safe_stat \
                + mc_str + "$")

        # add hist_trials (assume zero lag)
        for trial_num in opts.hist_trials:
            plot.add_content([loudest_sq[:, trial_num]],
                             "trial %d" % trial_num)

        # add background
        if opts.hist_background:
            plot.add_background(loudest_sq[~trial_veto_mask_2d][:, None],
                                label=r"off-source\ trials")

        # finalize plot
        plot.finalize(num_bins=50)
        plot.ax.set_ylim(ymax=1.2)

        # output it as appropriate
        if opts.enable_output:
            page.add_plot(plot.fig,
                "trials_%s_cum_hist" % opts.coinc_statistic + fname_suffix)
        if not opts.show_plot:
            plot.close()

    # plot loudest event statistic vs trial number
    if opts.plot_trials_loudest:
        text = "Loudest event statistic vs trial number"

        # initialize plot
        plot = plotutils.BarPlot("trial number", safe_stat,
            r"$\textrm{Loudest events in each trial}" + mc_str + "$")

        # prepare data
	y_data = loudest_sq[~trial_veto_mask_2d].flatten()
	x_data = range(0,len(y_data))
        plot.add_content(x_data, y_data)
        plot.finalize()

        # output it as appropriate
        if opts.enable_output:
            page.add_plot(plot.fig,
                "trial_loudest_%s" % opts.coinc_statistic + fname_suffix)
        if not opts.show_plot:
            plot.close()

    # plot the number of coincs in each trial
    if opts.plot_num_coincs:
        text = "the number of coincs in each trial"

        # initialize plot
        plot = plotutils.BarPlot(r"trial \#", r"\# coincs",
            r"$\textrm{\# coincs in each trial}" + mc_str + "$")

        # prepare data
        y_data = counts[~trial_veto_mask_2d].flatten()
        x_data = range(0,len(y_data))
        # complete plot
        plot.add_content(x_data, y_data) 
        plot.finalize()


        # output it as appropriate
        if opts.enable_output:
            page.add_plot(plot.fig,
                "plot_num_coincs" + fname_suffix)
        if not opts.show_plot:
            plot.close()


    # plot a histogram of the number of coincs in each trial
    if opts.hist_num_coincs:
        text = "histogram of number of coincs in each trial"

        # initialize plot
        plot = plotutils.BarPlot(r"\# coincs", r"\# trials",
            r"$\textrm{Histogram of \# coincs in each trial}" + mc_str + "$")

        # prepare data
        x_data = numpy.arange(counts.max(), dtype=int)
        y_data = [(counts[~trial_veto_mask_2d] == i).sum() for i in x_data]
        plot.add_content(x_data, y_data)

        # complete plot
        plot.finalize()

        # output it as appropriate
        if opts.enable_output:
            page.add_plot(plot.fig,
                "hist_num_coincs" + fname_suffix)
        if not opts.show_plot:
            plot.close()

#############################################################################
# Generate HTML and cache file
page.write_page()

if opts.show_plot:
    pylab.show()
