
# rmsd.py
#
# Copyright (C) 2005  Dr. Stephane Gagne
# the full copyright notice is found in the LICENSE file in this directory
#
# This program is free software; 
# you can redistribute it and/or modify it under the terms of the 
# GNU General Public License as published by the Free Software Foundation; 
# either version 2 of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful, 
# but WITHOUT ANY WARRANTY; without even the implied warranty of 
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. 
# See the GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License 
# along with this program; if not, write to the Free Software Foundation, Inc., 
# 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
#
# Contact:  nmr@rsvs.ulaval.ca
#
# --------------------------------------------------------------------
# author:  Leigh Willard
# lab:     Stephane Gagne, Laval Universite
# date:    mai 2005
# --------------------------------------------------------------------
#



from pymol import cmd
from common import pdb
from Numeric import *
import time, string
import MLab



class RMSD:
    """ class to hold all rmsd data """

    def __init__(self, ref, atoms, range, bbatoms):

        def create_res2pdb(obj):
            """ create resnum to atomnum hash table """
            self.res2pdb = dict()
            # for each atom
            for i in xrange(self.numatoms):
                # find its resnum
                try: self.res2pdb[int(pdb.myPDB[obj].items[i].res_num)].append(i)
                except: self.res2pdb[int(pdb.myPDB[obj].items[i].res_num)] = [i]

            
        self.ref = ref
        self.atoms = string.split(atoms)

        cmd.delete("mean")
        self.tmpfile = "/tmp/pynmr.mean"
        self.bbatoms = bbatoms
        self.heavyatoms = ["-H"]
        self.objs = cmd.get_names('objects',1)
        self.origrmsdL = []
        self.rmsdD = dict()
        for i in self.objs:
            self.rmsdD[i] = dict()
        self.rmsdD['mean'] = dict()
            
        self.resnums = list()
        self.numatoms = 0 # to be set later
        self.res2atom = dict()
        self.res2name = dict()
        self.atom2name = dict()

#        self.car = dict() # car = coordinates, atom#, residue#
        for i in self.objs:
#            self.car[i] = self.make_coord_list(i)
            self.make_X2X(i)
        self.includeL, self.excludeL = self.make_lists(atoms)
        self.range = self.get_range(self.objs[0], range)


        # dictionary of coordinates, by residue, for each object
        self.res_coord_bb = dict()
        self.res_coord_heavy = dict()
        for obj in self.objs:
            self.res_coord_bb[obj] = dict()
            self.res_coord_heavy[obj] = dict()
        self.res_rmsd_bb = dict()
        self.res_rmsd_heavy = dict()

        # create fast hash table of residue # to pdb structure
        # we only need to do this once, so use any object
        create_res2pdb(obj)


    #----------------------------------------------------------
    def get_range(self, struc, guirange):
        """
        get the residue numbers in a list, using the user entered residue range
        eg. "1, 3-9, 11-111"
        returns [1,3,4,...9,11,12...111]
        """
        import copy

        allrange = list()
        myrange = list()
        numatoms = pdb.myPDB[struc].serial
        for i in range(numatoms):
            resnum = pdb.myPDB[struc].items[i].res_num
            if allrange.count(int(resnum)) > 0:
                continue
            allrange.append(int(resnum))

        if guirange=="all":
            myrange = allrange

        elif guirange.count(' ') > 0:
            print "*** error in Range specification."
            print "*** no spaces allowed.  use comma (,) as separator"

        else:
            parts = guirange.split(',')
            for part in parts:
                chunks = part.split('-')
                if len(chunks) == 1:    # one value only ie. 2
                    try: myrange.append(int(chunks[0]))
                    except: print "*** error in Range specification ", chunks
                if len(chunks) == 2:
                    try: myrange.extend(range(int(chunks[0]), int(chunks[1])+1))
                    except: print "*** error in Range specification ", chunks

            copyrange = copy.copy(myrange)
            for i in copyrange:

                if (allrange.count(i) == 0):
                    myrange.remove(i)

        return myrange


    #----------------------------------------------------------
    # create a dictionary (atomnum,resnum) = coord
    # used later for fast coordinate lookup.
    def make_coord_list(self, obj):
        mylist = list()

        # navigate through coords of mystruc
        numatoms = pdb.myPDB[obj].serial

        for i in xrange(numatoms):
            atomnum = pdb.myPDB[obj].items[i].atom_num
            coord =  pdb.myPDB[obj].items[i].coord
            resnum = pdb.myPDB[obj].items[i].res_num
            x = array(coord, Float0)
            mylist.insert(i, (coord, int(atomnum), int(resnum)))

        return mylist

    #----------------------------------------------------------
    # make the res2atom array - mapping from residue # to atom numbers
    # make the res2name array - mapping from residue # to residue name
    # make the atom2name array - mapping from atom # to atom name
    def make_X2X(self, obj):

        self.numatoms = pdb.myPDB[obj].serial
        firstres = pdb.myPDB[obj].items[0].res_num
        maxres = pdb.myPDB[obj].items[self.numatoms-1].res_num
