#!/usr/bin/python
#
# Copyright (C) 2009  Tomoki Isodai
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

"""
%prog --result_glog=Glob [options]

Tomoki Isogai (isogait@carleton.edu)

This program makes the following plots from result files of KW_veto_calc.

1. Use Percentage vs Threshold
2. GPS time vs SNR of all KW triggers, KW triggers above critical threshold and 
 coincident KW triggers
- 2 histograms and 2 additionals for candidates
1. GW Triggers Number vs SNR for low SNR range
2. GW Number vs SNR for whole SNR range
For candidates
3. Vetoed GW Triggers Number vs SNR for SNR range
4. Vetoed GW Triggers Number vs SNR for whole SNR range
    
The program also makes a thumbnail for each plot for webpage sake.

!CAUTION!: make sure to put ' mark around glob; otherwise bash wouldn't pass the right string.
"""

# =============================================================================
#
#                               PREAMBLE
#
# =============================================================================

from __future__ import division

import sys
import os
import optparse
import glob
import math

try:
    import sqlite3
except ImportError:
    # pre 2.5.x
    from pysqlite2 import dbapi2 as sqlite3

import matplotlib
matplotlib.use('Agg')
from pylab import *

from glue.segments import segment, segmentlist

from pylal import git_version
from pylal import KW_veto_utils

__author__ = "Tomoki Isogai <isogait@carleton.edu>"
__date__ = "2009/7/10"
__version__ = "2.0"

def parse_commandline():
    """
    Parse the options given on the command-line.
    """
    parser = optparse.OptionParser(usage=__doc__,version=__version__)
    parser.add_option("-r", "--result_glob",
                      help="Glob for result files from KW_veto_calc. Required.")
    parser.add_option("-o", "--out_dir", default = "plots",
                      help="Output directory name. (Default: plots)")
    parser.add_option("-l", "--scratch_dir", default=".",
                      help="Scratch directory to be used for database engine. Specify local scratch directory for better performance and less fileserver load. (Default: current directory)")
    parser.add_option("-v", "--verbose", action="store_true", default=False,
                      help="Run verbosely. (Default: False)")
    
    opts, args = parser.parse_args()
    
    # check if necessary input exists
    if opts.result_glob is None:
      parser.error("--result_glob is a required parameter")
    # make an output directory if it doesn't exist yet
    if not os.path.exists(opts.out_dir): os.makedirs(opts.out_dir)
        
    # show parameters
    if opts.verbose:
        print >> sys.stderr, ""
        print >> sys.stderr, "running KW_veto_plots..."
        print >> sys.stderr, git_version.verbose_msg
        print >> sys.stderr, ""
        print >> sys.stderr, "***************** PARAMETERS ********************"
        for o in opts.__dict__.items():
          print >> sys.stderr, o[0]+":"
          print >> sys.stderr, o[1]
        print >> sys.stderr, ""
        
    return opts

