!==================================================================================
!
! Routines:
!
! (1) distrib_kernel()       Originally By MLT       Last Modified 7/10/2008 (JRD)
!
!      nv = # of valence bands
!      nc = # of conduction bands
!      nk = # of k-points
!      np = # of MPI processes
!
!=================================================================================

#include "f_defs.h"

subroutine distrib_kernel(xct,ng,kg,kgq,gvec)

  use global_m
  use fftw_m
  use misc_m
  implicit none

  type (xctinfo), intent(inout) :: xct
  integer, intent(in) :: ng
  type (grid), intent(in) :: kg
  type (grid), intent(in) :: kgq !< this is only for the not-working finite-q kernel
  type (gspace), intent(inout) :: gvec

  integer :: npools,ipool,ipoolrank,iholdperown
  integer :: nproc, nown, npown
  integer :: ik,ikp,kvp,jcp,iownmax,iownwfmax,iownkmax
  integer :: ii,kv,jc,ipe,ierr
  integer :: ijk,nv,nc,nq
  integer :: Nrod,Nplane,Nfft(3),dNfft(3),dkmax(3),nmpinode
  real(DP) :: mem,rmem,rmem2,rmemtemp1,rmemtemp2
  real(DP) :: scale,dscale
  character*128 :: tmpstr
  character*16 :: tmpstr1,tmpstr2,tmpstr3,tmpstr4
  integer, allocatable :: iown(:)

  PUSH_SUB(distrib_kernel)
      
!-----------------------------
! Determine the Parallelization Scheme

  if (peinf%npes .gt. 2*(xct%nkpt*xct%ncband)**2 .or. xct%ilowmem .eq. 1) then
    xct%ivpar = 1
    xct%icpar = 1
    nproc = peinf%npes/(xct%nkpt**2*xct%ncband**2)
    
    if (nproc .eq. 0) then
      npools = 0
    else
      call createpools(xct%nvband,xct%nvband,nproc,npools,nown,npown)
    endif
    
    peinf%nck = (xct%nkpt*xct%ncband*xct%nvband)**2
    tmpstr = '(nkcv)^2'
    if (peinf%inode .eq. 0) write(6,*) "Parallelizing over (k,c,v)^2"
  else if (peinf%npes .gt. 2*(xct%nkpt)**2) then
    xct%ivpar = 0
    xct%icpar = 1
    nproc = peinf%npes/(xct%nkpt**2)
    
    call createpools(xct%ncband,xct%ncband,nproc,npools,nown,npown)
    
    peinf%nck = (xct%nkpt*xct%ncband)**2
    tmpstr = '(nkc)^2'
    if (peinf%inode .eq. 0) write(6,*) "Parallelizing over (k,c)^2"
  else
    xct%ivpar = 0
    xct%icpar = 0
    nproc = peinf%npes
    
    call createpools(xct%nkpt,xct%nkpt,nproc,npools,nown,npown)
    
    peinf%nck = (xct%nkpt)**2
    tmpstr = '(nk)^2'
    if (peinf%inode .eq. 0) write(6,*) "Parallelizing over (k)^2"
  endif
  
  if (npools .eq. 0) then
    if (mod((xct%nkpt*xct%ncband*xct%nvband)**2,peinf%npes).eq.0) then
      peinf%nckpe = (xct%nkpt*xct%ncband*xct%nvband)**2 / peinf%npes
    else
      peinf%nckpe = (xct%nkpt*xct%ncband*xct%nvband)**2 / peinf%npes + 1
    endif
  else
    peinf%nckpe = nown * npown
  endif
  
  if (peinf%inode .eq. 0) then
    if (npools .eq. 0) then
      tmpstr = "Using Low Memory Option"
    else
      write(tmpstr1,'(i10)') nproc
      write(tmpstr2,'(i10)') npools
      write(tmpstr3,'(i10)') nown
      write(tmpstr4,'(i10)') npown
      tmpstr = "Using pools:  nproc = " // TRUNC(tmpstr1) // &
        "  npool = " // TRUNC(tmpstr2) // "  nown = " // &
        TRUNC(tmpstr3) // "  npown = " // TRUNC(tmpstr4)
    endif
    write(6,*) tmpstr
  endif

!---------------------------------
! Determine if the conditions for a decent load-balancing are met

  if (mod(peinf%nck,peinf%npes).ne.0) then
    if(peinf%inode.eq.0) then
      write(0,993) TRUNC(tmpstr),peinf%nck,peinf%npes
    endif
  endif