#        self.resnums = range(int(firstres), int(maxres)+1)

        for i in range(int(self.numatoms)):
            rn = pdb.myPDB[obj].items[i].res_num
            if self.resnums.count(int(rn)) > 0:
                continue
            self.resnums.append(int(rn))

        for i in self.resnums:
            self.res2atom[i] = list()

        for i in xrange(self.numatoms):
            atom = pdb.myPDB[obj].items[i].atom
            atomnum = pdb.myPDB[obj].items[i].atom_num
            resnum = pdb.myPDB[obj].items[i].res_num
            res = pdb.myPDB[obj].items[i].res
            self.res2name[int(resnum)] = res
            self.atom2name[int(atomnum)] = atom
            self.res2atom[int(resnum)].append(int(atomnum))

    #----------------------------------------------------------
    # check an atom list to determine inclusion
    # ie. list = [ C HA ] | [ -H -O ]
    def atom_check(self, atom, list):

        go = 0
        if (len(list) == 0):
            return 1

        for pat in list:
            if pat[0] == "-":    # exclude pattern
                go = 1
                pat = pat[1:]
                if atom.startswith(pat):
                    return 0 
                if not atom[0].isalpha():
                    if atom.startswith(pat,1):
                        return 0 
            else:   # exact match
                if atom == pat:
                    go = 1
                    break

        return go
            
        
    #----------------------------------------------------------
    # check an include list for an exact atom match
    def check_include(self, atom, list, go):
        go = 0
        if (len(list) != 0):
            for pat in list:
                if atom == pat:
                    go = 1
                    break
        else:
            go = 1
        return go

    #----------------------------------------------------------
    # check an exclude list for anything starting with the atom
    def check_exclude(self, atom, list, go):
        useatom = atom.lstrip("0123456789")
        for pat in list:
            if useatom.find(pat) == 0: # it matched at the start of the atom
                go = 0
                break
        return go


    #----------------------------------------------------------
    # return a coordinate list (in numeric array format) 
    # given the atom and residue range
    def get_coords(self, obj, range, atom_list):

        coordlist = list()
        idx = 0

        # for each residue in user defined range
#        print "DEBUG get_coords.  range = ", range
        for i in range:
            atomnums = self.res2pdb[i]
            for atm in atomnums:

                # is atomname in the user defined type?
                try: atom = pdb.myPDB[obj].items[atm].atom
                except:
                    print "WARNING, error in object", obj, " atom #", atm
                    continue


                go = self.atom_check(atom, atom_list)

                if go:
                    coordlist.append(pdb.myPDB[obj].items[atm].coord )
                    idx += 1

        x = array(coordlist, Float0)
        return x


    #----------------------------------------------------------
    # return a coordinate list (in numeric array format) 
    # given the atom and residue range
    def get_coords_slow(self, obj, range, atom_list):

        import re
        coordlist = list()

        # for each atom
        for i in xrange(self.numatoms):

            # is resnum in user defined range?
            resnum = int(pdb.myPDB[obj].items[i].res_num)
            if range.count(resnum) == 0:
                continue;

            
            # is atomname in the user defined type?
            atom = pdb.myPDB[obj].items[i].atom
            go = self.atom_check(atom, atom_list)

