#!/usr/bin/env python

# Copyright (C) 2011 Ian W. Harry
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

"""
Aligned spin bank generator.
"""


from __future__ import division
import warnings

def _warning(
    message,
    category = UserWarning,
    filename = '',
    lineno = -1):
    print(message)

warnings.showwarning = _warning

warnings.warn("""DEPRECATION WARNING:
THE FOLLOWING CODE HAS BEEN DEPRECATED AND MOVED INTO THE PYCBC PACKAGE. IT WILL BE REMOVED FROM PYLAL AT SOME POINT IN THE NEAR FUTURE.

FOR NOW PLEASE CONSULT:

https://ldas-jobs.ligo.caltech.edu/~cbc/docs/pycbc/tmpltbank.html

FOR INSTRUCTIONS OF HOW TO USE THE REPLACEMENTS IN PYCBC. PYCBC BUILD INSTRUCTIONS ARE HERE

https://ldas-jobs.ligo.caltech.edu/~cbc/docs/pycbc/index.html""", DeprecationWarning)

import matplotlib
matplotlib.use('Agg')
import pylab
import time
start = int(time.time()*10**6)
elapsed_time = lambda: int(time.time()*10**6-start)

import os,sys,optparse
import tempfile
import ConfigParser
import numpy
from pylal import geom_aligned_bank_utils,git_version
from glue import pipeline

__author__  = "Ian Harry <ian.harry@astro.cf.ac.uk>"
__version__ = "git id %s" % git_version.id
__date__    = git_version.date


# Read command line options
usage = """usage: %prog [options]"""
_desc = __doc__[1:]
parser = optparse.OptionParser(usage, version=__version__, description=_desc)
parser.add_option("-v", "--verbose", action="store_true", default=False,\
                    help="verbose output, default: %default")
parser.add_option("-L", "--log-path", action="store", type="string",\
                   default=None,\
                   help="Directory to store condor logs.")
parser.add_option("-a", "--psd-file", action="store", type="string",\
                   default=None,\
                   help="ASCII file containing the PSD.")
parser.add_option("-o", "--pn-order", action="store", type="string",\
                   default=None,\
                   help="""Determines the PN order to use, choices are:
    * "twoPN": will include spin and non-spin terms up to 2PN in phase
    * "threePointFivePN": will include non-spin terms to 3.5PN, spin to 2.5PN
    * "taylorF4_45PN": use the R2D2 metric with partial terms to 4.5PN""")
parser.add_option("-f", "--f0", action="store", type="float",\
                  default=70., help="f0 for use in metric calculation," +\
                                    "default: %default")
parser.add_option("-l", "--f-low", action="store", type="float",\
                  default=15., help="f_low for use in metric calculation," +\
                                    "default: %default")
parser.add_option("-u", "--f-upper", action="store", type="float",\
                  default=2000., help="f_up for use in metric calculation," +\
                                      "default: %default")
parser.add_option("-d", "--delta-f", action="store", type="float",\
                  default=0.001, help="delta_f for use in metric calculation,"+\
                                      "linear interpolation used to get this,"+\
                                      "default: %default")
parser.add_option("-m", "--min-match", action="store", type="float",\
                  default=0.03, help="Minimum match to generate bank with"+\
                                      "default: %default")
parser.add_option("-y", "--min-mass1", action="store", type="float",\
                  default=0.03, help="Minimum mass1 to generate bank with"+\
                                     ", mass1 *must* be larger than mass2" +\
                                      "default: %default")
parser.add_option("-Y", "--max-mass1", action="store", type="float",\
                  default=0.03, help="Maximum mass1 to generate bank with"+\
                                      "default: %default")
parser.add_option("-z", "--min-mass2", action="store", type="float",\
                  default=0.03, help="Minimum mass2 to generate bank with"+\
                                      "default: %default")
parser.add_option("-Z", "--max-mass2", action="store", type="float",\
                  default=0.03, help="Maximum mass2 to generate bank with"+\
                                      "default: %default")
parser.add_option("-x", "--max-ns-spin-mag", action="store", type="float",\
                  default=0.03, help="Maximum neutron star spin magnitude"+\
                                      "default: %default")
parser.add_option("-X", "--max-bh-spin-mag", action="store", type="float",\
                  default=0.03, help="Maximum black hole spin magnitude"+\
                                      "default: %default")
parser.add_option("-n", "--nsbh-flag", action="store_true", default=False,\
                    help="Set this if running with NSBH, default: %default")
