#!/usr/bin/python

#
# =============================================================================
#
#                                   Preamble
#
# =============================================================================
#

description = \
"Compiles rate statistics about an ihope run and creates a rate_statistics table " + \
"and a file_statistics table. The rate_statistics gives the average trigger rate for the " + \
"specified file type per ifo. The file_statistics table lists statistics about each file, " + \
"including where the file ranks for different statistics."

__prog__ = "lalapps_cbc_compute_rs"
__author__ = "Collin Capano <cdcapano@physics.syr.edu>"

import os, sys, time, datetime, re
import numpy
from optparse import OptionParser

from glue import lal
from glue import git_version
from glue.ligolw import ligolw
from glue.ligolw import table
from glue.ligolw import utils
from glue.ligolw import lsctables
from glue.ligolw import ilwd
from glue.ligolw.utils import process

from pylal.xlal.date import XLALGPSToUTC
try:
    from pylal.xlal.datatypes.ligotimegps import LIGOTimeGPS
except ImportError:
    # s6 code
    from pylal.xlal.date import LIGOTimeGPS


# =============================================================================
#
#                                   Set Options
#
# =============================================================================


def parse_command_line():
    """
    Parse the command line, return options and check for consistency among the
    options.
    """
    parser = OptionParser(
        version = git_version.verbose_msg,
        usage   = "%prog [--file-type] [options] ihope.cache1 ihope.cache2 ...",
        description = description )

    parser.add_option( "-t", "--file-type", action = "append", type = "string",
        default = [],
        help = 
            "The type of files to compute rates for; e.g., 'INSPIRAL_FIRST_FULL_DATA'. " +
            "To specify multiple file types, give the argument several times; at least one file-type " +
            "is required. The order given will be the order saved to the output xml. " +
            "Tmpltbank statistics will be re-used when possible. "
            "Note: This is an exact-match sieve. For example, THINCA_SECOND*FULL_DATA will only match " +
            "THICNA_SECOND*FULL_DATA files; THINCA_SECOND*FULL_DATA* will match THINCA_SECOND*FULL_DATA, " +
            "THINCA_SECOND*FULL_DATA_CAT_2, etc."
            )
    parser.add_option("-o", "--output", action = "store", type = "string",
        default = None,
        help =
            "File to save results to. If not specified, results will be printed to stdout."
            )
    parser.add_option("-d", "--daily-ihope-pages-location", action = "store",
        default = "https://ldas-jobs.ligo.caltech.edu/~cbc/ihope_daily",
        help =
            "Web address of the daily ihope pages. " +
            "Default is https://ldas-jobs.ligo.caltech.edu/~cbc/ihope_daily"
            )
    parser.add_option("-F", "--force", action = "store_true", default = False,
        help =
            "Compute statistics for the run even if some of the files specified in the cache " +
            "file(s) are missing. (Default action is to raise an error.)"
            )
    parser.add_option("-v", "--verbose", action = "store_true", default = False,
        help =
            "Be verbose."
            )

    (options, filenames) = parser.parse_args()

    if options.file_type == []:
        raise ValueError, "Must specify at least one --file-type."

    return options, filenames

# =============================================================================
#
#                       Function Definitions
#
# =============================================================================

