#!/usr/bin/python
#
# Copyright (C) 2012 Matthew West
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 2 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.


#
# =============================================================================
#
#                                 Preamble
#
# =============================================================================
#

from optparse import OptionParser
import sqlite3
import sys
import os
import pdb

import numpy
import matplotlib
matplotlib.use('Agg')
from matplotlib import pyplot
from matplotlib import rcParams

from glue.lal import CacheEntry
from glue.ligolw import dbtables
from pylal import ligolw_cbc_plotvt_utils as efficiency
from pylal import git_version

__author__ = "Matthew West <matthew.west@ligo.org>"
__version__ = "git id %s" % git_version.id
__date__ = git_version.date

description = """
    %prog calculates the detection efficiency as a function of distance. 
    The produces plots for efficiency and cumulative VxT vs distance, as
    well as total VxT vs FAR threshold.
    """


#
# =============================================================================
#
#                                Command Line
#
# =============================================================================
#


def parse_command_line():
    """
    Parse the command line.
    """
    parser = OptionParser(
        version = "Name: %%prog\n%s" % git_version.verbose_msg,
        usage = "%prog [options]",
        description = description
        )
    parser.add_option("-c", "--database-cache", type="string",
        metavar = "name",
        help = "Get database files from the LAL cache named filename. [Required]"
        )
    parser.add_option( "-t", "--tmp-space", action="store", type="string", default=None,
        metavar = "PATH",
        help =
            "Location of local disk on which to do work. This is optional; " +
            "it is only used to enhance performance in a networked " +
            "environment. "
        )
    parser.add_option("-P", "--output-path", action = "store", type = "string",
        default = os.getcwd(), metavar = "PATH",
        help =
            "Path where the figures should be stored. Default is current directory."
        )
    parser.add_option("-f", "--format", action = "store", default = "png",
        metavar = "{\"png\",\"pdf\",\"svg\",\"eps\",...}",
        help = "Set the output image format. (default = 'png')"
        )
    parser.add_option("--distance-scale", type="string", default="linear",
        metavar = "string",
        help = 
            "How the distance bins are broken up. The default option is 'linear'." +
            "The alternative is 'log'."
        )
    parser.add_option("--calc-dist-range", action="store", type="string",
        metavar = "string",
        help = "The range of distances used in calculations: 'd_min,d_max' (Required)"
        )
    parser.add_option("--plot-dist-range", action="store", type="string",
        metavar = "string",
        help = "The distance range for figures: 'd_min,d_max' (Required)"
        )
    parser.add_option("--distance-step", action="store", type="float", default=4.0,
        metavar = "float",
        help =
            "The step size for distance bins used in the efficiency calculation." +
            "If the distance scale is 'linear', then step=d[j+1]-d[j]. If the " +
            "scale is 'log', then step=d[j+1]/d[j]. The default value is 4.0"
        )
    parser.add_option("--distance-type", type="string", default="decisive_distance",
        metavar = "string",
        help =
            "The desired distance parameter the efficiency histogram will be " +
            "broken up in. The default option is 'decisive_distance', which is " +
            "the second largest effective distance from a found signal. The " +
            "alternative is 'distance'."
        )
    parser.add_option("--weight-dist", action = "store_true", default = True,
        help =
            "Weight the distance parameter by chirp mass of a canonical DNS " +
            "binary (1.4-1.4). The default value is False. The formula is " +
            "chirp_dist = dist*(mchirp_DNS/mchirp)**(5.0/6.0)."
        )
    parser.add_option("--confidence", type="float", default=0.682689492137,
        metavar = "",
        help =
            "The maximum density credible interval used for errorbars in the " +
            "measured efficiencies. The value should be between 0 and 1 and " +
            "the default is 0.6827."
        )
    parser.add_option("--veto-category", type="int",
        metavar = "INT",
        help = "The highest DQ veto category the triggers survive. [Required]"
        )
    parser.add_option("--on-instruments", type="string",
        metavar = "IFOS",
        help = 
            "A comma separated string of the instruments one wishes to " +
            "calculate the efficiency for. Example: 'H1,L1,V1' [Required]"
        )
    parser.add_option("--database-tags", type="string",
        metavar = "'tag-1,tag-2,...'",
        help =
            "The usertag for the injection run(s) used to compute an efficiency. " +
            "The default is 'ALL_INJ'."
        )
    parser.add_option("--inj-usertag", type="string", default="ALL_INJ",
        metavar = "usertag",
        help =
            "The usertag for the injection run(s) used to compute an efficiency. " +
            "The default is 'ALL_INJ'."
        )
    parser.add_option("-v", "--verbose", action = "store_true", default = False,
        help = "Be verbose."
        )

    options, arguments =  parser.parse_args()

    required_options = [
        "database_cache", "database_tags",
        "veto_category", "on_instruments",
        "calc_dist_range", "plot_dist_range"
    ]
    missing_options = [option for option in required_options if getattr(options, option) == None]

    if missing_options:
        missing_options = ', '.join([
            "--%s" % option.replace("_", "-")
            for option in missing_options
        ])
        raise ValueError, "missing required option(s) %s" % missing_options

    return options

