#!/usr/bin/env python
#
# Copyright (C) 2010, 2011  Jordi Burguet-Castell
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 2 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

from gstlal import pipeparts
from gstlal import reference_psd
from pylal.date import XLALUTCToGPS
import time
from gstlal import simplehandler
from gstlal import datasource
from optparse import OptionParser, Option
from glue.ligolw import ligolw
from glue.ligolw import array
from glue.ligolw import param
array.use_in(ligolw.LIGOLWContentHandler)
param.use_in(ligolw.LIGOLWContentHandler)
from glue.ligolw import utils
from glue import segments
from pylal.xlal.datatypes.ligotimegps import LIGOTimeGPS
import pygtk
pygtk.require("2.0")
import gobject
gobject.threads_init()
import pygst
pygst.require("0.10")
import gst
import sys
import os
import numpy

def psd_resolution_changed(elem, pspec, psd):
	# get frequency resolution and number of bins
	delta_f = elem.get_property("delta-f")
	n = int(round(elem.get_property("f-nyquist") / delta_f) + 1)
	# interpolate and install PSD
	psd = reference_psd.interpolate_psd(psd, delta_f)
	elem.set_property("mean-psd", psd.data[:n])

def write_graph(demux):
	pipeparts.write_dump_dot(pipeline, "%s.%s" % (options.write_pipeline, "PLAYING"), verbose = True)
	
parser = OptionParser(description = __doc__)

#
# Append data source options
#

datasource.append_options(parser)

#
# Append program specific options
#

parser.add_option("--output-channel-name", metavar = "name", help = "The name of the channel in the output frames. The default is the same as the channel name")
parser.add_option("--reference-psd", metavar = "name", help = "Set the name of psd xml file to whiten the data with")
parser.add_option("--recolor-psd", metavar = "name", help = "Set the name of psd xml file to recolor the data with")
parser.add_option("--track-psd", action = "store_true", help = "Calculate PSD from input data and track with time.")
parser.add_option("-v", "--verbose", action = "store_true", help = "Be verbose (optional).")
parser.add_option("--sample-rate", metavar = "Hz", default = 4096, type = "int", help = "Sample rate at which to generate the data, should be less than or equal to the sample rate of the measured psds provided, default = 4096 Hz")
parser.add_option("--frame-duration", metavar = "seconds", type = "int", help = "Set the number of seconds for each frame.")
parser.add_option("--frames-per-file", metavar = "count", type = "int", help = "Set the number of frames per frame file.")
parser.add_option("--write-pipeline", metavar = "filename", help = "Write a DOT graph description of the as-built pipeline to this file (optional).  The environment variable GST_DEBUG_DUMP_DOT_DIR must be set for this option to work.")
parser.add_option("--write-to-shm-partition", metavar = "name", help = "Set the name of the shared memory partition to write to. If this is not provided, will be written to file.")
parser.add_option("--required-on", metavar = "bit", type = "int", default = 0x1, help = "Set the required-on bits for the ODC to DQ vector conversion.")
parser.add_option("--buffer-mode", metavar = "number", type = "int", default = 2, help = "Set the buffer mode for the lvshmsink element. (Default=2)")
parser.add_option("--frame-type", metavar = "name", default = "FAKE_STRAIN", help = "Set the frame type as input to the frame writing element. (Default=FAKE_STRAIN)")
parser.add_option("--output-path", metavar = "name", default = ".", help = "Set the output path for writing frame files. (Default=Current)")
parser.add_option("--calibrate", action = "store_true", help = "Calibrate the data instead of recoloring.")
parser.add_option("--filters-file", help = "Name of file containing filters (in npz format)")
parser.add_option("--wings", type="int", default=16, help = "Size of wings in seconds.")
parser.add_option("--doubles", action="store_true", help="Use doubles instead of floats.")
parser.add_option("--time-domain", action="store_true", help="Apply FIR filters in the time domain.")

#
# Parse options
#

options, filenames = parser.parse_args()

if options.output_channel_name is None:
	options.output_channel_name = options.channel_name

sr = options.sample_rate

gw_data_source = datasource.GWDataSourceInfo(options)

# Assume instrument is the first (only) key of the channel dict
instrument = gw_data_source.channel_dict.keys()[0]

#
# Read psd file
#

