#!/usr/bin/env python
#
# Copyright (C) 2013 Chad Hanna
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 2 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA

## @file
# A program to extract the loudest coincs from an offline gstlal_inspiral DAG database into single files that can be uploaded to gracedb
#
# ### Command line interface
#	+ `--fap-thresh` [probability] (float): Set the false alarm probability: default 0.01.
#	+ `--gps-times` [GPS (comma separated)]: Restrict times to this list. give as GPS1,GPS2,...GPSN. Assumes +- 1s.

from glue.ligolw import ligolw, lsctables, utils, dbtables
import sqlite3
import sys
from optparse import OptionParser

def snglrow(connection):
	return lsctables.table.get_table(dbtables.get_xml(connection), lsctables.SnglInspiralTable.tableName).row_from_cols
def coincrow(connection):
	return lsctables.table.get_table(dbtables.get_xml(connection), lsctables.CoincInspiralTable.tableName).row_from_cols
def coinceventrow(connection):
	return lsctables.table.get_table(dbtables.get_xml(connection), lsctables.CoincTable.tableName).row_from_cols
def coinceventmaprow(connection):
	return lsctables.table.get_table(dbtables.get_xml(connection), lsctables.CoincMapTable.tableName).row_from_cols

parser = OptionParser()
parser.add_option("--fap-thresh", default=0.01, type="float", metavar="probability", help="Set the false alarm probability: default 0.01")
parser.add_option("--gps-times", metavar="GPS (comma separated)", help="Restrict times to this list. give as GPS1,GPS2,...GPSN. Assumes +- 1s")
options, filenames = parser.parse_args()

if len(filenames) != 1:
	print >> sys.stderr, "Only supports excactly 1 database currently"
	sys.exit(1)

db = sqlite3.connect(filenames[0])

if options.gps_times is None:
	cids = list(db.cursor().execute('SELECT coinc_inspiral.coinc_event_id, coinc_inspiral.end_time, coinc_inspiral.ifos FROM coinc_inspiral JOIN coinc_event ON coinc_event.coinc_event_id == coinc_inspiral.coinc_event_id WHERE NOT EXISTS(SELECT * FROM time_slide WHERE time_slide.time_slide_id == coinc_event.time_slide_id AND time_slide.offset != 0) AND coinc_inspiral.false_alarm_rate <= ?', (options.fap_thresh,)).fetchall())
else:
	gps_list = [int(t) for t in options.gps_times.split(',')]
	# +1
	gps_list.extend([t+1 for t in gps_list])
	# -1
	gps_list.extend([t-1 for t in gps_list])
	print >> sys.stderr, gps_list
	query = 'SELECT coinc_inspiral.coinc_event_id, coinc_inspiral.end_time, coinc_inspiral.ifos FROM coinc_inspiral JOIN coinc_event ON coinc_event.coinc_event_id == coinc_inspiral.coinc_event_id WHERE NOT EXISTS(SELECT * FROM time_slide WHERE time_slide.time_slide_id == coinc_event.time_slide_id AND time_slide.offset != 0) AND coinc_inspiral.false_alarm_rate <= ? AND coinc_inspiral.end_time IN (%s)' % (",".join(["%d" % t for t in sorted(gps_list)]),)
	print >> sys.stderr, query
	cids = list(db.cursor().execute(query, (options.fap_thresh,)).fetchall())

for (cid, time, ifos) in cids:
	xmldocmain = ligolw.Document()
	xmldoc = xmldocmain.appendChild(ligolw.LIGO_LW())
	
	# coinc inspiral table
	coincinspiral = lsctables.New(lsctables.CoincInspiralTable)
	xmldoc.appendChild(coincinspiral)
	rowfunc = coincrow(db) 
	for val in db.cursor().execute('SELECT * FROM coinc_inspiral WHERE coinc_event_id == ?', (cid,)):
		coincinspiral.append(rowfunc(val))
	
	# coinc event table
	coincevent = lsctables.New(lsctables.CoincTable)
	xmldoc.appendChild(coincevent)
	rowfunc = coinceventrow(db) 
	for val in db.cursor().execute('SELECT * FROM coinc_event WHERE coinc_event_id == ?', (cid,)):
		coincevent.append(rowfunc(val))
	
	# coinc event map table
	coinceventmap = lsctables.New(lsctables.CoincMapTable)
	xmldoc.appendChild(coinceventmap)
	rowfunc = coinceventmaprow(db)
	snglids = []
	for val in db.cursor().execute('SELECT * FROM coinc_event_map WHERE coinc_event_id == ?', (cid,)):
		row = rowfunc(val)
		snglids.append(row.event_id)
		coinceventmap.append(row)
	
	# sngl inspiral table
	sngl = lsctables.New(lsctables.SnglInspiralTable)
	xmldoc.appendChild(sngl)
	rowfunc = snglrow(db)
	#FIXME Terrible hack, figure out how to do this correctly
	query = 'SELECT * FROM sngl_inspiral WHERE event_id IN (%s)' % ",".join(['"%s"' % str(i) for i in snglids])
	for val in db.cursor().execute(query):
		sngl.append(rowfunc(val))

	utils.write_filename(xmldocmain, '%s-LLOID-%d-0.xml.gz' % (ifos.replace(",",""), time), gz=True, verbose=True)