#
# =============================================================================
#
#                                    Main
#
# =============================================================================
#


#
# Command line
#

opts = parse_command_line()

# get database url paths from cache
database_list = [CacheEntry(line).path for line in file(opts.database_cache)]
# make list of database tags
db_tags = opts.database_tags.split(',')

on_ifos = ','.join(sorted(opts.on_instruments.split(','))).upper()
veto_cat = 'VETO_CAT%i_CUMULATIVE' % opts.veto_category

measured_eff = {}
eff_vs_dist = {}
volume_eff = {}
cumVT_fgd = {}

# make the distance bins for the efficiency histogram
dist_bins = efficiency.inj_dist_range(
    map(int, (opts.calc_dist_range).split(',')),
    dist_scale = opts.distance_scale,
    step = opts.distance_step
)
V_shell = 4./3*numpy.pi*(dist_bins[1:]**3 - dist_bins[:-1]**3)

if opts.verbose:
    print >> sys.stdout, "computing p(cd|d)"
# determine the probability of a signal having chirp dist for a given distance 
prob_cd_d = efficiency.rescale_dist(
    on_ifos.split(','),
    opts.distance_type,
    opts.weight_dist,
    phys_dist = dist_bins,
    param_dist = dist_bins
)

for db_idx, db_path in enumerate(database_list):

    # Setup working databases and connections
    db_tag = db_tags[db_idx]
    if opts.verbose:
        print >> sys.stdout, "Set up connection for %s database" % db_tag

    working_filename = dbtables.get_connection_filename(
        db_path,
        tmp_path=opts.tmp_space,
        verbose=opts.verbose)
    connection = sqlite3.connect( working_filename )
    if opts.tmp_space:
        dbtables.set_temp_store_directory(
            connection,
            opts.tmp_space,
            verbose=opts.verbose)
    dbtables.DBTable_set_connection( connection )
    
    # create needed indices on tables if they don't already exist
    current_indices = [index[0] for index in
        connection.execute('SELECT name FROM sqlite_master WHERE type == "index"').fetchall()]
    
    sqlscript = ''
    if 'si_idx' not in current_indices:
        sqlscript += 'CREATE INDEX si_idx ON sngl_inspiral (event_id, snr, chisq, chisq_dof, mchirp, eta);\n'
    if 'ci_idx' not in current_indices:
        sqlscript += 'CREATE INDEX ci_idx ON coinc_inspiral (coinc_event_id);\n'
    if 'cem_idx' not in current_indices:
        sqlscript += 'CREATE INDEX cem_idx ON coinc_event_map (coinc_event_id, event_id);\n'
    if 'es_idx' not in current_indices:
        sqlscript += 'CREATE INDEX es_idx ON experiment_summary (experiment_id, time_slide_id, datatype);\n'
    if 'em_idx' not in current_indices:
        sqlscript += 'CREATE INDEX em_idx ON experiment_map (coinc_event_id, experiment_summ_id);\n'
    
    connection.executescript( sqlscript )


    #
    # get lists of the found injections and their associated FAR values
    # each entry in found_inj is the tuple (simulation_id, gps_end_time, distance)
    #
    if opts.verbose:
        print >> sys.stdout, "\tGet list of found injections"

    found_inj, found_fars, _ = efficiency.found_injections(
        connection,
        opts.inj_usertag,
        on_ifos,
        dist_type = opts.distance_type,
        weight_dist = opts.weight_dist,
        verbose = opts.verbose)

    # get list of tuples containing all injections that went in during desired coinc time
    all_inj = efficiency.successful_injections(
        connection,
        opts.inj_usertag,
        on_ifos,
        veto_cat,
        dist_type = opts.distance_type,
        weight_dist = opts.weight_dist,
        verbose = opts.verbose)

    secINyear = 60.*60.*24.*365.25
    # Determine the minimum false_alarm_rate one can estimate from slides
    minFAR = secINyear / efficiency.get_livetime(connection, veto_cat, on_ifos, 'slide')
    far_list = minFAR * numpy.logspace(start=4.0, stop=0.0, num=1e2)
    far_list = numpy.append(far_list, 0.0)
    # Calculate the foreground search time in years
    T_fgd = efficiency.get_livetime(connection, veto_cat, on_ifos, 'all_data') / secINyear
    cumVT_fgd[db_tag] = numpy.cumsum(V_shell)*T_fgd

    # ------------------------- For Each FAR Threshold ------------------------- #

    # determine the measured efficiency
    if opts.verbose:
        print >> sys.stdout, "\tDetermine the measured efficiency"
    measured_eff[db_tag] = efficiency.detection_efficiency(
        all_inj,
        found_inj, found_fars,
        far_list,
        dist_bins,
        opts.confidence)

    # compute the efficiency vs distance for a physical distribution of sources
    if opts.verbose:
        print >> sys.stdout, "\tCompute the efficiency vs binned distance"
    eff_vs_dist[db_tag] = efficiency.eff_vs_dist(
        measured_eff[db_tag],
        prob_cd_d)

    # compute the volume efficiency vs distance for a physical distribution of sources
    if opts.verbose:
        print >> sys.stdout, "\tCompute the volume efficiency vs distance"
    volume_eff[db_tag] = efficiency.volume_efficiency(
        measured_eff[db_tag],
        V_shell,
        prob_cd_d)

    start_time, end_time = connection.execute(
        'SELECT DISTINCT gps_start_time, gps_end_time FROM experiment'
        ).fetchone()

    # close connection
    connection.close()
    dbtables.discard_connection_filename(
        db_path,
        working_filename,
        verbose=opts.verbose)

