# 6/28/10 - Manish Jain and Brad Malone

# Converts UPF to FHI format
# Does not currently work for nonlinear core corrections, spin orbit, ...

import os,sys,time
from math import log,exp

try:
    upffilename=sys.argv[1]
except:
    print 'Usage: python upf2fhi.py [upffilename]';sys.exit(1)


upffile=open(upffilename,'r')
lines=[]
for line in upffile:
    lines.append(line.strip('\n').strip())
upffile.close()

def get_Z():
    for line in lines:
        if line.find('Z valence')!=-1:
            zval=float(line.split()[0])
    return zval

def get_proj():
    for line in lines:
        if line.find('Number of Projectors')!=-1:
            theline=line.split()
            numwfc=int(theline[0])
            numproj=int(theline[1])
    return numwfc,numproj

def get_meshsize():
    for line in lines:
        if line.find('Number of points in mesh')!=-1:
            theline=line.split()
            meshsize=int(theline[0])
    return meshsize

def read_mesh():
    array=[]
    startind=lines.index('<PP_R>')
    endind=lines.index('</PP_R>')
    for line in lines[startind+1:endind]:
        theline=line.split()
        for item in range(0,len(theline)):
            addition=float(theline[item])
            array.append(addition)
            
    if len(array)!=meshsize:
        print 'Size of array is not equal to mesh size'; sys.exit(1)
    return array

def read_local():
    global meshsize
    array=[]
    startind=lines.index('<PP_LOCAL>')
    endind=lines.index('</PP_LOCAL>')
    for line in lines[startind+1:endind]:
        theline=line.split()
        for item in range(0,len(theline)):
            addition=float(theline[item])
            array.append(addition)
            
    if len(array)!=meshsize:
        meshsize=len(array)
    return array

def read_projectors():
    pot_array=['None']*(numproj+1)
    index_array=[]
    startind=lines.index('<PP_NONLOCAL>')
    endind=lines.index('</PP_NONLOCAL>')
    projlines=lines[startind+1:endind]
    saveprojlines=projlines
    savestop=-1
    for i in range(0,numproj):
        startind=projlines.index('<PP_BETA>')+savestop+1
        stopind=projlines.index('</PP_BETA>')+savestop+1
        savestop=stopind
        index_array.append((startind,stopind))
        projlines=saveprojlines[stopind+1:]
      # now we have index array, get projectors
    projlines=saveprojlines
    del saveprojlines
    for i in range(0,numproj):
        temparray=[]
        startind=index_array[i][0]
        stopind=index_array[i][1]
        lval=int(projlines[startind+1].split()[1])
        for line in projlines[startind+3:stopind]:
            theline=line.split()
            for item in theline:
                temparray.append(float(item))
        if len(temparray)!=meshsize:
            diff=meshsize-len(temparray)
            addition=[0.0]*diff
            temparray=temparray+addition
        pot_array[lval]=temparray    
    return pot_array

def superindex(thelist,word):
    for item in thelist:
        if item.find(word)!=-1:
            theindex=thelist.index(item)
            break
    return theindex

def read_wfcs():
    global meshsize
    wfc_array=[]
    index_array=[]
    startind=lines.index('<PP_PSWFC>')
    endind=lines.index('</PP_PSWFC>')
    projlines=lines[startind+1:endind]
    saveprojlines=projlines
    savestop=0
    for i in range(0,numwfc):
        startind=superindex(projlines,'Wavefunction')+savestop
        savestop=startind+1
        index_array.append(startind)
        projlines=saveprojlines[startind+1:]
    index_array.append(len(saveprojlines))
    # now we have index array, get projectors
    projlines=saveprojlines
    del saveprojlines
    for i in range(0,numwfc):
        temparray=[]
        startind=index_array[i]
        stopind=index_array[i+1]
        for line in projlines[startind+1:stopind]:
            theline=line.split()
            for item in theline:
                temparray.append(float(item))
        if len(temparray)!=meshsize:
            meshsize=len(temparray)
            print 'Resizing mesh'
        wfc_array.append(temparray)    
    return wfc_array

def read_rhoatom():
    array=[]
    startind=lines.index('<PP_RHOATOM')
    endind=lines.index('</PP_RHOATOM')
    for line in lines[startind+1:endind]:
        theline=line.split()
        for item in range(0,len(theline)):
            addition=float(theline[item])
            array.append(addition)
            
    if len(array)!=meshsize:
        print 'Size of array is not equal to mesh size'; sys.exit(1)
    return array


outfilename=upffilename.split('.UPF')[0]+'.cpi'
outfile=open(outfilename,'w')

zval=get_Z()
numwfc,numproj=get_proj()
meshsize=get_meshsize()
mesh=read_mesh()
local=read_local()
nonlocal=read_projectors()
wfcs=read_wfcs()
rhoatom=read_rhoatom()

rhofilename=upffilename.split('.UPF')[0]+'.rho'
rhofile=open(rhofilename,'w')
print >> rhofile,meshsize,'%15.14E' %a
for j in range(0,meshsize):
    print >> rhofile,'%4i %15.14E %15.14E' %(j+1,mesh[j],rhoatom[j])
rhofile.close()

print >> outfile,'%15.14E %i' % (zval,numproj+1)
print >> outfile,'  0.0000    0.0000    0.0000   0.0000'
for i in range(0,9):
    print >> outfile,'  0.0000    .00e+00   .00e+00'
# compute mesh parameters
a=mesh[1]/mesh[0]

# which l is local?
llocal=nonlocal.index('None')

for i in range(0,numwfc):
    print >> outfile,meshsize,'%15.14E' %a
    if i==llocal:
        for j in range(0,meshsize):
            print >> outfile,'%4i %15.14E %15.14E %15.14E' %(j+1,mesh[j],wfcs[i][j],local[j]/2.0)
    else: 
        for j in range(0,meshsize):
            if abs(wfcs[i][j])>1.0E-10:
                temp=(nonlocal[i][j]/wfcs[i][j]+local[j])/2.0
            else:
                temp=local[j]/2.0
            print >> outfile,'%4i %15.14E %15.14E %15.14E' %(j+1,mesh[j],wfcs[i][j],temp)


outfile.close()

print 'Conversion complete. Output written to', outfilename