#
#   Some helper functions for manipulating times
#
def get_dst_start_end(ifo, year):
    """
    Figures out what dates daylight savings time starts and ends at a given site on a given year.
    """
    # in the United Stats, prior to 2007, DST began on the first Sunday in April
    # and ended on the last Sunday in October (http://aa.usno.navy.mil/faq/docs/daylight_time.php)
    if ("H" in ifo  or "L" in ifo) and year < 2007:
        for ii in range(1,28):
            dst_start = datetime.datetime(year, 4, ii, 2, 0, 0)
            if dst_start.strftime('%A') == 'Sunday':
                break
        for ii in range(31,0,-1):
            dst_end = datetime.datetime(year, 10, ii, 2, 0, 0)
            if dst_end.strftime('%A') == 'Sunday':
                break
    # in the US, starting in 2007, DST begins on the second Sunday in March and ends on the first
    # Sunday in November
    elif ("H" in ifo  or "L" in ifo) and year >= 2007:
        nn = 1
        for ii in range(1,31):
            dst_start = datetime.datetime(year, 3, ii, 2, 0, 0)
            if dst_start.strftime('%A') == 'Sunday' and nn == 2:
                break
            elif dst_start.strftime('%A') == 'Sunday':
                nn += 1
        for ii in range(1,28):
            dst_end = datetime.datetime(year, 11, ii, 2, 0, 0)
            if dst_end.strftime('%A') == 'Sunday':
                break
    # in Europe, DST begins on the last Sunday of March and ends on the last Sunday of October
    # source: http://www.timeanddate.com/news/time/europe-dst-starts-march-29-2009.html
    elif ("V" in ifo or "G" in ifo):
        for ii in range(31,0,-1):
            dst_start = datetime.datetime(year, 3, ii, 2, 0, 0)
            if dst_start.strftime('%A') == 'Sunday':
                break
        for ii in range(31,0,-1):
            dst_end = datetime.datetime(year, 10, ii, 2, 0, 0)
            if dst_end.strftime('%A') == 'Sunday':
                break
    else:
        raise ValueError, "unrecognized ifo %s" % ifo
    
    return dst_start, dst_end
        

def get_sitelocaltime_from_gps(ifo, gpstime):
    # get the utc time in datetime.datetime format
    utctime = XLALGPSToUTC(LIGOTimeGPS(gpstime, 0))
    utctime = datetime.datetime(utctime[0],utctime[1],utctime[2],utctime[3],utctime[4],utctime[5],utctime[6])
    # figure out if daylight savings time was on or not
    dst_start, dst_end = get_dst_start_end(ifo, utctime.year)
    # figure out the appropriate time offset
    if "H" in ifo:
        toffset = datetime.timedelta(hours=-7)
    elif "L" in ifo:
        toffset = datetime.timedelta(hours=-5)
    elif ("V" in ifo or "G" in ifo):
        toffset = datetime.timedelta(hours=+2)
    # apply the dst time offset to see if daylight savings was on; if not, adjust the toffset
    if not (utctime + toffset >= dst_start and utctime + toffset < dst_end):
        toffset = toffset + datetime.timedelta(hours=-1)

    return utctime + toffset

#
#   Tables for storing statistics about the files
#
class RateStatisticsTable(table.Table):
    tableName = "rate_statistics:table"
    validcolumns = {
        "ifo":"lstring",
        "file_type":"lstring",
        "average_rate":"real_8",
        "max_rate":"real_8",
        "min_rate":"real_8",
        "median_rate":"real_8",
        "standard_deviation":"real_8",
        "average_num_tmplts":"real_8",
        "average_rate_per_tmplt":"real_8"
        }   
        
class RateStatistics(object):
    __slots__ = RateStatisticsTable.validcolumns.keys()

    def get_pyvalue(self):
        if self.value is None:
            return None
        return ligolwtypes.ToPyType[self.type or "lstring"](self.value)

class FileStatisticsTable(table.Table):
    tableName = "file_statistics:table"
    validcolumns = {
        "ifo":"lstring",
        "num_files":"int_4u",
        "file_type":"lstring",
        "file_location":"lstring",
        "file_name":"lstring",
        "out_start_time":"int_4u",
        "out_start_time_utc":"lstring",
        "segment_duration":"int_4u",
        "num_events":"int_4u",
        "trigger_rate":"real_8",
        "trigger_rate_rank":"int_4u",
        "num_tmplts":"real_8",
        "trigger_rate_per_tmplt":"real_8",
        "trigger_rate_per_tmplt_rank":"int_4u",
        "average_snr":"real_8",
        "average_snr_rank":"int_4u",
        "max_snr":"real_8",
        "max_snr_rank":"int_4u",
        "min_snr":"real_8",
        "median_snr":"real_8",
        "median_snr_rank":"int_4u",
        "snr_standard_deviation":"real_8",
        "snr_standard_deviation_rank":"int_4u",
        "daily_ihope_page":"lstring",
        "elog_page":"lstring"
        }