parser.add_option("-s", "--stack-distance", action="store", type="float",\
                  default=0.2, help="Minimum metric spacing before we stack"+\
                                      "default: %default")
parser.add_option("-3", "--threed-lattice", action="store_true", default=False,\
                    help="Set this to use a 3D lattice, default: %default")
parser.add_option("-S", "--split-bank-num", action="store", type="int",\
                    default=100,\
                    help="Number of points per job in dag, default: %default")


(opts,args) = parser.parse_args()

opts.min_total_mass = opts.min_mass1 + opts.min_mass2
opts.max_total_mass = opts.max_mass1 + opts.max_mass2
opts.min_comp_mass = opts.min_mass2
opts.max_comp_mass = opts.max_mass2

# Begin by calculating a metric
evals,evecs = geom_aligned_bank_utils.determine_eigen_directions(opts.psd_file,\
    opts.pn_order,opts.f0,opts.f_low,opts.f_upper,opts.delta_f,\
    verbose=opts.verbose,elapsed_time=elapsed_time)

if opts.verbose:
  print >>sys.stdout, "Calculating covariance matrix at %d." %(elapsed_time())

vals = geom_aligned_bank_utils.estimate_mass_range_slimline(1000000,\
       opts.pn_order,evals['fixed'],evecs['fixed'],opts.max_mass1,\
       opts.min_mass1,opts.max_mass2,opts.min_mass2,\
       opts.max_ns_spin_mag,opts.f0,\
       covary=False,maxBHspin=opts.max_bh_spin_mag)
cov = numpy.cov(vals)
evalsCV,evecsCV = numpy.linalg.eig(cov)

if opts.verbose:
  print>> sys.stdout, "Covariance matrix calculated at %d." %(elapsed_time())

vals = geom_aligned_bank_utils.estimate_mass_range_slimline(1000000,\
       opts.pn_order,evals['fixed'],evecs['fixed'],opts.max_mass1,\
       opts.min_mass1,opts.max_mass2,opts.min_mass2,\
       opts.max_ns_spin_mag,opts.f0,\
       covary=True,evecsCV=evecsCV,maxBHspin=opts.max_bh_spin_mag)

chi1Max = vals[0].max()
chi1Min = vals[0].min()
chi1Diff = chi1Max - chi1Min
chi2Max = vals[1].max()
chi2Min = vals[1].min()
chi2Diff = chi2Max - chi2Min

if opts.verbose:
  print>> sys.stdout, "Calculating lattice at %d." %(elapsed_time())

if not opts.threed_lattice:
  v1s,v2s = geom_aligned_bank_utils.generate_hexagonal_lattice(chi1Max + 0.02*chi1Diff,chi1Min - 0.02*chi1Diff,chi2Max + 0.02*chi2Diff,chi2Min - 0.02*chi2Diff,opts.min_match)
else:
  chi3Max = vals[2].max()
  chi3Min = vals[2].min()
  chi3Diff = chi3Max - chi3Min
  v1s,v2s,v3s = geom_aligned_bank_utils.generate_anstar_3d_lattice(chi1Max + 0.02*chi1Diff,chi1Min - 0.02*chi1Diff,chi2Max + 0.02*chi2Diff,chi2Min - 0.02*chi2Diff,chi3Max + 0.02*chi3Diff,chi3Min - 0.02*chi3Diff,opts.min_match)
  chi3Max = vals[2].max()
  chi3Min = vals[2].min()
  chi3Diff = chi3Max - chi3Min

if opts.verbose:
  print>> sys.stdout, "Lattice calculated at %d." %(elapsed_time())
  print>> sys.stdout, "Lattice contains %d points." %(len(v1s))

# Dump necessary information and make a few plots
geom_aligned_bank_utils.make_plots(vals[0],vals[1],vals[2],vals[3],'chi 1','chi 2','chi 3','chi 4')

pylab.figure(10)
pylab.plot(v1s,v2s,'b.')
pylab.grid()
pylab.xlabel('chi 1')
pylab.ylabel('chi 2')
pylab.savefig('plots/bank.png')

numpy.savetxt('metric_evecs.dat',evecs['fixed'])
numpy.savetxt('metric_evals.dat',evals['fixed'])
numpy.savetxt('covariance_evecs.dat',evecsCV)

bankFile = open('bank_chis.dat','w')
if opts.threed_lattice:
  for i in xrange(len(v1s)):
    print >> bankFile, "%e %e %e" %(v1s[i],v2s[i],v3s[i])
