#include "f_defs.h"

!-----------------------------------------------------------------------
subroutine input(crys,gvec,kg,syms,xct,kgr,index_k)
!-----------------------------------------------------------------------
!
!       Read data from file WFN_fi and initialize variables
!
!     input: xct types
!
!     output: crys,gvec,kg,syms types
!             peinf type (from distrib.f90)
!             INT_VWFN_* and INT_CWFN_* files
!
!     Copied from BSE/input.f90, without eqp arrays.
!
  use global_m
  use fullbz_m
  use input_utils_m
  use misc_m
  use sort_m
  use wfn_rho_vxc_io_m
  implicit none

  type (crystal), intent(out) :: crys
  type (gspace), intent(out) :: gvec
  type (grid), intent(out) :: kg
  type (symmetry), intent(out) :: syms
  type (xctinfo), intent(inout) :: xct
  real(DP), intent(in) :: kgr(3,xct%nkpt)
  integer, intent(in) :: index_k(xct%nkpt)

  type (wavefunction) :: wfnc
  type (kpoints) :: kp
  type (grid) :: kgt

  character :: filenamec*20
  character :: tmpfn*16, errmsg*100
  integer :: itpc,iwrite
  integer :: ii,jj,kk,irk,gv(3)
  integer :: dest,tag
  integer :: irks
  real(DP) :: diffvol,norm,vcell,kt(3)
  real(DP) :: tol
  real(DP), allocatable :: ek_tmp(:)
  integer, allocatable :: indxk(:),k_tmp(:,:)
  integer, allocatable :: index(:),isend(:)
  SCALAR, allocatable :: cg(:,:)
  
  character(len=3) :: sheader
  integer :: iflavor
  type(gspace) :: gvec_kpt

  logical :: skip_checkbz

  PUSH_SUB(input)

  if(peinf%inode == 0) call open_file(25,file='WFN_fi',form='unformatted',status='old')

  sheader = 'WFN'
  iflavor = 0
  call read_binary_header_type(25, sheader, iflavor, kp, gvec, syms, crys)
  
  SAFE_ALLOCATE(gvec%k, (3, gvec%ng))
  call read_binary_gvectors(25, gvec%ng, gvec%ng, gvec%k)
  
  call get_volume(vcell,crys%bdot)
  diffvol=abs(crys%celvol-vcell)
  if (diffvol.gt.1.0d-6) then
    call die('volume mismatch.', only_root_writes = .true.)
  endif
  
  kp%nvband=minval(kp%ifmax(:,:)-kp%ifmin(:,:))+1
  kp%ncband=kp%mnband-maxval(kp%ifmax(:,:))
  
  if(xct%nvband.gt.kp%nvband) then
    write(errmsg,'(a,i6,a,i6,a)') 'You requested ', xct%nvband, ' valence bands but WFN_fi contains only ', kp%nvband, '.'
    call die(errmsg, only_root_writes = .true.)
  endif
  if(xct%ncband.gt.kp%ncband) then
    write(errmsg,'(a,i6,a,i6,a)') 'You requested ', xct%ncband, ' conduction bands but WFN_fi contains only ', kp%ncband, '.'
    call die(errmsg, only_root_writes = .true.)
  endif
  if(xct%nspin.ne.kp%nspin) then
    write(errmsg,'(a,2i6)') 'Number of spins mismatch: ', xct%nspin, kp%nspin
    call die(errmsg, only_root_writes = .true.)
  endif
!-----------------------------------------------------------------------
!     Check if all k-points are available and define grid
!
  tol = 1.d-4
  kgt%nr=kp%nrk
  SAFE_ALLOCATE(kgt%r, (3,kgt%nr))
  kgt%r(1:3,1:kgt%nr)=kp%rk(1:3,1:kp%nrk)
  call fullbz(crys,syms,kgt,syms%ntran,skip_checkbz,wigner_seitz=.true.,paranoid=.true.)
  tmpfn='WFN_fi'
  if (.not. skip_checkbz) then
    call checkbz(kgt%nf,kgt%f,kp%kgrid,kp%shift,crys%bdot, &
      tmpfn,'k',.true.,xct%freplacebz,xct%fwritebz)
  endif
!
  SAFE_ALLOCATE(indxk, (xct%nkpt))
  indxk=0
  do jj=1,xct%nkpt
    do ii=1,kgt%nf
      kt(:) = mod(kgr(:,jj) - kgt%f(:,ii)+10.0,1.0d0)
      if ((abs(kt(1)).lt.tol).and.(abs(kt(2)).lt.tol) &
        .and.(abs(kt(3)).lt.tol)) then
        if (indxk(jj).ne.0) write(0,*) 'WARNING: multiple definition of k-point',jj,indxk(jj),kgr(:,jj)
        indxk(jj)=ii
      endif
    enddo
