#!/usr/bin/python
#
# Copyright (C) 2008  Nickolas Fotopoulos
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 2 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
#
"""
This code computes the probabilities that go into the distance upper limits
for the CBC external trigger search.

There is a lot of advanced indexing technique here.  Reference:
http://scipy.org/Cookbook/Indexing
"""

from __future__ import division

__author__ = "Nickolas Fotopoulos <nvf@gravity.phys.uwm.edu>"
__prog__ = "plotgrbl"
__title__ = "GRB likelihood diagnostics"

import cPickle as pickle
import optparse
import os.path as op
import shutil
import sys
import os

import numpy
numpy.seterr(all="raise")  # throw an exception on any funny business
from scipy import stats
import matplotlib
matplotlib.use("Agg")
import pylab

from glue.ligolw import ligolw
from glue.ligolw import lsctables
from glue.ligolw import table
from glue.ligolw import utils
from pylal import CoincInspiralUtils
from pylal import grbsummary
from pylal import InspiralUtils
from pylal import plotutils
from pylal import rate
from pylal import SnglInspiralUtils
from pylal import git_version

# override pylab defaults, optimizing for web presentation
pylab.rcParams.update({
    "text.usetex": True,
    "text.verticalalignment": "center",
    "lines.linewidth": 2.5,
    "font.size": 16,
    "axes.titlesize": 20,
    "axes.labelsize": 16,
    "xtick.labelsize": 16,
    "ytick.labelsize": 16,
    "legend.fontsize": 16,
})

# which m2 bin to trace for diagnostics
trace_m2_bin = 0

# set all empty trials to have likelihoods of -infinity
GRBL_EMPTY_TRIAL = -numpy.inf

def get_fap(snr_by_cat):
    """
    Return the false alarm probability, the probability of obtaining snr_by_mc
    given background alone.
    """
    return (offsource_loudest_by_trial_cat > snr_by_cat[None, :])\
        .sum(axis=0) / num_offsource_trials

def get_min_fap(snr_by_cat):
    """
    Return the minimum false alarm probability, subject to a candidate actually
    being present.  In the case of no candidate, return 1.
    """
    actual_candidates = snr_by_cat > 0
    if actual_candidates.any():
        return get_fap(snr_by_cat)[actual_candidates].min()
    else:
        return 1.

class always_equal(object):
    def __eq__(self, other):
        return True

def create_mchirp_plot(statistic, mc_bins, ifos,\
                       onsource_loudest_by_mc, p0_by_mc):
  """
  Creates the mchirp plot and returns it.
  @params statistic: The used statistic
  @params mc_bins: An IrregularBins object describing the used mchirp bins
  @params ifos: the used ifos (as a string)
  @params onsource_loudest_cat: loudest onsource coincs by mchirp
  @params p0_by_cat: the FAP by mchirp
  """

  ## mchirp vs loudest stat horizontal bar graph
  text = "mchirp vs loudest statistic"
  mc_latex = r"\langle \hat{M}_\mathrm{chirp} \rangle"

  plot = plotutils.NumberVsBinBarPlot(\
    ppstat,
    "$" + mc_latex + "$", "Loudest statistics by template bank mchirp %s"%ifos)
  plot.add_content(mc_bins, onsource_loudest_by_mc)
  plot.finalize(orientation="horizontal")

  # add p(c|0) on top as text
  for mc, snr, pc0 in \
          zip(mc_bins.centres(), onsource_loudest_by_mc, p0_by_mc):

    plot.ax.text(snr + 0.2, mc, r"$\mathrm{FAP} = %.3f;\;$ comb %s$=%.2f$" \
                   % (pc0, ppstat, snr), verticalalignment="center")

  if plot.ax.get_xlim()[1] < 20:
    plot.ax.set_xlim(xmax=20)
  plot.ax.set_ylim((mc_bins.min, mc_bins.max))

  # add mchirp dividers
  for divider in mc_bins.boundaries:
    plot.ax.plot(plot.ax.get_xlim(), (divider, divider),
                 "k--", label="_nolegend_")

  return plot


def make_html_table(two_dim_seq, header_row=None):
    """
    Given a sequence of sequences (2-D array works), make an HTML table out
    of it.  If header_row is provided, make it the table header.
    """
    last_len = always_equal()
    html = ["<table  border=\"1\" cellspacing=\"0\">"]
    if header_row is not None:
        html.append("  <tr><th>" + "</th><th>".join(map(str, header_row)) \
            + "</th></tr>")
        last_len = len(header_row)
    for row in two_dim_seq:
        if len(row) != last_len:
            raise ValueError, "header and row lengths must all be consistent."
        last_len = len(row)
        html.append("  <tr><td>" + "</td><td>".join(map(str, row)) \
            + "</td></tr>")
    html.append("</table><br>")
    return "\n".join(html)