def plot_channel():
    """
    This is the main function which makes all the plots.
    """
    ## get data for plotting
    thresh, usedPer, efficiency, deadTimePer = zip(*cursor.execute("select threshold, usedPercentage, efficiency, deadTimePercentage from result order by threshold asc").fetchall())
    # candidate is a tuple:
    # (threshold, used percentage, # of coincident KW triggers above the threshold, # of total KW triggers above the threshold, # of vetoed GW triggers, veto efficiency, dead time, dead time percentage, (used percentage) / (random used percentage), (veto efficiency) / (dead time percentage))
    # that corresponds to threshold to be used for veto.
    # candidate is just None for non candidate channel
    candidate = KW_veto_utils.get_candidate(cursor, params["critical_usedPer"])
    # file name to save plots
    prefix = os.path.join(opts.out_dir, params["filePrefix"])
    if candidate != None: # means candidate
      veto_thresh = candidate[0] # KW Significance threshold to be used for veto
    # start and end of x axis
    start = thresh[0]; 
    end = thresh[-1]
    # when interval is 0 (meaning there is no KW trigger) at least plot null
    # one threshold step
    if start == end: 
      end = start + params['resolution']

    # other parameter for plots
    tickspace=500 
    ticks=[start]+range((int(start/tickspace)+1)*tickspace,
           (int(end/tickspace)+1)*tickspace,tickspace)+[end]

    matplotlib.rcParams.update({
	"font.size": 8.0,
	"axes.titlesize": 10.0,
	"axes.labelsize": 10.0,
	"xtick.labelsize": 8.0,
	"ytick.labelsize": 8.0,
	"legend.fontsize": 8.0})
 
    ############################## plots #######################################

    ## plot use percentage
    # show critical used percentage and threshold for veto line for candidate channel
    if candidate != None:
      axhline(y = params["critical_usedPer"], linewidth = 2, color = 'r', label="Used Percentage Threshold: %d%%"%params['critical_usedPer'])
      text(end + 10, params["critical_usedPer"] - 1, '%d%%'%params["critical_usedPer"], color='k')
      axvline(x = veto_thresh, linewidth = 2, color = 'k',label="Veto Threshold: %d"%veto_thresh)
      legend(loc='lower right')
    # actual plotting
    plot(thresh,usedPer,'k',thresh,usedPer,'bo',markersize=3)
    xlim(start,end)
    xticks(ticks)
    xlabel('KW Significance Threshold')
    ylabel('Used Percentage (%)')
    title(params["channel"])
    savefig("%s-usePercentage_plot.png"%prefix)
    savefig("%s-usePercentage_plot-thumbnail.png"%prefix,dpi=60)
    close('all')

    ## plot veto efficiency / dead time
    # draw diagonal line
    x_max = max(efficiency) * 1.1
    y_max = max(deadTimePer) * 1.1
    max_value = min((x_max,y_max))
    plot([0,max_value],[0,max_value],'k--',linewidth = 2,label="Diagonal Line")
    x = plot(efficiency, deadTimePer, 'bo', markersize=5)
    # show the threshold
    if candidate != None:
      plot([candidate[5]],[candidate[7]],'r*',markersize=10,label="Veto Threshold")
    legend(loc='best')
    xlabel('Veto Efficiency (%)')
    ylabel('Dead Time Percentage (%)')
    xlim(0,x_max)
    ylim(0,y_max)
    title(params["channel"])
    savefig("%s-eff_over_deadTime_plot.png"%prefix)
    savefig("%s-eff_over_deadTime-thumbnail.png"%prefix,dpi=60)
    close('all')
    
    ############################## time series #################################
    
    ## time series for GW triggers and KW triggers
    # figure out range of x axis
    # use common x range for both insp and KW triggers
    x_min = params["start_time"]
    x_max = params["end_time"]   
    # avoid science notation for GPS times
    from matplotlib.ticker import FormatStrFormatter 

    # GW trigs time series
    subplot(211)
    GW_time, GW_SNR = zip(*cursor.execute("select GPSTime, SNR from GWtrigs").fetchall())
    semilogy(GW_time, GW_SNR, 'bo', markersize=3, label='Triggers')
    ylabel('SNR')
    title("GW Triggers")
    del GW_time
    if candidate != None: # means candidate channel
      # coincGW_id is a list of GW trigger IDs coincident to this KW trigger
      coincGW_id = reduce(lambda x, y: x + map(int,y[0].split(',')), cursor.execute("select CoincidenceGWTrigID from KWtrigs where KWSignificance > ? and CoincidenceGWTrigID <> 'No Coincidence'", (veto_thresh,)),[])

      # original function to find coincident GW triggers 
      connection.create_function("coinc", 1, lambda x: x in coincGW_id)
      # find coincident GW triggers
      coincGW_time, coincGW_SNR = zip(*cursor.execute("select GPSTime, SNR from GWtrigs where coinc(ID)").fetchall())

      semilogy(coincGW_time, coincGW_SNR, 'ro', markersize=3,
                 label='Vetoed Triggers')
      del coincGW_time
    legend(loc='lower right')
    xlim([x_min, x_max])
    gca().xaxis.set_major_formatter(FormatStrFormatter("%d"))

    # KW trigs time series
    subplot(212)
    # all the KW triggers
    KWtrigs = zip(*cursor.execute("select GPSTime, KWSignificance from KWtrigs").fetchall())
    if KWtrigs == []: # when there is no KW trigger
      KWtrigs = [[],[]]
    semilogy(KWtrigs[0],KWtrigs[1],'bx',markersize=3,label = 'All KW Triggers')
    xlabel('GPS time')
    ylabel('KW significance')
    title(params["channel"] + " KW Triggers")
    if candidate != None: # means candidate channel
      # KW triggers above threshold to be used for veto
      KWtrigs = zip(*cursor.execute("select GPSTime, KWSignificance from KWtrigs where KWSignificance > ?",(veto_thresh,)).fetchall())
      # coincident KW triggers above the threshold
      coincKWtrigs = zip(*cursor.execute("select GPSTime, KWSignificance from KWtrigs where KWSignificance > ? and CoincidenceGWTrigID <> 'No Coincidence'",(veto_thresh,)).fetchall())        
      semilogy(KWtrigs[0], KWtrigs[1], 'bo',
                 markersize = 3, label = 'KW Triggers Used for Veto')
      semilogy(coincKWtrigs[0], coincKWtrigs[1], 'ro',
                 markersize=3,label='Coincident KW Triggers')
      axhline(y = veto_thresh, linewidth = 1, color = 'r',label="Veto Threshold: %d"%veto_thresh)
      del coincKWtrigs
    del KWtrigs
    legend(loc='lower right')
    xlim([x_min,x_max])
    
    # do not use science notification for GPS time
    gca().xaxis.set_major_formatter(FormatStrFormatter("%d"))

    savefig("%s-timeSeries.png"%prefix,dpi=250)
    #savefig("%s-timeSeries.eps"%prefix) # FIXME: doesn't work on condor...
    savefig("%s-timeSeries-thumbnail.png"%prefix,dpi=60)

    close('all')

    ############################## histgrams ###################################

    ## histogram: GW Triggers Number vs SNR
    # whole SNR range
    # set bottom to 0.1 because it's log scale and cannot plot 0
    n, all_bins, patches = hist([math.log(t,10) for t in GW_SNR] ,bins=100, 
                             bottom=0.1,facecolor='b', log=True)
    xlabel('SNR')
    ylabel('Number of GW Triggers')
    title('GW Triggers, All SNR Range')
    if candidate != None: # means candidate channel
      # FIXME: find more efficient way
      # figure out non vetoed triggers' SNR
      non_coincGW_id = set(range(1,params['totalGWtrigsNum']+1)) - set(coincGW_id)
      # original function to pick non-coincident GW triggers.
      connection.create_function("non_coinc", 1, lambda x: x in non_coincGW_id)
        
      # reduce to flatten
      # this is the same as sum(,()) but builtin sum is overwritten by numpy...
      non_coincGW = reduce(lambda x,y: x+y, cursor.execute("select SNR from GWtrigs where non_coinc(ID)").fetchall(), ())
      # set bottom to 0.1 because it's log scale and cannot plot 0
      n2,bins2,patches2=hist([math.log(t,10) for t in non_coincGW], bins=all_bins, bottom=0.1,\
                             facecolor='r', log=True)
      legend((patches[0],patches2[0]),("Before Veto","After Veto"),\
              loc='upper right')

    # format x axis tick label
    from matplotlib.ticker import FormatStrFormatter, MultipleLocator, FixedLocator
    hmin, hmax = gca().xaxis.get_view_interval()
    hmin = int(math.floor(hmin))
    if hmin == 0:
      hmin == 1
    hmax = int(math.ceil(hmax))
    tickLoc = []
    for b in range(hmin,hmax+1):
      tickLoc += [math.log(10**b * i,10) for i in range(1,10)]
    gca().xaxis.set_major_locator(MultipleLocator(1))
    gca().xaxis.set_major_formatter(FormatStrFormatter(r'$\mathdefault{10^{%d}}$'))
    gca().xaxis.set_minor_locator(FixedLocator(tickLoc))
        
    savefig("%s-triggers_hist.png"%prefix,dpi=200)
    savefig("%s-triggers_hist-thumbnail.png"%prefix,dpi=60)
    close('all')
    
    ## histogram for low SNR range
    # plot the lowest 50% trigger
    num = int(len(GW_SNR) / 2) or 1 # so that no 0
    low_GW = sorted(GW_SNR)[:num]
    # 1 bin for 1 SNR
    max_bin = int(max(low_GW))+1
    min_bin = int(min(low_GW))
    # set bottom to 0.1 because it's log scale and cannot plot 0
    n, low_bins, patches = hist(low_GW, bins = 50, bottom=0.1, facecolor='b', log=True)
    xlabel('SNR')
    ylabel('Number of GW Triggers')
    title("GW Triggers, Low SNR Range")
    
    if candidate != None: # means candidate channel
        # set bottom to 0.1 because it's log scale and cannot plot 0
        bound = max(low_GW)
        low_non_coincGW = [s for s in non_coincGW if s < bound]
        n2, bins2, patches2 = hist(low_non_coincGW, bins = low_bins,
                                   bottom = 0.1, facecolor = 'r')
        if len(patches)>0 and len(patches2)>0:
          legend((patches[0], patches2[0]), ("Before Veto","After Veto"),
               loc='upper right')
    savefig("%s-triggers_hist_lowSnr.png"%prefix,dpi=200)
    savefig("%s-triggers_hist_lowSnr-thumbnail.png"%prefix,dpi=60)
    close('all')
    
    ## histogram: Vetoed Inspiral Triggers Number vs SNR
    # whole SNR range
    if candidate != None: # means candidate channel
        # set bottom to 0.1 because it's log scale and cannot plot 0
        # 100 bins
        hist([math.log(t,10) for t in coincGW_SNR], bins=all_bins, bottom=0.1,\
             facecolor='b', log=True)
        hmin, hmax = gca().xaxis.get_view_interval()
        hmin = int(math.floor(hmin))
        if hmin == 0:
          hmin == 1
        hmax = int(math.ceil(hmax))
        tickLoc = []
        for b in range(hmin,hmax+1):
          tickLoc += [math.log(10**b * i,10) for i in range(1,10)]
        gca().xaxis.set_major_locator(MultipleLocator(1))
        gca().xaxis.set_major_formatter(FormatStrFormatter(r'$\mathdefault{10^{%d}}$'))
        gca().xaxis.set_minor_locator(FixedLocator(tickLoc))

        xlabel('SNR')
        ylabel('Number of Vetoed GW Triggers')
        title("Vetoed GW Triggers by " + params['channel'] + ", All SNR Range")
        savefig("%s-vetoed_triggers_hist.png"%prefix,dpi=200)
        savefig("%s-vetoed_triggers_hist-thumbnail.png"%prefix,dpi=60)
        close('all')
    
    ## histogram: Vetoed Inspiral Triggers Number vs SNR
    # low SNR range
    if candidate != None: # means candidate channel
        low_coincGW = [s for s in coincGW_SNR if s < bound]
        # set bottom to 0.1 because it's log scale and cannot plot 0
        # 1 bin for 1 SNR
        hist(low_coincGW, bins = low_bins, bottom=0.1, facecolor='b', log = True)
        xlabel('SNR')
        ylabel('Number of Vetoed GW Triggers')
        title("Vetoed GW Triggers by " + params['channel'] + ", Low SNR Range")
        savefig("%s-vetoed_triggers_hist_lowSnr.png"%prefix,dpi=200)
        savefig("%s-vetoed_triggers_hist_lowSnr-thumbnail.png"%prefix,dpi=60)
        close('all')

    ## distribution
    if candidate != None: # means veto candidate channel
        # original function to pick coincident GW triggers.
        connection.create_function("coinc", 2, lambda x, y: x in map(int,y.split(",")))

        coincKWtrigs = cursor.execute("select frequency, KWSignificance, CoincidenceGWTrigID from KWtrigs where KWSignificance > ? and CoincidenceGWTrigID <> 'No Coincidence'",(veto_thresh,)).fetchall() 
        distData = []
        for kw in coincKWtrigs:
          coincGW = cursor.execute("select SNR from GWtrigs where coinc(ID,?)", (kw[2],)).fetchall()
          freq = kw[0]
          sigRatio = max(coincGW)[0]/kw[1]
          distData.append((freq,sigRatio))
        distData = zip(*distData)
        loglog(distData[0],distData[1],'ro',markersize=5)
        xlabel('Frequency (Hz)')
        ylabel('GW / AUX Significance')
        grid()
        title(params['channel'] + ": Coinc Triggers Distribution")
        savefig("%s-coinc_triggers_dist.png"%prefix,dpi=200)
        savefig("%s-coinc_triggers_dist-thumbnail.png"%prefix,dpi=60)
        close('all')

