#!/usr/bin/python


# ============================================================================
#
#                               Preamble
#
# ============================================================================

from optparse import OptionParser
try:
    import sqlite3
except ImportError:
    # pre 2.5.x
    from pysqlite2 import dbapi2 as sqlite3
import sys, os
import numpy

from glue.ligolw import lsctables
from glue.ligolw import dbtables
from glue import git_version

from pylal import ligolw_sqlutils as sqlutils
from pylal import InspiralUtils
from pylal import CoincInspiralUtils

__prog__ = "ligolw_cbc_plotslides"
__author__ = "Collin Capano <cdcapano@physics.syr.edu>"

description = \
"Plots number of events and durations for each slide in each experiment " + \
"time in a database."

# ============================================================================
#
#                               Set Options
#
# ============================================================================

def parse_command_line():
    """
    Parser function dedicated
    """
    parser = OptionParser(
        version = git_version.verbose_msg,
        usage   = "%prog [options]",
        description = description
        )

    # following are related to file input and output naming
    parser.add_option( "-i", "--input", action = "store", type = "string", default = None,
        help = 
            "Input database to read. Can only input one at a time."
        )
    parser.add_option( "-t", "--tmp-space", action = "store", type = "string", default = None,
        metavar = "PATH",
        help = 
            "Location of local disk on which to do work. This is optional; " +
            "it is only used to enhance performance in a networked " +
            "environment. "
        )
    parser.add_option( "-P", "--output-path", action = "store", type = "string", \
        default = os.getcwd(), metavar = "PATH", \
        help = 
            "Optional. Path where the figures should be stored. Default is current directory." 
        )
    parser.add_option( "-O", "--enable-output", action = "store_true", \
        default =  False, metavar = "OUTPUT", \
        help = 
            "enable the generation of html and cache documents" 
        )
    parser.add_option( "", "--coinc-table", action = "store", type = "string",\
        default = None, \
        help =
            "Required. Can be any table with a coinc_event_id and a time column. Ex. coinc_inspiral."
        )
    parser.add_option( "-p", "--plot-playground-only", action = "store_true",
        default = False,
        help =
            "Plot only playground data. Use if plotting un-opened box."
        )
    parser.add_option( "-u", "--user-tag", action = "store", type = "string",
        default = None, metavar = "USERTAG",
        help =
            "Add a tag in the name of the figures"
        )
    parser.add_option( "-s", "--show-plot", action = "store_true", default = False, \
        help = 
            "display the plots on the terminal" 
        )
    parser.add_option( "-v", "--verbose", action = "store_true", default = False, \
        help = 
            "print information to stdout" 
        )

    (options,args) = parser.parse_args()

    #check if required options specified and for self-consistency
    if not options.input:
        raise ValueError, "--input must be specified"

    return options, sys.argv[1:]
            

# =============================================================================
#
#                       Function Definitions
#
# =============================================================================

def get_statistics( numlist ):
    """
    Returns mean and standard deviation of a list.
    """
    numlist = numpy.array(numlist)
    return numpy.mean(numlist), numpy.std(numlist)


# ============================================================================
#
#                                 Main
#
# ============================================================================

#
#   Generic Initialization
#

# parse command line
opts, args = parse_command_line()

# get input database filename
filename = opts.input
if not os.path.isfile( filename ):
    raise ValueError, "The input file, %s, cannot be found." % filename

# Setup working databases and connections
if opts.verbose: 
    print >> sys.stdout, "Creating a database connection..."
working_filename = dbtables.get_connection_filename( 
    filename, tmp_path = opts.tmp_space, verbose = opts.verbose )
connection = sqlite3.connect( working_filename )
if opts.tmp_space:
    dbtables.set_temp_store_directory(connection, opts.tmp_space, verbose = opts.verbose)

coinc_table = sqlutils.validate_option( opts.coinc_table )

#
#   Plotting Initialization
#

# Change to Agg back-end if show() will not be called 
# thus avoiding display problem
if not opts.show_plot:
  import matplotlib
  matplotlib.use('Agg')
from pylab import *
from numpy import histogram
rc('text', usetex=True)

#
#   Program-specific Initialization
#

# Get all the experiment_ids
sqlquery = """
    SELECT
        experiment.experiment_id,
        experiment.gps_start_time,
        experiment.gps_end_time,
        experiment.instruments
    FROM
        experiment
    """