def parse_args():
    parser = optparse.OptionParser(version=git_version.verbose_msg)

    # input
    parser.add_option("--relic-onsource", help="output of pylal_relic "\
        "containing the onsource loudest coincs")
    parser.add_option("--relic-offsource", help="output of pylal_relic "\
        "containing the offsource loudest coincs")
    parser.add_option("--relic-injections", help="output of pylal_relic "\
        "containing the loudest injection coincs")
    parser.add_option("--grblikelihood-onsource",
        help="On-source output pickle from pylal_grblikelihood")
    parser.add_option("--grblikelihood-offsource",
        help="Off-source output pickle from pylal_grblikelihood")
    parser.add_option("--grblikelihood-injections",
        help="Injection output pickle from pylal_grblikelihood")

    # InspiralUtils compatibility
    parser.add_option("--gps-start-time", type="int",
        help="GPS start time of data analyzed")
    parser.add_option("--gps-end-time", type="int",
        help="GPS end time of data analyzed")
    parser.add_option("--ifo-tag", help="IFO coincidence time analyzed")
    parser.add_option("--user-tag", help="a tag to label your plots")
    parser.add_option("--output-path", help="root of the HTML output")
    parser.add_option("--enable-output", action="store_true",
        default=False, help="enable plots and HTML output")
    parser.add_option("--html-for-cbcweb", action="store_true",
        default=False, help="enable HTML output with the appropriate headers "
        "for the CBC website")
    parser.add_option("--show-plot", action="store_true", default=False,
        help="display the plots to screen if an X11 display is available")
    parser.set_defaults(onmismatch="raise")
    parser.add_option("--version-mismatch-ok", action="store_const",
        const="warn", dest="onmismatch", help="by default, the program "\
        "will halt if the version id stored in the pickle does not match the "\
        "running version id; specify this option to reduce to a warning")

    # use defined arguments
    parser.add_option("--statistic", help="The statistic to use (default: effective_snr)",\
         default = 'effective_snr')

    parser.add_option("--verbose", action="store_true", default=False,
        help="extra information to the console")

    return parser.parse_args()


################################################################################
# parse arguments
opts, args = parse_args()


##############################################################################
# HTML initialization
InspiralUtils.initialise(opts, __prog__, git_version.verbose_msg)
page = InspiralUtils.InspiralPage(opts)

html_footer = ""

##############################################################################
# Read input

if opts.verbose:
    print "Reading in bin definitions and loudest statistics..."
id, statistic, mc_bins, mc_ifo_cats, onsource_loudest_by_cat \
    = pickle.load(open(opts.relic_onsource))
git_version.check_match(id, onmismatch=opts.onmismatch)
id, statistic, mc_bins, mc_ifo_cats, offsource_loudest_by_trial_cat \
    = pickle.load(open(opts.relic_offsource))
git_version.check_match(id, onmismatch=opts.onmismatch)
id, statistic, mc_bins, mc_ifo_cats, m2_bins, D_bins, m2_D_by_inj, inj_loudest_by_inj_cat \
    = pickle.load(open(opts.relic_injections))
git_version.check_match(id, onmismatch=opts.onmismatch)
if statistic not in ("effective_snr", "new_snr"):
    raise NotImplementedError

# pretty-print name for the statistic (ppstat)
if statistic == "effective_snr":
    ppstat = r"$\rho_\mathrm{eff}$"
    pphtmlstat = "<em>&rho;</em><sub>eff</sub>"
elif statistic == "new_snr":
    ppstat = r"$\rho_\mathrm{new}$"
    pphtmlstat = "<em>&rho;</em><sub>new</sub>"
else:
    ppstat = statistic
    pphtmlstat = statistic

id, log_pc0_by_cat, log_pch_by_cat_m2, actual_candidate_mask, \
    log_L_by_m2, log_sum_L = \
    pickle.load(open(opts.grblikelihood_onsource))
git_version.check_match(id, onmismatch=opts.onmismatch)
id, offsource_pc0_by_trial_cat, offsource_pch_by_trial_cat_m2, \
    offsource_log_L_by_trial_m2, offsource_log_sum_L_by_trial = \
    pickle.load(open(opts.grblikelihood_offsource))
git_version.check_match(id, onmismatch=opts.onmismatch)
id, inj_log_pc0_by_trial_cat, inj_log_pch_by_trial_cat_m2, \
    inj_log_L_by_trial_m2, inj_log_sum_L_by_trial = \
    pickle.load(open(opts.grblikelihood_injections))
git_version.check_match(id, onmismatch=opts.onmismatch)

m2_D_bins = rate.NDBins((m2_bins, D_bins))
num_sims_by_m2_D = rate.BinnedArray(m2_D_bins)
for m2_D in m2_D_by_inj:
    num_sims_by_m2_D[m2_D] += 1
num_sims_by_m2_D = num_sims_by_m2_D.array

num_offsource_trials = offsource_loudest_by_trial_cat.shape[0]
num_inj = inj_log_L_by_trial_m2.shape[0]
if inj_loudest_by_inj_cat.shape[0] != num_inj:
    raise ValueError, "relic injections do not match grblikelihood injections"

m2_bin_by_inj = [m2_bins[m2_D[0]] for m2_D in m2_D_by_inj]

log_L_by_cat_m2 = log_pch_by_cat_m2 - log_pc0_by_cat[:, None]
inj_log_L_by_trial_cat_m2 = inj_log_pch_by_trial_cat_m2 \
    - inj_log_pc0_by_trial_cat[:, :, None]

##############################################################################
# compute unextrapolated FAP for IFAR comparison
if opts.verbose:
    print "Computing IFAR..."

raw_pc0_by_cat = get_fap(onsource_loudest_by_cat)

# with multiple candidates, take the minimum FAP to get the maximum IFAR
# for FAP == 0, use 1e-5 instead
offsource_fap_by_trial = numpy.fromiter((get_min_fap(trial) or 1e-5 \
    for trial in offsource_loudest_by_trial_cat),
    dtype=float, count=num_offsource_trials)
inj_fap_by_trial = numpy.fromiter((get_min_fap(trial) or 1e-5 for trial \
    in inj_loudest_by_inj_cat), dtype=float, count=num_inj)

on_ifar = 1 / (get_min_fap(onsource_loudest_by_cat) or 1e-5)
off_ifar = 1 / offsource_fap_by_trial
inj_ifar = 1 / inj_fap_by_trial

# get unextrapolated p(c|0) for a plot
inj_fap_by_trial_cat = \
    numpy.asarray(map(get_fap, inj_loudest_by_inj_cat))