#            go = self.check_include(atom, self.includeL, 0)
#            go = self.check_exclude(atom, self.excludeL, go)
            if go:
                coordlist.append(pdb.myPDB[obj].items[i].coord )

        x = array(coordlist, Float0)
        return x


#        for (coord, atomnum, resnum) in car_list:
#            atom = self.atom2name[atomnum]

            # if this residue is not in the range
#            if self.range.count(resnum) == 0:
#                continue

            # if the atom is in the include list or if it
            # is NOT in the exclude list

            # check include then exclude lists
#            go = self.check_include(atom, self.includeL, 0)
#            go = self.check_exclude(atom, self.excludeL, go)

            # if it matches
#            if go:
#                coordlist.append( coord )

#        x = array(coordlist, Float0)
#        return x


    #----------------------------------------------------------
    # create an exclude and include list of atoms.
    # atoms = string ex. "N CA C" or "not H" or "ALL"
    def make_lists(self, atoms):
        includeL = list()
        excludeL = list()
        atoms = string.split(atoms)
        for atom in atoms:
            if atom == "ALL":
                includeL = []
                excludeL = []
            elif atom[0] == '-':
                excludeL.append(atom[1:])
            else:
                includeL.append(atom)
        return(includeL, excludeL)

    #----------------------------------------------------------
    # update our pdb list to reflect the transformation
    def update_pdb(self, obj, m):
        numatoms = pdb.myPDB[obj].serial
        for i in xrange(numatoms):

            c =  pdb.myPDB[obj].items[i].coord
            x = m[0]*(c[0]+m[3]) + m[4]*(c[1]+m[7]) + m[8]*(c[2]+m[11]) + m[12]
            y = m[1]*(c[0]+m[3]) + m[5]*(c[1]+m[7]) + m[9]*(c[2]+m[11]) + m[13]
            z = m[2]*(c[0]+m[3]) + m[6]*(c[1]+m[7]) + m[10]*(c[2]+m[11]) + m[14]
            pdb.myPDB[obj].items[i].coord = (x, y, z)

    #----------------------------------------------------------
    def write_pdb(self, fname, obj):

        print "not implemented yet"
        
#        try: fp = open(fname, 'w')
#        except:
#            print "cannot write to file ", fname
#            return
    
#        print >> fp, "HEADER    Superimposed coordinates from PyNMR"
#        print >> fp, "COMPND    ", obj

#        numatoms = self.numatoms 
#        for i in xrange(numatoms):

#            (c, anum, rnum) = self.car[obj][i]
#            a = self.atom2name[anum]
#            r = self.res2name[rnum]
# this code checks the pdb array for values
#            me = pdb.myPDB[obj].items[i]
#            c =  me.coord
#            anum = me.atom_num
#            a = me.atom
#            r = me.res
#            rnum = me.res_num
            
#            print >> fp, "ATOM %6d %4s %3s %5d     %7.3f %7.3f %7.3f  1.00  0.00          " % \
#                (int(anum), a, r, int(rnum), c[0], c[1], c[2])                