if not options.calibrate:
	if options.reference_psd is not None:
		wpsd = reference_psd.read_psd_xmldoc(utils.load_filename(options.reference_psd, verbose = options.verbose, contenthandler = ligolw.LIGOLWContentHandler))[instrument]
	else:
		wpsd = None
		if options.verbose:
			print >>sys.stderr, "No reference PSD provided, whitening will be done on the fly."
	rpsd = reference_psd.read_psd_xmldoc(utils.load_filename(options.recolor_psd, verbose = options.verbose, contenthandler = ligolw.LIGOLWContentHandler))[instrument]

#
# Setup the pipeline
#

pipeline = gst.Pipeline(sys.argv[0])
mainloop = gobject.MainLoop()
handler = simplehandler.Handler(mainloop, pipeline)

# 
# Turn off debugging tools or verboseness
#

pipeparts.mkchecktimestamps = lambda pipeline, src, *args: src
if not options.verbose:
	pipeparts.mkprogressreport = lambda pipeline, src, *args: src

#
# Read in data from frames or shared memory
#

if options.data_source == "lvshm":
	src = pipeparts.mklvshmsrc(pipeline, shm_name = gw_data_source.shm_part_dict[instrument], assumed_duration = 1)
elif options.data_source == "frames":
	src = pipeparts.mklalcachesrc(pipeline, location = gw_data_source.frame_cache, cache_dsc_regex = instrument)
elif options.data_source == "white":
	rawstrain = pipeparts.mkaudiotestsrc(pipeline, wave = 9, samplesperbuffer=16384, blocksize = 16384 * 4 * 1)
	rawstrain = pipeparts.mkcapsfilter(pipeline, rawstrain, "audio/x-raw-float, width=32, rate=16384")	

	odcstatevector = pipeparts.mkaudiotestsrc(pipeline, wave = 4, samplesperbuffer=16384, blocksize = 16384 * 4 * 1)
	odcstatevector = pipeparts.mkcapsfilter(pipeline, odcstatevector, "audio/x-raw-float, width=32, rate=16384")
	odcstatevector = pipeparts.mkgeneric(pipeline, odcstatevector, "exp")
	odcstatevector = pipeparts.mkgeneric(pipeline, odcstatevector, "lal_fixodc")
else:
	raise ValueError("invalid --data-source %s" % options.data_source)

if options.data_source != "white":
	channel_list = gw_data_source.channel_dict.items()
	dq_channel_list = (instrument, gw_data_source.dq_channel_dict[instrument])
	channel_list.append(dq_channel_list)
	demux = pipeparts.mkframecppchanneldemux(pipeline, src, do_file_checksum = True, skip_bad_files = True, channel_list = map("%s:%s".__mod__, channel_list))
	# Write the pipeline graph after pads have been hooked up to the demuxer
	if options.write_pipeline is not None:
		demux.connect("no-more-pads", write_graph)	

	# Set up the raw strain and ODC state vector branches
	rawstrain = pipeparts.mkqueue(pipeline, None)
	odcstatevector = pipeparts.mkgeneric(pipeline, None, "lal_fixodc")

	# Hook up the raw strain and ODC state vector branches to appropriate channels in the demuxer
	pipeparts.src_deferred_link(demux, "%s:%s" % (instrument, gw_data_source.channel_dict[instrument]), rawstrain.get_pad("sink"))
	pipeparts.src_deferred_link(demux, "%s:%s" % (instrument, gw_data_source.dq_channel_dict[instrument]), odcstatevector.get_pad("sink"))

rawstrain = pipeparts.mkreblock(pipeline, rawstrain, block_duration = gst.SECOND)
odcstatevector = pipeparts.mkreblock(pipeline, odcstatevector, block_duration = gst.SECOND)

# When reading from disk, clip raw strain stream to segment list
if options.data_source == "frames" and gw_data_source.frame_segments[instrument] is not None:
	rawstrain = pipeparts.mkgate(pipeline, rawstrain, threshold = 1, control = pipeparts.mksegmentsrc(pipeline, gw_data_source.frame_segments[instrument]))

#
# ODC STATE VECTOR BRANCH
#

#odcstatevector = pipeparts.mkodctodqv(pipeline, odcstatevector, required_on = options.required_on, status_out = 0x3)
odcstatevector = pipeparts.mkodctodqv(pipeline, odcstatevector, required_on = options.required_on, status_out = 0x7)
odcstatevectortee = pipeparts.mktee(pipeline, odcstatevector)