##############################################################################
# Compute P(L > L_obs | 0)
if opts.verbose:
    print "Computing inputs for other plots..."

# in each m2 bin, find what fraction of trials are louder than the on-source
pL0_by_m2 = \
    (offsource_log_L_by_trial_m2 > log_L_by_m2[None, :])\
    .sum(axis=0, dtype=float) / num_offsource_trials

# find the fraction of trials whose sum log L is greater than the on-source's
L_sum_star = offsource_log_sum_L_by_trial.copy()
L_sum_star.sort()
pLsum0 = (L_sum_star > log_sum_L).sum(dtype=float) / num_offsource_trials
L_sum_star[L_sum_star < 0.] = 0
P_L_sum_star = numpy.arange(num_offsource_trials, dtype=float)[::-1] \
    / num_offsource_trials

# for diagnostics, find P(L > L_* | 0) for various L_*
L_star_by_m2_trial = offsource_log_L_by_trial_m2.T.copy()
L_star_by_m2_trial.sort(axis=1)
L_star_by_m2_trial[L_star_by_m2_trial < 0.] = 0.
offsource_pL0_by_trial = numpy.arange(num_offsource_trials, dtype=float)[::-1] \
    / num_offsource_trials

##############################################################################
# Compute P(L > L_obs | h)

inj_louder_count_by_m2_D = numpy.zeros((len(m2_bins), len(D_bins)), dtype=int)
for inj_log_L_by_m2, m2_D in zip(inj_log_L_by_trial_m2, m2_D_by_inj):
    m2_ind, D_ind = m2_D_bins[m2_D]
    if inj_log_L_by_m2[m2_ind] > log_L_by_m2[m2_ind]:
        inj_louder_count_by_m2_D[m2_ind, D_ind] += 1
pLh_by_m2_D = inj_louder_count_by_m2_D / (num_sims_by_m2_D + 1e-10)

##############################################################################
# print summary of loudest events
if opts.verbose:
    print "Computation complete.  Writing summary page and plots..."

relic_onsource_xml = opts.relic_onsource.replace(".pickle", ".xml")
shutil.copy(relic_onsource_xml, op.join(opts.output_path, os.path.basename(relic_onsource_xml)))
html_footer += "Loudest on-source coincidences [<a href=\"" \
    + relic_onsource_xml + "\">XML</a>]:<br><br>\n"
# make overall statement
html_footer += "ln (&Sigma; <em>L</em><sub>obs</sub>) = %.2g<br>\n" \
        % log_sum_L
html_footer += "P(ln (&Sigma; <em>L</em>) &gt; "\
        "ln (&Sigma; <em>L</em><sub>obs</sub>) | 0) "\
        "= %.2g<br><br>\n" % pLsum0

## x-start

# parse coincs
coinc_stat = CoincInspiralUtils.coincStatistic(statistic)
# XXX: For S5 LIGO-only searches, we have int_8s event_id columns.
# For S5 LIGO-Virgo and later searches, we have ilwd:char event_id columns.
try:
    onsource_doc = SnglInspiralUtils.ReadSnglInspiralFromFiles(\
        [relic_onsource_xml])
except ligolw.ElementError:
    onsource_doc = SnglInspiralUtils.ReadSnglInspiralFromFiles(\
        [relic_onsource_xml], old_document=True)
onsource_trigs = table.get_table(onsource_doc,
    lsctables.SnglInspiralTable.tableName)
onsource_coincs = CoincInspiralUtils.coincInspiralTable(onsource_trigs,
    coinc_stat)
if len(onsource_coincs) > len(mc_ifo_cats):
    raise ValueError, "more on-source coincs (%d) than categories (%d)" %\
                     (len(onsource_coincs), len(mc_ifo_cats))


# some dummy hack, fix, feature, whatever
cat_keys = [cat[0] for cat in mc_ifo_cats.centres()]
cat_dict = {}

# FIXME: use from grbsummary
ifo_list = ['H1','L1','V1']
ifos_list = ['H1L1V1','H1L1','H1V1','L1V1']
columns = ["ifos","mc_bin", "event_id"] \
          + ["%s<sub>%s</sub>" % (pphtmlstat, ifo) for ifo in ifo_list] \
          + ["combined %s" % pphtmlstat, "mchirp",
             "FAP"]
output_list = []
rows = []
#for cat in cat_keys:
for mc_ind, ifos in cat_keys:

  # get the basic index and the mchirps and ifos
  cat_index = mc_ifo_cats[(mc_ind, ifos)]
  mchirp_low = mc_bins.lower()[mc_ind]
  mchirp_upp = mc_bins.upper()[mc_ind]

  # prepare dictionary to save the data
  output_dict = {'event_id':None,'snrH':None,'snrL':None,'snrV':None}

  # retrieve all coincs for this ifo combination
  ifos_coinc = onsource_coincs.getChirpMass(mchirp_low, mchirp_upp)

  # keep only the one with the correct IFO combination
  use_coinc = None
  for coinc in ifos_coinc:
    if coinc.ifos == ifos:
      use_coinc = coinc
      break

  # store the coinc for later
  cat_dict[cat_index]=use_coinc

  ifo_str = "".join(sorted(ifos))
  row = ["%s"%ifo_str, "[%.2f, %.2f)" % (mchirp_low, mchirp_upp)]
  output_dict['ifos'] = ifo_str
  output_dict['mc-bins'] = [mchirp_low, mchirp_upp]
  if use_coinc:
    use_id = str(use_coinc.event_id).split(':')[2]
    row += [use_id]
    output_dict['event_id']=use_id
  else:
    row += ['&mdash']

  if use_coinc:
    for ifo in ifo_list:
      if hasattr(use_coinc, ifo):
        trig = getattr(use_coinc, ifo)
        row += ["%.4g"%trig.get_effective_snr()]
        output_dict['snr'+ifo[0].upper()]=trig.get_effective_snr()
      else:
        row += ["&mdash"]

    pc0 = raw_pc0_by_cat[cat_index]
    row += ["%.4g" % use_coinc.stat, "%.4g" % \
           grbsummary.get_mean_mchirp(use_coinc),"%.4g" % pc0]
    output_dict['snr_effective']=use_coinc.stat
    output_dict['mchirp']=grbsummary.get_mean_mchirp(use_coinc)
    output_dict['prob']=pc0

  else:
    row += ['&mdash','&mdash','&mdash','&mdash','&mdash','&mdash']

  # add this row to the rows-list
  rows.append(row)
  output_list.append(output_dict)

