#!/usr/bin/env python

# Copyright (C) 2011 Ian W. Harry
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

# =============================================================================
# Preamble
# =============================================================================

from __future__ import division

import os,sys,time
from optparse import OptionParser
from pylal import MultiInspiralUtils,coh_PTF_pyutils,git_version
from glue.ligolw import table,lsctables,utils,ligolw,ilwd
from glue import segments,lal

__author__  = "Ian Harry <ian.harry@astro.cf.ac.uk>"
__version__ = "git id %s" % git_version.id
__date__    = git_version.date

# set up timer
start = int(time.time()*10**6)
elapsed_time = lambda: int(time.time()*10**6-start)

# =============================================================================
# Parse command line
# =============================================================================

def parse_command_line():

  usage = """usage: %prog [options] 

coh_PTF_trig_combiner is designed to coalesce the triggers made from each element of the split template bank used by the search. The required arguments are

--ifo-tag
--user-tag
--cache
"""

  parser = OptionParser(usage, version=__version__)

  parser.add_option("-i", "--ifo-tag", action="store", type="string",\
                     default=None,\
                     help="The ifo tag, H1L1 or H1L1V1 for instance")

  parser.add_option("-u", "--user-tag", action="store", type="string",\
                     default='COH_PTF_INSPIRAL_FIRST',\
                     help="The user tag, COH_PTF_INSPIRAL_FIRST for instance")

  parser.add_option("-S", "--slide-tag", action="store", type="string",\
                     default='',\
                     help="The slide tag, used to differentiate long slides")

  parser.add_option("-n", "--grb-name", action="store", type="string",\
                     default=None, help="GRB name, such as 090802 "+\
                                        "(will be appended to output user tag")

  parser.add_option("-T", "--num-trials", action="store", type="int",\
                     default=6,\
                     help="The number of off source trials, default: %default")

  parser.add_option("-s", "--segment-length", action="store", type="float",\
                     default=None,\
                     help="The length of analysis segments.")

  parser.add_option("-p", "--pad-data", action="store", type="float",\
                     default=None,\
                     help="The length of padding around analysis chunk.")

  parser.add_option("-a", "--segment-dir", action="store", type="string",\
                     help="directory holding buffer, on and off source "+\
                          "segment files.")

  parser.add_option("-o", "--output-dir", action="store", type="string",\
                     default=os.getcwd(), help="output directory, "+\
                                               "default: %default")

  parser.add_option("-t", "--slide-cache", action="store",default=None,\
                    help="Read slide triggers from this cache.")

  parser.add_option("-v", "--verbose", action="store_true", default=False,\
                     help="verbose output with microsecond timer, "+\
                          "default: %default")

  parser.add_option("-c", "--cache", action="store", type="string",\
                     default=None, help="read trigger files from this cache.")

  (opts,args) = parser.parse_args()

  if not opts.ifo_tag:
    parser.error("must provide --ifo-tag")

  if not opts.user_tag:
    parser.error("must provide --user-tag")

  if not opts.cache and not opts.slide_cache:
    parser.error("must provide --cache or --slide-cache")

  if not opts.segment_dir:
    parser.error("must provide --segment-dir")

  return opts,args

# =============================================================================
# Main function
# =============================================================================

