#
# noe.py
#
# Copyright (C) 2005  Dr. Stephane Gagne
# the full copyright notice is found in the LICENSE file in this directory
#
# This program is free software; 
# you can redistribute it and/or modify it under the terms of the 
# GNU General Public License as published by the Free Software Foundation; 
# either version 2 of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful, 
# but WITHOUT ANY WARRANTY; without even the implied warranty of 
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. 
# See the GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License 
# along with this program; if not, write to the Free Software Foundation, Inc., 
# 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
#
# Contact:  nmr@rsvs.ulaval.ca
#

import Tkinter, os, string, math, re, time, sys, os.path
import pbar
import pymol
from pymol import cmd
from verify_common import *
from common import *

global VOID
VOID = 9999


#..........................................................................
# generic restraint object.
#
class NOEConstraints:

    def __init__(self):
        self.numCsts = 0
        self.cnst = []



#   ............................
    class CstItem:
        def __init__(self):
            self.upperlimit = 0
            self.lowerlimit = 0
            self.target = 0
            self.atoms = []
            self.raw_atoms = []
            self.resnums = []
            self.serial = 0
            self.items = ""
            
    def assign_line(self, rstline):
        mycst = self.CstItem()
        if rstline == "": return
        rstline = string.split(string.strip(rstline))
        length = len(rstline)
        if (length != 14): return
        self.numCsts = self.numCsts + 1
        mycst.resnums = [rstline[2], rstline[7]]
        mycst.raw_atoms = [rstline[5][:-1], rstline[10][:-1]]
        mycst.target = float(rstline[11])
        mycst.lowerlimit = mycst.target - float(rstline[12])
        mycst.upperlimit = mycst.target + float(rstline[13])
        mycst.serial = self.numCsts
        mycst.items = mycst.resnums[0] + "." + rstline[5][:-1] + " "  + \
                      mycst.resnums[1] + "." + rstline[10][:-1]
        self.cnst.append(mycst)

                

    def readCnsCnsts(self, maingui):
        import re
                                                                                
#       open the file, then read the whole thing into memory
        fullfname = os.path.join(workingdir, maingui.fname.get())
        try: rstfile = open(fullfname, "r")
        except:
            print "** ERROR:  cannot open restraints file", fullfname
            return
        rstlines = rstfile.readlines()
        rstfile.close()

        rstline = ""
        for line in rstlines:
            if line[0] == "!":
                continue
            splitline = string.split(string.strip(line))
            if len(splitline) == 0: continue
           # at the start of a restraint line, which can be split across lines
            if splitline[0] == "assign":
                # try reading the data out of the previous line(s)
                self.assign_line(rstline)
                # re-set rstlines
                rstline = line
            # append to the line
            else: rstline += line

        # and do the last line in the file...
        self.assign_line(rstline)


#   ............................
    def readCyanaCnsts(self, maingui):
        return

#   ............................
    def readCnsts(self, maingui):
        if maingui.ftype.get() == "CNS":
            self.readCnsCnsts(maingui)
        else:
            self.readCyanaCnsts(maingui)


#   ..................................
    def matchCnsts(self, maingui):
        """ match the constraint to lines in an (arbitrary) PDB file """
        return





#..........................................................................
#
class NOEDistances:


    def __init__(self):
        # dictionary looks like this:
        # [ {(rst.serial, obj)}: distance ]
        self.distances = dict()

    def get_atoms(self, obj, resnum, atom):
        """ this builds a list of atom names for the given residue 
            ie. if the residue is 3, the atom is HB*, and the obj mine.pdb
            then all HB atoms from residue 3 will be returned.
            the pymol command iterate is used to get the names """

        mylist = []
        pymol.stored.tmplist = []
        # if atom contains * expand it into a list
        if atom[-1] == "*":
                
            # get list of all atoms for this res
            cmd.iterate(obj+"///"+resnum+"/", "stored.tmplist.append(name)")

            # look for the target atom names
            rexp = re.compile(atom[:-1]+"+", re.IGNORECASE)
            for atm in pymol.stored.tmplist:
                if rexp.match(atm):
                    mylist.append(obj+"///"+resnum+"/"+atm)
        else:
            mylist = [obj+"///"+resnum+"/"+atom]

        return mylist


    # function to calculate the noe distances
    def calc_distances(self, obj, myverify, hashtable):

        method = myverify.maingui.avg.get()
        avg = myverify.maingui.avg.get()


#       for each restraint, calc distance
        junk = 0
        junk2 = 0
        junk3 = 0
        for rst in myverify.constraints.cnst:

            junk3 = junk3 + 1

            # get (list of) coordinate tuple for each half of the restraint
            # coord_tuple = [(x, y, z), (x, y, z), ... ]
            # matchX = a list of matching line # in the PDB file
            match1 = pdb.myPDB[obj].find_line(rst.raw_atoms[0], 
                rst.resnums[0], hashtable)
            match2 = pdb.myPDB[obj].find_line(rst.raw_atoms[1], 
                rst.resnums[1], hashtable)
            ct1 = list()
            ct2 = list()
            for idx in match1:
                ct1.append(pdb.myPDB[obj].items[idx].coord)
            for idx in match2:
                ct2.append(pdb.myPDB[obj].items[idx].coord)

            # check to make sure the restraint has matched something
            if (len(match1) == 0 or len(match2) == 0):
                junk = junk + 1
                continue
                