#        fp.close()

    #----------------------------------------------------------
    # update our internal coordinate array to reflect the transformation
    def update_coords(self, cdx, m):
        new = list()
        for (c, a, r) in cdx:
            x = m[0]*(c[0]+m[3]) + m[4]*(c[1]+m[7]) + m[8]*(c[2]+m[11]) + m[12]
            y = m[1]*(c[0]+m[3]) + m[5]*(c[1]+m[7]) + m[9]*(c[2]+m[11]) + m[13]
            z = m[2]*(c[0]+m[3]) + m[6]*(c[1]+m[7]) + m[10]*(c[2]+m[11]) + m[14]
            new.append(((x, y, z), a, r))
        return new


    #----------------------------------------------------------
    # update our internal coordinate array to reflect the transformation
    def update_coords2(self, cdx, m):
        new = list()
        for c in cdx:
            x = m[0]*(c[0]+m[3]) + m[4]*(c[1]+m[7]) + m[8]*(c[2]+m[11]) + m[12]
            y = m[1]*(c[0]+m[3]) + m[5]*(c[1]+m[7]) + m[9]*(c[2]+m[11]) + m[13]
            z = m[2]*(c[0]+m[3]) + m[6]*(c[1]+m[7]) + m[10]*(c[2]+m[11]) + m[14]
            new.append([x, y, z])
        new = array(new, Float0)
        return new


    #-----------------------------------------------------------
    # get the rmsd of each structure to each other structure
    def rmsd_table(self, coords):
        print "RMSD TABLE: "
        print "            ",
        for i in coords:
            print "%-10s  " %(i),
        print ""
        for i in coords:
            print "%-10s  " %(i),
            for j in coords:
                if i == j:
                    print "            ",
                    continue
                ret = sup.calc_rms(coords[i], coords[j])
                print "%-7.2f     " % (ret),
            print ""

    #-----------------------------------------------------------
    def calc_show_mean(self, showflag = 1):
        """ calculate the mean structure, then display it """

        # get rid of any old occurrences
        try: self.objs.remove("mean")
        except: pass

        # put this in the pdb data structure, to treat it like
        # any other object.
        pdb.myPDB["mean"] = pdb.PdbFile()

        numobjs = len(self.objs)
        mylist = list()

        for i in xrange(self.numatoms):
            total =  [0.0, 0.0, 0.0]
            for j in self.objs: 
                try: coord =  pdb.myPDB[j].items[i].coord
                except:
                    print "WARNING, error in object", j, " atom #", i
                    continue

                for k in [0, 1, 2]: total[k] += coord[k]
                    
            for k in [0, 1, 2]: total[k] = total[k] / numobjs

            line = pdb.myPDB["mean"].PdbItem()
            line.atom_num = pdb.myPDB[j].items[i].atom_num
            line.atom = pdb.myPDB[j].items[i].atom
            line.res_num = pdb.myPDB[j].items[i].res_num
            line.res = pdb.myPDB[j].items[i].res
            line.coord = total
            pdb.myPDB["mean"].items.append(line)

        if showflag == 1:
            self.objs.append("mean")

        return


    #-----------------------------------------------------------
    def calc_mean_rmsd(self):
        """ calculate the rmsd of every structure to the mean """

        meanrmsd = list()

#        cdx_ref = myrmsd.return_coords(myrmsd.car["mean"]

#        for i in myrmsd.objs:
#             cdx_obj = myrmsd.return_coords(myrmsd.car[i]
#             sup.set(cdx_ref, cdx_obj)
#             result = sup.get_rms2()
#             meanrmsd.append(result) 

#        return


    #-----------------------------------------------------------
    def make_coords_res(self, routine, list, varname, default):
        """ this chunk of code just gathers the coordinates together, per residue, per object """

        for obj in self.objs:
            tmp = dict()
            for atom in xrange(self.numatoms):
                (coords, anum, res) = self.car[obj][atom]
                atom = self.atom2name[anum]

                if routine(atom, list, default):