993 format(/,1x,'WARNING:',1x,a,1x,"=",i8,/, &
      8x,'is not a multiple of the number of PEs =',i6,/, &
      8x,'Optimal load-balancing cannot be achieved.')

!---------------------------------

  SAFE_ALLOCATE(peinf%ik, (peinf%npes,peinf%nckpe))
  SAFE_ALLOCATE(peinf%ic, (peinf%npes,peinf%nckpe))
  SAFE_ALLOCATE(peinf%iv, (peinf%npes,peinf%nckpe))
  SAFE_ALLOCATE(peinf%ikp, (peinf%npes,peinf%nckpe))
  SAFE_ALLOCATE(peinf%icp, (peinf%npes,peinf%nckpe))
  SAFE_ALLOCATE(peinf%ivp, (peinf%npes,peinf%nckpe))
  SAFE_ALLOCATE(peinf%ipev, (peinf%npes,xct%nvband,xct%nkpt))
  SAFE_ALLOCATE(peinf%ipec, (peinf%npes,xct%ncband,xct%nkpt))
  SAFE_ALLOCATE(peinf%ipek, (peinf%npes,xct%nkpt))
  
  SAFE_ALLOCATE(iown, (peinf%npes))
  SAFE_ALLOCATE(peinf%iownwfv, (peinf%npes))
  SAFE_ALLOCATE(peinf%iownwfc, (peinf%npes))
  SAFE_ALLOCATE(peinf%iownwfk, (peinf%npes))
  
  if (xct%ivpar .eq. 1) then
    SAFE_ALLOCATE(peinf%wown, (xct%nvband,xct%ncband,xct%nkpt,xct%nvband,xct%ncband,xct%nkpt))
  else if (xct%icpar .eq. 1) then
    SAFE_ALLOCATE(peinf%wown, (1,xct%ncband,xct%nkpt, 1,xct%ncband,xct%nkpt))
  else
    SAFE_ALLOCATE(peinf%wown, (1,1,xct%nkpt,1,1,xct%nkpt))
  endif

  peinf%ik=0
  peinf%ic=0
  peinf%iv=0
  peinf%ikp=0
  peinf%icp=0
  peinf%ivp=0
  peinf%ipev=0
  peinf%ipec=0
  peinf%ipek=0
  peinf%iownwfv=0
  peinf%iownwfc=0
  peinf%iownwfk=0
  peinf%wown=0
  
  ipe=0
  iown=0
  
  if (xct%ivpar .eq. 1) then
    iholdperown = 1
  else if (xct%icpar .eq. 1) then
    iholdperown = (xct%nvband)**2
  else
    iholdperown = (xct%nvband*xct%ncband)**2
  endif

  do ikp=1,xct%nkpt
    do ik=1,xct%nkpt
      do jcp=1,xct%ncband
        do jc=1,xct%ncband
          do kvp=1,xct%nvband
            do kv=1,xct%nvband
              
              if (npools .eq. 0) then
                ipe = mod((((ik-1)*xct%nkpt+(ikp-1))*xct%ncband**2 &
                  +(jc-1)*xct%ncband+(jcp-1))*xct%nvband**2 &
                  +(kv-1)*xct%nvband+(kvp-1),peinf%npes)
              elseif (xct%icpar .ne. 1) then
                ipool = mod((ik-1),npools)
                ipoolrank = mod((ikp-1),nproc/npools)
                ipe = ipool*(nproc/npools)+ipoolrank
              else if (xct%ivpar .ne. 1) then
                ipool = mod((jc-1),npools)
                ipoolrank = mod((jcp-1),nproc/npools)
                ipe = ipool*(nproc/npools)+ipoolrank &
                  +((ik-1)*xct%nkpt+(ikp-1))*nproc
              else
                ipool = mod((kv-1),npools)
                ipoolrank = mod((kvp-1),nproc/npools)
                ipe = ipool*(nproc/npools)+ipoolrank &
                  +(((ik-1)*xct%nkpt+(ikp-1))*xct%ncband**2 &
                  +(jc-1)*xct%ncband+(jcp-1))*nproc
              endif
              
              ipe = ipe + 1
              
              if ((xct%ivpar .eq. 1 .or. (kv .eq. 1 .and. kvp .eq. 1)) .and. &
                (xct%icpar .eq. 1 .or. (jc .eq. 1 .and. jcp .eq. 1))) then
                iown(ipe)=iown(ipe)+1
                if (ipe .eq. peinf%inode+1) then
                  if (xct%ivpar .eq. 1) then
                    peinf%wown(kvp,jcp,ikp,kv,jc,ik) = (iown(ipe)-1) * iholdperown + 1
                  else if (xct%icpar .eq. 1) then
                    peinf%wown(1,jcp,ikp,1,jc,ik) = (iown(ipe)-1) * iholdperown + 1
                  else
                    peinf%wown(1,1,ikp,1,1,ik) = (iown(ipe)-1) * iholdperown + 1
                  endif
                endif
                peinf%iv(ipe,iown(ipe)) = kv
                peinf%ivp(ipe,iown(ipe)) = kvp
                peinf%ic(ipe,iown(ipe)) = jc
                peinf%icp(ipe,iown(ipe)) = jcp
                peinf%ik(ipe,iown(ipe)) = ik
                peinf%ikp(ipe,iown(ipe)) = ikp
              endif
              
              if (peinf%ipec(ipe,jc,kg%indr(ik)).eq.0) then
                if (xct%icpar .eq. 1) then
                  peinf%iownwfc(ipe)=peinf%iownwfc(ipe)+1
                  peinf%ipec(ipe,jc,kg%indr(ik))=peinf%iownwfc(ipe)
                else if (jc .eq. 1) then
                  do ijk =1, xct%ncband
                    peinf%ipec(ipe,ijk,kg%indr(ik))=peinf%iownwfc(ipe)+ijk
                  enddo
                  peinf%iownwfc(ipe)=peinf%iownwfc(ipe)+xct%ncband
                endif
              endif
              
              if (peinf%ipec(ipe,jcp,kg%indr(ikp)).eq.0) then
                if (xct%icpar .eq. 1) then
                  peinf%iownwfc(ipe)=peinf%iownwfc(ipe)+1
                  peinf%ipec(ipe,jcp,kg%indr(ikp))=peinf%iownwfc(ipe)
                else if (jcp .eq. 1) then
                  do ijk =1, xct%ncband
                    peinf%ipec(ipe,ijk,kg%indr(ikp))=peinf%iownwfc(ipe)+ijk
                  enddo
                  peinf%iownwfc(ipe)=peinf%iownwfc(ipe)+xct%ncband
                endif
              endif
              
              if (peinf%ipev(ipe,kv,kgq%indr(ik)).eq.0) then
                if (xct%ivpar .eq. 1) then
                  peinf%iownwfv(ipe)=peinf%iownwfv(ipe)+1
                  peinf%ipev(ipe,kv,kgq%indr(ik))=peinf%iownwfv(ipe)
                else if (kv .eq. 1) then
                  do ijk =1, xct%nvband
                    peinf%ipev(ipe,ijk,kg%indr(ik))=peinf%iownwfv(ipe)+ijk
                  enddo
                  peinf%iownwfv(ipe)=peinf%iownwfv(ipe)+xct%nvband
                endif
              endif
              
              if (peinf%ipev(ipe,kvp,kgq%indr(ikp)).eq.0) then
                if (xct%ivpar .eq. 1) then
                  peinf%iownwfv(ipe)=peinf%iownwfv(ipe)+1
                  peinf%ipev(ipe,kvp,kgq%indr(ikp))=peinf%iownwfv(ipe)
                else if (kvp .eq. 1) then
                  do ijk =1, xct%nvband
                    peinf%ipev(ipe,ijk,kg%indr(ikp))=peinf%iownwfv(ipe)+ijk
                  enddo
                  peinf%iownwfv(ipe)=peinf%iownwfv(ipe)+xct%nvband
                endif
              endif
              
              if (peinf%ipek(ipe,kgq%indr(ik)).eq.0) then
                peinf%iownwfk(ipe)=peinf%iownwfk(ipe)+1
                peinf%ipek(ipe,kgq%indr(ik))=peinf%iownwfk(ipe)
              endif
              
              if (peinf%ipek(ipe,kgq%indr(ikp)).eq.0) then
                peinf%iownwfk(ipe)=peinf%iownwfk(ipe)+1
                peinf%ipek(ipe,kgq%indr(ikp))=peinf%iownwfk(ipe)
              endif
              
            enddo
          enddo
        enddo
      enddo
    enddo
  enddo