#										#
# -------------------------------- Making Plots ------------------------------- #
#										#
if opts.verbose:
    print >> sys.stdout, "Making Plots"

# shared plotting parameters
plotting_params = {
    'font.size': 14,
    'text.usetex': True,
    'xtick.labelsize': 'medium',
    'ytick.labelsize': 'medium',
    'axes.grid': False,
    'axes.titlesize': 'medium',
    'axes.labelsize': 'medium',
    'grid.color': 'k',
    'grid.linestyle': ':',
    'legend.fontsize': 'smaller',
    'legend.numpoints': 1,
    'savefig.dpi': 200,
    'savefig.format': opts.format
}
rcParams.update(plotting_params)

# strings to designate distance type in title and plotname
pname_dist = 'DIST'
if opts.weight_dist:
    opts.distance_type = opts.distance_type.replace('distance','chirp_distance')
    pname_dist = 'CHIRP-'+pname_dist
if 'decisive' in opts.distance_type:
    pname_dist = 'DEC-'+pname_dist

# plot basename
plot_basename = '_'.join([
    on_ifos.replace(',',''),
    '%s', pname_dist,
    opts.inj_usertag.replace('_','-'),
    'CAT-%i-VETO' % opts.veto_category, 
    str(start_time), str(end_time - start_time)
])

# determine xlim_max for FAR_vs_VT plot
maxFAR = max([numpy.max(eff['mode'].keys()) for eff in measured_eff.values()])
# get min and max distances for plots
plot_dist = map(float, opts.plot_dist_range.split(','))
arg_minD = numpy.argmin(numpy.abs(dist_bins[1:] - plot_dist[0]))
arg_maxD = numpy.argmin(numpy.abs(dist_bins[1:] - plot_dist[1]))

# --------------------- Measured Efficiency vs Distance for FAR=0 -------------------- #
if opts.verbose:
    print >> sys.stdout, "\tMeasured Efficiency vs %s" % opts.distance_type

Eff_Dist_fig = pyplot.figure()
Eff_Dist_plot = Eff_Dist_fig.add_subplot(111)

for tag in db_tags:
    Eff_Dist_plot.errorbar(
        dist_bins[1:], measured_eff[tag]['mode'][0.0],
        xerr=(
            [opts.distance_step]*len(dist_bins[1:]),
            [0]*len(dist_bins[1:])
            ),
        yerr=(
            measured_eff[tag]['mode'][0.0]-measured_eff[tag]['low'][0.0],
            measured_eff[tag]['high'][0.0]-measured_eff[tag]['mode'][0.0]
            ),
        label=tag, linestyle='None',
        marker='.', markersize=10)