!
!     If some k-point listed in kgr is not found in WFN_fi, indxk
!     will store zero.
!
    if (indxk(jj).eq.0) then
      write(errmsg,'(a,3f12.6,a)') 'Could not find vector ', kgr(:,jj), ' in WFN_fi'
      call die(errmsg, only_root_writes = .true.)
    endif
  enddo
!
!   update kgt -> kg
!
  kg%nr = kgt%nr
  kg%nf = xct%nn
  SAFE_ALLOCATE(kg%r, (3,kg%nr))
  kg%r = kgt%r
  kg%sz = kgt%sz
  SAFE_ALLOCATE(kg%itran, (kg%nf))
  SAFE_ALLOCATE(kg%indr, (kg%nf))
  SAFE_ALLOCATE(kg%f, (3,kg%nf))
  SAFE_ALLOCATE(kg%kg0, (3,kg%nf))
  do jj=1,xct%nn
    kg%itran(jj) = kgt%itran(indxk(index_k(jj)))
    kg%indr(jj) = kgt%indr(indxk(index_k(jj)))
    kg%kg0(:,jj) = kgt%kg0(:,indxk(index_k(jj)))
    kg%f(:,jj) = kgt%f(:,indxk(index_k(jj)))
  enddo
  SAFE_DEALLOCATE(indxk)
!
!     indxk : stores the correspondence between k-points kg%r and kp%rk
!     (it is used to select the set of wavefunctions to be stored)
!     tol : tolerance in the coordinates of k-points
!
  SAFE_ALLOCATE(indxk, (kg%nr))
  indxk=0
  do jj=1,kg%nr
    do ii=1,kp%nrk
      kt(:) = kg%r(:,jj) - kp%rk(:,ii)
      if ((abs(kt(1)).lt.tol).and.(abs(kt(2)).lt.tol) &
        .and.(abs(kt(3)).lt.tol)) then
        if (indxk(jj).ne.0) write(0,*) 'WARNING: multiple ', &
          'definition of k-point',jj,indxk(jj),kg%r(:,jj)
        indxk(jj)=ii
      endif
    enddo
!
!     If some k-point listed in kg%r is not found in WFN_fi, indxk
!     will store zero. Later, the job will stop in genwf.
!
    if (indxk(jj).eq.0) write(0,'(a,3f12.6,a)') 'WARNING: could not find vector ',kg%r(:,jj),' in WFN_fi'
  enddo

!-----------------------------------------------------------------------
!       Distribute kpoints among the PEs
!
#ifdef VERBOSE
  call logit('input:  calling distrib')
#endif
  call distrib(xct)
!

!-----------------------------------------------------------------------
!     Order g-vectors with respect to their kinetic energy
!
#ifdef VERBOSE
  call logit('input:  reordering gvecs')
#endif
  SAFE_ALLOCATE(index, (gvec%ng))
  SAFE_ALLOCATE(gvec%ekin, (gvec%ng))
  do ii=1,gvec%ng
    gv(:)=gvec%k(:,ii)
    norm=DOT_PRODUCT(gv,MATMUL(crys%bdot,gv ))
    gvec%ekin(ii)=norm
  enddo
  call sortrx_D(gvec%ng, gvec%ekin, index, gvec = gvec%k)
  
  SAFE_ALLOCATE(ek_tmp, (gvec%ng))
  ek_tmp = gvec%ekin
  SAFE_ALLOCATE(k_tmp, (3,gvec%ng))
  k_tmp = gvec%k
  do ii=1,gvec%ng
    gvec%ekin(ii) = ek_tmp(index(ii))
    gvec%k(:,ii) = k_tmp(:,index(ii))
  enddo
  SAFE_DEALLOCATE(ek_tmp)
  SAFE_DEALLOCATE(k_tmp)
  SAFE_DEALLOCATE(index)

  call gvec_index(gvec)

!-----------------------------------------------------------------------
!     Read the wavefunctions and create INT_CWFN_*
!

#ifdef VERBOSE
  call logit('input:  reading WFN_fi')
#endif
  wfnc%nband=xct%ncband
  wfnc%nspin=kp%nspin

  if(peinf%inode.lt.10000) then
    write(filenamec,'(a,i4.4)') 'INT_CWFN_', peinf%inode
  else
    call die('input: cannot use more than 10000 nodes')
  endif
  itpc=128+(2*peinf%inode)+1
  call open_file(itpc,file=filenamec,form='unformatted',status='replace')