#           now calculate the distance.  
#           the distance calculation varies depending
#           on the averaging method.
            if avg == 0: dist = self.r6_dist(ct1, ct2)
            elif avg == 1: dist = self.r3_dist(ct1, ct2)
            elif avg == 2: dist = self.cent_dist(ct1, ct2)
            elif avg == 3: dist = self.sum_dist(myverify, ct1, ct2)
            elif avg == 4: dist = self.dist_points(ct1[0], ct2[0])
            elif avg == 5: dist = self.closest_dist(ct1, ct2)

            try: 
                self.setDistance((rst.serial, obj), 
                (pdb.myPDB[obj].items[match1[0]].atom,
                 pdb.myPDB[obj].items[match2[0]].atom,
                 rst.raw_atoms[0], rst.raw_atoms[1]), dist)
                junk2 = junk2 + 1
            except: 
                junk = junk + 1
                continue

        # print "Number of constraints not matching any atom: ", junk

    def r6_dist(self, ct1, ct2):
        total = 0.0
        cnt = 0
        for coord1 in ct1:
            for coord2 in ct2:
                total += pow(self.dist_points(coord1, coord2), -6)
                cnt += 1
        if cnt > 0:
            return(pow(total/cnt, -(1.0/6)))
        else: return(9999)     

    def r3_dist(self, ct1, ct2):
        total = 0.0
        cnt = 0
        for coord1 in ct1:
            for coord2 in ct2:
                total += pow(self.dist_points(coord1, coord2), -3)
                cnt += 1
        if cnt > 0:
            return(pow(total/cnt, -(1.0/3)))
        else: return(9999)     

    def sum_dist(self, myverify, ct1, ct2):
        total = 0.0
        mono = myverify.maingui.nmono.get()
        for coord1 in ct1:
            for coord2 in ct2:
                total += pow(self.dist_points(coord1, coord2), -6)
        if total > 0:
            return(pow(total/mono, -(1.0/6)))
        else: return(9999)

    def cent_dist(self, ct1, ct2):

        def centre(coordinates):
            avg = [0.0, 0.0, 0.0]
            cnt = 0
            for coord in coordinates:
                for i in (0, 1, 2):
                    avg[i] = avg[i] + coord[i]
                cnt += 1
            return([pt/cnt for pt in avg])

        pt1 = centre(ct1)
        pt2 = centre(ct2)
        return(self.dist_points(pt1, pt2))
        
    def closest_dist(self, ct1, ct2):
        smallest = 999.0
        for coord1 in ct1:
            for coord2 in ct2:
                dist = self.dist_points(coord1, coord2)
                if dist < smallest:
                    smallest = dist
        return(smallest)


    # calculate the distance between two points.
    def dist_points(self, pt0, pt1):

        sum = 0
        for i in (0, 1, 2):
            diff = (pt0[i] - pt1[i])
            sum = sum + (diff * diff)
        return math.sqrt(sum)


    def setDistance(self, key, tup1, dist):

        try: test = self.distances[key]
        except KeyError: 
            self.distances[key] = []
        self.distances[key].append((tup1, dist))

    def getDistance(self, rstnum, obj):

        try:  result = self.distances[(rstnum, obj)]
        except KeyError:
#            return (("",""), VOID)
            return []
        return result





class NoeViolation(Violations):

    # compare the constraints against the actual distances, and when it finds
    # a violation it stores it in the violations object.
    def checkViolations(self, obj, rsts, distances, violations, maingui):

        cutoff = maingui.cutoff.get()

        # foreachline in the constraint file

        for cst in rsts.cnst:

            uplim = cst.upperlimit
            lowlim = cst.lowerlimit
#        ((at1,at2), actual) = distances.getDistance(cst.serial, obj)
            for ((at1,at2,raw1,raw2), actual) in distances.getDistance(cst.serial, obj):
                if actual == VOID:
                    continue 
                if ( actual <= (lowlim-cutoff)) :
                    violations.vappend(obj, lowlim, uplim, actual,
                        cst.serial, 
                        cst.resnums[0]+'.'+raw1+' '+cst.resnums[1]+'.'+raw2, 
                        "NOE", (lowlim - actual), at1, at2)
                elif (actual >= (uplim + cutoff)):
                    violations.vappend(obj, lowlim, uplim, actual,
                        cst.serial,
                        cst.resnums[0]+'.'+raw1+' '+cst.resnums[1]+'.'+raw2, 
                        "NOE", (actual - uplim ), at1, at2)

# debug.  added next code at Oli's suggestion.  what we are doing here is
# calling EVERYTHING a violation, no matter what.  So each restraint will
# generate a lien in violations, but the value of the violation will be "0".
# later we will weed these out.
                else :
                    violations.vappend(obj, lowlim, uplim, actual,
                        cst.serial,
                        cst.resnums[0]+'.'+raw1+' '+cst.resnums[1]+'.'+raw2, 
                        "NOE", 0, at1, at2)
    
        
    # function to do the checking of dihedral restraints.
    ## this is called from verify.py
    def doVerifyNOE(self, myverify):
        from pymol import cmd 

        # correlate the restraints to matching pdb lines
        myverify.constraints.matchCnsts(self)

        # for each on-screen object (pdb file)
        objs = cmd.get_names('objects',1)

        # put up progress bar because this takes a while
        max = len(objs)
        if (max == 0): 
            return
        mybar = pbar.ProgressBar(max=max)
        mytimeS = time.localtime(time.time())
        counter = 1



        for obj in objs:

            # this is needed in find_line (called in calc_distances)
            # it is here to take it out of any loop, to speed execution
            if myverify.maingui.proton.get() == 1:
                hashtable = pdb.myPDB[obj].hash_resP
            else: hashtable = pdb.myPDB[obj].hash_res

            # calulcate the noe distances for each restraint item and object
            myverify.distances.calc_distances(obj, myverify, hashtable)

            self.checkViolations(obj, myverify.constraints, 
                myverify.distances, myverify.violations, myverify.maingui)

            mybar.update(counter)
            counter += 1

        mybar.done()
        mytimeF = time.localtime(time.time())