# =============================================================================
#
#                                    MAIN
#
# =============================================================================

# parse command line    
opts = parse_commandline()

try:    
  # make a list of the result files from KW_veto_calc and check sanity
  files_list = glob.glob(opts.result_glob)
  if opts.verbose:
    print >> sys.stderr, "result files:", files_list
  if files_list==[]: # check if there is at least one file
    print >> sys.stderr, "Error: no result files found for %s"%opts.result_glob
    sys.exit(1)

  # check if they have the same name tag
  if os.path.commonprefix(files_list)=="":
    print >> sys.stderr, """
    Error: Possibly files from different runs are coexisting. 
           Please check the result_glob.
    """
    sys.exit(1)

  # for each veto_data files make plots
  for f in files_list:
    ## retrieve data from file for plots
    global working_filename
    cursor, connection, working_filename, params = \
         KW_veto_utils.load_db(f, opts.scratch_dir, opts.verbose)
        
    ## plot
    if opts.verbose: 
      print >> sys.stderr, "plotting %s..."%params['channel']
      plot_channel()        
    if opts.verbose: 
      print >> sys.stderr, "plots for %s done."%params['channel']

finally:
  # erase temporal database
  if globals().has_key('working_filename'):
    db = globals()['working_filename']
    if opts.verbose:
       print >> sys.stderr, "removing temporary workspace '%s'..." % db
    os.remove(db)       

if opts.verbose: print >> sys.stderr, "KW_veto_plots done!"