html_footer += make_html_table(rows, columns)

# print per-m2 info
temp_log_L_by_cat_m2 = log_pch_by_cat_m2 - log_pc0_by_cat[:, None]
temp_log_L_by_cat_m2[~actual_candidate_mask, :] = GRBL_EMPTY_TRIAL
cat_ind_by_m2 = temp_log_L_by_cat_m2.argmax(axis=0)

m2_ranges = zip(m2_bins.lower(), m2_bins.upper())
columns = ["m2_bin", "event_id", "ln <em>L</em>(<em>m</em><sub>2</sub>)",
    "<em>P</em>(<em>L</em>(<em>m</em><sub>2</sub>) &gt; "\
    "<em>L</em><sub>obs</sub>(<em>m</em><sub>2</sub>) | 0)"]
rows = []
for m2_range, cat_ind, log_L, pL0 in zip(m2_ranges, cat_ind_by_m2,
    log_L_by_m2, pL0_by_m2):
    row = ["[%.2f, %.2f)" % m2_range]

    m2_coincs = [cat_dict[cat_ind]]

    if len(m2_coincs) == 0:
        row.append("&mdash;")
    elif len(m2_coincs) == 1:
        use_id = str(m2_coincs[0].event_id).split(':')[2]
        row.append(use_id)
    else:
        raise ValueError, "there should be exactly one or zero coincs"
    row.extend(["%.4g" % log_L, "%.4g" % pL0])
    rows.append(row)

html_footer += make_html_table(rows, columns)


################################################################################
# plots
fnameList = []
tagList = []

c_in_E = r"\textrm{any candidate in } \mathcal{E}(c_0)"
mc_latex = r"\langle \hat{M}_\mathrm{chirp} \rangle"
mc_ranges = zip(mc_bins.lower(), mc_bins.upper())
m2_ranges = zip(m2_bins.lower(), m2_bins.upper())

if actual_candidate_mask.any():
    observation = r"L(m_2) > L_\mathrm{obs}(m_2)"
else:
    observation = r"\textrm{any candidates}"

## Significance
text = "Significance"

# these thresholds are on the mean L
thresh_05 = numpy.log(numpy.exp(L_sum_star[int(0.95 * len(L_sum_star))]) \
                      / len(m2_bins))
thresh_01 = numpy.log(numpy.exp(L_sum_star[int(0.99 * len(L_sum_star))]) \
                      / len(m2_bins))

plot = plotutils.SimplePlot(r"$m_2\ (M_\odot)$", r"$\ln L$",
    "Significance")
if actual_candidate_mask.any():
    plot.add_content(m2_bins.centres(), log_L_by_m2)
else:
    # get plot limits to autoscale the way we want
    plot.add_content(m2_bins.centres(), numpy.zeros(len(m2_bins)),
        visible=False)
plot.finalize()
plot.ax.set_ylim(ymin=-0.5)