!---------------------------------
! Calculate Max of iown and iownwf Their Memory Cost

  iownmax = 0
  iownwfmax = 0
  iownkmax = 0
  
  do ii=1,peinf%npes
    if (iown(ii) .gt. iownmax) iownmax = iown(ii)
    if ((peinf%iownwfv(ii) + peinf%iownwfc(ii)) .gt. iownwfmax) &
      iownwfmax = (peinf%iownwfc(ii)+peinf%iownwfv(ii))
  enddo
  
  do ii=1,peinf%npes
    if (peinf%iownwfk(ii) .gt. iownkmax) iownkmax = peinf%iownwfk(ii)
  enddo
  
  peinf%myown = iown(peinf%inode + 1)
  
  if (iownmax .ne. peinf%nckpe) then
    write(tmpstr1,'(i10)') peinf%nckpe
    write(tmpstr2,'(i10)') iownmax
    tmpstr = 'nckpe estimate wrong, nckpe = ' // &
      TRUNC(tmpstr1) // ', iownmax = ' // TRUNC(tmpstr2)
    call die(tmpstr)
  endif

  SAFE_DEALLOCATE(iown)

!---------------------------------
! Determine the available memory

  call procmem(mem,nmpinode)
  if (peinf%inode .eq. 0) then
    write(6,998) mem/1024.0d0**2
  endif