class FileStatistics(object):
    __slots__ = FileStatisticsTable.validcolumns.keys()

    def set_elog_page(self):
        # set site_address
        if "H" in self.ifo:
            site_address = "http://ilog.ligo-wa.caltech.edu/ilog/pub/ilog.cgi?group=detector"
        elif "L" in self.ifo:
            site_address = "http://ilog.ligo-la.caltech.edu/ilog/pub/ilog.cgi?group=detector"
        elif "V" in self.ifo:
            #FIXME: What's the site address and format for Virgo log book?
            site_address = "https://pub3.ego-gw.it/logbook/"
        # get local time at the site
        site_localtime = get_sitelocaltime_from_gps(self.ifo, self.out_start_time)
        # set the address
        if "H" in self.ifo or "L" in self.ifo:
            site_address = "%s&date_to_view=%s" % ( site_address, site_localtime.strftime("%m/%d/%Y") )
        self.elog_page = '<a href="%s">link</a>' % site_address

    def get_pyvalue(self):
        if self.value is None:
            return None
        return ligolwtypes.ToPyType[self.type or "lstring"](self.value)

# connect rows to the tables
RateStatisticsTable.RowType = RateStatistics
FileStatisticsTable.RowType = FileStatistics

#
#   Classes for gathering statistics about the files
#

class FileList:
    def __init__(self):
        self.files = []
        self.avg_rate = {}
        self.max_rate = {}
        self.min_rate = {}
        self.med_rate = {}
        self.std_rate = {}

    def append_file(self, filestat_obj):
        self.files.append(filestat_obj)
    
    def get_distinct_ifos(self):
        return set([ifo for file in self.files for ifo in file.ifos])

    def compute_rate_stats(self):
        rates = {}
        for file in self.files:
            for ifo in file.ifos:
                if ifo not in rates:
                    rates[ifo] = []
                rates[ifo].append( file.rate[ifo] )
        for ifo, rates_list in rates.items():
            rates_list = numpy.array(rates_list)
            if rates_list != numpy.array([]):
                self.avg_rate[ifo] = numpy.mean(rates_list)
                self.max_rate[ifo] = numpy.max(rates_list)
                self.min_rate[ifo] = numpy.min(rates_list)
                self.med_rate[ifo] = numpy.median(rates_list)
                self.std_rate[ifo] = numpy.std(rates_list)
            else:
                self.avg_rate[ifo] = self.max_rate[ifo] = self.min_rate[ifo] = self.med_rate[ifo] = self.std_rate[ifo] = 0. 
        del rates

    def set_relative_ranks(self):
        for ifo in sorted(self.get_distinct_ifos()):
            selected_files = [file for file in self.files if ifo in file.ifos]
            for n, file in enumerate(sorted(selected_files, key=lambda file: file.rate[ifo], reverse=True)):
                file.rate_rank[ifo] = n+1
            for n, file in enumerate(sorted(selected_files, key=lambda file: file.rate_per_tmplt[ifo], reverse=True)):
                file.rate_per_tmplt_rank[ifo] = n+1
            for n, file in enumerate(sorted(selected_files, key=lambda file: file.avg_stat[ifo], reverse=True)):
                file.avg_stat_rank[ifo] = n+1
            for n, file in enumerate(sorted(selected_files, key=lambda file: file.max_stat[ifo], reverse=True)):
                file.max_stat_rank[ifo] = n+1
            for n, file in enumerate(sorted(selected_files, key=lambda file: file.median_stat[ifo], reverse=True)):
                file.median_stat_rank[ifo] = n+1
            for n, file in enumerate(sorted(selected_files, key=lambda file: file.std_stat[ifo], reverse=True)):
                file.std_stat_rank[ifo] = n+1