# This is the branch that gets converted to the DQ vector
odcstatevector = pipeparts.mkaudioundersample(pipeline, odcstatevectortee)
odcstatevector = pipeparts.mkcapsfilter(pipeline, odcstatevector, "audio/x-raw-int, rate=1")
odcstatevector = pipeparts.mkprogressreport(pipeline, odcstatevector, "progress_odc_%s" % instrument)

#
# RECOLORING OR CALIBRATION BRANCH
#

# Use the ODC state vector to gate the raw strain channel
odccontrol = pipeparts.mkqueue(pipeline, odcstatevectortee)
odccontrol = pipeparts.mkstatevector(pipeline, odccontrol, required_on = 0x3)
rawstrain = pipeparts.mkgate(pipeline, rawstrain, threshold = 1, default_state = False, control =  odccontrol)

# Provide an audioconvert to allow Virgo data (which is single-precision) to be adapted into the pipeline
rawstrain = pipeparts.mkaudioconvert(pipeline, rawstrain)

rawstrain = pipeparts.mkprogressreport(pipeline, rawstrain, "progress_src_%s" % instrument)
rawstrain = pipeparts.mkresample(pipeline, rawstrain, quality = 9)
rawstrain = pipeparts.mkcapsfilter(pipeline, rawstrain, "audio/x-raw-float, width=64, rate=%d" % sr)
rawstrain = pipeparts.mkaudiorate(pipeline, rawstrain, skip_to_first = True, silent = False) # This audiorate works around a bug in the resampler
rawstrain = pipeparts.mkchecktimestamps(pipeline, rawstrain, "%s_timestamps_%d_hoft" % (instrument, sr))

# Whiten the raw strain channel and apply recoloring kernel, if this pipeline is meant to recolor
if not options.calibrate:
	rawstrain = pipeparts.mkwhiten(pipeline, rawstrain, fft_length = 8, zero_pad = 0, average_samples = 64, median_samples = 7, expand_gaps = True, name = "lal_whiten_%s" % instrument)
	if wpsd is None:
		# use running average PSD
		rawstrain.set_property("psd-mode", 0)
	else:
		# use running psd
		if options.track_psd:
			rawstrain.set_property("psd-mode", 0)
		# use fixed PSD
		else:
			rawstrain.set_property("psd-mode", 1)
		rawstrain.connect_after("notify::f-nyquist", psd_resolution_changed, wpsd)
		rawstrain.connect_after("notify::delta-f", psd_resolution_changed, wpsd)
	rawstrain = pipeparts.mkchecktimestamps(pipeline, rawstrain, "%s_timestamps_%d_whitehoft" % (instrument, sr))	

	# Recolor kernel
	max_sample = int(round(1.0 / rpsd.deltaF * sr / 2.0)) + 1 
	# Truncate to requested output sample rate, if it is higher than the psd provides an assert will fail later
	rpsd.data = 1. / rpsd.data[:max_sample]
	fir_matrix, latency, measured_sample_rate = reference_psd.psd_to_fir_kernel(rpsd)
	# Add latency to fix the time stamps
	latency -= 1# FIXME:  remove this if reference_psd.psd_to_fir_kernel() is adjusted
	rawstrain = pipeparts.mkfirbank(pipeline, rawstrain, latency = latency, fir_matrix = [fir_matrix], block_stride = sr)