Eff_Dist_plot.grid(which='major', linewidth=0.75)
Eff_Dist_plot.grid(which='minor', linewidth=0.25)
Eff_Dist_plot.set_xscale(opts.distance_scale)
Eff_Dist_plot.set_xlim(dist_bins[arg_minD], dist_bins[arg_maxD+1])
Eff_Dist_plot.set_ylim(0, 1)

Eff_Dist_plot.set_xlabel( r'Distance to Outer Edge of Volume Shell $(Mpc)$' )
Eff_Dist_plot.set_ylabel( r'Fraction of Recovered Injections in Bin' )
Eff_Dist_plot.set_title(
    r'Measured Efficiency vs %s at FAR=0' % opts.distance_type.replace('_',' ').title()
    )
Eff_Dist_plot.legend(loc='lower left')

plot_name = opts.output_path + plot_basename  % 'MEFF-FAR0'
pyplot.savefig(plot_name + '.' + opts.format)
pyplot.close(Eff_Dist_fig)

# --------------------- Calculated Efficiency vs Distance for FAR=0 -------------------- #
if opts.verbose:
    print >> sys.stdout, "\tCalculated Efficiency vs Distance"

Eff_Dist_fig = pyplot.figure()
Eff_Dist_plot = Eff_Dist_fig.add_subplot(111)

for tag in db_tags:
    Eff_Dist_plot.errorbar(
        dist_bins[1:], eff_vs_dist[tag]['wavg'][0.0],
        xerr=(
            [opts.distance_step]*len(dist_bins[1:]),
            [0]*len(dist_bins[1:])
            ),
        yerr=(
            eff_vs_dist[tag]['low'][0.0],
            eff_vs_dist[tag]['high'][0.0]
            ),
        label=tag, linestyle='None',
        marker='.', markersize=10)

Eff_Dist_plot.grid(which='major', linewidth=0.75)
Eff_Dist_plot.grid(which='minor', linewidth=0.25)
Eff_Dist_plot.set_xscale(opts.distance_scale)
Eff_Dist_plot.set_xlim(dist_bins[arg_minD], dist_bins[arg_maxD+1])
Eff_Dist_plot.set_ylim(0, 1)

Eff_Dist_plot.set_xlabel( r'Distance to Outer Edge of Volume Shell $(Mpc)$' )
Eff_Dist_plot.set_ylabel( r'Fractional Effciency' )
Eff_Dist_plot.set_title( r'Calculated Detection Efficiency vs Distance at FAR=0' )
Eff_Dist_plot.legend(loc='upper right')

plot_name = opts.output_path + plot_basename  % 'CEFF-FAR0'
pyplot.savefig(plot_name + '.' + opts.format)
pyplot.close(Eff_Dist_fig)


# --------------- Volume Efficiency vs Distance for FAR=0 ------------------ #
if opts.verbose:
    print >> sys.stdout, "\tVolume Efficiency vs Distance"

vol_eff_fig = pyplot.figure()
vol_eff_plot = vol_eff_fig.add_subplot(111)

for tag in db_tags:
    vol_eff_plot.errorbar(
        dist_bins[1:], volume_eff[tag]['wavg'][0.0],
        yerr=(
            volume_eff[tag]['low'][0.0],
            volume_eff[tag]['high'][0.0]
            ),
        label=tag, linestyle='None',
        marker='.', markersize=10)

vol_eff_plot.grid(which='major', linewidth=0.75)
vol_eff_plot.grid(which='minor', linewidth=0.25)
vol_eff_plot.set_xscale(opts.distance_scale)
vol_eff_plot.set_yscale(opts.distance_scale)
vol_eff_plot.set_xlim(dist_bins[arg_minD], dist_bins[arg_maxD+1])
vol_eff_plot.set_ylim(0, 1)

vol_eff_plot.set_xlabel( r'Distance to Outer Edge of Sphere $(Mpc)$' )
vol_eff_plot.set_ylabel( r'Fractional Efficiency' )
vol_eff_plot.set_title( r'Efficiency within Volume vs Distance at FAR=0' )
vol_eff_plot.legend(loc='upper right')

plot_name = opts.output_path + plot_basename % 'VEFF-FAR0'
pyplot.savefig(plot_name + '.' + opts.format)
pyplot.close(vol_eff_fig)