!
  SAFE_ALLOCATE(wfnc%isort, (gvec%ng))
  do irk=1,kp%nrk

    irks = 0
    do ii=1,kg%nr
      if (irk.eq.indxk(ii)) then
        irks=ii
        exit
      endif
    enddo

    SAFE_ALLOCATE(gvec_kpt%k, (3, kp%ngk(irk)))
    call read_binary_gvectors(25, kp%ngk(irk), kp%ngk(irk), gvec_kpt%k)

    SAFE_ALLOCATE(cg, (kp%ngk(irk),kp%nspin))
    if(irks > 0) then

      do ii = 1, kp%ngk(irk)

        call findvector(wfnc%isort(ii), gvec_kpt%k(1, ii), gvec_kpt%k(2, ii), gvec_kpt%k(3, ii), gvec)
        if(wfnc%isort(ii) == 0) call die('input: could not find gvec')
      enddo
!

      wfnc%ng=kp%ngk(irk)
      SAFE_ALLOCATE(wfnc%cg, (wfnc%ng,wfnc%nband,wfnc%nspin))

!       Determine which PEs will write the wavefunctions for this k-point
      iwrite=0
      do ii=1, peinf%ikt(peinf%inode+1)
        if(kg%indr(peinf%ik(peinf%inode+1,ii)).eq.irks) then
          iwrite=1
          exit
        endif
      enddo


!       Determine to which PEs the wavefunctions for this k-point
!       need to be sent...
      SAFE_ALLOCATE(isend, (peinf%npes))
      isend=0
      if(peinf%inode.eq.0) then
        do jj=2,peinf%npes
          do ii=1, peinf%ikt(jj)
            if(kg%indr(peinf%ik(jj,ii)).eq.irks) then
              isend(jj)=1
              exit
            endif
          enddo
        enddo
      endif
    endif
!
!       Loop over the bands
!
    do ii=1,kp%mnband

      call read_binary_data(25, kp%ngk(irk), kp%ngk(irk), kp%nspin, cg)
      
      if(irks == 0) cycle
      
      if(peinf%inode.eq.0) then
        do kk = 1, kp%nspin
          call checknorm('WFN_fi',ii,irks,kk,kp%ngk(irk),cg(:,kk))
        enddo
      endif

!         If ii is one of the selected bands...
      if((ii.gt.kp%nvband).and.(ii.le.kp%nvband+xct%ncband)) then
#ifdef MPI
        if(peinf%inode.eq.0) then
          do jj=2,peinf%npes
            if(isend(jj).eq.1) then
              dest=jj-1
              tag=1000+dest
              call MPI_SEND(cg,kp%ngk(irk)*kp%nspin,MPI_SCALAR, &
                dest,tag,MPI_COMM_WORLD,mpierr)
            endif
          enddo
        else
          if(iwrite.eq.1) then
            tag=1000+peinf%inode
            call MPI_RECV(cg,kp%ngk(irk)*kp%nspin,MPI_SCALAR, &
              0,tag,MPI_COMM_WORLD,mpistatus,mpierr)
          endif
        endif
#endif

        if(iwrite.eq.1) &
          wfnc%cg(1:wfnc%ng,ii-kp%nvband,1:wfnc%nspin)= &
          cg(1:wfnc%ng,1:wfnc%nspin)
        
      endif !ii is one of the selected bands
      
    enddo
    SAFE_DEALLOCATE(cg)
    if(irks == 0) cycle
    
    if(iwrite.eq.1) then
      
      write(itpc) irks,wfnc%ng,wfnc%nband,wfnc%nspin
      write(itpc) (wfnc%isort(ii),ii=1,gvec%ng), &
        (((wfnc%cg(ii,jj,kk),ii=1,wfnc%ng),jj=1,wfnc%nband),kk=1,wfnc%nspin)
    endif
    
    SAFE_DEALLOCATE(isend)
    SAFE_DEALLOCATE_P(wfnc%cg)
    
  enddo !end loop over k-points
  SAFE_DEALLOCATE_P(wfnc%isort)
  SAFE_DEALLOCATE(indxk)
  call close_file(itpc)

  if(peinf%inode.eq.0) then
    write(6,3004)
3004 format(/,2x,'crystal wavefunctions read from tape WFN_fi')
    write(6,3007) kg%nr
3007 format(/,6x,'nrk= ',i6,26x)
    write(6,'(12x,3f10.4)') ((kg%r(ii,jj),ii=1,3),jj=1,kg%nr)
    write(6,3070) kg%nf,kg%sz
3070 format(/,6x,'  fine grid     nfk= ',i6,4x,'ksz=',f10.5)
    
    call close_file(25)
  endif !end if(inode.eq.0)
  
  SAFE_DEALLOCATE_P(kp%rk)
  SAFE_DEALLOCATE_P(kp%ifmin)
  SAFE_DEALLOCATE_P(kp%ifmax)
  SAFE_DEALLOCATE_P(kp%el)

  ! only needed for comm_disk
#ifdef MPI
  call MPI_Barrier(MPI_COMM_WORLD, mpierr)
#endif

  POP_SUB(input)

  return
end subroutine input