# Calibrate, if this pipeline is not recoloring
if options.calibrate:
	filters = numpy.load(options.filters_file)
	actuationsr = 2048
	if not options.doubles:
		caps = "audio/x-raw-float, width=32" # = 4 bytes, a float
	if options.doubles:
		caps = "audio/x-raw-float, width=64" # = 8 bytes, a double
	# For now, tee off the PSL (or whatever other temporary channel is being read in) and feed it into both DARM_ERR and DARM_CTRL
	rawstraintee = pipeparts.mktee(pipeline, rawstrain)
	derr = pipeparts.mkqueue(pipeline, rawstraintee)
	dctrl = pipeparts.mkqueue(pipeline, rawstraintee)

	# DARM_ERR branch
	derr = pipeparts.mkaudioconvert(pipeline, derr)
	derr = pipeparts.mkcapsfilter(pipeline, derr, caps)
	derr = pipeparts.mkchecktimestamps(pipeline, derr, "derr_pre_invsensing")
	# FIXME: The latency in this filter causes problems... Find out what and why.
	#derr = pipeparts.mkfirbank(pipeline, derr, latency = int(-filters["inv_sens_delay"]), fir_matrix = [filters["inv_sensing"]], time_domain = options.time_domain, block_stride = sr)
	derr = pipeparts.mkfirbank(pipeline, derr, fir_matrix = [filters["inv_sensing"]], time_domain = options.time_domain, block_stride = sr)
	derr = pipeparts.mkchecktimestamps(pipeline, derr, "derr_post_invsensing")

	# DARM_CTRL branch
	dctrl = pipeparts.mkaudioconvert(pipeline, dctrl)
	dctrl = pipeparts.mkcapsfilter(pipeline, dctrl, caps)
	dctrl = pipeparts.mkchecktimestamps(pipeline, dctrl, "dctrl_pre_awhitening")
	dctrl = pipeparts.mkfirbank(pipeline, dctrl, fir_matrix = [filters["awhitening"]], time_domain = options.time_domain, block_stride = sr)
	dctrl = pipeparts.mkchecktimestamps(pipeline, dctrl, "dcrl_pre_actuation")
	dctrl = pipeparts.mkresample(pipeline, dctrl, quality = 9)
	dctrl = pipeparts.mkcapsfilter(pipeline, dctrl, "audio/x-raw-float, rate=%d" % actuationsr)
	dctrl = pipeparts.mkfirbank(pipeline, dctrl, fir_matrix = [filters["actuation"]], time_domain = options.time_domain, block_stride = actuationsr)
	dctrl = pipeparts.mkchecktimestamps(pipeline, dctrl, "dctrl_post_actuation")
	dctrl = pipeparts.mkresample(pipeline, dctrl, quality = 9)
	dctrl = pipeparts.mkcapsfilter(pipeline, dctrl, "audio/x-raw-float, rate=%d" % sr)

	# Add DARM_ERR and DARM_CTRL to make h(t)
	rawstrain = gst.element_factory_make("lal_adder")
	rawstrain.set_property("sync", True)
	pipeline.add(rawstrain)
	pipeparts.mkqueue(pipeline, derr, max_size_time = gst.SECOND * 100).link(rawstrain)
	pipeparts.mkqueue(pipeline, dctrl, max_size_time = gst.SECOND * 100).link(rawstrain)
	rawstrain = pipeparts.mkprogressreport(pipeline, rawstrain, "%s_progress_hoft" % instrument)
	
	if options.data_source == "frames":
		T = int(options.gps_end_time) - int(options.gps_start_time)
		rawstrain = pipeparts.mktrim(pipeline, rawstrain, initial_offset = sr * options.wings, final_offset = sr * (T - options.wings))
		rawstrain = pipeparts.mkaudiorate(pipeline, rawstrain, silent = False, skip_to_first = True)	
	

# Put the units back to strain before writing to frames
# Additionally, override the output channel name if provided from the command line
rawstraintee = pipeparts.mktee(pipeline, rawstrain)
straintagstr = "units=strain,channel-name=%s,instrument=%s" % (options.output_channel_name, instrument)
strain = pipeparts.mktaginject(pipeline, rawstraintee, straintagstr)

# FIXME: The code below causes the adder to get stuck after 6 seconds.  Find out why and fix it, so that the LLD-DQ_VECTOR bits can be set independently

"""
#
# H(t)-OK BIT BRANCH
#

htdqbit = pipeparts.mkchecktimestamps(pipeline, psltee)
htdqbit = pipeparts.mkbitvectorgen(pipeline, htdqbit, bit_vector = 0x4, nongap_is_control = True)
htdqbit = pipeparts.mkcapsfilter(pipeline, htdqbit, "audio/x-raw-int, width=32")
htdqbit = pipeparts.mkaudioundersample(pipeline, htdqbit)
htdqbit = pipeparts.mkcapsfilter(pipeline, htdqbit, "audio/x-raw-int, rate=1")
htdqbit = pipeparts.mkchecktimestamps(pipeline, htdqbit)

#
# COMBINE ODC VECTOR WITH H(t)-OK BIT
#

dqvector = gst.element_factory_make("lal_adder")
dqvector.set_property("sync", True)
pipeline.add(dqvector)
pipeparts.mkqueue(pipeline, odcstatevector, max_size_time = gst.SECOND * 100).link(dqvector)
pipeparts.mkqueue(pipeline, htdqbit, max_size_time = gst.SECOND * 100).link(dqvector)
dqtagstr = "channel-name=%s:LLD-DQ_VECTOR, instrument=%s" % (instrument, instrument)
dqvector = pipeparts.mktaginject(pipeline, dqvector, dqtagstr)
dqvector = pipeparts.mkchecktimestamps(pipeline, dqvector)
#pipeparts.mknxydumpsink(pipeline, dqvector, "dqvector.dump")
"""