# --------------- Sensitive Volume vs Distance for FAR=0 ------------------ #
if opts.verbose:
    print >> sys.stdout, "\tAccumulated VxT vs Distance"

vol_eff_fig = pyplot.figure()
vol_eff_plot = vol_eff_fig.add_subplot(111)

maxVT = []

for tag in db_tags:
    vol_eff_plot.errorbar(
        dist_bins[1:],
        cumVT_fgd[tag]*volume_eff[tag]['wavg'][0.0],
        yerr=(
            cumVT_fgd[tag]*volume_eff[tag]['low'][0.0],
            cumVT_fgd[tag]*volume_eff[tag]['high'][0.0]
            ),
        label=tag, linestyle='None',
        marker='.', markersize=10)

    maxVT.append( cumVT_fgd[tag][arg_maxD]*(
        volume_eff[tag]['wavg'][0.0][arg_maxD] + volume_eff[tag]['high'][0.0][arg_maxD])
    )

vol_eff_plot.grid(which='major', linewidth=0.75)
vol_eff_plot.grid(which='minor', linewidth=0.25)
vol_eff_plot.set_xscale(opts.distance_scale)
vol_eff_plot.set_yscale(opts.distance_scale)
vol_eff_plot.set_xlim(dist_bins[arg_minD], dist_bins[arg_maxD+1])
vol_eff_plot.set_ylim(0, numpy.max(maxVT)*1.05)

vol_eff_plot.set_xlabel( r'Distance to Outer Edge of Sphere $(Mpc)$' )
vol_eff_plot.set_ylabel( r'Cumulative Sensitive 4-Volume $(Mpc^{3}yr)$' )
vol_eff_plot.set_title( r'$V_{eff}T_{fgd}$ vs Distance at FAR=0' )
vol_eff_plot.legend(loc='upper left')

plot_name = opts.output_path + plot_basename % 'VT-FAR0'
pyplot.savefig(plot_name + '.' + opts.format)
pyplot.close(vol_eff_fig)

# -------------------------------- VxT vs FAR ------------------------------- #
if opts.verbose:
    print >> sys.stdout, "\tTotal 4-Volume (VxT) vs FAR threshold"

veff_far_fig = pyplot.figure()
veff_far_plot = veff_far_fig.add_subplot(111)

minVT = []
maxVT = []

for tag in db_tags:
    # the way to plot FAR=0 point on a log plot
    new_minFAR = 10**numpy.floor(numpy.log10( numpy.sort(volume_eff[tag]['wavg'].keys())[1] ))
    # Replace FAR=0 with the new_minFAR
    for key, eff_dict_far in volume_eff[tag].items():
        volume_eff[tag][key][new_minFAR] = eff_dict_far.pop(0.0)

    # Compute the total effective 4-Volume (VxT)
    VT_wavg = cumVT_fgd[tag][arg_maxD] * numpy.array(volume_eff[tag]['wavg'].values())[:,arg_maxD]

    veff_far_plot.plot(
        volume_eff[tag]['wavg'].keys(),
        VT_wavg,
        label=tag, linestyle='None',
        marker='.', markersize=10)

    minVT.append(numpy.min(VT_wavg))
    maxVT.append(numpy.max(VT_wavg))

veff_far_plot.grid(which='major', linewidth=0.75)
veff_far_plot.grid(which='minor', linewidth=0.25)
veff_far_plot.set_xscale('log')
veff_far_plot.set_yscale(opts.distance_scale)
veff_far_plot.set_xlim(new_minFAR, 2.0*maxFAR)
veff_far_plot.set_ylim(numpy.min(minVT)*0.90, numpy.max(maxVT)*1.10)

veff_far_plot.set_xlabel( r'False Alarm Rate Threshold $(yr^{-1})$' )
veff_far_plot.set_ylabel( r'Total Sensitive 4-Volume $(Mpc^{3}yr)$' )
veff_far_plot.set_title( r'Asymptotic $V_{eff}T_{fgd}$ vs FAR @ Distance = %s$Mpc$' % dist_bins[arg_maxD+1] )
veff_far_plot.legend(loc='upper left')

plot_name = opts.output_path + plot_basename % 'VT-FAR'
pyplot.savefig(plot_name + '.' + opts.format)
pyplot.close(veff_far_fig)



# When all plots are made, exit
if opts.verbose:
    print >> sys.stdout, "Finished!"
sys.exit(0)