class FileStats:
    def __init__(self, filename):
        if not os.path.isfile(filename):
            raise ValueError, "missing file %s" % filename
        self.name = filename
        self.xmldoc = None
        self.search_summ_table = None
        self.sngl_insp_table = None
        self.out_list = None
        self.seg_start = None
        self.seg_end = None
        self.duration = None
        self.ifos = None
        self.stats_list = {}
        self.num_events = {}
        self.events_by_param = {}
        self.rate = {}
        self.rate_rank = {}
        self.num_tmplts = {}
        self.rate_per_tmplt = {}
        self.rate_per_tmplt_rank = {}
        self.avg_stat = {}
        self.avg_stat_rank = {}
        self.max_stat = {}
        self.max_stat_rank = {}
        self.min_stat = {}
        self.median_stat = {}
        self.median_stat_rank = {}
        self.std_stat = {}
        self.std_stat_rank = {}
        self.daily_ihope_page = {}

    def load_file(self):
        self.xmldoc = utils.load_filename( self.name, gz = self.name.endswith(".gz") )
        self.search_summ_table = table.get_table(self.xmldoc, lsctables.SearchSummaryTable.tableName)
        try:
            self.sngl_insp_table = table.get_table(self.xmldoc, lsctables.SnglInspiralTable.tableName)
        except ValueError:
            self.sngl_insp_table = lsctables.New(lsctables.SnglInspiralTable) 


    def close_file(self):
        del self.stats_list
        self.stats_list = None
        del self.sngl_insp_table
        self.sngl_insp_table = None
        del self.search_summ_table
        self.search_summ_table = None
        del self.xmldoc
        self.xmldoc = None

    def get_ifos(self):
        self.ifos = set([ ifo for row in self.search_summ_table for ifo in row.get_ifos() ])
        if self.ifos == set([]):
            self.ifos = set([row.ifo for row in self.sngl_insp_table])
        for ifo in self.ifos:
            self.stats_list[ifo] = []
            self.rate[ifo] = self.rate_per_tmplt[ifo] = self.avg_stat[ifo] = self.max_stat[ifo] = self.min_stat[ifo] = self.median_stat[ifo] = self.std_stat[ifo] = \
            self.rate_rank[ifo] = self.rate_per_tmplt_rank[ifo] = self.avg_stat_rank[ifo] = self.max_stat_rank[ifo] = self.median_stat_rank[ifo] = self.std_stat_rank[ifo] = None

    def get_stats_list(self, stat):
        for row in self.sngl_insp_table:
            self.stats_list[row.ifo].append( getattr(row, stat) )

    def set_num_events(self):
        self.num_events = dict([ [ifo, len(statlist)] for ifo, statlist in self.stats_list.items() ])

    def compute_stats(self, stat):
        self.get_ifos()
        self.get_stats_list(stat)
        self.set_num_events()
        for ifo, stat_vals in self.stats_list.items():
            if self.num_events[ifo] != 0:
                stat_vals = numpy.array(stat_vals)
                self.avg_stat[ifo] = numpy.mean(stat_vals)
                self.max_stat[ifo] = numpy.max(stat_vals)
                self.min_stat[ifo] = numpy.min(stat_vals)
                self.median_stat[ifo] = numpy.median(stat_vals)
                self.std_stat[ifo] = numpy.std(stat_vals)

    def get_seg_times(self):
        self.out_list = self.search_summ_table.get_outlist()
        if len(self.out_list) > 1:
            raise ValueError, "more than one in start/end time for file %s" % self.name
        self.seg_start, self.seg_end = self.out_list[0][0].seconds, self.out_list[0][1].seconds

    def get_num_tmplts(self, tmpltbank_filelist):
        for ifo in self.ifos:
            self.num_tmplts[ifo] = [tmpltfile.num_events[ifo] for tmpltfile in tmpltbank_filelist.files if ifo in tmpltfile.ifos and self.seg_start in tmpltfile.out_list].pop()

    def compute_rate(self):
        self.duration = self.seg_end - self.seg_start
        if self.duration != 0:
            self.rate = dict([ [ifo, float(self.num_events[ifo])/self.duration] for ifo in self.ifos ])
            self.rate_per_tmplt = dict([ [ifo, float(self.rate[ifo])/self.num_tmplts[ifo]] for ifo in self.ifos if ifo in self.num_tmplts and self.num_tmplts[ifo] != 0])

    def set_daily_ihope_page(self, pages_location = "https://ldas-jobs.ligo.caltech.edu/~cbc/ihope_daily"):
        utctime = XLALGPSToUTC(LIGOTimeGPS(self.seg_start, 0))
        self.daily_ihope_page = "%s/%s/%s/" %(pages_location, time.strftime("%Y%m", utctime), time.strftime("%Y%m%d", utctime))