# Until the above code is fixed, make the LLD-DQ_VECTOR by gating the DQ vector with the strain to check for gaps
dqvector = pipeparts.mkgate(pipeline, pipeparts.mkqueue(pipeline, odcstatevector, max_size_time = gst.SECOND * 100), threshold = 1e-300, default_state = False, control = pipeparts.mkqueue(pipeline, rawstraintee), hold_length = -1)
dqtagstr = "channel-name=%s:LLD-DQ_VECTOR, instrument=%s" % (instrument, instrument)
dqvector = pipeparts.mktaginject(pipeline, dqvector, dqtagstr)
dqvector = pipeparts.mkchecktimestamps(pipeline, dqvector)

#
# CREATE MUXER AND HOOK EVERYTHING UP TO IT
#

mux = pipeparts.mkframecppchannelmux(pipeline, None)

if options.frame_duration is not None:
        mux.set_property("frame-duration", options.frame_duration)
if options.frames_per_file is not None:
        mux.set_property("frames-per-file", options.frames_per_file)

# Link the manipulated ODC state vector to muxer
dqvector.get_pad("src").link(mux.get_pad("%s:LLD-DQ_VECTOR" % instrument))

# Link fake strain (recolored PSL) to the muxer
pipeparts.mkqueue(pipeline, strain, max_size_time = gst.SECOND * 100).get_pad("src").link(mux.get_pad("%s:%s" % (instrument, options.output_channel_name)))

mux = pipeparts.mkprogressreport(pipeline, mux, "progress_sink_%s" % instrument)

if options.write_to_shm_partition is not None:
	lvshmsink = gst.element_factory_make("gds_lvshmsink")
	lvshmsink.set_property("shm-name", options.write_to_shm_partition)
	lvshmsink.set_property("num-buffers", 10)
	lvshmsink.set_property("blocksize", 405338 * options.frame_duration * options.frames_per_file)
	lvshmsink.set_property("buffer-mode", options.buffer_mode)
	pipeline.add(lvshmsink)
	mux.link(lvshmsink)
else:
	pipeparts.mkframecppfilesink(pipeline, mux, frame_type = options.frame_type, path = options.output_path, instrument = instrument) 

# Run pipeline

if options.write_pipeline is not None:
	pipeparts.write_dump_dot(pipeline, "%s.%s" %(options.write_pipeline, "NULL"), verbose = options.verbose)

if gw_data_source.data_source == "frames":
	datasource.do_seek(pipeline, gw_data_source.seekevent)
	print >>sys.stderr, "seeking GPS start and stop times ..."

if options.data_source == "white":
	tm = time.gmtime()
	NOW = XLALUTCToGPS(tm).ns()
	seek = gst.event_new_seek(1., gst.FORMAT_TIME, gst.SEEK_FLAG_FLUSH | gst.SEEK_FLAG_KEY_UNIT, gst.SEEK_TYPE_SET, NOW, gst.SEEK_TYPE_SET, -1)
	datasource.do_seek(pipeline, seek)	

if options.verbose:
	print >>sys.stderr, "setting pipeline state to playing ..."
if pipeline.set_state(gst.STATE_PLAYING) == gst.STATE_CHANGE_FAILURE:
	raise RuntimeError("pipeline failed to enter PLAYING state")
if options.write_pipeline is not None:
	pipeparts.write_dump_dot(pipeline, "%s.%s" %(options.write_pipeline, "PLAYING"), verbose = options.verbose)
	
if options.verbose:
	print >>sys.stderr, "running pipeline ..."

mainloop.run()