998 format(/,1x,'Memory available:',f10.1,1x,'MB per PE')
  
!---------------------------------
! JRD: Report Memory

  if (peinf%inode .eq. 0) then

    call open_file(11,file='epsmat',form='unformatted',status='old',iostat=ierr)
    if (ierr.eq.0) then
      read(11)
      read(11)
      read(11)
      read(11)
      read(11)
      read(11)
      read(11)
      read(11) nq
      nq=nq+1
      call close_file(11)
    else
      nq=1
    endif

    rmem=0D0

! Storing epsilon
    if (xct%bLowComm) then
      rmem=rmem+dble(xct%neps*xct%neps*nq)
      !write(6,*) "Epsilon Low Comm", dble(xct%neps*xct%neps*nq)
    else
      rmem=rmem+dble(xct%neps*xct%neps*nq/peinf%npes)
      !write(6,*) "Epsilon Low Comm", dble(xct%neps*xct%neps*nq/peinf%npes)
    endif
! Storing intwfnv and intwfnc and wfnv wnfc wfnvp wfncp
    rmem=rmem+dble(iownwfmax+4)*dble(ng)
    !write(6,*) "intwfn etc...", dble(iownwfmax+4)*dble(ng)
! Storing bsemats
    rmem=rmem+4D0*dble(iownmax)*dble(iholdperown)
    !write(6,*) "bsemats", 4D0*dble(iownmax)*dble(iholdperown)
! Storing bsemat write temp arrays
    rmem=rmem+8D0*dble(xct%nvband*xct%ncband*xct%nkpt)
    nv=xct%nvband
    nc=xct%ncband

! JRD: Since mvv,mcc,tempb,tempw etc.. are not allocated
! at the same time as mvc etc.. we see which one is bigger

    rmemtemp1=0D0
    rmemtemp2=0D0

! Note this assumes that bare coulomb cutoff is same as wf cutoff
! Storing tempw, tempb, mvv, mvvold
    if (xct%ivpar .eq. 1) then
      rmemtemp1=rmemtemp1+4D0*dble(ng)
      !write(6,*) "mvv +old", 4D0*dble(ng) 
    else ! no mvvold in this case
      rmemtemp1=rmemtemp1+3D0*dble(ng*nv*nv)
      !write(6,*) "mvv -old", 3D0*dble(ng*nv*nv) 
    endif
! Storing mcc, mccold
    if (xct%icpar .eq. 1) then
      rmemtemp1=rmemtemp1+2D0*dble(ng)
      !write(6,*) "mcc +old", 2D0*dble(ng) 
    else ! no mccold in this case
      rmemtemp1=rmemtemp1+1D0*dble(ng*nc*nc)
      !write(6,*) "mcc -old", 1D0*dble(ng*nc*nc) 
    endif

