#!/usr/bin/python
#
# Copyright (C) 2012 Matthew West
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 2 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.


#
# =============================================================================
#
#                                 Preamble
#
# =============================================================================
#

from optparse import OptionParser
import sqlite3
import sys
import os

import numpy as np

from pylal import git_version
from pylal import sngl_tmplt_far
from pylal import InspiralUtils
from pylal import ligolw_sqlutils as sqlutils

from glue import iterutils
from glue.ligolw import dbtables
from glue.ligolw.utils import process

__prog__ = "ligolw_sngl_tmplt_far"
__author__ = "Matthew West <matthew.west@ligo.org>"
__version__ = "git id %s" % git_version.id
__date__ = git_version.date

description = """
    %prog generates a plot comparing the all-possible-coincs
    method of background estimation to that of slides.
    """


#
# =============================================================================
#
#                                Command Line
#
# =============================================================================
#


def parse_command_line():
    """
    Parse the command line.
    """
    parser = OptionParser(
        version = "Name: %%prog\n%s" % git_version.verbose_msg,
        usage = "%prog [options]",
        description = description 
        )
    # The available command line options
    parser.add_option("-i", "--input-db", action="store", type="string", default=None,
        help =
            "Input database to read. Can only input one at a time."
        )
    parser.add_option("-o", "--output-db", action="store", type="string", default=None,
        help =
            "Name of output database to save to."
        )
    parser.add_option( "-t", "--tmp-space", action="store", type="string", default=None,
        metavar = "PATH",
        help =
            "Location of local disk on which to do work. This is optional; " +
            "it is only used to enhance performance in a networked " +
            "environment. "
        )
    parser.add_option("--snr-stat", action="store", type="string", default=None,
        metavar = "{rawsnr|snroverchi|effsnr|newsnr}",
        help =
            "Select the desired coincident ChiSq weighted SNR (required). " +
            "\"rawsnr\" is just the matched filter SNR and lacks any weighting " +
            "by a ChiSq statistic. The statistic \"snroverchi\" is used by gstlal. " +
            "\"effsnr\" uses the Effective SNR algorithm to calculate a ChiSq " +
            "weighted detection SNR.  \"newsnr\" uses the New SNR algorithm to " +
            "calculate a ChiSq weighted SNR."
        )
    parser.add_option("--sngls-threshold", type="float", default=5.5,
        metavar = "float",
        help = 
            "The minimum chisq-weighted snr for singles triggers"
        )
    parser.add_option("--sngls-bin-width", type="float", default=0.01,
        metavar = "float", 
        help = 
            "The width of bins in the single-ifo snr histograms"
        )
    parser.add_option("--coinc-bin-width", type="float", default=0.05,
        metavar = "float", 
        help = 
            "The width of bins in the coinc snr histograms"
        )
    parser.add_option("-c", "--output-column", action="store", type="string", default=None,
        help =
            "Required. Column in the output-table to store output to. If the column doesn't exist, " +
            "it will be created. Column name must be all lower-case. "
        )
    parser.add_option( "-T", "--time-units", action = "store", default = 'yr',
        metavar = "s, min, hr, days, OR yr",
        help =
            "Time units to use when calculating FARs (the units of the FARs will be the inverse of this). " +
            "Options are s, min, hr, days, or yr. Default is yr."
        )
    parser.add_option( "-v", "--verbose", action = "store_true", default = False,
        help =
            "Be verbose."
        )

    options, arguments =  parser.parse_args()
    required_options = ["input_db", "output_db", "snr_stat", "output_column"]
    missing_options = [option for option in required_options if not getattr(options, option)]

    if missing_options:
        missing_options = ', '.join([
            "--%s" % option.replace("_", "-")
            for option in missing_options
        ])
        raise ValueError, "missing required option(s) %s" % missing_options

    if options.snr_stat not in ("rawsnr", "snroverchi", "effsnr", "newsnr"):
        raise ValueError, "unrecognized --snr-stat %s" % options.snr-stat

    return options


#
# =============================================================================
#
#                                    Main
#
# =============================================================================
#


#
# Command line
#

opts = parse_command_line()

sqlite3.enable_callback_tracebacks(opts.verbose)

# Setup working databases and connections
working_filename = dbtables.get_connection_filename(
    opts.input_db,
    tmp_path = opts.tmp_space,
    verbose = opts.verbose)
connection = sqlite3.connect( working_filename )
if opts.tmp_space:
    dbtables.set_temp_store_directory(
        connection,
        opts.tmp_space,
        verbose = opts.verbose)
dbtables.DBTable_set_connection( connection )

# create an xmldoc representation of the database
xmldoc = dbtables.get_xml(connection)
# Add entries to process and process_params tables for this program
proc_id = process.register_to_xmldoc(xmldoc, __prog__, opts.__dict__, version = git_version.id)

# create needed indices on tables if they don't already exist
current_indices = [index[0] for index in
    connection.execute('SELECT name FROM sqlite_master WHERE type == "index"')]

sqlscript = ''
if 'si_idx' not in current_indices:
    sqlscript += 'CREATE INDEX si_idx ON sngl_inspiral (event_id, snr, chisq, chisq_dof, mchirp, eta);\n'
if 'ci_idx' not in current_indices:
    sqlscript += 'CREATE INDEX ci_idx ON coinc_inspiral (coinc_event_id);\n'
if 'cem_idx' not in current_indices:
    sqlscript += 'CREATE INDEX cem_idx ON coinc_event_map (coinc_event_id, event_id);\n'
if 'ts_idx' not in current_indices:
    sqlscript += 'CREATE INDEX ts_idx ON time_slide (instrument, offset);\n'
