!===============================================================================
!
! subsample_plan       Originally By FHJ       Last Modified 07/27/2014 (FHJ)
!
!
! Given a WFN file, automatically finds the radial subsampling q-points, their
! weights, and the k-points associated to a WFNq file. The final WFNq file,
! along with WFN, can be used in Epsilon to generate the subsampled eps0mat file.
!
! TODO:
! - Support HDF5
!
!===============================================================================

#include "f_defs.h"

program subsample_plan

  use global_m
  use wfn_rho_vxc_io_m
  use random_m
  use irrbz_m
  use fullbz_m
  use checkbz_m
  implicit none

  integer :: narg
  character(len=256) :: fname_wfn, tmpstr
  character(len=5) file_fmt

  real(DP) :: dq(3), dq_abs, qmax, bdot(3,3), ee(3), ee_abs, dq_cur
  real(DP) :: qq(3), qq_min(3), qq_tmp(3), len_min, len_tmp
  real(DP), allocatable  :: qran(:,:), qlen(:), qq_abs(:)
  real(DP) :: dq_abs_min, dq_abs_max, weight
  integer :: nn, degree
  integer :: i1, i2, i3, iq, nq, ii, ni(3)
  type(mf_header_t) :: mf
  logical :: auto_mode, skip_checkbz
  logical, allocatable :: cond(:)
  type(grid) :: gr
  
  PUSH_SUB(subsample_plan)

  narg = iargc()
  if (narg/=3.and.narg/=6.and.narg/=7) then
    write(0,*) 'Usage: subsample_plan.x ASCII|BIN|HDF5 WFN nq [e1 e2 e3] [degree]'
    write(0,*)
    write(0,*) 'Required arguments:'
    write(0,*) '  ASCII|BIN|HDF5: format of the WFN file'
    write(0,*) '  WFN: WFN file to used in the Epsilon/Sigma calculation'
    write(0,*) '  nq: number of radial subsampled q-points'
    write(0,*)
    write(0,*) 'Optional arguments:'
    write(0,*) '  e1, e2, e3: direction of the shift in crystal coords (defaults to 1 0 0)'
    write(0,*) '  degree: the radial intervals are define by dq(i) = dq0 * i^{degree};'
    write(0,*) '          valid options are 0, 1 and 2 (default is 1)'
    stop
  endif
  call getarg(1, file_fmt)
  call getarg(2, fname_wfn)
  call getarg(3, tmpstr)
  read(tmpstr,*) nq
  degree = 1
  auto_mode = .true.
  if (narg>3) then
    call getarg(4, tmpstr)
    read(tmpstr,*) ee(1)
    call getarg(5, tmpstr)
    read(tmpstr,*) ee(2)
    call getarg(6, tmpstr)
    read(tmpstr,*) ee(3)
    auto_mode = .false.
  endif
  if (narg==7) then
    call getarg(6, tmpstr)
    read(tmpstr,*) degree
  endif
  if (file_fmt=='ASCII') then
    call open_file(unit=11, file=TRIM(fname_wfn), form='formatted', status='old')
    call read_mf_header(11, mf)
    call close_file(11)
  elseif (file_fmt=='BIN  ') then
    call open_file(unit=11, file=TRIM(fname_wfn), form='unformatted', status='old')
    call read_mf_header(11, mf)
    call close_file(11)
  elseif (file_fmt=='HDF5 ') then
    call die('not implemented')
    call read_mf_header(11, mf)
  else
    call die('Unknown format "'//TRIM(file_fmt)//'". Must be either ASCII, BIN or HDF5.')
  endif
  write(6,'(a,3(1x,i0))') 'Read k-grid:', mf%kp%kgrid

  if (auto_mode) then
    ee(:) = 0d0
    do ii=1,3
      if (mf%kp%kgrid(ii)>1) then
        ee(ii) = 1d0
        exit
      endif
    enddo
    if (all(ee==0d0)) then
      call die('There must be at least one periodic direction.')
    endif
  endif ! auto_mode
  write(6,'(a,3(1x,f5.3))') 'Shifting in direction:', ee
  write(6,*)

  ! FHJ: sample Voronoi cell
  nn = 10000000
  write(6,'(a,i0,a)') 'Sampling q=0 Voronoi cell stocastically with ', nn, ' points.'
  SAFE_ALLOCATE(qran, (3,nn))
  SAFE_ALLOCATE(qlen, (nn))
  call genrand_init(put=5000)
  do iq = 1, nn
    do ii = 1, 3
      call genrand_real4(qran(ii, iq))
    enddo
  enddo
  do ii = 1, 3
    if (mf%kp%kgrid(ii)==1d0) qran(ii,:) = 0d0
  enddo

  ni(:) = 1
  where (mf%kp%kgrid==1d0)
    ni(:) = 0
  endwhere

  do ii = 1, 3
    bdot(:,ii) = mf%crys%bdot(:,ii)/mf%kp%kgrid(ii)
  enddo
  do ii = 1, 3
    bdot(ii,:) = bdot(ii,:)/mf%kp%kgrid(ii)
  enddo
  do iq = 1, nn
    len_min = INF
    qq(:) = qran(:,iq)
    do i1 = -ni(1), ni(1)
      qq_tmp(1) = qq(1) - i1
      do i2 = -ni(2), ni(2)
        qq_tmp(2) = qq(2) - i2
        do i3 = -ni(3), ni(3)
          qq_tmp(3) = qq(3) - i3
          len_tmp = DOT_PRODUCT(qq_tmp,MATMUL(bdot,qq_tmp))
          if (len_tmp<len_min) then
            len_min = len_tmp
            qq_min(:) = qq_tmp(:)
          endif
        enddo
      enddo
    enddo
    qlen(iq) = sqrt(len_min)
    qran(1:3,iq) = qq_min(1:3)
  enddo ! iq
  qmax = maxval(qlen)
  write(6,'(2x,a,es12.6,a)') 'Maximum |q| in Voronoi cell = ', qmax,' 1/bohr'

  ! We want to bin the random q-points radially and find the appropriate weights
  ! |           *           |           *        ...        |
  ! 0           dq         2dq         3dq          (2nq)dq = qmax
  ! 
  ! where: * is a point where we`ll calculated epsinv(q)
  !        | is the boundary of each interval
  SAFE_ALLOCATE(qq_abs, (nq))
  degree = 1
  if (degree==0) then
    ! sum_(i=1)^n i^0 = n 
    dq_abs = qmax/(2d0*nq)
  elseif (degree==1) then
    ! sum_(i=1)^n i^1 = 1/2 n (n+1)
    dq_abs = qmax/(nq*(nq+1d0))
  elseif (degree==2) then
    ! sum_(i=1)^n i^2 = 1/6 n (n+1) (2 n+1)
    dq_abs = 3d0*qmax/(nq*(nq+1d0)*(2d0*nq+1d0))
  else
    call die('Unsupported degree.')
  endif

  ee_abs = sqrt(DOT_PRODUCT(ee,MATMUL(mf%crys%bdot,ee)))
  dq = ee/ee_abs * dq_abs
  write(6,'(2x,a,3(1x,es22.15),a)') 'dq = ', dq, ' (crystal units)'
  write(6,'(2x,a,1x,es22.15,1x,a)') '|dq| = ', dq_abs, ' 1/bohr'
  write(6,*)

  write(6,'(a,a)') 'Determining q0 weights.'
  SAFE_ALLOCATE(cond, (nn))
  dq_abs_min = 0d0
  call open_file(unit=13, file='subweights.dat', form='formatted', status='replace')
  write(13,'(i0)') nq
  do iq=1,nq
    dq_cur = 2d0*dq_abs*(iq**degree)
    dq_abs_max = dq_abs_min + dq_cur
    cond = qlen>=dq_abs_min.and.qlen<dq_abs_max
    weight = count(cond)
    qq_abs(iq) = dq_abs_min + 0.5d0*dq_cur
    write(6,'(2x,3(a,es12.6,2x))') '|q| = ', qq_abs(iq), 'weight = ', weight
    write(13,'(2(es12.6,2x))') weight, qq_abs(iq)
    dq_abs_min = dq_abs_max
  enddo
  SAFE_DEALLOCATE(cond)
  call close_file(13)
  write(6,'(2x,2a)') 'Wrote q0 weights to subweights.dat'

  call open_file(unit=13, file='epsilon_q0s.inp', form='formatted', status='replace')
  write(13,*)
  write(13,'(a)') 'subsample'
  write(13,'(a,i0)') 'number_qpoints ', nq
  write(13,'(a)') 'begin qpoints'
  do iq = 1, nq
    write(13,'(3(es22.15,1x),a)') qq_abs(iq)/dq_abs*dq, '1 1'
  enddo
  write(13,'(a)') 'end'
  write(13,*)
  call close_file(13)
  write(6,'(2x,a)') 'Wrote lines for epsilon.inp in epsilon_q0s.inp'
  write(6,*)

  write (6,'(a)') 'Generating k-points for WFNq file(s)'
  ! FHJ: Use symmetries to unfold BZ from WFN
  gr%nr = mf%kp%nrk
  SAFE_ALLOCATE(gr%r, (3, gr%nr))
  gr%r = mf%kp%rk
  call fullbz(mf%crys, mf%syms, gr, mf%syms%ntran, skip_checkbz, wigner_seitz=.false., paranoid=.true.)
  if (.not.skip_checkbz) call checkbz(gr%nf, gr%f, mf%kp%kgrid, mf%kp%shift, &
    mf%crys%bdot, TRUNC(fname_wfn), 'k', .false., .false., .false.)
  ! FHJ: Find subgroup that leaves dq invariant
  call subgrp(dq, mf%syms)
  ! FHJ: For each dq, displace k-points by dq and use symmetries in the
  ! subgroup to fold the k-points
  do ii=0,nq
    call gen_kpoints_file(ii)
  enddo
  write(6,*)

  POP_SUB(subsample_plan)

  contains

    subroutine gen_kpoints_file(iq_in)
      integer, intent(in) :: iq_in

      real(DP), allocatable :: kpts_new(:,:)
      integer :: nk_new, nk_fold, iq_min, iq_max, ik
      integer, allocatable :: neq(:), indrk(:)
      character(len=64) :: fname_kpts

      POP_SUB(subsample_plan.gen_kpoints_file)

      if (iq_in>0) then
        nk_new = gr%nf
        iq_min = iq_in
        iq_max = iq_in
        write(fname_kpts,'(a,i3.3,a)') 'kpoints_', iq_in, '.dat'
      else
        nk_new = gr%nf*nq
        iq_min = 1
        iq_max = nq
        fname_kpts = 'kpoints_all.dat'
      endif

      SAFE_ALLOCATE(kpts_new, (3,nk_new))
      SAFE_ALLOCATE(indrk, (nk_new))
      SAFE_ALLOCATE(neq, (nk_new))
      nk_new = 0
      do iq=iq_min,iq_max
        qq = qq_abs(iq)/dq_abs*dq
        forall(ik=1:gr%nf) kpts_new(1:3, nk_new+ik) = gr%f(1:3, ik) + qq(1:3)
        nk_new = nk_new + gr%nf
      enddo
      call irrbz(mf%syms, nk_new, kpts_new, nk_fold, neq, indrk)

      call open_file(unit=13, file=TRUNC(fname_kpts), form='formatted', status='replace')
      write(13,'(a)') 'K_POINTS crystal'
      write(13,'(i0)') nk_fold
      do iq=1, nk_fold
        write(13, '(3(f13.9),f6.1)') kpts_new(:,indrk(iq)), dble(neq(iq))
      enddo
      call close_file(13)
      write(6,'(2x,2a)') 'Wrote kpoints to ', TRUNC(fname_kpts)

      SAFE_DEALLOCATE(kpts_new)
      SAFE_DEALLOCATE(neq)
      SAFE_DEALLOCATE(indrk)

      PUSH_SUB(subsample_plan.gen_kpoints_file)

    end subroutine gen_kpoints_file

end program subsample_plan