! Storing mvc, mvpcp, mvcold, mvpcpold
    if (xct%ivpar .eq. 1 .and. xct%icpar .eq. 1) then
      rmemtemp2=rmemtemp2+4D0*dble(ng)
      !write(6,*) "mvc +old", 4D0*dble(ng) 
    else 
      rmemtemp2=rmemtemp2+2D0*dble(ng*nv*nc)
      !write(6,*) "mvc -old", 2D0*dble(ng*nv*nc) 
    endif

! outtemp in gx_sum
    rmemtemp2=rmemtemp2+dble(iholdperown)    

    if (rmemtemp1 .gt. rmemtemp2) then
      rmem = rmem + rmemtemp1
    else
      rmem = rmem + rmemtemp2
    endif

! FFTBOXES
    call setup_FFT_sizes(gvec%kmax,Nfft,scale)
    rmem=rmem+2D0*dble(Nfft(1)*Nfft(2)*Nfft(3))

    rmem=rmem * sizeof_scalar()

! Storing intwfnc%isort intwfnv%isort
    rmem=rmem+dble(2*iownkmax*gvec%ng)*4D0
! Array gvec%indv in input_kernel
    gvec%nktot=product(gvec%kmax(1:3))
    rmem=rmem+dble(gvec%nktot)*4.0d0

    write(6,989) rmem/1024.0d0**2
989 format(1x,'Memory required for execution:',f7.1,1x,'MB per PE')
  endif

!---------------------------------------------------------
! (gsm) Determine the amount of memory required for Vcoul

! random numbers
  rmem=0.0D0
! (gsm) We don`t do random numbers in kernel anymore
!      if (xct%icutv.ne.5) then
! arrays ran, qran, and qran2
! (ran is deallocated before qran2 is allocated)
!        rmem=rmem+6.0D0*dble(nmc)*8.0D0
!      endif
! various truncation schemes
  call setup_FFT_sizes(gvec%kmax,Nfft,scale)
  rmem2=0.0d0
! cell wire truncation
  if (xct%icutv.eq.4) then
    dkmax(1) = gvec%kmax(1) * n_in_wire
    dkmax(2) = gvec%kmax(2) * n_in_wire
    dkmax(3) = 1
    call setup_FFT_sizes(dkmax,dNfft,dscale)
! array fftbox_2D
    rmem2=rmem2+dble(dNfft(1))*dble(dNfft(2))*16.0d0
! array inv_indx
    rmem2=rmem2+dble(Nfft(1))*dble(Nfft(2))*dble(Nfft(3))*4.0d0
! array qran
    rmem2=rmem2+3.0D0*dble(nmc)*8.0D0
  endif
! cell box truncation (parallel version only)
  if (xct%icutv.eq.5) then
    dkmax(1:3) = gvec%kmax(1:3) * n_in_box
    call setup_FFT_sizes(dkmax,dNfft,dscale)
    if (mod(dNfft(3),peinf%npes) == 0) then
      Nplane = dNfft(3)/peinf%npes
    else
      Nplane = dNfft(3)/peinf%npes+1
    endif
    if (mod(dNfft(1)*dNfft(2),peinf%npes) == 0) then
      Nrod = (dNfft(1)*dNfft(2))/peinf%npes
    else
      Nrod = (dNfft(1)*dNfft(2))/peinf%npes+1
    endif
! array fftbox_2D
    rmem2=rmem2+dble(dNfft(1))*dble(dNfft(2))*dble(Nplane)*16.0d0
! array fftbox_1D
    rmem2=rmem2+dble(dNfft(3))*dble(Nrod)*16.0d0
! array dummy
!        rmem2=rmem2+dble(dNfft(1))*dble(dNfft(2))*16.0d0
! arrays dummy1 and dummy2
    rmem2=rmem2+dble(Nrod)*dble(peinf%npes+1)*16.0d0
! array inv_indx
    rmem2=rmem2+dble(Nfft(1))*dble(Nfft(2))*dble(Nfft(3))*4.0d0
  endif
  if (rmem2 .gt. rmem) rmem = rmem2
  if (peinf%inode .eq. 0) then
    write(6,988) rmem/1024.0d0**2
  endif
988 format(1x,'Memory required for vcoul:',f7.1,1x,'MB per PE')

!---------------------------------

  if (peinf%inode .eq. 0) then
    write(6,*) iownmax*iholdperown, ' elements per PE'
    write(6,*) iownwfmax, ' wavefunctions stored per PE'
    write(6,*) ng, ' G-vectors per wavefunction'
  endif

  POP_SUB(distrib_kernel)

  return
end subroutine distrib_kernel
