!===============================================================================
!
! Routines:
!
! (1) genwf_mpi()      modified from genwf_disk()      Last Modified 6/09/2010 (gsm)
!
! Generates valence and conduction wavefunctions at k-point rkq = rkn - rq
! from wavefunctions at k-point rkn read from structures distributed in memory.
!
!===============================================================================

#include "f_defs.h"

subroutine genwf_mpi(rkq,syms,gvec,crys,kp,sig,wfnkq,wfnkqmpi)

  use global_m
  use gmap_m
  use sort_m
  implicit none

  real(DP), intent(in) :: rkq(3)
  type (symmetry), intent(in) :: syms
  type (gspace), intent(in) :: gvec
  type (crystal), intent(in) :: crys
  type (kpoints), intent(in) :: kp
  type (siginfo), intent(in) :: sig
  type (wfnkqstates), intent(inout) :: wfnkq ! pointers are allocated outside
  type (wfnkqmpiinfo), intent(in) :: wfnkqmpi

  integer :: nkpti
  integer :: nbandi
  integer, allocatable :: isrtk(:)
  SCALAR, allocatable :: zin(:,:)

  integer :: i,ig,k,n,irk,ikn1,it,iband
  integer :: naddv,iaddv,iadd
  integer :: nkpt2,kg0(3)
  integer, allocatable :: isorti(:),ind(:)
  real(DP) :: del,qk(3)
  real(DP), allocatable :: xnorm(:,:)
  real(DP), allocatable :: ekin(:)
  SCALAR, allocatable :: ph(:)
  logical :: found

!---------------------------------------
! find rotation (R), translation (g0), and reduced k-point (rk)
! such that krq = rkn - rq = R(rk) + g0
! it = rotation, kg0 = translation, irk = reduced k-point

  PUSH_SUB(genwf_mpi)

  found = .false.
  irk_loop: do irk=1,kp%nrk
    it_loop: do it=1,syms%ntran
      do i=1,3
        qk(i) = DOT_PRODUCT(dble(syms%mtrx(i,:,it)),kp%rk(:,irk))
        del = rkq(i) - qk(i)
        if (del .ge. 0.0d0) kg0(i) = del + TOL_Small
        if (del .lt. 0.0d0) kg0(i) = del - TOL_Small
        if (abs(del-kg0(i)) .gt. TOL_Small) cycle it_loop
      enddo
      found = .true.
      exit irk_loop
    enddo it_loop
  enddo irk_loop
  if(.not. found) call die('genwf: rkq mismatch')
  ikn1 = irk

!---------------------------------------
! Read in (rk-q) wavefunction

  SAFE_ALLOCATE(ekin, (gvec%ng))
  SAFE_ALLOCATE(isrtk, (gvec%ng))
  SAFE_ALLOCATE(isorti, (gvec%ng))
  SAFE_ALLOCATE(ind, (gvec%ng))
  SAFE_ALLOCATE(ph, (gvec%ng))
  nkpti = wfnkqmpi%nkptotal(ikn1)
  nbandi = sig%ntband
  isrtk(1:nkpti) = wfnkqmpi%isort(1:nkpti,ikn1)
  do k = 1, sig%nspin
    wfnkq%ekq(1:nbandi,k) = wfnkqmpi%el(1:nbandi,k,ikn1)
  enddo
  qk(1:3) = wfnkqmpi%qk(1:3,ikn1)
  wfnkq%nkpt=nkpti
  if (any(abs(kp%rk(1:3,ikn1)-qk(1:3)) .gt. TOL_Small)) then
    call die('genwf: kp mismatch')
  endif

!---------------------------------------
! Compute inverse to array isort
! isrtk orders |rk + G|^2
! isrtkq  orders |rkn - rq + G|^2
! with krq = rkn - rq = R(rk) + g0

  isorti(:)=0
  do i=1,nkpti
    isorti(isrtk(i))=i
  enddo
  SAFE_DEALLOCATE(isrtk)

!---------------------------------------
! compute kinetic energy for rkq = rk-q

  do i=1,gvec%ng
    qk(:)=rkq(:)+gvec%k(:,i)
    ekin(i)=DOT_PRODUCT(qk,MATMUL(crys%bdot,qk))
  enddo

!---------------------------------------
! sort array ekin to ascending order
! store indices in array isort

  call sortrx_D(gvec%ng, ekin, wfnkq%isrtkq, gvec = gvec%k)
  do i=1,gvec%ng
    wfnkq%ekin(i)=ekin(wfnkq%isrtkq(i))
  enddo

!---------------------------------------
! map indices for rk(ikn1) to those for rkq

  call gmap(gvec,syms,nkpti,it,kg0, &
    wfnkq%isrtkq,isorti,ind,ph,.true.)

!---------------------------------------
! loop over wfnkq%zkq wavefunctions
! map zin wfns onto wfnkq%zkq

  SAFE_ALLOCATE(xnorm, (sig%ntband,sig%nspin))
  SAFE_ALLOCATE(zin, (wfnkq%nkpt,sig%nspin))
  SAFE_ALLOCATE(wfnkq%zkq, (peinf%ntband_node*wfnkq%nkpt,sig%nspin))
  do n=1,peinf%ntband_node
    iband=wfnkqmpi%band_index(n,ikn1)
    nkpt2=wfnkqmpi%nkptotal(ikn1)
    do k = 1, sig%nspin
      zin(1:nkpt2,k)=wfnkqmpi%cg(1:nkpt2,n,k,irk)
    enddo
    if (iband.ne.peinf%indext(n)) then
      if(peinf%inode.eq.0) then
        write(0,*) 'iband=',iband,' <> ',peinf%indext(n)
      endif
      call die('genwf: iband error')
    endif
    
    naddv = (n-1)*wfnkq%nkpt

!---------------------------------------
! loop over components of zv

    xnorm(n,:)=0.0d0
    do ig=1,wfnkq%nkpt
      iaddv=naddv+ig
      do k=1,sig%nspin
        wfnkq%zkq(iaddv,k)=ph(ig)*zin(ind(ig),k)
        xnorm(n,k)=xnorm(n,k)+abs(wfnkq%zkq(iaddv,k))**2
      enddo
    enddo
  enddo
  SAFE_DEALLOCATE(ekin)
  SAFE_DEALLOCATE(isorti)
  SAFE_DEALLOCATE(ind)
  SAFE_DEALLOCATE(ph)
  
!---------------------------------------
! renormalize wavefunctions

  do n=1,peinf%ntband_node
    do k=1,sig%nspin
      if (xnorm(n,k).lt.TOL_Small) then
        if (peinf%inode.eq.0) write(0,8000) irk,n,xnorm(n,k),k
        call die('genwf: normalize error')
      endif
      xnorm(n,k)=sqrt(xnorm(n,k))
    enddo
8000 format(/,1x,'In genwf reading from file 15 state',/, &
       3x,'nrk=',i5,' band=',i5,' xnorm=',f10.5,' ispin=',1i1,/)
  enddo
  
  do n=1,peinf%ntband_node
    do ig=1,wfnkq%nkpt
      do k=1,sig%nspin
        iadd=(n-1)*wfnkq%nkpt+ig
        wfnkq%zkq(iadd,k)=wfnkq%zkq(iadd,k)/xnorm(n,k)
      enddo
    enddo
  enddo
  SAFE_DEALLOCATE(xnorm)
  SAFE_DEALLOCATE(zin)

!---------------------------------------
! end of loop over wavefunctions

  POP_SUB(genwf_mpi)
  
  return
end subroutine genwf_mpi
