#! /usr/bin/python
from optparse import *
import glob
import sys
from matplotlib import use
use('Agg')
from matplotlib import pyplot as plt
import math
import numpy as np

parser=OptionParser(usage="usage", version="kari")
parser.add_option("","--files", default="H1L1*evaluation*.dat", help="your evaluation/testing files, warning: do not use training files here")
parser.add_option("","--zerolag", default=None, help="your zerolag evaluation files, ex. H1L1*zerolag*.dat; to avoid plotting zerolag do not give this argument")
parser.add_option("","--tag", default="histogram", help="label with ifo combo, run tag")
parser.add_option("","--title", default="", help="this will go on the top of all the plots")
parser.add_option("","--histograms",action="store_true", default=False, help="use if you want to turn on histogram production")
parser.add_option("","--comparison",help="name of variable (look at choices in top row of .dat files) to compare MVSC to in ROC, ex. coinc_inspiral_snr. optional.")
(opts,args)=parser.parse_args()

files = glob.glob(opts.files)
print "These are the .dat files you are plotting from:", files
if len(files) == 0:
	raise RuntimeError("Could not find any files!  Exiting...")
print "Is this all of your evaluation files? It should be..."
zerolag = glob.glob(opts.zerolag) if opts.zerolag else None
tag = opts.tag

variables = []
tmp = None

for f in files:
	flines = open(f).readlines()
	variables = flines[0].split()
	formats = ['i','i']+['g8' for a in range(len(variables)-2)]
	if tmp != None:
		data = np.concatenate((data,np.loadtxt(f, skiprows=1, dtype={'names':variables, 'formats':formats})), axis=0)
	else:
		tmp = np.loadtxt(f, skiprows=1, dtype={'names':variables, 'formats':formats})
		data = tmp
print "variables:", variables
print "total events", data.shape
bk_data = data[np.nonzero(data['i']==0)[0],:]
sg_data = data[np.nonzero(data['i']==1)[0],:]
print "total background events (Class 0)", bk_data.shape
print "total signal events (Class 1)", sg_data.shape
maxlen = max(len(bk_data), len(sg_data))

if opts.zerolag:
	for f in zerolag:
		flines = open(f).readlines()
		variables = flines[0].split()
		formats = ['i','i']+['g8' for a in range(len(variables)-2)]
		zl_data = np.loadtxt(f, skiprows=1, dtype={'names':variables, 'formats':formats})
	zl_data = zl_data[np.nonzero(zl_data['i']==0)[0],:]
else:
	zl_data = None

def get_limits(var, bk_data, sg_data, zl_data, opt, log=False):
	if opt.zerolag:
		min_x=min(np.min(bk_data[var]), np.min(sg_data[var]), np.min(zl_data[var]))
		max_x=max(np.max(bk_data[var]), np.max(sg_data[var]), np.max(zl_data[var]))
	else:
		min_x=min(np.min(bk_data[var]), np.min(sg_data[var]))
		max_x=max(np.max(bk_data[var]), np.max(sg_data[var]))
	return (np.log10(min_x), np.log10(max_x)) if log else (min_x, max_x)

if opts.histograms:
	# make loglog histograms
	for i,var in enumerate(variables):
		min_x, max_x = get_limits(var, bk_data, sg_data, zl_data, opts, log=True)
		try:
			nbk,bins,patches = plt.hist(np.log10(bk_data[var]), bins=100, range=(min_x,max_x), normed=True, log=True, label="background")
		except ValueError:
			print >>sys.stderr, 'ValueError encountered in making histogram of log %s (range may include 0 or negative values)' % var
			continue
		plt.setp(patches, 'facecolor', 'k', 'alpha', .5)
		plt.hold(1)
		nsg,bins,patches = plt.hist(np.log10(sg_data[var]), bins=100, range=(min_x,max_x), normed=True, log=True, label="signal")
		plt.setp(patches, 'facecolor', 'r', 'alpha', .5)
		if opts.zerolag:
			n,bins,patches = plt.hist(np.log10(zl_data[var]), bins=100, range=(min_x,max_x), normed=True, log=True, label="zerolag")
			plt.setp(patches, 'facecolor', 'b', 'alpha', .5)
		allplot = np.concatenate((nbk,nsg))
		plt.ylim(0.5*np.ma.masked_equal(allplot, 0, copy=False).min(), 3*allplot.max())   # make sure all samples in tails are plotted
		plt.xlim(0.95*min_x if min_x>0 else 1.05*min_x, 1.05*max_x if max_x>0 else 0.95*max_x) # allow for breathing space around hist
		plt.xlabel("log10 "+var, size='x-large')
		plt.title(opts.title, size='x-large')
		plt.ylabel("density", size='x-large')
		plt.legend(loc="upper right")
		plt.savefig("hist_"+var+"_"+tag+"_logx.png")
		plt.close()

	# make histograms
	for i,var in enumerate(variables):
		min_x, max_x = get_limits(var, bk_data, sg_data, zl_data, opts, log=False)
		try:
			nbk,bins,patches = plt.hist(bk_data[var], bins=100, range=(min_x,max_x), normed=True, log=True, label="background")
		except ValueError:
			print >>sys.stderr, 'ValueError encountered in making histogram of %s' % var
			continue
		plt.setp(patches, 'facecolor', 'k', 'alpha', .5)
		plt.hold(1)
		nsg,bins,patches = plt.hist(sg_data[var], bins=100, range=(min_x,max_x), normed=True, log=True, label="signal")
		plt.setp(patches, 'facecolor', 'r', 'alpha', .5)
		if opts.zerolag:
			n,bins,patches = plt.hist(zl_data[var], bins=100, range=(min_x,max_x), normed=True, log=True, label="zerolag")
			plt.setp(patches,'facecolor', 'b', 'alpha', .5)
		allplot = np.concatenate((nbk,nsg))
		plt.ylim(0.5*np.ma.masked_equal(allplot, 0, copy=False).min(), 3*allplot.max())   # make sure all samples in tails are plotted
		plt.xlim(0.95*min_x if min_x>0 else 1.05*min_x, 1.05*max_x if max_x>0 else 0.95*max_x) # allow for breathing space around hist
		plt.xlabel(var, size='x-large')
		plt.title(opts.title, size='x-large')
		plt.ylabel("density", size='x-large')
		plt.legend(loc="upper right")
		plt.savefig("hist_"+var+"_"+tag+".png")
                plt.close()