if not actual_candidate_mask.any():
    plot.ax.text(m2_bins.centres()[len(m2_bins)//2], 0.0,
        "No candidates found.", horizontalalignment="center",
        fontsize="x-large")

# add thresholds
xlims = tuple(plot.ax.get_xlim())
plot.ax.plot(xlims, (thresh_05, thresh_05), linestyle="--")
plot.ax.text(xlims[1], thresh_05, r"$5\%\textrm{ false alarm on }\ln(\bar{L})$",
    horizontalalignment="right", verticalalignment="bottom")
plot.ax.plot(xlims, (thresh_01, thresh_01), linestyle="--")
plot.ax.text(xlims[1], thresh_01, r"$1\%\textrm{ false alarm on }\ln(\bar{L})$",
    horizontalalignment="right", verticalalignment="bottom")
plot.ax.set_xlim(xlims)

page.add_plot(plot.fig, 'max_log_L_by_mc_m2')

## P(L_sum > L_sum^* | 0) vs L_sum^*
text = "P(sum(ln L) > sum(ln L)^*| 0) vs sum(ln L)^*"

plot = plotutils.SimplePlot(\
    r"$\ln(\Sigma L)^*$",
    r"$P(\ln(\Sigma L) > \ln(\Sigma L)^* \,|\,0)$",
    r"$\textrm{Background distribution of } \ln(\Sigma L)$")
non_empty_ind = L_sum_star > 0.
plot.add_content(L_sum_star[non_empty_ind], P_L_sum_star[non_empty_ind],
    marker=".", linestyle="None")
plot.finalize()
plot.ax.set_xscale("log")
plot.ax.autoscale_view()

# pLsum0
if actual_candidate_mask.any():
    xlims = tuple(plot.ax.get_xlim())
    plot.ax.plot((log_sum_L, log_sum_L), (0, 1), "k--")
    plot.ax.set_xlim(xlims)
plot.ax.set_ylim((0, 1))

page.add_plot(plot.fig, 'P_L_sum_by_L')

## mchirp vs loudest stat horizontal bar graph
# loudest coinc per mchirp plot; one for each IFO
ifo_set = set()
for a in mc_ifo_cats.centres(): ifo_set.add(a[0][1])

print "Creating the blue-bar plots"
for ifos in ifo_set:
  ifos_str = "".join(sorted(ifos))
  indices = [mc_ifo_cats[cat] for cat in cat_keys if cat[1]==ifos]

  # create the 'blue-bar' plot
  plot = create_mchirp_plot(statistic, mc_bins, ifos_str, \
                            onsource_loudest_by_cat[indices], \
                            raw_pc0_by_cat[indices])

  page.add_plot(plot.fig, "loudest_stats_by_mchirp"+ifos_str)


  
## L vs m2 for loudest candidates
# TODO: before and after MC and calibration systematics
text = "ln L(c(mc)|h(m2))"

# get likelihoods at extrapolation points
# XXX: 0.3 is a magic number; we start extrapolating at p(c|0) = 0.3
num_nonempty_offsource_trials = (offsource_loudest_by_trial_cat > 0).sum(axis=0)
extrap_ind = \
    [offsource_loudest_by_trial_cat.argsort(axis=0)[-int(0.3 * N), i] \
     for i,N in enumerate(num_nonempty_offsource_trials)]
extrap_snr_by_cat = \
    offsource_loudest_by_trial_cat[extrap_ind, range(len(mc_ifo_cats))]

for ifos in ifo_set:
  ifos_str = "".join(sorted(ifos))
  indices = [mc_ifo_cats[cat] for cat in cat_keys if cat[1]==ifos]

  plot = plotutils.SimplePlot(\
    r"$m_2\ (M_\odot)$",
    r"$\ln L$",
    r"Significance of candidates in different mchirp regions in %s"%ifos_str)

  for mc_range, mc_log_L_by_m2, is_candidate, color in \
        zip(mc_ranges, log_L_by_cat_m2[indices], actual_candidate_mask[indices], plotutils.default_colors()):

      if is_candidate:
          linestyle = "-"
      else:
          linestyle = "--"
      plot.add_content(m2_bins.centres(), mc_log_L_by_m2, linestyle=linestyle,
                       color=color,
                       label=r"$" + mc_latex + r"\in [%4.2f, %4.2f)$" % mc_range)
  plot.finalize(loc=4)
  plot.ax.set_ylim(ymin=-0.5)


  page.add_plot(plot.fig, "log_L_by_mc_m2_%s"%ifos_str)

## L_sum vs rho
text = "ln sum L vs max rho"

plot = plotutils.SimplePlot("max %s" % ppstat, r"$\ln (\Sigma L)$",
    r"$\max \ln (\Sigma L)$ vs max %s" % ppstat)

# loop over all different categories
for mc_ind, ifos in cat_keys:

    cat_index = mc_ifo_cats[(mc_ind, ifos)]
    mchirp_low = mc_bins.lower()[mc_ind]
    mchirp_upp = mc_bins.upper()[mc_ind]
    ifos_str = "".join(sorted(ifos))
     
    plot = plotutils.SimplePlot("max %s" % ppstat, r"$\ln (\Sigma L)$",
                                r"$\max \ln (\Sigma L)$ vs max %s %.2f-%.2f %s"%\
                                (ppstat, mchirp_low, mchirp_upp, ifos_str))

    # injections
    x_data = inj_loudest_by_inj_cat[:,cat_index]
    y_data = inj_log_sum_L_by_trial
    non_empty_trials = (x_data > 0) & (y_data > 0)
    x_data = x_data[non_empty_trials]
    y_data = y_data[non_empty_trials]
    plot.add_content(x_data, y_data, marker=".", color="b", linestyle="None",
                     label="injections", markersize=4)

    # background
    x_data = offsource_loudest_by_trial_cat[:,cat_index]
    y_data = offsource_log_sum_L_by_trial
    non_empty_trials = (x_data > 0) & (y_data > 0)
    x_data = x_data[non_empty_trials]
    y_data = y_data[non_empty_trials]
    plot.add_content(x_data, y_data, marker="x", color="r", linestyle="None",
                     label="background", markersize=6)
    
    # on-source observation
    if log_L_by_m2.sum() > 0:
        plot.add_content([onsource_loudest_by_cat[cat_index]], [log_sum_L],
                         marker="D", markeredgecolor="g", linestyle="None", markeredgewidth=2,
                         label="on-source observation", markersize=8, markerfacecolor="g")

    plot.finalize(loc=4)
    plot.ax.set_xscale("log")
    plot.ax.set_yscale("log")
    plot.ax.axis("tight")
    page.add_plot(plot.fig, "log_sum_L_vs_rho_max_%d"%cat_index)

plot = plotutils.SimplePlot("max %s" % ppstat, r"$\ln (\Sigma L)$",
    r"$\max \ln (\Sigma L)$ vs max %s" % ppstat)

# injections
x_data = inj_loudest_by_inj_cat.max(axis=1)
y_data = inj_log_sum_L_by_trial
non_empty_trials = (x_data > 0) & (y_data > 0)
x_data = x_data[non_empty_trials]
y_data = y_data[non_empty_trials]
plot.add_content(x_data, y_data, marker=".", color="b", linestyle="None",
    label="injections", markersize=4)

# background
x_data = offsource_loudest_by_trial_cat.max(axis=1)
y_data = offsource_log_sum_L_by_trial
non_empty_trials = (x_data > 0) & (y_data > 0)
x_data = x_data[non_empty_trials]
y_data = y_data[non_empty_trials]
plot.add_content(x_data, y_data, marker="x", color="r", linestyle="None",
    label="background", markersize=6)

# on-source observation
if log_L_by_m2.sum() > 0:
    plot.add_content([onsource_loudest_by_cat.max()], [log_sum_L],
        marker="D", markeredgecolor="g", linestyle="None", markeredgewidth=2,
        label="on-source observation", markersize=8, markerfacecolor="g")

plot.finalize(loc=4)
plot.ax.set_xscale("log")
plot.ax.set_yscale("log")
plot.ax.axis("tight")
page.add_plot(plot.fig, "log_sum_L_vs_rho_max")

## IFAR vs rho
text = "IFAR vs max rho"

plot = plotutils.SimplePlot("max %s" % ppstat, r"IFAR",
    r"max IFAR vs max %s" % ppstat)

# injections
x_data = inj_loudest_by_inj_cat.max(axis=1)
y_data = inj_ifar
non_empty_trials = (x_data > 0) & (y_data > 0)
x_data = x_data[non_empty_trials]
y_data = y_data[non_empty_trials]
plot.add_content(x_data, y_data, marker=".", color="b", linestyle="None",
    label="injections", markersize=4)

# background
x_data = offsource_loudest_by_trial_cat.max(axis=1)
y_data = off_ifar
non_empty_trials = (x_data > 0) & (y_data > 0)
x_data = x_data[non_empty_trials]
y_data = y_data[non_empty_trials]
plot.add_content(x_data, y_data, marker="x", color="r", linestyle="None",
    label="background", markersize=6)

# on-source observation
if log_L_by_m2.sum() > 0:
    plot.add_content([onsource_loudest_by_cat.max()], [on_ifar],
        marker="D", markeredgecolor="g", linestyle="None", markeredgewidth=2,
        label="on-source observation", markersize=8, markerfacecolor="g")

plot.finalize(loc=4)
plot.ax.set_xscale("log")
plot.ax.set_yscale("log")
plot.ax.axis("tight")

page.add_plot(plot.fig, "IFAR_vs_rho_max")

## L_sum vs IFAR
text = "max ln sum L vs max IFAR"

plot = plotutils.SimplePlot(r"IFAR", r"$\ln (\Sigma L)$",
    r"$\max \ln (\Sigma L)\textrm{ vs max IFAR}$")

# injections
x_data = inj_ifar
y_data = inj_log_sum_L_by_trial
non_empty_trials = (x_data > 0) & (y_data > 0)
x_data = x_data[non_empty_trials]
y_data = y_data[non_empty_trials]
plot.add_content(x_data, y_data, marker=".", color="b", linestyle="None",
    label="injections", markersize=4)

# background
x_data = off_ifar
y_data = offsource_log_sum_L_by_trial
non_empty_trials = (x_data > 0) & (y_data > 0)
x_data = x_data[non_empty_trials]
y_data = y_data[non_empty_trials]
plot.add_content(x_data, y_data, marker="x", color="r", linestyle="None",
    label="background", markersize=6)

# on-source observation
if log_L_by_m2.sum() > 0:
    plot.add_content([on_ifar], [log_sum_L],
        marker="D", markeredgecolor="g", linestyle="None", markeredgewidth=2,
        label="on-source observation", markersize=8, markerfacecolor="g")

plot.finalize(loc=2)
plot.ax.set_xscale("log")
plot.ax.set_yscale("log")
plot.ax.axis("tight")
page.add_plot(plot.fig, "log_sum_L_vs_IFAR_max")

## ROC plot
text = "ROC curve (uniform priors)"

plot = plotutils.ROCPlot("False alarm probability", "Efficiency",
    r"$\textrm{ROC curve (uniform priors)}$")
plot.add_content(offsource_loudest_by_trial_cat.max(axis=1),
    inj_loudest_by_inj_cat.max(axis=1),
    color="r", linestyle="-", label=ppstat)
plot.add_content(off_ifar, inj_ifar,
    color="g", linestyle="-", label="IFAR")
plot.add_content(offsource_log_sum_L_by_trial,
    inj_log_sum_L_by_trial,
    color="b", linestyle="-", label=r"$\ln(\Sigma L)$")

plot.finalize(loc=4)

# diagonal line represents blind guessing
min_fap = 1 / len(offsource_loudest_by_trial_cat)
diag_line = numpy.logspace(numpy.log10(min_fap), 0, 100)
plot.ax.plot(diag_line, diag_line, linestyle="--", color="k",
    label="_nolegend_")
plot.ax.set_xlim((min_fap, 1))
plot.ax.set_ylim((min_fap, 1))
plot.ax.set_xscale("log")
page.add_plot(plot.fig, "ROC_uniform")

## ROC plot (D^3 reweighted)
text = "ROC curve (D^3 weighting)"

# get D^3 weighting; to revert to unweighted, set all weights to 1 / num_inj.
weight_by_inj = numpy.array([m2_D[1]**3 for m2_D in m2_D_by_inj])
weight_by_inj /= weight_by_inj.sum()

plot = plotutils.ROCPlot("False alarm probability", "Efficiency",
    r"$\textrm{ROC curve (}D^3\textrm{ weighting)}$")
plot.add_content(offsource_loudest_by_trial_cat.max(axis=1),
    inj_loudest_by_inj_cat.max(axis=1), weight_by_inj,
    color="r", linestyle="-", label=ppstat)
plot.add_content(1 / offsource_fap_by_trial, 1 / inj_fap_by_trial,
    weight_by_inj,
    color="g", linestyle="-", label="IFAR")
plot.add_content(offsource_log_sum_L_by_trial,
    inj_log_sum_L_by_trial, weight_by_inj,
    color="b", linestyle="-", label=r"$\ln(\Sigma L)$")

plot.finalize(loc=4)

# diagonal line represents blind guessing
min_fap = 1.0 / len(offsource_loudest_by_trial_cat)
diag_line = numpy.logspace(numpy.log10(min_fap), 0, 100)
plot.ax.plot(diag_line, diag_line, linestyle="--", color="k",
    label="_nolegend_")
plot.ax.set_xlim((min_fap, 1))
plot.ax.set_ylim((min_fap, 1))
plot.ax.set_xscale("log")

page.add_plot(plot.fig, "ROC_D3")



## p(c|0) vs rho for each mc bin
text = "ln P(c|0) vs rho for each mc bin for m2 = " \
    + str(m2_bins.centres()[trace_m2_bin])

# loop over the IFOs
for ifos in ifo_set:
   
    ifos_str = "".join(sorted(ifos)) 
    indices = [mc_ifo_cats[cat] for cat in cat_keys if cat[1]==ifos]
    
    plot = plotutils.SimplePlot(ppstat, "",
                            r"$\ln P(" + c_in_E + r"\,|\,0)\textrm{ vs stat}$ %s"%ifos_str)

    # add unextrapolated p(c|0) to the right of the extrapolation points
    extrap_trials_by_inj_cat = \
                            (inj_loudest_by_inj_cat > extrap_snr_by_cat[None, :]) & \
                            (inj_fap_by_trial_cat > 0)

    for (mc_range, inj_loudest_by_inj, inj_fap_by_trial, extrap_trials_by_inj, \
         inj_log_pc0_by_trial, color) \
        in zip(mc_ranges, inj_loudest_by_inj_cat[:,indices], inj_fap_by_trial_cat[:,indices],
               extrap_trials_by_inj_cat[:,indices], inj_log_pc0_by_trial_cat[:,indices],
               plotutils.default_colors()):
        # unextrapolated
        plot.add_content(inj_loudest_by_inj[extrap_trials_by_inj],
            numpy.log(inj_fap_by_trial[extrap_trials_by_inj]), marker="^",
            markersize=4, markeredgewidth=1, markerfacecolor=color,
            linestyle="None", markeredgecolor="k")
        # extrapolated
        plot.add_content(inj_loudest_by_inj, inj_log_pc0_by_trial,
            marker=".", linestyle="None", color=color,
            label="$" + mc_latex + r"\in [%4.2f, %4.2f)$" % mc_range)

    plot.finalize(loc=1)

    # add SNRs at which we begin extrapolating
    ylims = tuple(plot.ax.get_ylim())
    for extrap_snr, color in zip(extrap_snr_by_cat[indices], plotutils.default_colors()):
        plot.ax.plot((extrap_snr, extrap_snr), ylims,
            color + "--", label="_nolegend_")
    plot.ax.set_ylim(ylims)
    page.add_plot(plot.fig, "log_pc0_vs_rho_"+ifos_str)

    # save a zoomed version, too
    plot.ax.set_xlim((4, 16))
    plot.ax.set_ylim((-6, 0.2))
    page.add_plot(plot.fig, "log_pc0_vs_rho_zoom_"+ifos_str)


    
plot = plotutils.SimplePlot(ppstat, "",
    r"$\ln P(" + c_in_E + r"\,|\,0)$ vs %s" % ppstat)

# add unextrapolated p(c|0) to the right of the extrapolation points
extrap_trials_by_inj_cat = \
  (inj_loudest_by_inj_cat > extrap_snr_by_cat[None, :]) & \
  (inj_fap_by_trial_cat > 0)

for (mc_range, inj_loudest_by_inj, inj_fap_by_trial, extrap_trials_by_inj, \
     inj_log_pc0_by_trial, color) \
    in zip(mc_ranges, inj_loudest_by_inj_cat.T, inj_fap_by_trial_cat.T,
           extrap_trials_by_inj_cat.T, inj_log_pc0_by_trial_cat.T,
           plotutils.default_colors()):
    # unextrapolated
    plot.add_content(inj_loudest_by_inj[extrap_trials_by_inj],
        numpy.log(inj_fap_by_trial[extrap_trials_by_inj]), marker="^",
        markersize=4, markeredgewidth=1, markerfacecolor=color,
        linestyle="None", markeredgecolor="k")
    # extrapolated
    plot.add_content(inj_loudest_by_inj, inj_log_pc0_by_trial,
        marker=".", linestyle="None", color=color,
        label="$" + mc_latex + r"\in [%4.2f, %4.2f)$" % mc_range)

plot.finalize(loc=1)

# add SNRs at which we begin extrapolating
ylims = tuple(plot.ax.get_ylim())
for extrap_snr, color in zip(extrap_snr_by_cat, plotutils.default_colors()):
    plot.ax.plot((extrap_snr, extrap_snr), ylims,
        color + "--", label="_nolegend_")
plot.ax.set_ylim(ylims)
page.add_plot(plot.fig, "log_pc0_vs_rho")

# save a zoomed version, too
plot.ax.set_xlim((4, 16))
plot.ax.set_ylim((-6, 0.2))
page.add_plot(plot.fig, "log_pc0_vs_rho_zoom")


# Hack: to use for ImagePlots with non-uniform bins
def fake_nonuniform_bins(orig_data, bins):
    """Copy rows in 2D-array orig_data to simulate non-uniform bins with many bins.

    Symbolically, [a,b,c] -> [a,a,a,b,b,b,b,b,b,b,c,c].
    """
    def mcd(x, y):  # Maximum Common Divisor, Euclides-like algorithm
        for i in range(6):  # no more than 6 iterations, please
            if y > x:
                x, y = y, x  # so x > y
            if x % y < 1e-3 * x:
                return y  # if y divides x "sufficiently", we're done
            x = x % y
        return x  # hopefully we won't to get here
    sizes = bins.upper() - bins.lower()  # bin sizes
    # worst residual of approximating a size by an integer times the smallest size
    worst = numpy.argmax(sizes % min(sizes))
    ds = mcd(min(sizes), sizes[worst])  # delta_size, the smallest size for the new bins
    data = orig_data.copy()
    pos = 0
    for i in range(orig_data.shape[0]):
        # copy original row i as many times as delta_size fits in this bin size
        copy_row = [orig_data[i,:] for  x in range(int(sizes[i]/ds))]
        data = numpy.insert(data, [pos]*int(sizes[i]/ds), copy_row, axis=0)
        pos += 1 + len(copy_row)  # take into account the extra rows we added
    return data


## p(c|h) from max L mchirp bins vs (m2, D) image
text = "p(observation|h(m2, D)) after 90% MC errors"

num_sigmas = stats.norm.ppf(0.9)
MC_sigma = numpy.sqrt(pLh_by_m2_D * (1 - pLh_by_m2_D) \
                      / (num_sims_by_m2_D + 1e-10))

plot = plotutils.ImagePlot(r"$m_2\ (M_\odot)$", r"$D\ \mathrm{(Mpc)}$",
    r"$P(" + observation + r"\,|\,h(m_2, D)) "
    r"\textrm{ w/ 90\% MC error}$")
plot.add_content(fake_nonuniform_bins(pLh_by_m2_D, m2_bins).T - \
                     num_sigmas * fake_nonuniform_bins(MC_sigma, m2_bins).T,
                 m2_bins, D_bins)
plot.finalize()

# XXX: Hack around an apparent matplotlib API backwards-incompatibility
# When all clusters have 0.90 or higher, remove the try-except.
try:
  image = [c for c in plot.ax.get_children() \
           if isinstance(c, matplotlib.image.AxesImage)][0]
  image.set_clim((0, 1))
except AttributeError:
  pass

page.add_plot(plot.fig, "pLh_afterMCerr_by_m2_D")


## p(c|h) from max L mchirp bins vs (m2, D) image
text = "p(observation|h(m2, D)) before MC errors"

plot = plotutils.ImagePlot(r"$m_2\ (M_\odot)$", r"$D\ \mathrm{(Mpc)}$",
    r"$P(" + observation + r"\,|\,h(m_2, D))$")
plot.add_content(fake_nonuniform_bins(pLh_by_m2_D, m2_bins).T, m2_bins, D_bins)
plot.finalize()

# XXX: Hack around an apparent matplotlib API backwards-incompatibility
# When all clusters have 0.90 or higher, remove the try-except.
try:
  image = [c for c in plot.ax.get_children() \
           if isinstance(c, matplotlib.image.AxesImage)][0]
  image.set_clim((0, 1))
except AttributeError:
  pass

page.add_plot(plot.fig, "pLh_by_m2_D")

## P(L > L_0 | 0) vs m2
text = "lim_{D->inf} P(observation| h(m2, D)) vs m_2"

orig_size = pylab.rcParams["axes.titlesize"]
pylab.rcParams["axes.titlesize"] = orig_size * 0.9
plot = plotutils.SimplePlot(\
    r"$m_2\ (M_\odot)$",
    r"$P(" + observation + r"\,|\,0)$",
    r"\[P(" + observation + r"\,|\,0) \left(= \lim_{D\rightarrow\infty} P(" \
    + observation + r"\,|\,h(m_2, D))\right)\]")
plot.add_content(m2_bins.centres(), pL0_by_m2)
plot.finalize(loc=4)
plot.ax.set_ylim((0, 1))
pylab.rcParams["axes.titlesize"] = orig_size

# the title is stupidly huge; shrink plot to fit it
plot.ax.set_position((0.12, 0.1, 0.8, 0.78))
page.add_plot(plot.fig, "pL0_by_m2")


## P(L > L_* | 0) vs L_*
text = "P(L(m_2) > L_*(m_2) | 0) vs ln L_*"

plot = plotutils.SimplePlot(\
    r"$\ln L_*$",
    r"$P(\ln L(m_2) > \ln L_* \,|\,0)$",
    r"\[P(\ln L(m_2) > \ln L_* \,|\,0) "\
    r"\left(= \lim_{D\rightarrow\infty} P(\ln L(m_2) > ln L_* "\
    r"\,|\,h(m_2, D))\right)\]")

m2_ranges = zip(m2_bins.lower(), m2_bins.upper())
for m2_range, Lstar_by_trial in \
    zip(m2_ranges, L_star_by_m2_trial):
    plot.add_content(Lstar_by_trial, offsource_pL0_by_trial,
                     label=r"$m_2 \in [%4.1f, %4.1f) M_\odot$" % m2_range)
plot.finalize(loc=4)
plot.ax.set_yscale("log")
plot.ax.set_xlim(xmin=0)

# the title is stupidly huge; shrink plot to fit it
plot.ax.set_position((0.12, 0.1, 0.8, 0.78))
page.add_plot(plot.fig, "pL0_by_L")

## injection count vs (m2, D) image
text = "Injection count"

plot = plotutils.ImagePlot("$m_2\ (M_\odot)$", "$D\ \mathrm{(Mpc)}$",
    r"Injections made")

plot.add_content(fake_nonuniform_bins(num_sims_by_m2_D, m2_bins).T, m2_bins, D_bins)
plot.finalize()
page.add_plot(plot.fig, "injection_count_by_m2_D")




#############################################################################
# Generate HTML and cache file
if opts.enable_output:

    page.write(html_footer)
    page.write_page()

    if opts.html_for_cbcweb:
        page.write_page(cbcweb = True)