else:
  for i in xrange(len(v1s)):
    print >> bankFile, "%e %e" %(v1s[i],v2s[i])
bankFile.close()

# Now begin to generate the dag
# First split the bank
if not os.path.isdir('split_banks'):
  os.makedirs('split_banks')
if not os.path.isdir('output_banks'):
  os.makedirs('output_banks')
if not os.path.isdir('logs'):
  os.makedirs('logs')

bankNum = 0
bankFile = open('split_banks/split_bank_%05d.dat'%(bankNum),'w')

if opts.verbose:
  print>> sys.stdout, "Printing split banks at %d." %(elapsed_time())

for i in xrange(len(v1s)):
  if opts.threed_lattice:
    print >> bankFile, "%e %e %e" %(v1s[i],v2s[i],v3s[i])
  else:
    print >> bankFile, "%e %e" %(v1s[i],v2s[i])
  if not (i+1) % opts.split_bank_num:
    bankFile.close()
    bankNum = bankNum + 1
    if not i == (len(v1s)-1):
      bankFile = open('split_banks/split_bank_%05d.dat'%(bankNum),'w')
  
if len(v1s) % opts.split_bank_num:
  bankFile.close()

# And begin dag generation
tempfile.tempdir = opts.log_path
tempfile.template='bank_gen.dag.log.'
logfile = tempfile.mktemp()
fh = open( logfile, "w" )
fh.close()
dag = pipeline.CondorDAG(logfile, False)
dag.set_dag_file('bank_generation')
exe_path = geom_aligned_bank_utils.which('pylal_geom_aligned_2dstack')
job = pipeline.CondorDAGJob('vanilla',exe_path)
#pipeline.AnalysisJob.__init__(job,cp,False)
job.set_stdout_file('logs/bank_gen-$(cluster)-$(process).out')
job.set_stderr_file('logs/bank_gen-$(cluster)-$(process).err')
job.set_sub_file('bank_gen.sub')
# Add global job options
cp = ConfigParser.ConfigParser()
cp.add_section('bank')
cp.set('bank','pn-order',opts.pn_order)
cp.set('bank','metric-evals-file','metric_evals.dat')
cp.set('bank','metric-evecs-file','metric_evecs.dat')
cp.set('bank','cov-evecs-file','covariance_evecs.dat')
cp.set('bank','f0',str(opts.f0))
cp.set('bank','max-mass1',str(opts.max_mass1))
cp.set('bank','min-mass1',str(opts.min_mass1))
cp.set('bank','max-mass2',str(opts.max_mass2))
cp.set('bank','min-mass2',str(opts.min_mass2))
cp.set('bank','max-ns-spin-mag',str(opts.max_ns_spin_mag))
cp.set('bank','max-bh-spin-mag',str(opts.max_bh_spin_mag))
cp.set('bank','min-match',str(opts.min_match))
cp.set('bank','stack-distance',str(opts.stack_distance))
if opts.nsbh_flag:
  cp.set('bank','nsbh-flag','')
if opts.threed_lattice:
  cp.set('bank','threed-lattice','')
job.add_ini_opts(cp,'bank')
job.add_condor_cmd('Requirements', 'Memory >= 1390')
job.add_condor_cmd('request_memory', '1400')
job.add_condor_cmd('getenv','True')
# Make the output job
cat_path = geom_aligned_bank_utils.which('pylal_aligned_bank_cat')
job_cat = pipeline.CondorDAGJob('vanilla',cat_path)
job_cat.set_stdout_file('logs/bank_cat-$(cluster)-$(process).out')
job_cat.set_stderr_file('logs/bank_cat-$(cluster)-$(process).err')
job_cat.set_sub_file('bank_cat.sub')
job_cat.add_condor_cmd('getenv','True')

# Make the output node
cat_node = pipeline.CondorDAGNode(job_cat)
cat_node.add_var_opt('input-glob','output_banks/output_bank_*.dat')
cat_node.add_var_opt('output-file','aligned_bank.xml')

# Make the nodes
numBanks = int((len(v1s) - 0.5)//opts.split_bank_num) + 1
for i in xrange(numBanks):
  node = pipeline.CondorDAGNode(job)
  node.add_var_opt('input-bank-file','split_banks/split_bank_%05d.dat'%(i))
  node.add_var_opt('output-bank-file','output_banks/output_bank_%05d.dat'%(i))
  cat_node.add_parent(node)
  dag.add_node(node)

dag.add_node(cat_node)
dag.write_sub_files()
dag.write_dag()
dag.write_script()
