#!/usr/bin/env python
#
# Copyright (C) 2012 Chad Hanna
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 2 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

"""
This program makes a dag to run a psd estimation dag
"""

__author__ = 'Chad Hanna <channa@caltech.edu>'

##############################################################################
# import standard modules and append the lalapps prefix to the python path
import sys, os

##############################################################################
# import the modules we need to build the pipeline
from glue import segments
from optparse import OptionParser
from gstlal import dagparts
from gstlal import datasource

#
# Classes for generating reference psds
#

class gstlal_reference_psd_job(dagparts.CondorDAGJob):
	"""
	A gstlal_reference_psd job
	"""
	def __init__(self, executable=dagparts.which('gstlal_reference_psd'), tag_base='gstlal_reference_psd'):
		"""
		A gstlal_reference_psd job
		"""
		dagparts.CondorDAGJob.__init__(self, executable, tag_base)


class gstlal_reference_psd_node(dagparts.CondorDAGNode):
	"""
	A gstlal_reference_psd node
	"""
	def __init__(self, job, dag, frame_cache, frame_segments_file, frame_segments_name, gps_start_time, gps_end_time, instruments, channel_dict, injections=None, p_node=[]):
		dagparts.CondorDAGNode.__init__(self, job, dag, p_node)
		self.add_var_opt("frame-cache", frame_cache)
		self.add_var_opt("frame-segments-file", frame_segments_file)
		self.add_var_opt("frame-segments-name", frame_segments_name)
		self.add_var_opt("gps-start-time",gps_start_time)
		self.add_var_opt("gps-end-time",gps_end_time)
		self.add_var_opt("data-source", "frames")
		self.add_var_opt("channel-name", datasource.pipeline_channel_list_from_channel_dict(channel_dict))
		if injections: self.add_var_opt("injections", injections)
		path = os.getcwd()
		output_name = self.output_name = '%s/%s-%d-%d-reference_psd.xml.gz' % (path, "".join(sorted(channel_dict)), gps_start_time, gps_end_time)
		self.add_var_opt("write-psd",output_name)


class gstlal_median_of_psds_job(dagparts.CondorDAGJob):
	"""
	A gstlal_median_of_psds job
	"""
	def __init__(self, executable=dagparts.which('gstlal_median_of_psds'), tag_base='gstlal_median_of_psds'):
		"""
		A gstlal_median_of_psds job
		"""
		dagparts.CondorDAGJob.__init__(self, executable, tag_base)


class gstlal_median_of_psds_node(dagparts.CondorDAGNode):
	"""
	A gstlal_median_of_psds node
	"""
	def __init__(self, job, dag, output_name = "median_psd.xml.gz", input_files = [], p_node = []):
		dagparts.CondorDAGNode.__init__(self, job, dag, p_node)
		self.add_var_opt("output-name", output_name)
		for f in input_files:
			self.add_file_arg(f)


def parse_command_line():
	parser = OptionParser(description = __doc__)

	# generic data source options
	datasource.append_options(parser)

	parser.add_option("--max-segment-length", type="int", metavar = "dur", default = 30000, help = "Break up segments longer than dur seconds into shorter (contiguous, non-overlapping) segments. Default 30000 seconds.")
	parser.add_option("--verbose", action = "store_true", help = "Be verbose")

	options, filenames = parser.parse_args()

	return options, filenames


#
# MAIN
#

options, filenames = parse_command_line()

detectors = datasource.GWDataSourceInfo(options)
channel_dict = detectors.channel_dict

#
# Setup analysis segments
#

segs = detectors.frame_segments

# union of all single detector segments that we want to analyze
segs = segs.union(channel_dict.keys()).coalesce()

# intersect so we only analyze segments in the requested time 
boundary_seg = detectors.seg
segs &= segments.segmentlist([boundary_seg])

# FIXME break up long segments into smaller ones with 1024 of overlap
segs = dagparts.breakupsegs(segs, options.max_segment_length, 1024)

try: os.mkdir("logs")
except: pass
dag = dagparts.CondorDAG("psd_pipe")

#
# setup the job classes
#

refPSDJob = gstlal_reference_psd_job()
medianPSDJob = gstlal_median_of_psds_job()

#
# Precompute the PSDs for each segment
#

def hash_seg(seg):
	# FIXME what is a good way to hash the segment?
	return str(seg)

psd_nodes = {}
for seg in segs:
	psd_nodes[hash_seg(seg)] = gstlal_reference_psd_node(refPSDJob, dag, options.frame_cache, options.frame_segments_file, options.frame_segments_name, seg[0].seconds, seg[1].seconds, channel_dict.keys(), channel_dict, injections=None, p_node=[])

gstlal_median_of_psds_node(medianPSDJob, dag, output_name = "median_psd.xml.gz", input_files = [node.output_name for node in psd_nodes.values()], p_node = psd_nodes.values())

dag.write_sub_files()
dag.write_dag()
dag.write_script()
dag.write_cache()