# =============================================================================
#
#                                     Main
#
# =============================================================================
    
opts, cachefiles = parse_command_line() 

#
#   Create output xml doc
#
ratedoc = ligolw.Document()
# setup the LIGOLW tag
ratedoc.appendChild(ligolw.LIGO_LW())
# add this program's metadata
#proc_id = process.register_to_xmldoc(ratedoc, __prog__, opts.__dict__)
# add the tables
rate_table = lsctables.New(RateStatisticsTable)
fs_table = lsctables.New(FileStatisticsTable)
# connect the tables to the document
ratedoc.childNodes[0].appendChild(rate_table)
ratedoc.childNodes[0].appendChild(fs_table)

#
#   Read cache file and cycle over file-types
#
if opts.verbose:
    print >> sys.stderr, "Reading cache file(s)..."
ihope_cache = lal.Cache().fromfilenames( cachefiles)

tmpltbank_sieve = ''
for file_type in opts.file_type:

    if opts.verbose:
        print >> sys.stderr, "Analyzing %s..." % file_type

    # check what kind of TMPLTBANK names are in the cache; older versions of ihope had things like
    # TMPLTBANK_FULL_DATA, TMPLTBANK_PLAYGROUND, etc., whereas the current version just has TMPLTBANK for all
    # data types

    # old format
    this_tmpltbank = re.sub('(INSPIRAL_FIRST)|(INSPIRAL_SECOND)|(THINCA_FIRST)|(THINCA_SECOND)', 'TMPLTBANK', file_type)
    this_tmpltbank = re.sub('_CAT_[0-9]_VETO', '', this_tmpltbank)
    # check for old format
    if len( ihope_cache.sieve(description = this_tmpltbank) ) == 0:
        # assume new format
        this_tmpltbank = 'TMPLTBANK'
    
    # if this_tmpltbank doesn't match tmpltbank_sieve, get the tmpltbank stats
    if this_tmpltbank != tmpltbank_sieve:
        tmpltbank_sieve = this_tmpltbank
        tmpltbank_cache = ihope_cache.sieve( description = tmpltbank_sieve )
        if opts.verbose:
            print >> sys.stderr, "\tgetting tmpltbank information..."
            num_files = float(len(tmpltbank_cache))
        tmpltbank_filelist = FileList()
        for n, entry in enumerate(tmpltbank_cache):
            if opts.verbose:
                print >> sys.stderr, "\t\tfile %i / %i\r" % (n+1, num_files),
            tmpltbankfile = FileStats( entry.path )
            tmpltbankfile.load_file()
            tmpltbankfile.get_ifos()
            tmpltbankfile.get_stats_list('event_id')
            tmpltbankfile.get_seg_times()
            tmpltbankfile.set_num_events()
            # for tmpltbank files, just set rate to num_events
            tmpltbankfile.rate = tmpltbankfile.num_events
            tmpltbankfile.close_file()
            tmpltbank_filelist.append_file(tmpltbankfile)
        tmpltbank_filelist.compute_rate_stats()
        if opts.verbose:
            print >> sys.stderr, ""

    selected_cache = ihope_cache.sieve( description = file_type, exact_match = True )
    if opts.verbose:
        print >> sys.stderr, "\tgathering statistics..."
        num_files = float(len(selected_cache))
    
    filelist = FileList()
    for n, entry in enumerate(selected_cache):
        if opts.verbose:
            print >> sys.stderr, "\t\tfile %i / %i\r" % (n+1, num_files),
        if opts.force and not os.path.isfile(entry.path):
            print >> sys.stderr, "\nWarning: File %s not found. Skipping..." % entry.path
            continue
        inspfile = FileStats( entry.path )
        inspfile.load_file()
        inspfile.compute_stats('snr')
        inspfile.get_seg_times()
        inspfile.get_num_tmplts(tmpltbank_filelist)
        inspfile.compute_rate()
        inspfile.set_daily_ihope_page(pages_location = opts.daily_ihope_pages_location)
        inspfile.close_file()
        filelist.append_file(inspfile)
    
    filelist.compute_rate_stats()
    filelist.set_relative_ranks()

    if opts.verbose:
        print >> sys.stderr, ""
    
    #
    #   Save results to tables
    #
    for ifo in sorted(filelist.get_distinct_ifos()):
        rate_row = RateStatistics()
        rate_row.ifo = ifo
        rate_row.file_type = file_type
        rate_row.average_rate = filelist.avg_rate[ifo]
        rate_row.max_rate = filelist.max_rate[ifo]
        rate_row.min_rate = filelist.min_rate[ifo]
        rate_row.median_rate = filelist.med_rate[ifo]
        rate_row.standard_deviation = filelist.std_rate[ifo]
        rate_row.average_num_tmplts = tmpltbank_filelist.avg_rate[ifo]
        rate_row.average_rate_per_tmplt = filelist.avg_rate[ifo]/tmpltbank_filelist.avg_rate[ifo]
        rate_table.append(rate_row)
        selected_files = [file for file in filelist.files if ifo in file.ifos]
        for file in selected_files:
            fs = FileStatistics()
            fs.ifo = ifo
            fs.num_files = len(selected_files)
            fs.file_type = file_type
            fs.file_location = os.path.dirname(file.name)
            fs.file_name = os.path.basename(file.name)
            fs.out_start_time = file.seg_start
            fs.out_start_time_utc = time.strftime("%a %d %b %Y %H:%M:%S", XLALGPSToUTC(LIGOTimeGPS(file.seg_start, 0)))
            fs.segment_duration = file.duration
            fs.num_events = file.num_events[ifo]
            fs.trigger_rate = file.rate[ifo]
            fs.trigger_rate_rank = file.rate_rank[ifo]
            fs.num_tmplts = file.num_tmplts[ifo] 
            fs.trigger_rate_per_tmplt = file.rate_per_tmplt[ifo] 
            fs.trigger_rate_per_tmplt_rank = file.rate_per_tmplt_rank[ifo]
            fs.average_snr = file.avg_stat[ifo]
            fs.average_snr_rank = file.avg_stat_rank[ifo]
            fs.max_snr = file.max_stat[ifo]
            fs.max_snr_rank = file.max_stat_rank[ifo]
            fs.min_snr = file.min_stat[ifo]
            fs.median_snr = file.median_stat[ifo]
            fs.median_snr_rank = file.median_stat_rank[ifo]
            fs.snr_standard_deviation = file.std_stat[ifo]
            fs.snr_standard_deviation_rank = file.std_stat_rank[ifo]
            fs.daily_ihope_page = ''.join(['<a href="', file.daily_ihope_page, '">link</a>'])
            fs.set_elog_page()
            fs_table.append(fs)

# save the results
utils.write_filename( ratedoc, opts.output, xsl_file = "ligolw.xsl")

sys.exit(0)