if 'e_idx' not in current_indices:
    sqlscript += 'CREATE INDEX e_idx ON experiment (experiment_id, instruments);\n'
if 'es_idx' not in current_indices:
    sqlscript += 'CREATE INDEX es_idx ON experiment_summary (experiment_id, time_slide_id, datatype);\n'
if 'em_idx' not in current_indices:
    sqlscript += 'CREATE INDEX em_idx ON experiment_map (coinc_event_id, experiment_summ_id);\n'

connection.executescript( sqlscript )

# create column to store output to if it doesn't already exist
_ , output_column = sqlutils.create_column( connection, 'coinc_inspiral', opts.output_column )

# get the template mass parameters
sqlquery = """
SELECT DISTINCT mchirp, eta, mass1, mass2
FROM sngl_inspiral
"""
mchirp, eta, m1, m2 = connection.execute( sqlquery ).fetchall()[0]
m1 = round(m1, 2)
m2 = round(m2, 2)

# get the set of ifos searched
sqlquery = """
SELECT DISTINCT ifos
FROM search_summary
"""
all_ifos = [ifo[0] for ifo in connection.execute( sqlquery )]

if opts.verbose:
    print >> sys.stdout, 'Making single-ifo snr histograms'
sngl_ifo_hist, sngl_ifo_midbins = sngl_tmplt_far.all_sngl_snr_hist(
    connection,
    mchirp, eta,
    all_ifos,
    min_snr = opts.sngls_threshold,
    sngls_width = opts.sngls_bin_width,
    snr_stat = opts.snr_stat
)

# get single-ifo analyzed times 
T_i = sngl_tmplt_far.get_singles_times(connection, verbose=opts.verbose)

# get list of time_slide_ids
sqlquery = """
SELECT time_slide_id, group_concat(offset)
FROM time_slide GROUP BY time_slide_id
"""
time_shifts = connection.execute(sqlquery).fetchall()
tsid_list = [slide[0] for slide in time_shifts]
zerolag_tsid = [slide[0] for slide in time_shifts if not any(map(float, slide[1].split(',')))][0]

for num_ifos in range(2, len(all_ifos)+1):
    for ifos in iterutils.choices(all_ifos, num_ifos):
        # use slides to determine the coincidence window
        tau = sngl_tmplt_far.get_coinc_window(connection, ifos)

        # make combined snr bins
        max_comb_snr = sngl_tmplt_far.quadrature_sum([max(sngl_ifo_midbins[ifo]) for ifo in ifos])
        combined_bins = np.arange(
            len(all_ifos)**(1./2)*opts.sngls_threshold,
            max_comb_snr + 2*opts.coinc_bin_width,
            opts.coinc_bin_width)
 
        # ----------------------- All Possible Coincs ----------------------- #
        if opts.verbose:
            print >> sys.stdout, 'Creating rates distribution from all-possible-coincs'
        apc_hist = sngl_tmplt_far.all_possible_coincs(
            sngl_ifo_hist,
            sngl_ifo_midbins,
            combined_bins,
            ifos
        )

        # ----------------------- Zerolag & Slide Coincs ------------------------ #
        # compute effective background time 
        t_bkgd = sngl_tmplt_far.eff_bkgd_time(T_i, tau, ifos, opts.time_units)

        if opts.verbose:
            print >> sys.stdout, 'Get zerolag coincs from %s' % zerolag_tsid
        zerolag_coincs = sngl_tmplt_far.coinc_snr_hist(
            connection,
            ifos,
            mchirp, eta,
            min_snr = opts.sngls_threshold,
            datatype = "all_data",
            slide_id = zerolag_tsid,
            combined_bins = combined_bins,
            snr_stat = opts.snr_stat
        )
        params = {}  
        for tsid in tsid_list: 
            if tsid != zerolag_tsid:
                type = 'slide'
                if opts.verbose:
                    print >> sys.stdout, 'Get slide coincs from %s' % tsid
                coincs = sngl_tmplt_far.coinc_snr_hist(
                    connection,
                    ifos,
                    mchirp, eta,
                    min_snr = opts.sngls_threshold,
                    datatype = type,
                    slide_id = tsid,
                    combined_bins = combined_bins,
                    snr_stat = opts.snr_stat
                )
                coincs2remove = np.add(coincs['snr_hist'], zerolag_coincs['snr_hist'])
            else:
                type = 'all_data'
                coincs = zerolag_coincs
                coincs2remove = coincs['snr_hist']

            # number of combinations with snr >= rho ["foreground" coincs are removed from combos]
            snr_survival =  np.cumsum((apc_hist - coincs2remove)[::-1])[::-1]
            # the amount of coinc time in this slide
            t_fgd = sngl_tmplt_far.coinc_time(connection, type, tsid, ifos, opts.time_units)

            if opts.verbose:
                print >> sys.stdout, 'Compute desired significance statistic for coincs from %s' % tsid
            coincs['fars'] = sngl_tmplt_far.calc_far(
                coincs['snrs'],
                snr_survival,
                combined_bins,
                output_column,
                t_fgd,
                t_bkgd
            ) 

            for i, ceid in enumerate(coincs['ids']):
                params = {'ceid':  ceid, 'far': coincs['fars'][i]}
                query = ''.join(["""
                UPDATE coinc_inspiral
                SET """, output_column,""" == :far
                WHERE coinc_event_id == :ceid
                """ ])
                connection.execute( query, params)

#
#       Save and Exit
#
connection.commit()
connection.close()

# write output database
dbtables.put_connection_filename(
    opts.output_db,
    working_filename,
    verbose = opts.verbose
)

if opts.verbose:
    print >> sys.stdout, "Finished!"
sys.exit(0)