#                    print "res = ", res, " added atom ", atom, 
                    try: tmp[res].append(coords)
                    except: tmp[res] = [coords]

            for res in self.resnums:
                x = array(tmp[res], Float0)
                varname[obj][res] = x
            
    def set_res_idx(self, atomlist):
        """ go through the pdb file (any one will do), and get
        indexes into it of all atoms of type atomlist """

        res_idx = dict()

        # navigate through coords of mystruc
        numatoms = pdb.myPDB[self.ref].serial

        for i in xrange(numatoms):
            atom = pdb.myPDB[self.ref].items[i].atom
            res = pdb.myPDB[self.ref].items[i].res_num
            try: test = res_idx[int(res)]
            except: res_idx[int(res)] = []
            go = self.atom_check(atom, atomlist)
            if go:
                res_idx[int(res)].append(i)
                
        return res_idx

      

    def rmsd_res_fit_pairs2(self, sup):
        """ this first fits each objects together to the MEAN, then
        computes the rmsd for each residue for the two fitted objects """
    

        def rmsd_by_ref(self, cdx1, cdx2, idx, rmslist):

            for res in self.resnums:

                cdxi = array((), Float0)
                cdxj = array((), Float0)
                for i in idx[res]:
                        
                    try: cdxi = concatenate([cdxi,cdx1[i]])
                    except:
                        print "WARNING, internal error with residue ", i
                        continue

                    try: cdxj = concatenate([cdxj,cdx2[i]])
                    except:
                        print "WARNING, internal error with residue ", i
                        continue

                if ( (len(cdxi) == 0) or (len(cdxj) == 0) ):
                    continue

                cdxi = reshape(cdxi, (-1,3))
                cdxj = reshape(cdxj, (-1,3))
                    
                rms = sup.calc_rms(cdxi, cdxj)

                try: rmslist[res].append(rms)
                except: rmslist[res] = [rms]
            return rmslist
            

        # setup indexes into cdx lists
        idxbb = self.set_res_idx(self.bbatoms)
        idxhvy = self.set_res_idx(self.heavyatoms)

        rmslist_bb = dict()
        rmslist_hvy = dict()
        for i in self.objs:
            
            if i == "mean":
                continue

            # fit the two objects together (all residues, selected atoms)
            cdx1 = self.get_coords("mean", self.resnums, self.atoms)
            cdx2 = self.get_coords(i, self.resnums, self.atoms)
            rms = sup.calc_rms(cdx1, cdx2)
            myfit = sup.set(cdx1, cdx2)
            sup.run()

            # now set the coordinates to contain all atoms, not just
            # the ones to transpose on
            cdx1 = self.get_coords("mean", self.resnums, [])
            cdx2 = self.get_coords(i, self.resnums, [])

            # transform the second coordinate set 
            rot, tran=sup.get_rotran()
            TTT = [rot[0, 0], rot[0,1], rot[0,2], 0, \
            rot[1, 0], rot[1,1], rot[1,2], 0, \
            rot[2, 0], rot[2,1], rot[2,2], 0, \
            tran[0], tran[1], tran[2], 0]
            cdx2 = self.update_coords2(cdx2,TTT)

            # now get the rmsd's residue by residue,
            # using the transformed coordinates.
            # do this for bb and then heavy
            rms_bb = rmsd_by_ref(self, cdx1, cdx2, idxbb, rmslist_bb)
            rms_hvy = rmsd_by_ref(self, cdx1, cdx2, idxhvy, rmslist_hvy)

        return (rmslist_bb, rmslist_hvy)
            

    def rmsd_res_fit_pairs(self, sup, coords, rmslist, ref):
        """ this first fits each objects together to each OTHER, then
        computes the rmsd for each residue for the two fitted objects """
    

        copy_objs = list()
        copy_objs.extend(self.objs)
        go_i = 1
        while go_i:
            try: i = copy_objs.pop(0)
            except: go_i = 0
            for j in copy_objs:
                if j == "mean":
                    continue

                # fit the two objects together (all residues, selected atoms)
                cdx1 = self.get_coords(i, self.resnums, self.atoms)
                cdx2 = self.get_coords(j, self.resnums, self.atoms)
                rms = sup.calc_rms(cdx1, cdx2)
                myfit = sup.set(cdx1, cdx2)
                sup.run()

                # transform the second coordinate set
                rot, tran=sup.get_rotran()
                TTT = [rot[0, 0], rot[0,1], rot[0,2], 0, \
                rot[1, 0], rot[1,1], rot[1,2], 0, \
                rot[2, 0], rot[2,1], rot[2,2], 0, \
                tran[0], tran[1], tran[2], 0]

                cdx2 = self.update_coords2(cdx2,TTT)

                rms = sup.calc_rms(cdx1, cdx2)

                # now get the rmsd's residue by residue,
                # using the transformed coordinates.
                for res in self.resnums:

                    cdxi = array((), Float0)
                    cdxj = array((), Float0)
                    for idx in self.resPtr[res]:
                        
                        cdxi = concatenate([cdxi,cdx1[idx]])
                        cdxj = concatenate([cdxj,cdx2[idx]])

                    cdxi = reshape(cdxi, (-1,3))
                    cdxj = reshape(cdxj, (-1,3))
                    rms = sup.calc_rms(cdxi, cdxj)

                    try: rmslist[res].append(rms)
                    except: rmslist[res] = [rms]
            

    #------------------------------------------------------------------------
    def get_rmsd_per_res(self, sup, ref, atoms):
        """ for each residue, compute the rms from each object
            to the reference (ie. mean) """

        rmsdict = dict()
        rmsavg = dict()

        # for each residue
        for res in self.resnums:

            cdx_ref = self.get_coords("mean", [res], atoms)

            # for each object
            for obj in self.objs:

                if obj == "mean":
                    continue

                cdx_obj = self.get_coords(obj, [res], atoms)

                # debug
                # if res == 111:
                #     print "objects ", 
                #     print "cdxi = ", cdx_ref, " cdxj = ", cdx_obj

                if ( (len(cdx_ref) == 0) or (len(cdx_obj) == 0) ):
                    continue

                # get rms from object.res to ref.res
                rms = sup.calc_rms(cdx_ref, cdx_obj)
                try: rmsdict[res].append(rms)
                except: rmsdict[res] = [rms]

        return rmsdict


    def cmp1(atom):
        useatom = atom.lstrip("0123456789")
        go = 0

        #check include list for an exact match
        if (len(self.includeL) != 0):
            for pat in self.includeL:
                if atom == pat:
                    go = 1
                    break
        else:
            go = 1
        if self.bbatoms.count(atom):
            return 1
        else: return 0

    def cmp2(atom):
        #check for anything starting with the atom.
        pat = 'H'
        useatom = atom.lstrip("0123456789")
        if useatom.find(pat) == 0:
            return 0
        return 1 


    def calc_print_res_rmsd(self, ref, sup, pf_flag):
        """ print the rmsd's per residue, for backbone and heavy atoms 
            called from superimpose in nmr_cmd.py """

        
        rmsd_bb_res = dict()
        rmsd_bb_res_fit = dict()
        if pf_flag == 0:
            rmsd_bb_res = self.get_rmsd_per_res(sup, "mean", self.bbatoms)
            rmsd_heavy_res = self.get_rmsd_per_res(sup, "mean", self.heavyatoms)
        if pf_flag == 1:
            (rmsd_bb_res, rmsd_heavy_res) = self.rmsd_res_fit_pairs2(sup)

        print "----------------------------"
        print "RMSD's relative to mean     "
        if pf_flag == 1:
            print "after pairwise fit to mean"
        print "----------------------------"
        print "residue  rmsd       rmsd    "
        print "         backbone   heavy   "
        print "-------  --------   --------"
        for res in self.resnums:
            try: mean1 = MLab.mean(rmsd_bb_res[res])
            except: continue
            print "%3d    %7.2f    %7.2f" % (res, mean1, 
                MLab.mean(rmsd_heavy_res[res]))

    #-----------------------------------------------------------