def main(cacheFile, ifoTag, userTag, segdir, outdir, grbTag=None, numTrials=6,\
         timeSlides=None, slideTag=None, segLength=None, padData=None,\
         verbose=False):

  if verbose: sys.stdout.write("Getting segments...\n")

  # get ifos
  ifos = sorted(lsctables.instrument_set_from_ifos(ifoTag))

  #
  # construct segments
  #

  # get segments
  lostTime       = int(segLength//4 + padData) 
  segs           = coh_PTF_pyutils.readSegFiles(segdir)
  trialTime      = abs(segs['on'])
  
  # FIXME: These off source trials may no longer be the right ones!
  # construct off source trials
  offSourceSegs  = segments.segmentlist()
  offSourceSegs.append(segs['off'])
  offSourceSegs.append(segments.segment(segs['off'][0]+lostTime,\
                                        segs['off'][0]+lostTime+trialTime))
  for i in xrange(numTrials-1):
    t = i+2
    offSourceSegs.append(segments.segment(tuple(map(trialTime.__add__,\
                                                    offSourceSegs[t-1]))))
 
  if verbose: 
    sys.stdout.write("Constructed %d off source trials with segments at %d:\n"\
                     % (numTrials, elapsed_time()))
    for i in xrange(numTrials):
      sys.stdout.write("%s\n" % str(offSourceSegs[i+1]))

  #
  # load triggers
  #

  # get trigger files 
  if verbose: sys.stdout.write("\nFinding files...\n")

  if cacheFile:
    cache = lal.Cache.fromfile(open(cacheFile, 'r'))
    cache = cache.sieve(ifos=ifoTag, description=userTag)

    if len(cache)<1:
      sys.stderr.write("No files found at %d.\n" % elapsed_time())
      sys.exit(1)
    else:
      sys.stdout.write("%d files found at %d.\n"\
                       % (len(cache), elapsed_time()))

  if timeSlides:
    slideCache = lal.Cache.fromfile(open(timeSlides, 'r'))
    slideCache = slideCache.sieve(ifos=ifoTag, description=userTag)
    if slideTag:
      slideCache = slideCache.sieve(description=slideTag)
    if len(slideCache)<1:
      sys.stderr.write("No files found at %d.\n" % elapsed_time())
      sys.exit(1)
    else:
      sys.stdout.write("%d files found at %d.\n"\
                       % (len(slideCache), elapsed_time()))

  # set columns
  ifoAtt = { 'G1':'g', 'H1':'h1', 'H2':'h2', 'L1':'l', 'V1':'v', 'T1':'t' }
  # set up columns
  validcols = lsctables.MultiInspiralTable.validcolumns.keys()
  cols = ['end_time', 'end_time_ns', 'ifos', 'process_id', 'ra', 'dec', 'snr',\
          'null_statistic', 'null_stat_degen', 'mass1', 'mass2', 'mchirp']
  cols.extend(['time_slide_id'])
  # add snr
  cols.extend(['snr_%s' % ifoAtt[i] for i in ifos\
               if 'snr_%s' % ifoAtt[i] in validcols])
  # add chisq
  cols.extend(['chisq', 'chisq_dof'])
  cols.extend(['chisq_%s' % ifoAtt[i] for i in ifos\
               if 'chisq_%s' % ifoAtt[i] in validcols])
  # add bank chisq
  cols.extend(['bank_chisq', 'bank_chisq_dof'])
  #cols.extend(['bank_chisq_%s' % ifoAtt[i] for i in ifos\
  #             if 'bank_chisq_%s' % ifoAtt[i] in validcols])
  # add auto chisq  
  cols.extend(['cont_chisq', 'cont_chisq_dof'])
  #cols.extend(['cont_chisq_%s' % ifoAtt[i] for i in ifos\
  #             if 'cont_chisq_%s' % ifoAtt[i] in validcols])
  # add sigmasq
  cols.extend(['sigmasq_%s' % ifoAtt[i] for i in ifos\
               if 'sigmasq_%s' % ifoAtt[i] in validcols])
  # add amplitude terms
  cols.extend(['amp_term_%d' % i for i in xrange(1,11)])
  # set columns
  cols = [c for c in cols if c in\
          lsctables.MultiInspiralTable.validcolumns.keys()]
  lsctables.MultiInspiralTable.loadcolumns = cols

  # prepare triggers
  trigs = {}
 
  if verbose: sys.stdout.write("\nLoading all triggers...\n")

  # load triggers
  if not timeSlides:
    trigs['ALL_TIMES'] = lsctables.New(lsctables.MultiInspiralTable,\
                               columns=lsctables.MultiInspiralTable.loadcolumns)
    trigs['ALL_TIMES'].extend(MultiInspiralUtils.ReadMultiInspiralFromFiles(\
                                                       e.path for e in cache))
    # This is a softcopy, so adding values to trigs also adds to zerolagtrigs
    zeroLagTrigs = trigs
    # Also load the basic time slide table
    timeSlideDoc = utils.load_filename((cache[0]).path,
        gz=( (cache[0]).path ).endswith(".gz"), contenthandler = lsctables.use_in(ligolw.LIGOLWContentHandler))
    timeSlideTable = table.get_table(timeSlideDoc,
          lsctables.TimeSlideTable.tableName)
    segmentTable = table.get_table(timeSlideDoc,
          lsctables.SegmentTable.tableName)
    timeSlideSegmentMap = table.get_table(timeSlideDoc,
          lsctables.TimeSlideSegmentMapTable.tableName)
  else:
    trigs['ALL_TIMES'] = lsctables.New(lsctables.MultiInspiralTable,\
                               columns=lsctables.MultiInspiralTable.loadcolumns)
    if cacheFile:
      inpFiles = [e.path for e in cache]
    else:
      inpFiles = []
    inpFiles.extend([e.path for e in slideCache])
    if not cacheFile:
      cache = slideCache
    tempTrigs,slideDictList,segmentDict,timeSlideTable,\
        segmentTable,timeSlideSegmentMap = \
            MultiInspiralUtils.ReadMultiInspiralTimeSlidesFromFiles(inpFiles,\
                generate_output_tables=True)
    trigs['ALL_TIMES'].extend(tempTrigs)
    # Also want to separate out the zerolag trigs
    zeroLagTrigs = {}
    zeroLagTrigs['ALL_TIMES'] = lsctables.New(lsctables.MultiInspiralTable,\
                               columns=lsctables.MultiInspiralTable.loadcolumns)
    # Identify the zero lag slide
    for slideID,dict in enumerate(slideDictList):
      for ifo in ifos:
        if dict[ifo] != 0:
          break
      else:
        zeroLagID = slideID
        break
    else:
      # No zero lag trigs here
      zeroLagTrigs = None

    # Add the zero-lag triggers
    if zeroLagTrigs:
      zeroLagTrigs['ALL_TIMES'].extend(trig for trig in trigs['ALL_TIMES'] if \
           int(trig.time_slide_id) == zeroLagID)

  # Temporary hack to allow the code to work with tables that dont have
  # the new time slide ID column.
  if 'time_slide_id' not in trigs['ALL_TIMES'].columnnames:
    trigs['ALL_TIMES'].appendColumn('time_slide_id')
  for trig in trigs['ALL_TIMES']:
    if not hasattr(trig, 'time_slide_id'):
      trig.time_slide_id = ilwd.ilwdchar("multi_inspiral:time_slide_id:0")

  if verbose: sys.stdout.write("%d triggers found at %d.\n"\
                               % (len(trigs['ALL_TIMES']), elapsed_time()))

  # 
  # load search summary table
  #

  # get search summary table from old file
  oldxml   = utils.load_filename(cache[0].path,\
                                  gz = cache[0].path.endswith("gz"), contenthandler = lsctables.use_in(ligolw.LIGOLWContentHandler))
  oldSearchSummTable = table.get_table(oldxml, "search_summary")

  #
  # bin triggers
  #

  if verbose: sys.stdout.write("\nSeparating triggers by end time...\n")

  if zeroLagTrigs:
    zeroLagTrigs['ONSOURCE'] = lsctables.New(lsctables.MultiInspiralTable,\
                               columns=lsctables.MultiInspiralTable.loadcolumns)
    zeroLagTrigs['OFFSOURCE'] = lsctables.New(lsctables.MultiInspiralTable,\
                               columns=lsctables.MultiInspiralTable.loadcolumns)
    for i in range(numTrials):
      zeroLagTrigs['OFFTRIAL_%d' % (i+1)] = lsctables.New(\
                               lsctables.MultiInspiralTable,\
                               columns=lsctables.MultiInspiralTable.loadcolumns)

    # separate into correct bins
    for trig in zeroLagTrigs['ALL_TIMES']:
      if trig.get_end() not in segs['buffer']:
        zeroLagTrigs['OFFSOURCE'].append(trig)
      elif trig.get_end() in segs['on']:
        zeroLagTrigs['ONSOURCE'].append(trig)
      for i in xrange(numTrials):
        if trig.get_end() in offSourceSegs[i+1]:
          zeroLagTrigs['OFFTRIAL_%d' % (i+1)].append(trig)
          break

  if verbose: sys.stdout.write("Done at %d.\n" % (elapsed_time()))

  if timeSlides:
    trigs['OFFSOURCE'] = lsctables.New(lsctables.MultiInspiralTable,\
                               columns=lsctables.MultiInspiralTable.loadcolumns)
    for trig in trigs['ALL_TIMES']:
      slideID = int(trig.time_slide_id)
      for ifo in ifos:
        offset = slideDictList[slideID][ifo]
        if trig.get_end()+offset in segs['buffer']:
          break
      else:
        trigs['OFFSOURCE'].append(trig)

  #
  # write combined triggers to file
  #

  if verbose: sys.stdout.write("\nWriting new xml files...\n")

  # prepare xmldocument 
  xmldoc = ligolw.Document()
  xmldoc.appendChild(ligolw.LIGO_LW())

  # append process params table
  xmldoc = coh_PTF_pyutils.append_process_params(xmldoc, sys.argv, __version__,\
                                                 __date__)
 
  # get search summary table from old file
  xmldoc.childNodes[-1].appendChild(oldSearchSummTable)

  # Add time slide table
  xmldoc.childNodes[-1].appendChild(timeSlideTable)
  # Add segment table
  xmldoc.childNodes[-1].appendChild(segmentTable)
  # And the map between the two
  xmldoc.childNodes[-1].appendChild(timeSlideSegmentMap)

  # construct new xml file names
  start, end = cache.to_segmentlistdict().extent_all()
  if timeSlides:
    tsUserTag = '%s_TIMESLIDES' %(userTag)
    if grbTag:
      tsUserTag = '%s_GRB%s' % (tsUserTag, grbTag)
  if grbTag:
    userTag = '%s_GRB%s' % (userTag, grbTag)

  # write new xml files
  if zeroLagTrigs:
    for trigTime in zeroLagTrigs.keys():
      xmlName = '%s/%s-%s_%s-%d-%d.xml.gz'\
                % (outdir, ifoTag, userTag, trigTime, start, end-start)
      xmldoc.childNodes[-1].appendChild(zeroLagTrigs[trigTime])
      utils.write_filename(xmldoc, xmlName, gz = xmlName.endswith('gz'))
      xmldoc.childNodes[-1].removeChild(zeroLagTrigs[trigTime]) 
      if verbose: 
        sys.stdout.write("%s written at %d\n" % (xmlName, elapsed_time()))

  if timeSlides:
    for trigTime in trigs.keys():
      if slideTag:
        fullTag = '%s-%s_%s' %(ifoTag, tsUserTag, slideTag)
      else:
        fullTag = '%s-%s' %(ifoTag, tsUserTag)
      xmlName = '%s/%s_%s-%d-%d.xml.gz'\
                % (outdir, fullTag, trigTime, start, end-start)
      xmldoc.childNodes[-1].appendChild(trigs[trigTime])
      utils.write_filename(xmldoc, xmlName, gz = xmlName.endswith('gz'))
      xmldoc.childNodes[-1].removeChild(trigs[trigTime])
      if verbose:
        sys.stdout.write("%s written at %d\n" % (xmlName, elapsed_time()))


if __name__=='__main__':

  opts, args = parse_command_line()

  outdir    = os.path.abspath(opts.output_dir)
  segdir    = os.path.abspath(opts.segment_dir)
  ifoTag    = opts.ifo_tag
  userTag   = opts.user_tag
  cacheFile = opts.cache
  GRBname   = opts.grb_name
  numTrials = opts.num_trials
  timeSlides = opts.slide_cache
  slideTag = opts.slide_tag
  segLength = opts.segment_length
  padData = opts.pad_data
  verbose   = opts.verbose

  main(cacheFile, ifoTag, userTag, segdir, outdir,\
       grbTag=GRBname, numTrials=numTrials, timeSlides=timeSlides,\
       slideTag=slideTag, segLength=segLength,padData=padData,verbose=verbose)
  if verbose: sys.stdout.write("Done at %d.\n" % (elapsed_time()))