#prepare to make ROC plot
all_ranks = np.concatenate((bk_data['Bagger'], sg_data['Bagger']))
all_ranks_sorted = np.sort(all_ranks)
background_ranks_sorted = np.sort(bk_data['Bagger'])
signal_ranks_sorted = np.sort(sg_data['Bagger'])
number_of_false_alarms = []
number_of_true_alarms = []
deadtime = [] #fraction
efficiency = [] #fraction

print len(all_ranks_sorted), "all ranks"
print len(signal_ranks_sorted), "signal ranks"
print len(background_ranks_sorted), "background ranks"

FAP = [] # false alarm percentage
TAP = [] # true alarm percentage
# classic ROC curve
for i,rank in enumerate(background_ranks_sorted):
	# get the number of background triggers with rank greater than or equal to a given rank
	number_of_false_alarms = len(background_ranks_sorted) - np.searchsorted(background_ranks_sorted,rank)
	# get the number of signals with rank greater than or equal to a given rank
	number_of_true_alarms = len(signal_ranks_sorted) - np.searchsorted(signal_ranks_sorted,rank)
	# calculate the total deadime if this given rank is used as the threshold
	FAP.append( number_of_false_alarms / float(len(background_ranks_sorted)))
	# calculate the fraction of correctly flagged signals
	TAP.append(number_of_true_alarms / float(len(signal_ranks_sorted)))

plt.figure(1000)
plt.semilogx(FAP,TAP, linewidth = 2.0,label='MVSC')
plt.hold(True)
#x = np.arange(0,1,.00001)
#plt.semilogy(x,x, linestyle="dashed", linewidth = 2.0)
plt.xlabel('False Alarm Fraction', size='x-large')
plt.ylabel('Efficiency', size='x-large')
plt.xlim([0,1])
plt.ylim([0,1])

if opts.comparison:
	all_ranks = np.concatenate((bk_data[opts.comparison],sg_data[opts.comparison]))
	all_ranks_sorted = np.sort(all_ranks)
	background_ranks_sorted = np.sort(bk_data[opts.comparison])
	signal_ranks_sorted = np.sort(sg_data[opts.comparison])
	number_of_false_alarms=[]
	number_of_true_alarms=[]
	deadtime=[] #fraction
	efficiency=[] #fraction
	FAP = [] # false alarm percentage
	TAP = [] # true alarm percentage
	# classic ROC curve
	for i,rank in enumerate(background_ranks_sorted):
		# get the number of background triggers with rank greater than or equal to a given rank
		number_of_false_alarms = len(background_ranks_sorted) - np.searchsorted(background_ranks_sorted,rank)
		# get the number of signals with rank greater than or equal to a given rank
		number_of_true_alarms = len(signal_ranks_sorted) - np.searchsorted(signal_ranks_sorted,rank)
		# calculate the total deadime if this given rank is used as the threshold
		FAP.append( number_of_false_alarms / float(len(background_ranks_sorted)))
		# calculate the fraction of correctly flagged signals
		TAP.append(number_of_true_alarms / float(len(signal_ranks_sorted)))
	plt.semilogx(FAP, TAP, linewidth=2.0, label=opts.comparison)
plt.legend(loc='lower right')
plt.title(opts.title, size='x-large')
plt.savefig('ROC_'+opts.tag+'.png')
plt.close()

print 'Done!'