experiment_ids = {}
for this_eid, gps_start_time, gps_end_time, on_instruments in connection.cursor().execute(sqlquery):
    on_instruments = lsctables.instrument_set_from_ifos(on_instruments)
    on_instr = r','.join(sorted(on_instruments)) + r' Time:'
    # get all the summary information for this experiment
    if opts.verbose:
        print >> sys.stdout, "Creating plots for %s time..." % ','.join(sorted(on_instruments))
        print >> sys.stdout, "\tcollecting durations, total nevents..."
    slide_numbers = {}
    prev_slide_id = None
    slide_num = 0
    datatypes = {}
    durations = {}
    tot_events = {}
    sqlquery = ''.join(["""
        SELECT
            experiment_summary.experiment_summ_id,
            experiment_summary.time_slide_id,
            experiment_summary.duration,
            experiment_summary.nevents,
            experiment_summary.datatype
        FROM
            experiment_summary
        JOIN
            time_slide ON (
                time_slide.time_slide_id == experiment_summary.time_slide_id )
        WHERE
            experiment_summary.datatype != "simulation"
            AND experiment_summary.experiment_id == """, '"', this_eid, '"', """
        GROUP BY 
            time_slide.time_slide_id, experiment_summary.datatype
        ORDER BY 
            SUM(time_slide.offset) ASC"""])
    for esid, slide_id, duration, tot_nevents, datatype in connection.cursor().execute(sqlquery):
        if prev_slide_id != slide_id:
            slide_num += 1
            prev_slide_id = slide_id
        slide_numbers[ esid ] = slide_num
        datatypes[ esid ] = datatype
        durations[ esid ] = duration
        if datatype == "all_data":
            zero_lag_slide_num = slide_numbers[ esid ]
        # if plotting playground_only, don't store tot_events
        if opts.plot_playground_only and (datatype == "all_data" or datatype == "exclude_play"):
            continue
        tot_events[ esid ] = tot_nevents
    
    # if the number of on_instruments, get information for each individual coinc_type
    if len(on_instruments) > 2:

        # get number of events for each coinc type
        if opts.verbose:
            print >> sys.stdout, "\tcollecting number of events per coincident ifo..."
        # if plotting playground only, turn on playground filter
        if opts.plot_playground_only:
            playground_filter = """
                AND ( experiment_summary.datatype == "playground"
                    OR experiment_summary.datatype == "slide" )
                """
        else:
            playground_filter = ''
        # determine all possible coinc_ifos from the on_instruments
        coinc_ifos = dict( [frozenset(ifos), {}] for ifos in CoincInspiralUtils.get_ifo_combos(list(on_instruments)) ) 
        sqlquery = ''.join(["""
            SELECT
                experiment_summary.experiment_summ_id,
                experiment_summary.datatype,
                """, coinc_table, """.ifos,
                COUNT(experiment_map.coinc_event_id)
            FROM
                experiment_summary
            JOIN
                """, coinc_table, """, experiment_map ON (
                    experiment_summary.experiment_summ_id == experiment_map.experiment_summ_id
                    AND experiment_map.coinc_event_id == """, coinc_table, """.coinc_event_id )
            WHERE
                experiment_summary.datatype != "simulation"
                AND experiment_summary.experiment_id == """, '"', this_eid, '"',
                playground_filter, """ 
            GROUP BY
                experiment_summary.experiment_summ_id, experiment_summary.datatype, """, coinc_table, """.ifos"""])
        for esid, datatype, ifo_set, ifos_nevents in connection.cursor().execute(sqlquery):
            ifo_set = lsctables.instrument_set_from_ifos(ifo_set)
            coinc_ifos[frozenset(ifo_set)][ esid ] = ifos_nevents

    else:
        coinc_ifos = {}

    # remap slide_numbers so that zero_lag esids have slide_num 0
    # also plug in any missing data in coinc_ifos; this can happen
    # if there are no events of particular ifo_type in a slide 
    for esid in slide_numbers:
        if 'zero_lag_slide_num' in locals():
            slide_numbers[esid] -= zero_lag_slide_num
        else:
            slide_numbers[esid] += 1
        for ifo_set in coinc_ifos:
            if esid not in coinc_ifos[ifo_set]:
                coinc_ifos[ifo_set][esid] = 0

    min_slide_num = min(slide_numbers.values())
    max_slide_num = max(slide_numbers.values())

    #
    #   Plotting
    #
    if opts.verbose:
        print >> sys.stdout, "\tcreating plot figures..."

    # set InspiralUtils options for file and plot naming
    opts.gps_start_time = gps_start_time
    opts.gps_end_time = gps_end_time
    opts.ifo_times = ''.join(sorted(on_instruments))
    opts.ifo_tag = ''
    InspiralUtilsOpts = InspiralUtils.initialise( opts, __prog__, git_version.verbose_msg )

    fnameList = []
    tagList = []
    figure_numbers = []
    fig_num = 0

    #
    # plot durations
    #
    fig_num += 1
    figure_numbers.append(fig_num)
    figure(fig_num)
    val_list = []
    for esid, slide_num in slide_numbers.items():
        if datatypes[ esid ] == 'slide':
            bar(left = slide_num, height = durations[ esid ],
                width = 1, color = 'gray',
                align = 'center')
            val_list.append( durations[ esid ] )
        if datatypes[ esid ] == 'all_data':
            bar(left = slide_num, height = durations[ esid ],
                width = 1, color = 'b',
                align = 'center')
            val_list.append( durations[ esid ] )
    # plot mean, +/- 1-sigma
    mean_val, std_val = get_statistics(val_list)
    axhline(mean_val, color='k', linewidth=2)
    axhline(mean_val + std_val, color='k', linestyle='--', linewidth=2)
    axhline(max(mean_val - std_val, 0), color='k', linestyle='--', linewidth=2)
    # set labels
    title_text = ' '.join([on_instr, r"All-Data Durations per Time-Slide"])
    title(title_text, size='x-large')
    xlabel( r"Slide number", size='x-large')
    ylabel( r"Duration (s)", size='x-large' )
    # set limits
    xlim(min_slide_num - 1, max_slide_num + 1)
    if val_list != []:
        ymax = max(val_list) + 2.*std_val
        ymin = max(0, min(val_list) - 2.*std_val)
    else:
        ymax = 1
        ymin = 0
    ylim(ymin, ymax)
    # Make the plot figure
    if opts.enable_output:
        name = InspiralUtils.set_figure_tag( "durations_per_slide",
                datatype_plotted = "ALL_DATA", open_box = False )
        fname = InspiralUtils.set_figure_name(InspiralUtilsOpts, name) 
        InspiralUtils.savefig_pylal( filename = fname )
        fnameList.append(fname)
        tagList.append(name)

    #
    # plot total number of events
    #
    if not opts.plot_playground_only:
        fig_num += 1
        figure_numbers.append(fig_num)
        figure(fig_num)
        val_list = []
        for esid, slide_num in slide_numbers.items():
            if datatypes[ esid ] == 'slide':
                bar(left = slide_num, height = tot_events[ esid ],
                    width = 1, color = 'gray',
                    align = 'center')
                val_list.append( tot_events[ esid ] )
            if datatypes[ esid ] == 'all_data':
                bar(left = slide_num, height = tot_events[ esid ],
                    width = 1, color = 'b',
                    align = 'center')
                val_list.append( tot_events[ esid ] )
        # plot mean, +/- 1-sigma
        mean_val, std_val = get_statistics(val_list)
        axhline(mean_val, color='k', linewidth=2)
        axhline(mean_val + std_val, color='k', linestyle='--', linewidth=2)
        axhline(max(mean_val - std_val, 0), color='k', linestyle='--', linewidth=2)
        # set labels
        title_text = ' '.join([on_instr, r"Total Number of Coinc. Events per Time-Slide"])
        title(title_text, size='x-large')
        xlabel( r"Slide number", size='x-large')
        ylabel( r"Number of Events", size='x-large' )
        # set limits
        xlim(min_slide_num - 1, max_slide_num + 1)
        if val_list != []:
            ymax = max(val_list) + 2.*std_val
            ymin = max(0, min(val_list) - 2.*std_val)
        else:
            ymax = 1
            ymin = 0
        ylim(ymin, ymax)
        # Make the plot figure
        if opts.enable_output:
            name = InspiralUtils.set_figure_tag( "total_nevents",
                datatype_plotted = "ALL_DATA", open_box = True )
            fname = InspiralUtils.set_figure_name(InspiralUtilsOpts, name)
            InspiralUtils.savefig_pylal( filename = fname )
            fnameList.append(fname)
            tagList.append(name)
        
    #
    # plot total rates
    #
    if opts.plot_playground_only:
        rate_fig_nums = dict({ 'playground': figure_numbers[-1] + 1 })
    else:
        rate_fig_nums = dict({ 'all_data': figure_numbers[-1] + 1, 'playground': figure_numbers[-1] + 2, 
            'exclude_play': figure_numbers[-1] + 3 })
    figure_numbers += rate_fig_nums.values()
    figure_numbers.sort()
    ymax = 0.
    val_dict = dict([ [datatype, list()] for datatype in rate_fig_nums ])
    for esid, slide_num in slide_numbers.items():
        # if plotting playground only, skip if not playground
        if opts.plot_playground_only and (datatypes[ esid ] == 'all_data' or datatypes[ esid ] == 'exclude_play'):
            continue
        if durations[ esid ]:
            tot_rate = float(tot_events[ esid ])/float(durations[ esid ])
        else:
            tot_rate = 0.
        if datatypes[ esid ] == 'slide':
            for datatype, fig_num in rate_fig_nums.items():
                figure(fig_num)
                bar(left = slide_num, height = tot_rate,
                    width = 1, color = 'gray',
                    align = 'center')
                val_dict[ datatype ].append( tot_rate )
        else:
            figure(rate_fig_nums[ datatypes[ esid ] ])
            bar(left = slide_num, height = tot_rate,
                width = 1, color = 'b',
                align = 'center')
            val_dict[ datatypes[esid] ].append( tot_rate )
    for datatype, fig_num in rate_fig_nums.items():
        figure(fig_num)
        # plot mean, +/- 1-sigma
        mean_val, std_val = get_statistics(val_dict[datatype])
        axhline(mean_val, color='k', linewidth=2)
        axhline(mean_val + std_val, color='k', linestyle='--', linewidth=2)
        axhline(max(mean_val - std_val, 0), color='k', linestyle='--', linewidth=2)
        # set labels
        title_text = ' '.join([on_instr, '-'.join(datatype.upper().split("_")), r"Total Trigger Rates per Time-Slide"])
        title(title_text, size='x-large')
        xlabel( r"Slide number", size='x-large')
        ylabel( "Rate (Hz)", size='x-large' )
        # set limits
        xlim(min_slide_num - 1, max_slide_num + 1)
        if val_dict[ datatype ] != []:
            ymax = max(val_dict[datatype]) + 2.*std_val
            ymin = max(0, min(val_dict[datatype]) - 2.*std_val)
        else:
            ymax = 1
            ymin = 0
        ylim(ymin, ymax)
        # Make the plot figure
        if opts.enable_output:
            if datatype == 'all_data' or datatype == 'exclude_play':
                box_flag = True
            else:
                box_flag = False
            name = InspiralUtils.set_figure_tag( "total_rates",
                datatype_plotted = datatype.upper(), open_box = box_flag )
            fname = InspiralUtils.set_figure_name(InspiralUtilsOpts, name)
            InspiralUtils.savefig_pylal( filename = fname )
            fnameList.append(fname)
            tagList.append(name)

    #
    #   Plot nevents and rates for individual coinc types
    #
    for ifo_set in coinc_ifos:
        fig_num = figure_numbers[-1] + 1
        figure_numbers.append(fig_num)
        figure(fig_num)
        val_list = []
        #
        # plot nevents
        #
        if not opts.plot_playground_only:
            for esid, slide_num in slide_numbers.items():
                if datatypes[ esid ] == 'slide':
                    bar(left = slide_num, height = coinc_ifos[ ifo_set ][ esid ],
                        width = 1, color = 'gray',
                        align = 'center')
                    val_list.append( coinc_ifos[ ifo_set ][ esid ] )
                if datatypes[ esid ] == 'all_data':
                    bar(left = slide_num, height = coinc_ifos[ ifo_set ][ esid ],
                        width = 1, color = InspiralUtils.get_coinc_ifo_colors(ifo_set),
                        align = 'center')
                    val_list.append( coinc_ifos[ ifo_set ][ esid ] )
            # plot mean, +/- 1-sigma
            mean_val, std_val = get_statistics(val_list)
            axhline(mean_val, color='k', linewidth=2)
            axhline(mean_val + std_val, color='k', linestyle='--', linewidth=2)
            axhline(max(mean_val - std_val, 0), color='k', linestyle='--', linewidth=2)
            # set labels
            title_text = ' '.join([ on_instr, r"Number of", r','.join(sorted(ifo_set)), "Events per Time-Slide" ])
            title(title_text, size='x-large')
            xlabel( r"Slide number", size='x-large')
            ylabel( r"Number of Events", size='x-large' )
            # set limits
            xlim(min_slide_num - 1, max_slide_num + 1)
            if val_list != []:
                ymax = max(val_list) + 2.*std_val
                ymin = max(0, min(val_list) - 2.*std_val)
            else:
                ymax = 1
                ymin = 0
            ylim(ymin, ymax)
            # Make the plot figure
            if opts.enable_output:
                name = InspiralUtils.set_figure_tag( ''.join(sorted(ifo_set)) + "_nevents",
                    datatype_plotted = "ALL_DATA", open_box = True )
                fname = InspiralUtils.set_figure_name(InspiralUtilsOpts, name)
                InspiralUtils.savefig_pylal( filename = fname )
                fnameList.append(fname)
                tagList.append(name)
        #
        # plot rates
        #
        if opts.plot_playground_only:
            rate_fig_nums = dict({ 'playground': figure_numbers[-1] + 1 })
        else:
            rate_fig_nums = dict({ 'all_data': figure_numbers[-1] + 1, 'playground': figure_numbers[-1] + 2, 
                'exclude_play': figure_numbers[-1] + 3 })
        figure_numbers += rate_fig_nums.values()
        figure_numbers.sort()
        ymax = 0.
        val_dict = dict([ [datatype, list()] for datatype in rate_fig_nums ])
        for esid, slide_num in slide_numbers.items():
            # if plotting playground only, skip if not playground
            if opts.plot_playground_only and (datatypes[ esid ] == 'all_data' or datatypes[ esid ] == 'exclude_play'):
                continue
            if durations[ esid ]:
                ifo_rate = float(coinc_ifos[ ifo_set ][ esid ])/float(durations[ esid ])
            else:
                ifo_rate = 0.
            if datatypes[ esid ] == 'slide':
                for datatype, fig_num in rate_fig_nums.items():
                    figure(fig_num)
                    bar(left = slide_num, height = ifo_rate,
                        width = 1, color = 'gray',
                        align = 'center')
                    val_dict[ datatype ].append( ifo_rate )
            else:
                figure(rate_fig_nums[ datatypes[ esid ] ])
                bar(left = slide_num, height = ifo_rate,
                    width = 1, color = InspiralUtils.get_coinc_ifo_colors(ifo_set),
                    align = 'center')
                val_dict[ datatypes[esid] ].append( ifo_rate )
        for datatype, fig_num in rate_fig_nums.items():
            figure(fig_num)
            # plot mean, +/- 1-sigma
            mean_val, std_val = get_statistics(val_dict[datatype])
            axhline(mean_val, color='k', linewidth=2)
            axhline(mean_val + std_val, color='k', linestyle='--', linewidth=2)
            axhline(max(mean_val - std_val, 0), color='k', linestyle='--', linewidth=2)
            # set labels
            title_text = ' '.join([on_instr, '-'.join(datatype.upper().split("_")), ','.join(sorted(ifo_set)), r"Trigger Rates per Time-Slide"])
            title(title_text, size='x-large')
            xlabel( r"Slide number", size='x-large')
            ylabel( "Rate (Hz)", size='x-large' )
            # set limits
            xlim(min_slide_num - 1, max_slide_num + 1)
            if val_dict[ datatype ] != []:
                ymax = max(val_dict[datatype]) + 2.*std_val
                ymin = max(0, min(val_dict[datatype]) - 2.*std_val)
            else:
                ymax = 1
                ymin = 0
            ylim(ymin, ymax)
            # Make the plot figure
            if opts.enable_output:
                if datatype == 'all_data' or datatype == 'exclude_play':
                    box_flag = True
                else:
                    box_flag = False
                name = InspiralUtils.set_figure_tag( ''.join(sorted(ifo_set)) + "_rates",
                    datatype_plotted = datatype.upper(), open_box = box_flag )
                fname = InspiralUtils.set_figure_name(InspiralUtilsOpts, name)
                InspiralUtils.savefig_pylal( filename = fname )
                fnameList.append(fname)
                tagList.append(name)


    #
    #   Write the cache file
    #
    if opts.enable_output:
        if opts.verbose:
            print >> sys.stdout, "\twriting html file and cache..."
        # write closed-box plots page
        if opts.verbose:
            print >> sys.stdout, "\t\tfor closed box plots..."
        closed_list = [fname for fname in fnameList if 'OPEN_BOX' not in fname]
        closed_html = InspiralUtils.write_html_output( InspiralUtilsOpts, args, closed_list,
            tagList, add_box_flag = True )
        InspiralUtils.write_cache_output( InspiralUtilsOpts, closed_html, closed_list )
        if not opts.plot_playground_only:
            if opts.verbose:
                print >> sys.stdout, "\t\tfor open box plots..."
            open_html = InspiralUtils.write_html_output( InspiralUtilsOpts, args, fnameList,
                tagList, add_box_flag = True )
            InspiralUtils.write_cache_output( InspiralUtilsOpts, open_html, fnameList )

    #
    #   Show plots if desired
    #
    if opts.show_plot:
        show()

    # close the figure number for next experiment
    for fig_num in figure_numbers:
        close(fig_num)

#
#   Finshed cycling over experiments; exit
#
connection.close()
dbtables.discard_connection_filename( filename, working_filename, verbose = opts.verbose)
if opts.verbose:
    print >> sys.stdout, "Finished!"
sys.exit(0)
