!! Copyright (C) 2002-2006 M. Marques, A. Castro, A. Rubio, G. Bertsch
!!
!! This program is free software; you can redistribute it and/or modify
!! it under the terms of the GNU General Public License as published by
!! the Free Software Foundation; either version 2, or (at your option)
!! any later version.
!!
!! This program is distributed in the hope that it will be useful,
!! but WITHOUT ANY WARRANTY; without even the implied warranty of
!! MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
!! GNU General Public License for more details.
!!
!! You should have received a copy of the GNU General Public License
!! along with this program; if not, write to the Free Software
!! Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
!! 02110-1301, USA.
!!
!! $Id: states.F90 15027 2016-01-09 17:32:09Z xavier $

#include "global.h"

module states_m
  use blacs_proc_grid_m
  use calc_mode_m
#ifdef HAVE_OPENCL
  use cl
#endif
  use cmplxscl_m
  use comm_m
  use batch_m
  use batch_ops_m
  use blas_m
  use datasets_m
  use derivatives_m
  use distributed_m
  use geometry_m
  use global_m
  use grid_m
  use hardware_m
  use io_m
  use kpoints_m
  use lalg_adv_m
  use lalg_basic_m
  use loct_m
  use loct_pointer_m
  use math_m
  use mesh_m
  use mesh_function_m
  use messages_m
  use modelmb_particles_m
  use mpi_m ! if not before parser_m, ifort 11.072 can`t compile with MPI2
  use mpi_lib_m
  use multicomm_m
  use ob_interface_m
#ifdef HAVE_OPENMP
  use omp_lib
#endif
  use opencl_m
  use parser_m
  use profiling_m
  use restart_m
  use simul_box_m
  use smear_m
  use states_group_m
  use states_dim_m
  use symmetrizer_m
  use types_m
  use unit_m
  use unit_system_m
  use utils_m
  use varinfo_m

  implicit none

  private

  public ::                           &
    states_t,                         &
    states_priv_t,                    &
    states_lead_t,                    &
    states_init,                      &
    states_look,                      &
    states_densities_init,            &
    states_exec_init,                 &
    states_allocate_wfns,             &
    states_allocate_intf_wfns,        &
    states_allocate_current,          &
    states_deallocate_wfns,           &
    states_null,                      &
    states_end,                       &
    states_copy,                      &
    states_generate_random,           &
    states_fermi,                     &
    states_eigenvalues_sum,           &
    states_lead_densities_init,       &
    states_lead_densities_end,        &
    states_spin_channel,              &
    states_calc_quantities,           &
    state_is_local,                   &
    state_kpt_is_local,               &
    states_distribute_nodes,          &
    states_wfns_memory,               &
    states_are_complex,               &
    states_are_real,                  &
    states_set_complex,               &
    states_blacs_blocksize,           &
    states_get_state,                 &
    states_set_state,                 &
    states_get_points,                &
    states_pack,                      &
    states_unpack,                    &
    states_sync,                      &
    states_are_packed,                &
    states_write_info,                &
    states_set_zero,                  &
    states_block_min,                 &
    states_block_max,                 &
    states_block_size,                &
    zstates_eigenvalues_sum,          &
    cmplx_array2_t,                   &
    states_wfs_t,                     &
    states_count_pairs,               &
    occupied_states,                  &
    states_get_ob_intf

  !> cmplxscl: Left and Right eigenstates
  type states_wfs_t    
    CMPLX, pointer     :: zL(:, :, :, :) !< (np, st%d%dim, st%nst, st%d%nik)
    CMPLX, pointer     :: zR(:, :, :, :) !< (np, st%d%dim, st%nst, st%d%nik)
    FLOAT, pointer     :: dL(:, :, :, :) !< (np, st%d%dim, st%nst, st%d%nik)
    FLOAT, pointer     :: dR(:, :, :, :) !< (np, st%d%dim, st%nst, st%d%nik)
  end type states_wfs_t
  
  !>cmplxscl: complex 2D matrices 
  type cmplx_array2_t    
    FLOAT, pointer     :: Re(:, :) !< Real components 
    FLOAT, pointer     :: Im(:, :) !< Imaginary components
  end type cmplx_array2_t

  type states_lead_t
    CMPLX, pointer     :: intf_psi(:, :, :, :) !< (np, st%d%dim, st%nst, st%d%nik)
    FLOAT, pointer     :: rho(:, :)   !< Density of the lead unit cells.
    CMPLX, pointer     :: self_energy(:, :, :, :, :) !< (np, np, nspin, ncs, nik) self-energy of the leads.
  end type states_lead_t

  type states_priv_t
    private
    type(type_t) :: wfs_type              !< real (TYPE_FLOAT) or complex (TYPE_CMPLX) wavefunctions
  end type states_priv_t

  type states_t
    type(states_dim_t)       :: d
    type(modelmb_particle_t) :: modelmbparticles
    type(states_priv_t)      :: priv                  !< the private components
    integer                  :: nst                   !< Number of states in each irreducible subspace

    logical                  :: only_userdef_istates  !< only use user-defined states as initial states in propagation
    !> pointers to the wavefunctions
    FLOAT, pointer           :: dpsi(:,:,:,:)         !< dpsi(sys%gr%mesh%np_part, st%d%dim, st%nst, st%d%nik)
    CMPLX, pointer           :: zpsi(:,:,:,:)         !< zpsi(sys%gr%mesh%np_part, st%d%dim, st%nst, st%d%nik)
   
   
     
    type(cmplxscl_t)         :: cmplxscl              !< contain the cmplxscl parameters                 
    !> Pointers to complexified quantities. 
    !! When we use complex scaling the Hamiltonian is no longer hermitian.
    !! In this case we have to distinguish between left and right eigenstates of H and
    !! both density and eigenvalues become complex.
    !! In order to modify the code to include this changes we allocate the general structures and 
    !! make the restricted quantities point to a part of the structure.
    ! For instance for the orbitals we allocate psi and make zpsi to point only to Right states as follows:
    !! zpsi => psi%zR
    !! Similarly for density and eigenvalues we make the old quantities to point to the real part:
    !! rho => zrho%Re
    !! eigenval => zeigenval%Re  
    type(states_wfs_t)       :: psi          !< cmplxscl: Left psi%zL(:,:,:,:) and Right psi%zR(:,:,:,:) orbitals    
    type(cmplx_array2_t)     :: zrho         !< cmplxscl: the complexified density <psi%zL(:,:,:,:)|psi%zR(:,:,:,:)>
    type(cmplx_array2_t)     :: zeigenval    !< cmplxscl: the complexified eigenvalues 
    FLOAT,           pointer :: Imrho_core(:)  
    FLOAT,           pointer :: Imfrozen_rho(:, :)   
    type(batch_t),   pointer :: psibL(:, :)  !< Left wave-functions blocks
    logical                  :: have_left_states

    type(states_group_t)     :: group

    logical             :: open_boundaries
    CMPLX, pointer      :: zphi(:, :, :, :)  !< Free states for open-boundary calculations.
    FLOAT, pointer      :: ob_eigenval(:, :) !< Eigenvalues of free states.
    type(states_dim_t)  :: ob_d              !< Dims. of the unscattered systems.
    integer             :: ob_nst            !< nst of the unscattered systems.
    FLOAT, pointer      :: ob_occ(:, :)      !< occupations
    type(states_lead_t) :: ob_lead(2*MAX_DIM)

    !> used for the user-defined wavefunctions (they are stored as formula strings)
    !! (st%d%dim, st%nst, st%d%nik)
    character(len=1024), pointer :: user_def_states(:,:,:)

    !> the densities and currents (after all we are doing DFT :)
    FLOAT, pointer :: rho(:,:)         !< rho(gr%mesh%np_part, st%d%nspin)
    FLOAT, pointer :: current(:, :, :) !<   current(gr%mesh%np_part, gr%sb%dim, st%d%nspin)


    FLOAT, pointer :: rho_core(:)      !< core charge for nl core corrections

    !> It may be required to "freeze" the deepest orbitals during the evolution; the density
    !! of these orbitals is kept in frozen_rho. It is different from rho_core.
    FLOAT, pointer :: frozen_rho(:, :)

    FLOAT, pointer :: eigenval(:,:) !< obviously the eigenvalues
    logical        :: fixed_occ     !< should the occupation numbers be fixed?
    logical        :: restart_fixed_occ !< should the occupation numbers be fixed by restart?
    logical        :: restart_reorder_occs !< used for restart with altered occupation numbers
    FLOAT, pointer :: occ(:,:)      !< the occupation numbers
    logical        :: fixed_spins   !< In spinors mode, the spin direction is set
                                    !< for the initial (random) orbitals.
    FLOAT, pointer :: spin(:, :, :)

    FLOAT          :: qtot          !< (-) The total charge in the system (used in Fermi)
    FLOAT          :: val_charge    !< valence charge

    logical        :: fromScratch
    type(smear_t)  :: smear         ! smearing of the electronic occupations

    !> This is stuff needed for the parallelization in states.
    logical                     :: parallel_in_states !< Am I parallel in states?
    type(mpi_grp_t)             :: mpi_grp            !< The MPI group related to the parallelization in states.
    type(mpi_grp_t)             :: dom_st_mpi_grp     !< The MPI group related to the domains-states "plane".
    type(mpi_grp_t)             :: st_kpt_mpi_grp     !< The MPI group related to the states-kpoints "plane".
    type(mpi_grp_t)             :: dom_st_kpt_mpi_grp !< The MPI group related to the domains-states-kpoints "cube".
#ifdef HAVE_SCALAPACK
    type(blacs_proc_grid_t)     :: dom_st_proc_grid   !< The BLACS process grid for the domains-states plane
#endif
    logical                     :: scalapack_compatible !< Whether the states parallelization uses ScaLAPACK layout
    integer                     :: lnst               !< Number of states on local node.
    integer                     :: st_start, st_end   !< Range of states processed by local node.
    integer, pointer            :: node(:)            !< To which node belongs each state.
    !> Node r manages states st_range(1, r) to
    !! st_range(2, r) for r = 0, ..., mpi_grp%size-1,
    !! i. e. st_start = st_range(1, r) and
    !! st_end = st_range(2, r) on node r.
    integer, pointer            :: st_range(:, :)  
    !> Number of states on node r, i. e.
    !! st_num(r) = st_num(2, r)-st_num(1, r).
    integer, pointer            :: st_num(:)         
    type(multicomm_all_pairs_t) :: ap                 !< All-pairs schedule.

    logical                     :: symmetrize_density
    logical                     :: packed
  end type states_t

  interface states_get_state
    module procedure dstates_get_state1, zstates_get_state1, dstates_get_state2, zstates_get_state2
  end interface states_get_state

  interface states_set_state
    module procedure dstates_set_state1, zstates_set_state1, dstates_set_state2, zstates_set_state2
  end interface states_set_state

  interface states_get_points
    module procedure dstates_get_points1, zstates_get_points1, dstates_get_points2, zstates_get_points2 
  end interface states_get_points

contains

  ! ---------------------------------------------------------
  subroutine states_null(st)
    type(states_t), intent(inout) :: st

    integer :: il

    PUSH_SUB(states_null)

    call states_dim_null(st%d)
    call states_group_null(st%group)

    st%d%orth_method = 0
    call modelmb_particles_nullify(st%modelmbparticles)
    st%priv%wfs_type = TYPE_FLOAT ! By default, calculations use real wavefunctions

    st%cmplxscl%space = .false.
    !cmplxscl
    nullify(st%psi%dL, st%psi%dR)
    nullify(st%psi%zL, st%psi%zR)     
    nullify(st%zeigenval%Re, st%zeigenval%Im) 
    nullify(st%zrho%Re, st%zrho%Im)
    nullify(st%Imrho_core, st%Imfrozen_rho)
    nullify(st%psibL)

    nullify(st%dpsi, st%zpsi)
    
    nullify(st%zphi, st%ob_eigenval, st%ob_occ)
    st%open_boundaries = .false.
    call states_dim_null(st%ob_d)
    do il = 1, 2*MAX_DIM
      nullify(st%ob_lead(il)%intf_psi, st%ob_lead(il)%rho, st%ob_lead(il)%self_energy)
    end do

    nullify(st%user_def_states)
    nullify(st%rho, st%current)
    nullify(st%rho_core, st%frozen_rho)
    nullify(st%eigenval, st%occ, st%spin)

    st%parallel_in_states = .false.
#ifdef HAVE_SCALAPACK
    call blacs_proc_grid_nullify(st%dom_st_proc_grid)
#endif
    nullify(st%node,st%st_range, st%st_num)
    nullify(st%ap%schedule)

    st%packed = .false.

    POP_SUB(states_null)
  end subroutine states_null


  ! ---------------------------------------------------------
  subroutine states_init(st, gr, geo)
    type(states_t), target, intent(inout) :: st
    type(grid_t),           intent(in)    :: gr
    type(geometry_t),       intent(in)    :: geo

    FLOAT :: excess_charge
    integer :: nempty, ierr, il, ntot, default, nthreads
    integer, allocatable :: ob_k(:), ob_st(:), ob_d(:)
    type(restart_t) :: ob_restart
    character(len=256)   :: restart_dir

    PUSH_SUB(states_init)

    st%fromScratch = .true. ! this will be reset if restart_read is called
    call states_null(st)


    !%Variable SpinComponents
    !%Type integer
    !%Default unpolarized
    !%Section States
    !%Description
    !% The calculations may be done in three different ways: spin-restricted (TD)DFT (<i>i.e.</i>, doubly
    !% occupied "closed shells"), spin-unrestricted or "spin-polarized" (TD)DFT (<i>i.e.</i> we have two
    !% electronic systems, one with spin up and one with spin down), or making use of two-component
    !% spinors.
    !%Option unpolarized 1
    !% Spin-restricted calculations.
    !%Option polarized 2
    !%Option spin_polarized 2
    !% (Synonym <tt>polarized</tt>.) Spin-unrestricted, also known as spin-DFT, SDFT. This mode will double the number of
    !% wavefunctions necessary for a spin-unpolarized calculation.
    !%Option non_collinear 3
    !%Option spinors 3
    !% (Synonym: <tt>non_collinear</tt>.) The spin-orbitals are two-component spinors. This effectively allows the spin-density to
    !% be oriented non-collinearly: <i>i.e.</i> the magnetization vector is allowed to take different
    !% directions at different points. This vector is always in 3D regardless of <tt>Dimensions</tt>.
    !%End
    call parse_integer(datasets_check('SpinComponents'), UNPOLARIZED, st%d%ispin)
    if(.not.varinfo_valid_option('SpinComponents', st%d%ispin)) call input_error('SpinComponents')
    call messages_print_var_option(stdout, 'SpinComponents', st%d%ispin)
    ! Use of spinors requires complex wavefunctions.
    if (st%d%ispin == SPINORS) st%priv%wfs_type = TYPE_CMPLX


    !%Variable ExcessCharge
    !%Type float
    !%Default 0.0
    !%Section States
    !%Description
    !% The net charge of the system. A negative value means that we are adding
    !% electrons, while a positive value means we are taking electrons
    !% from the system.
    !%End
    call parse_float(datasets_check('ExcessCharge'), M_ZERO, excess_charge)


    !%Variable TotalStates
    !%Type integer
    !%Default 0
    !%Section States
    !%Description
    !% This variable sets the total number of states that Octopus will
    !% use. This is normally not necessary since by default Octopus
    !% sets the number of states to the minimum necessary to hold the
    !% electrons present in the system. (This default behavior is
    !% obtained by setting <tt>TotalStates</tt> to 0).
    !%
    !% If you want to add some unoccupied states, probably it is more convenient to use the variable
    !% <tt>ExtraStates</tt>.
    !%End
    call parse_integer(datasets_check('TotalStates'), 0, ntot)
    if (ntot < 0) then
      write(message(1), '(a,i5,a)') "Input: '", ntot, "' is not a valid value for TotalStates."
      call messages_fatal(1)
    end if

    !%Variable ExtraStates
    !%Type integer
    !%Default 0
    !%Section States
    !%Description
    !% The number of states is in principle calculated considering the minimum
    !% numbers of states necessary to hold the electrons present in the system.
    !% The number of electrons is
    !% in turn calculated considering the nature of the species supplied in the
    !% <tt>Species</tt> block, and the value of the <tt>ExcessCharge</tt> variable.
    !% However, one may command <tt>Octopus</tt> to use more states, which is necessary if one wants to
    !% use fractional occupational numbers, either fixed from the beginning through
    !% the <tt>Occupations</tt> block or by prescribing
    !% an electronic temperature with <tt>Smearing</tt>, or in order to calculate
    !% excited states (including with <tt>CalculationMode = unocc</tt>).
    !%End
    call parse_integer(datasets_check('ExtraStates'), 0, nempty)
    if (nempty < 0) then
      write(message(1), '(a,i5,a)') "Input: '", nempty, "' is not a valid value for ExtraStates."
      message(2) = '(0 <= ExtraStates)'
      call messages_fatal(2)
    end if

    if(ntot > 0 .and. nempty > 0) then
      message(1) = 'You cannot set TotalStates and ExtraStates at the same time.'
      call messages_fatal(1)
    end if

    ! For non-periodic systems this should just return the Gamma point
    call states_choose_kpoints(st%d, gr%sb)

    call geometry_val_charge(geo, st%val_charge)

    if(gr%ob_grid%open_boundaries) then
      ! renormalize charge of central region to match leads (open system, not finite)
      st%val_charge = st%val_charge * (gr%ob_grid%lead(LEFT)%sb%lsize(TRANS_DIR) / gr%sb%lsize(TRANS_DIR))
    end if

    st%qtot = -(st%val_charge + excess_charge)

    do il = 1, NLEADS
      nullify(st%ob_lead(il)%intf_psi)
    end do
    ! When doing open-boundary calculations the number of free states is
    ! determined by the previous periodic calculation.
    st%open_boundaries = gr%ob_grid%open_boundaries
    if(gr%ob_grid%open_boundaries) then
      SAFE_ALLOCATE( ob_k(1:NLEADS))
      SAFE_ALLOCATE(ob_st(1:NLEADS))
      SAFE_ALLOCATE( ob_d(1:NLEADS))
      do il = 1, NLEADS
        restart_dir = trim(gr%ob_grid%lead(il)%info%restart_dir)+"/"+GS_DIR
        call restart_init(ob_restart, RESTART_UNDEFINED, RESTART_TYPE_LOAD, mpi_world, ierr, dir=restart_dir)
        ! first get nst and kpoints of all states
        if(ierr == 0) call states_look(ob_restart, ob_k(il), ob_d(il), ob_st(il), ierr)
        if(ierr /= 0) then
          message(1) = 'Could not read the states information of the periodic calculation'
          message(2) = 'from '//trim(restart_dir)//'.'
          call messages_fatal(2)
        end if
        call restart_end(ob_restart)
      end do
      if(NLEADS > 1) then
        if(ob_k(LEFT) /= ob_k(RIGHT).or. &
          ob_st(LEFT) /= ob_st(LEFT).or. &
          ob_d(LEFT) /= ob_d(RIGHT)) then
          message(1) = 'The number of states for the left and right leads are not equal.'
          call messages_fatal(1)
        end if
      end if
      st%ob_d%dim = ob_d(LEFT)
      st%ob_nst   = ob_st(LEFT)
      st%ob_d%nik = ob_k(LEFT)
      st%d%nik = st%ob_d%nik
      SAFE_DEALLOCATE_A(ob_d)
      SAFE_DEALLOCATE_A(ob_st)
      SAFE_DEALLOCATE_A(ob_k)
      call distributed_nullify(st%ob_d%kpt, 0)
      if((st%d%ispin == UNPOLARIZED.and.st%ob_d%dim /= 1) .or.   &
        (st%d%ispin == SPIN_POLARIZED.and.st%ob_d%dim /= 1) .or. &
        (st%d%ispin == SPINORS.and.st%ob_d%dim /= 2)) then
        message(1) = 'The spin type of the leads calculation from '&
                     //gr%ob_grid%lead(LEFT)%info%restart_dir
        message(2) = 'and SpinComponents of the current run do not match.'
        call messages_fatal(2)
      end if
      SAFE_DEALLOCATE_P(st%d%kweights)
      SAFE_ALLOCATE(st%d%kweights(1:st%d%nik))
      st%d%kweights = M_ZERO
      st%d%kweights(1) = M_ONE
      SAFE_ALLOCATE(st%ob_d%kweights(1:st%ob_d%nik))
      SAFE_ALLOCATE(st%ob_eigenval(1:st%ob_nst, 1:st%ob_d%nik))
      SAFE_ALLOCATE(st%ob_occ(1:st%ob_nst, 1:st%ob_d%nik))
      st%ob_d%kweights = M_ZERO
      st%ob_eigenval   = huge(st%ob_eigenval)
      st%ob_occ        = M_ZERO
      call read_ob_eigenval_and_occ()
    else
      st%ob_nst   = 0
      st%ob_d%nik = 0
      st%ob_d%dim = 0
    end if

    select case(st%d%ispin)
    case(UNPOLARIZED)
      st%d%dim = 1
      st%nst = int(st%qtot/2)
      if(st%nst*2 < st%qtot) st%nst = st%nst + 1
      st%d%nspin = 1
      st%d%spin_channels = 1
    case(SPIN_POLARIZED)
      st%d%dim = 1
      st%nst = int(st%qtot/2)
      if(st%nst*2 < st%qtot) st%nst = st%nst + 1
      st%d%nspin = 2
      st%d%spin_channels = 2
    case(SPINORS)
      st%d%dim = 2
      st%nst = int(st%qtot)
      if(st%nst < st%qtot) st%nst = st%nst + 1
      st%d%nspin = 4
      st%d%spin_channels = 2
    end select
    
    if(ntot > 0) then
      if(ntot < st%nst) then
        message(1) = 'TotalStates is smaller than the number of states required by the system.'
        call messages_fatal(1)
      end if

      st%nst = ntot
    end if

    st%nst = st%nst + nempty
    if(st%nst == 0) then
      message(1) = "Cannot run with number of states = zero."
      call messages_fatal(1)
    endif

    !%Variable StatesBlockSize
    !%Type integer
    !%Section Execution::Optimization
    !%Description
    !% Some routines work over blocks of eigenfunctions, which
    !% generally improves performance at the expense of increased
    !% memory consumption. This variable selects the size of the
    !% blocks to be used. If OpenCl is enabled, the default is 32;
    !% otherwise it is max(4, 2*nthreads).
    !%End

    nthreads = 1
#ifdef HAVE_OPENMP
    !$omp parallel
    !$omp master
    nthreads = omp_get_num_threads()
    !$omp end master
    !$omp end parallel
#endif    

    if(opencl_is_enabled()) then
      default = 32
    else
      default = max(4, 2*nthreads)
    end if

    if(default > pad_pow2(st%nst)) default = pad_pow2(st%nst)

    ASSERT(default > 0)

    call parse_integer(datasets_check('StatesBlockSize'), default, st%d%block_size)
    if(st%d%block_size < 1) then
      call messages_write("The variable 'StatesBlockSize' must be greater than 0.")
      call messages_fatal()
    end if

    st%d%block_size = min(st%d%block_size, st%nst)
    conf%target_states_block_size = st%d%block_size

    ! FIXME: For now, open-boundary calculations are only possible for
    ! continuum states, i.e. for those states treated by the Lippmann-
    ! Schwinger approach during SCF.
    ! Bound states should be done with extra states, without k-points.
    if(gr%ob_grid%open_boundaries) then
      if(st%nst /= st%ob_nst .or. st%d%nik /= st%ob_d%nik) then
        message(1) = 'Open-boundary calculations for possibly bound states'
        message(2) = 'are not possible yet. You have to match your number'
        message(3) = 'of states to the number of free states of your previous'
        message(4) = 'periodic run.'
        write(message(5), '(a,i5,a)') 'Your central region contributes ', st%nst, ' states,'
        write(message(6), '(a,i5,a)') 'while your lead calculation had ', st%ob_nst, ' states.'
        write(message(7), '(a,i5,a)') 'Your central region contributes ', st%d%nik, ' k-points,'
        write(message(8), '(a,i5,a)') 'while your lead calculation had ', st%ob_d%nik, ' k-points.'
        call messages_fatal(8)
      end if
    end if

    st%d%cdft = .false. ! CDFT was removed

    !cmplxscl
    call cmplxscl_init(st%cmplxscl)

    st%have_left_states = .false.
    
    if (st%cmplxscl%space) then
      !Even for gs calculations it requires complex wavefunctions
      st%priv%wfs_type = TYPE_CMPLX
      !Allocate imaginary parts of the eigenvalues
      SAFE_ALLOCATE(st%zeigenval%Im(1:st%nst, 1:st%d%nik))
      st%zeigenval%Im = M_ZERO
    end if
    SAFE_ALLOCATE(st%zeigenval%Re(1:st%nst, 1:st%d%nik))
    st%zeigenval%Re = huge(st%zeigenval%Re)
    st%eigenval => st%zeigenval%Re(1:st%nst, 1:st%d%nik) 


    ! Periodic systems require complex wavefunctions
    ! but not if it is Gamma-point only
    if(simul_box_is_periodic(gr%sb)) then
      if(.not. (kpoints_number(gr%sb%kpoints) == 1 .and. kpoints_point_is_gamma(gr%sb%kpoints, 1))) then
        st%priv%wfs_type = TYPE_CMPLX
      endif
    endif

    ! Calculations with open boundaries require complex wavefunctions.
    if(gr%ob_grid%open_boundaries) st%priv%wfs_type = TYPE_CMPLX

    !%Variable OnlyUserDefinedInitialStates
    !%Type logical
    !%Default no
    !%Section States
    !%Description
    !% If true, then only user-defined states from the block <tt>UserDefinedStates</tt>
    !% will be used as initial states for a time-propagation. No attempt is made
    !% to load ground-state orbitals from a previous ground-state run.
    !%End
    call parse_logical(datasets_check('OnlyUserDefinedInitialStates'), .false., st%only_userdef_istates)

    ! we now allocate some arrays
    SAFE_ALLOCATE(st%occ     (1:st%nst, 1:st%d%nik))
    st%occ      = M_ZERO
    ! allocate space for formula strings that define user-defined states
    SAFE_ALLOCATE(st%user_def_states(1:st%d%dim, 1:st%nst, 1:st%d%nik))
    if(st%d%ispin == SPINORS) then
      SAFE_ALLOCATE(st%spin(1:3, 1:st%nst, 1:st%d%nik))
    else
      nullify(st%spin)
    end if

    ! initially we mark all 'formulas' as undefined
    st%user_def_states(1:st%d%dim, 1:st%nst, 1:st%d%nik) = 'undefined'

    call states_read_initial_occs(st, excess_charge, gr%sb%kpoints)
    call states_read_initial_spins(st)

    nullify(st%zphi)

    st%st_start = 1
    st%st_end = st%nst
    st%lnst = st%nst
    SAFE_ALLOCATE(st%node(1:st%nst))
    st%node(1:st%nst) = 0

    call mpi_grp_init(st%mpi_grp, -1)
    st%parallel_in_states = .false.

    nullify(st%dpsi, st%zpsi)

    call distributed_nullify(st%d%kpt, st%d%nik)

    call modelmb_particles_init (st%modelmbparticles,gr)

    !%Variable SymmetrizeDensity
    !%Type logical
    !%Default no
    !%Section States
    !%Description
    !% When enabled the density is symmetrized. Currently, this can
    !% only be done for periodic systems. (Experimental.)
    !%End
    call parse_logical(datasets_check('SymmetrizeDensity'), .false., st%symmetrize_density)
    call messages_print_var_value(stdout, 'SymmetrizeDensity', st%symmetrize_density)

    ! Why? Resulting discrepancies can be suspiciously large even at SCF convergence;
    ! the case of partially periodic systems has not been fully considered.
    if(st%symmetrize_density) call messages_experimental('SymmetrizeDensity')

#ifdef HAVE_SCALAPACK
    call blacs_proc_grid_nullify(st%dom_st_proc_grid)
#endif

    st%packed = .false.

    POP_SUB(states_init)

  contains

    subroutine read_ob_eigenval_and_occ()
      integer            :: occs, iline, ist, ik, idim, idir, err
      FLOAT              :: flt, eigenval, imeigenval, occ, kweights
      character          :: char
      character(len=256) :: chars
      character(len=256), allocatable :: lines(:)

      PUSH_SUB(states_init.read_ob_eigenval_and_occ)

      restart_dir = trim(gr%ob_grid%lead(LEFT)%info%restart_dir)+"/"+GS_DIR
      call restart_init(ob_restart, RESTART_UNDEFINED, RESTART_TYPE_LOAD, mpi_world, err, dir=restart_dir)
      if(err /= 0) then
        message(1) = 'Could not read open-boundaries eigenvalues and occupations.'
        call messages_fatal(1)
      endif

      occs = restart_open(ob_restart, 'occs')
      if(occs  <  0) then
        message(1) = 'Could not read "occs" from left lead.'
        call messages_fatal(1)
      end if

      SAFE_ALLOCATE(lines(3 + st%ob_nst*st%ob_d%nik))

      call iopar_read(mpi_world, occs, lines, 3 + st%ob_nst*st%ob_d%nik, err)

      do iline = 3, 2 + st%ob_nst*st%ob_d%nik

        ! Extract eigenvalue.
        ! # occupations | eigenvalue[a.u.] | k-points | k-weights | filename | ik | ist | idim
        read(lines(iline), *) occ, char, eigenval, char, imeigenval, char, (flt, char, idir = 1, gr%sb%dim), kweights, &
           char, chars, char, ik, char, ist, char, idim

        if(st%d%ispin  ==  SPIN_POLARIZED) then
          call messages_not_implemented('Spin-Transport')

          if(is_spin_up(ik)) then
            !FIXME
            !              st%ob_eigenval(jst, SPIN_UP) = eigenval
            !              st%ob_occ(jst, SPIN_UP)      = occ
          else
            !              st%ob_eigenval(jst, SPIN_DOWN) = eigenval
            !              st%ob_occ(jst, SPIN_DOWN)      = occ
          end if
        else
          st%ob_eigenval(ist, ik) = eigenval
          st%ob_occ(ist, ik)      = occ
          st%ob_d%kweights(ik)    = kweights
        end if
      end do

      SAFE_DEALLOCATE_A(lines)

      call restart_close(ob_restart, occs)

      call restart_end(ob_restart)

      POP_SUB(states_init.read_ob_eigenval_and_occ)
    end subroutine read_ob_eigenval_and_occ
  end subroutine states_init

  ! ---------------------------------------------------------
  !> Reads the 'states' file in the restart directory, and finds out
  !! the nik, dim, and nst contained in it.
  ! ---------------------------------------------------------
  subroutine states_look(restart, nik, dim, nst, ierr)
    type(restart_t), intent(inout) :: restart
    integer,         intent(out)   :: nik
    integer,         intent(out)   :: dim
    integer,         intent(out)   :: nst
    integer,         intent(out)   :: ierr

    character(len=256) :: lines(3)
    character(len=20)   :: char
    integer :: iunit

    PUSH_SUB(states_look)

    ierr = 0

    iunit = restart_open(restart, 'states')
    call restart_read(restart, iunit, lines, 3, ierr)
    if (ierr == 0) then
      read(lines(1), *) char, nst
      read(lines(2), *) char, dim
      read(lines(3), *) char, nik
    end if
    call restart_close(restart, iunit)

    POP_SUB(states_look)
  end subroutine states_look

  ! ---------------------------------------------------------
  !> Allocate the lead densities.
  subroutine states_lead_densities_init(st, gr)
    type(states_t), intent(inout) :: st
    type(grid_t),   intent(in)    :: gr

    integer :: il

    PUSH_SUB(states_lead_densities_init)

    if(gr%ob_grid%open_boundaries) then
      do il = 1, NLEADS
        SAFE_ALLOCATE(st%ob_lead(il)%rho(1:gr%ob_grid%lead(il)%mesh%np, 1:st%d%nspin))
        st%ob_lead(il)%rho(:, :) = M_ZERO
      end do
    end if

    POP_SUB(states_lead_densities_init)
  end subroutine states_lead_densities_init


  ! ---------------------------------------------------------
  !> Deallocate the lead density.
  subroutine states_lead_densities_end(st, gr)
    type(states_t), intent(inout) :: st
    type(grid_t),   intent(in)    :: gr

    integer :: il

    PUSH_SUB(states_lead_densities_end)

    if(gr%ob_grid%open_boundaries) then
      do il = 1, NLEADS
        SAFE_DEALLOCATE_P(st%ob_lead(il)%rho)
      end do
    end if

    POP_SUB(states_lead_densities_end)
  end subroutine states_lead_densities_end


  ! ---------------------------------------------------------
  !> Reads from the input file the initial occupations, if the
  !! block "Occupations" is present. Otherwise, it makes an initial
  !! guess for the occupations, maybe using the "Smearing"
  !! variable.
  !!
  !! The resulting occupations are placed on the st\%occ variable. The
  !! boolean st\%fixed_occ is also set to .true., if the occupations are
  !! set by the user through the "Occupations" block; false otherwise.
  subroutine states_read_initial_occs(st, excess_charge, kpoints)
    type(states_t),  intent(inout) :: st
    FLOAT,           intent(in)    :: excess_charge
    type(kpoints_t), intent(in)    :: kpoints

    integer :: ik, ist, ispin, nspin, ncols, nrows, el_per_state, icol, start_pos
    type(block_t) :: blk
    FLOAT :: rr, charge
    logical :: integral_occs
    FLOAT, allocatable :: read_occs(:, :)
    FLOAT :: charge_in_block

    PUSH_SUB(states_read_initial_occs)

    !%Variable RestartFixedOccupations
    !%Type logical
    !%Default no
    !%Section States
    !%Description
    !% Setting this variable will make the restart proceed as
    !% if the occupations from the previous calculation had been set via the <tt>Occupations</tt> block,
    !% <i>i.e.</i> fixed. Otherwise, occupations will be determined by smearing.
    !%End
    call parse_logical(datasets_check('RestartFixedOccupations'), .false., st%restart_fixed_occ)
    ! we will turn on st%fixed_occ if restart_read is ever called

    !%Variable Occupations
    !%Type block
    !%Section States
    !%Description
    !% The occupation numbers of the orbitals can be fixed through the use of this
    !% variable. For example:
    !%
    !% <tt>%Occupations
    !% <br>&nbsp;&nbsp;2.0 | 2.0 | 2.0 | 2.0 | 2.0
    !% <br>%</tt>
    !%
    !% would fix the occupations of the five states to <i>2.0</i>. There can be
    !% at most as many columns as states in the calculation. If there are fewer columns
    !% than states, then the code will assume that the user is indicating the occupations
    !% of the uppermost states, assigning maximum occupation (i.e. 2 for spin-unpolarized
    !% calculations, 1 otherwise) to the lower states. The number of rows should be equal
    !% to the number of k-points times the number of spins. For example, for a finite system
    !% with <tt>SpinComponents == spin_polarized</tt>,
    !% this block should contain two lines, one for each spin channel.
    !% All rows must have the same number of columns.
    !% This variable is very useful when dealing with highly symmetric small systems
    !% (like an open-shell atom), for it allows us to fix the occupation numbers
    !% of degenerate states in order to help <tt>octopus</tt> to converge. This is to
    !% be used in conjuction with <tt>ExtraStates</tt>. For example, to calculate the
    !% carbon atom, one would do:
    !%
    !% <tt>ExtraStates = 2
    !% <br>%Occupations
    !% <br>&nbsp;&nbsp;2 | 2/3 | 2/3 | 2/3
    !% <br>%</tt>
    !%
    !% If you want the calculation to be spin-polarized (which makes more sense), you could do:
    !%
    !% <tt>ExtraStates = 2
    !% <br>%Occupations
    !% <br>&nbsp;&nbsp; 2/3 | 2/3 | 2/3
    !% <br>&nbsp;&nbsp; 0   |   0 |   0
    !% <br>%</tt>
    !%
    !% Note that in this case the first state is absent, the code will calculate four states
    !% (two because there are four electrons, plus two because <tt>ExtraStates</tt> = 2), and since
    !% it finds only three columns, it will occupy the first state with one electron for each
    !% of the spin options.
    !%
    !% If the sum of occupations is not equal to the total charge set by <tt>ExcessCharge</tt>,
    !% an error message is printed.
    !% If <tt>FromScratch = no</tt> and <tt>RestartFixedOccupations = yes</tt>,
    !% this block will be ignored.
    !%End

    integral_occs = .true.

    if(st%open_boundaries) then
      st%fixed_occ = .true.
      st%occ  = st%ob_occ
      st%d%kweights = st%ob_d%kweights
      st%qtot = M_ZERO
      do ist = 1, st%nst
        st%qtot = st%qtot + sum(st%occ(ist, 1:st%d%nik) * st%d%kweights(1:st%d%nik))
      end do

    else
      occ_fix: if(parse_block(datasets_check('Occupations'), blk)==0) then
        ! read in occupations
        st%fixed_occ = .true.

        ncols = parse_block_cols(blk, 0)
        if(ncols > st%nst) then
          message(1) = "Too many columns in block Occupations."
          call messages_warning(1)
          call input_error("Occupations")
        end if

        nrows = parse_block_n(blk)
        if(nrows /= st%d%nik) then
          message(1) = "Wrong number of rows in block Occupations."
          call messages_warning(1)
          call input_error("Occupations")
        end if

        do ik = 1, st%d%nik - 1
          if(parse_block_cols(blk, ik) /= ncols) then
            message(1) = "All rows in block Occupations must have the same number of columns."
            call messages_warning(1)
            call input_error("Occupations")
          endif
        enddo

        ! Now we fill all the "missing" states with the maximum occupation.
        if(st%d%ispin == UNPOLARIZED) then
          el_per_state = 2
        else
          el_per_state = 1
        endif
     
        SAFE_ALLOCATE(read_occs(1:ncols, 1:st%d%nik))
 
        do ik = 1, st%d%nik
          do icol = 1, ncols
            call parse_block_float(blk, ik - 1, icol - 1, read_occs(icol, ik))
          end do
        end do

        charge_in_block = sum(read_occs)

        start_pos = int((st%qtot - charge_in_block)/(el_per_state*st%d%nik))

        if(start_pos + ncols > st%nst) then
          message(1) = "To balance charge, the first column in block Occupations is taken to refer to state"
          write(message(2),'(a,i6,a)') "number ", start_pos, " but there are too many columns for the number of states."
          write(message(3),'(a,i6,a)') "Solution: set ExtraStates = ", start_pos + ncols - st%nst
          call messages_fatal(3)
        end if

        do ik = 1, st%d%nik
          do ist = 1, start_pos
            st%occ(ist, ik) = el_per_state
          end do
        end do

        do ik = 1, st%d%nik
          do ist = start_pos + 1, start_pos + ncols
            st%occ(ist, ik) = read_occs(ist - start_pos, ik)
            integral_occs = integral_occs .and. &
              abs((st%occ(ist, ik) - el_per_state) * st%occ(ist, ik))  <=  M_EPSILON
          end do
        end do

        do ik = 1, st%d%nik
          do ist = start_pos + ncols + 1, st%nst
             st%occ(ist, ik) = M_ZERO
          end do
        end do
        
        call parse_block_end(blk)

        SAFE_DEALLOCATE_A(read_occs)

      else
        st%fixed_occ = .false.
        integral_occs = .false.

        ! first guess for occupation...paramagnetic configuration
        rr = M_ONE
        if(st%d%ispin == UNPOLARIZED) rr = M_TWO

        st%occ  = M_ZERO
        st%qtot = -(st%val_charge + excess_charge)

        nspin = 1
        if(st%d%nspin == 2) nspin = 2

        do ik = 1, st%d%nik, nspin
          charge = M_ZERO
          do ispin = ik, ik + nspin - 1
            do ist = 1, st%nst
              st%occ(ist, ispin) = min(rr, -(st%val_charge + excess_charge) - charge)
              charge = charge + st%occ(ist, ispin)
            end do
          end do
        end do

      end if occ_fix
    end if

    !%Variable RestartReorderOccs
    !%Type logical
    !%Default no
    !%Section States
    !%Description
    !% Consider doing a ground-state calculation, and then restarting with new occupations set
    !% with the <tt>Occupations</tt> block, in an attempt to populate the orbitals of the original
    !% calculation. However, the eigenvalues may reorder as the density changes, in which case the
    !% occupations will now be referring to different orbitals. Setting this variable to yes will
    !% try to solve this issue when the restart data is being read, by reordering the occupations
    !% according to the order of the expectation values of the restart wavefunctions.
    !%End
    if(st%fixed_occ) then
      call parse_logical(datasets_check('RestartReorderOccs'), .false., st%restart_reorder_occs)
    else
      st%restart_reorder_occs = .false.
    endif

    call smear_init(st%smear, st%d%ispin, st%fixed_occ, integral_occs, kpoints)

    if(.not. smear_is_semiconducting(st%smear) .and. .not. st%smear%method == SMEAR_FIXED_OCC) then
      if((st%d%ispin /= SPINORS .and. st%nst * 2  <=  st%qtot) .or. &
         (st%d%ispin == SPINORS .and. st%nst  <=  st%qtot)) then
        call messages_write('Smearing needs unoccupied states (via ExtraStates) to be useful.')
        call messages_warning()
      endif
    endif

    ! sanity check
    charge = M_ZERO
    do ist = 1, st%nst
      charge = charge + sum(st%occ(ist, 1:st%d%nik) * st%d%kweights(1:st%d%nik))
    end do
    if(abs(charge - st%qtot) > CNST(1e-6)) then
      message(1) = "Initial occupations do not integrate to total charge."
      write(message(2), '(6x,f12.6,a,f12.6)') charge, ' != ', st%qtot
      call messages_fatal(2, only_root_writes = .true.)
    end if

    POP_SUB(states_read_initial_occs)
  end subroutine states_read_initial_occs


  ! ---------------------------------------------------------
  !> Reads, if present, the "InitialSpins" block. This is only
  !! done in spinors mode; otherwise the routine does nothing. The
  !! resulting spins are placed onto the st\%spin pointer. The boolean
  !! st\%fixed_spins is set to true if (and only if) the InitialSpins
  !! block is present.
  subroutine states_read_initial_spins(st)
    type(states_t), intent(inout) :: st

    integer :: i, j
    type(block_t) :: blk

    PUSH_SUB(states_read_initial_spins)

    st%fixed_spins = .false.
    if(st%d%ispin /= SPINORS) then
      POP_SUB(states_read_initial_spins)
      return
    end if

    !%Variable InitialSpins
    !%Type block
    !%Section States
    !%Description
    !% The spin character of the initial random guesses for the spinors can
    !% be fixed by making use of this block. Note that this will not "fix" the
    !% the spins during the calculation (this cannot be done in spinors mode, in
    !% being able to change the spins is why the spinors mode exists in the first
    !% place).
    !%
    !% This block is meaningless and ignored if the run is not in spinors mode
    !% (<tt>SpinComponents = spinors</tt>).
    !%
    !% The structure of the block is very simple: each column contains the desired
    !% <math>\left< S_x \right>, \left< S_y \right>, \left< S_z \right> </math> for each spinor.
    !% If the calculation is for a periodic system
    !% and there is more than one <i>k</i>-point, the spins of all the <i>k</i>-points are
    !% the same.
    !%
    !% For example, if we have two spinors, and we want one in the <math>S_x</math> "down" state,
    !% and another one in the <math>S_x</math> "up" state:
    !%
    !% <tt>%InitialSpins
    !% <br>&nbsp;&nbsp;&nbsp; 0.5 | 0.0 | 0.0
    !% <br>&nbsp;&nbsp; -0.5 | 0.0 | 0.0
    !% <br>%</tt>
    !%
    !% WARNING: if the calculation is for a system described by pseudopotentials (as
    !% opposed to user-defined potentials or model systems), this option is
    !% meaningless since the random spinors are overwritten by the atomic orbitals.
    !%
    !% There are a couple of physical constraints that have to be fulfilled:
    !% <br>(A) <math> \left| \left< S_i \right> \right| \le \frac{1}{2} </math>
    !% <br>(B) <math> \left< S_x \right>^2 + \left< S_y \right>^2 + \left< S_z \right>^2 = \frac{1}{4} </math>
    !%End
    spin_fix: if(parse_block(datasets_check('InitialSpins'), blk)==0) then
      do i = 1, st%nst
        do j = 1, 3
          call parse_block_float(blk, i-1, j-1, st%spin(j, i, 1))
        end do
        ! This checks (B).
        if( abs(sum(st%spin(1:3, i, 1)**2) - M_FOURTH) > CNST(1.0e-6)) call input_error('InitialSpins')
      end do
      call parse_block_end(blk)
      ! This checks (A). In fact (A) follows from (B), so maybe this is not necessary...
      if(any(abs(st%spin(:, :, :)) > M_HALF)) then
        call input_error('InitialSpins')
      end if
      st%fixed_spins = .true.
      do i = 2, st%d%nik
        st%spin(:, :, i) = st%spin(:, :, 1)
      end do
    end if spin_fix

    POP_SUB(states_read_initial_spins)
  end subroutine states_read_initial_spins


  ! ---------------------------------------------------------
  !> Allocates the KS wavefunctions defined within a states_t structure.
  subroutine states_allocate_wfns(st, mesh, wfs_type, alloc_zphi, alloc_Left)
    type(states_t),         intent(inout)   :: st
    type(mesh_t),           intent(in)      :: mesh
    type(type_t), optional, intent(in)      :: wfs_type
    logical,      optional, intent(in)      :: alloc_zphi !< only needed for gs transport
    logical,      optional, intent(in)      :: alloc_Left !< allocate an addtional set of wfs to store left eigenstates

    integer :: ip, ik, ist, idim, st1, st2, k1, k2, np_part
    logical :: force

    PUSH_SUB(states_allocate_wfns)

    if(associated(st%dpsi).or.associated(st%zpsi)) then
      call messages_write('Trying to allocate wavefunctions that are already allocated.')
      call messages_fatal()
    end if

    if (present(wfs_type)) then
      ASSERT(wfs_type == TYPE_FLOAT .or. wfs_type == TYPE_CMPLX)
      st%priv%wfs_type = wfs_type
    end if

    st%have_left_states = optional_default(alloc_Left, .false.) .and. st%cmplxscl%space
    if(st%have_left_states) then
      ASSERT(st%priv%wfs_type == TYPE_CMPLX) 
    end if

    !%Variable ForceComplex
    !%Type logical
    !%Default no
    !%Section Execution::Debug
    !%Description
    !% Normally <tt>Octopus</tt> determines automatically the type necessary
    !% for the wavefunctions. When set to yes this variable will
    !% force the use of complex wavefunctions.
    !%
    !% Warning: This variable is designed for testing and
    !% benchmarking and normal users need not use it.
    !%
    !%End
    call parse_logical(datasets_check('ForceComplex'), .false., force)

    if(force) call states_set_complex(st)

    st1 = st%st_start
    st2 = st%st_end
    k1 = st%d%kpt%start
    k2 = st%d%kpt%end
    np_part = mesh%np_part

    if(.not. st%d%pack_states) then

      if (states_are_real(st)) then
        SAFE_ALLOCATE(st%dpsi(1:np_part, 1:st%d%dim, st1:st2, k1:k2))
      else        
        SAFE_ALLOCATE(st%psi%zR(1:np_part, 1:st%d%dim, st1:st2, k1:k2))  
        st%zpsi => st%psi%zR
        if(st%cmplxscl%space) then 
          if (st%have_left_states) then
            SAFE_ALLOCATE(st%psi%zL(1:np_part, 1:st%d%dim, st1:st2, k1:k2))  
          else
            st%psi%zL => st%psi%zR  
          end if          
        end if
      end if
      
      if(optional_default(alloc_zphi, .false.)) then
        SAFE_ALLOCATE(st%zphi(1:np_part, 1:st%ob_d%dim, st1:st2, k1:k2))
        forall(ik=k1:k2, ist=st1:st2, idim=1:st%d%dim, ip=1:np_part)
          st%zphi(ip, idim, ist, ik) = M_Z0
        end forall
      else
        nullify(st%zphi)
      end if

    end if

    call states_init_block(st, mesh)
    call states_set_zero(st)

    POP_SUB(states_allocate_wfns)
  end subroutine states_allocate_wfns

  ! ---------------------------------------------------------
  !> Allocates the interface wavefunctions defined within a states_t structure.
  subroutine states_allocate_intf_wfns(st, ob_mesh)
    type(states_t),         intent(inout)   :: st
    type(mesh_t),           intent(in)      :: ob_mesh(:)

    integer :: st1, st2, k1, k2, il

    PUSH_SUB(states_allocate_intf_wfns)

    ASSERT(st%open_boundaries)

    st1 = st%st_start
    st2 = st%st_end
    k1 = st%d%kpt%start
    k2 = st%d%kpt%end

    do il = 1, NLEADS
      ASSERT(.not.associated(st%ob_lead(il)%intf_psi))
      SAFE_ALLOCATE(st%ob_lead(il)%intf_psi(1:ob_mesh(il)%np, 1:st%d%dim, st1:st2, k1:k2))
      st%ob_lead(il)%intf_psi = M_z0
    end do

    ! TODO: write states_init_block for intf_psi
!    call states_init_block(st)

    POP_SUB(states_allocate_intf_wfns)
  end subroutine states_allocate_intf_wfns
  ! -----------------------------------------------------


  !---------------------------------------------------------------------
  !> Initializes the data components in st that describe how the states
  !! are distributed in blocks:
  !!
  !! st\%nblocks: this is the number of blocks in which the states are divided. Note that
  !!   this number is the total number of blocks, regardless of how many are actually stored
  !!   in each node.
  !! block_start: in each node, the index of the first block.
  !! block_end: in each node, the index of the last block.
  !!   If the states are not parallelized, then block_start is 1 and block_end is st\%nblocks.
  !! st\%iblock(1:st\%nst, 1:st\%d\%nik): it points, for each state, to the block that contains it.
  !! st\%block_is_local(): st\%block_is_local(ib) is .true. if block ib is stored in the running node.
  !! st\%block_range(1:st\%nblocks, 1:2): Block ib contains states fromn st\%block_range(ib, 1) to st\%block_range(ib, 2)
  !! st\%block_size(1:st\%nblocks): Block ib contains a number st\%block_size(ib) of states.
  !! st\%block_initialized: it should be .false. on entry, and .true. after exiting this routine.
  !!
  !! The set of batches st\%psib(1:st\%nblocks) contains the blocks themselves.
  subroutine states_init_block(st, mesh, verbose)
    type(states_t),           intent(inout) :: st
    type(mesh_t),   optional, intent(in)    :: mesh
    logical, optional,        intent(in)    :: verbose

    integer :: ib, iqn, ist
    logical :: same_node, verbose_
    integer, allocatable :: bstart(:), bend(:)

    PUSH_SUB(states_init_block)

    SAFE_ALLOCATE(bstart(1:st%nst))
    SAFE_ALLOCATE(bend(1:st%nst))
    SAFE_ALLOCATE(st%group%iblock(1:st%nst, 1:st%d%nik))
    st%group%iblock = 0

    verbose_ = optional_default(verbose, .true.)

    ! count and assign blocks
    ib = 0
    st%group%nblocks = 0
    bstart(1) = 1
    do ist = 1, st%nst
      INCR(ib, 1)

      st%group%iblock(ist, st%d%kpt%start:st%d%kpt%end) = st%group%nblocks + 1

      same_node = .true.
      if(st%parallel_in_states .and. ist /= st%nst) then
        ! We have to avoid that states that are in different nodes end
        ! up in the same block
        same_node = (st%node(ist + 1) == st%node(ist))
      end if

      if(ib == st%d%block_size .or. ist == st%nst .or. .not. same_node) then
        ib = 0
        INCR(st%group%nblocks, 1)
        bend(st%group%nblocks) = ist
        if(ist /= st%nst) bstart(st%group%nblocks + 1) = ist + 1
      end if
    end do

    SAFE_ALLOCATE(st%group%psib(1:st%group%nblocks, 1:st%d%nik))
    if(st%have_left_states) then
      SAFE_ALLOCATE(st%psibL(1:st%group%nblocks, 1:st%d%nik))
    end if
    SAFE_ALLOCATE(st%group%block_is_local(1:st%group%nblocks, 1:st%d%nik))
    st%group%block_is_local = .false.
    st%group%block_start  = -1
    st%group%block_end    = -2  ! this will make that loops block_start:block_end do not run if not initialized

    do ib = 1, st%group%nblocks
      if(bstart(ib) >= st%st_start .and. bend(ib) <= st%st_end) then
        if(st%group%block_start == -1) st%group%block_start = ib
        st%group%block_end = ib
        do iqn = st%d%kpt%start, st%d%kpt%end
          st%group%block_is_local(ib, iqn) = .true.

          if (states_are_real(st)) then
            if(associated(st%dpsi)) then
              call batch_init(st%group%psib(ib, iqn), st%d%dim, bstart(ib), bend(ib), st%dpsi(:, :, bstart(ib):bend(ib), iqn))
            else
              ASSERT(present(mesh))
              call batch_init(st%group%psib(ib, iqn), st%d%dim, bend(ib) - bstart(ib) + 1)
              call dbatch_new(st%group%psib(ib, iqn), bstart(ib), bend(ib), mesh%np_part)
            end if
          else
            if(associated(st%zpsi)) then
              call batch_init(st%group%psib(ib, iqn), st%d%dim, bstart(ib), bend(ib), st%zpsi(:, :, bstart(ib):bend(ib), iqn))
            else
              ASSERT(present(mesh))
              call batch_init(st%group%psib(ib, iqn), st%d%dim, bend(ib) - bstart(ib) + 1)
              call zbatch_new(st%group%psib(ib, iqn), bstart(ib), bend(ib), mesh%np_part)
            end if
            if(st%have_left_states) then !cmplxscl
              if(associated(st%psi%zL)) then
                call batch_init(st%psibL(ib, iqn), st%d%dim, bstart(ib), bend(ib), st%psi%zL(:, :, bstart(ib):bend(ib), iqn))
              else
                ASSERT(present(mesh))
                call batch_init(st%psibL(ib, iqn), st%d%dim, bend(ib) - bstart(ib) + 1)
                call zbatch_new(st%psibL(ib, iqn), bstart(ib), bend(ib), mesh%np_part)
              end if 
            else
              st%psibL => st%group%psib                           
            end if            
          end if
          
        end do
      end if
    end do

    SAFE_ALLOCATE(st%group%block_range(1:st%group%nblocks, 1:2))
    SAFE_ALLOCATE(st%group%block_size(1:st%group%nblocks))

    st%group%block_range(1:st%group%nblocks, 1) = bstart(1:st%group%nblocks)
    st%group%block_range(1:st%group%nblocks, 2) = bend(1:st%group%nblocks)
    st%group%block_size(1:st%group%nblocks) = bend(1:st%group%nblocks) - bstart(1:st%group%nblocks) + 1

    st%group%block_initialized = .true.

    if(verbose_) then
      call messages_write('Info: Blocks of states')
      call messages_info()
      do ib = 1, st%group%nblocks
        call messages_write('      Block ')
        call messages_write(ib, fmt = 'i8')
        call messages_write(' contains ')
        call messages_write(st%group%block_size(ib), fmt = 'i8')
        call messages_write(' states')
        if(st%group%block_size(ib) > 0) then
          call messages_write(':')
          call messages_write(st%group%block_range(ib, 1), fmt = 'i8')
          call messages_write(' - ')
          call messages_write(st%group%block_range(ib, 2), fmt = 'i8')
        endif
        call messages_info()
      end do
    end if
    
!     !cmplxscl
!     if(st%have_left_states) then
!       do ib = 1, st%nblocks
!         do iqn = st%d%kpt%start, st%d%kpt%end
!           call batch_copy(st%group%psib(ib,iqn), st%psibL(ib,iqn), reference = .false.)
!         end do
!       end do
!     else
!       st%psibL => st%group%psib
!     end if

!!$!!!!DEBUG
!!$    ! some debug output that I will keep here for the moment
!!$    if(mpi_grp_is_root(mpi_world)) then
!!$      print*, "NST       ", st%nst
!!$      print*, "BLOCKSIZE ", st%d%block_size
!!$      print*, "NBLOCKS   ", st%group%nblocks
!!$
!!$      print*, "==============="
!!$      do ist = 1, st%nst
!!$        print*, st%node(ist), ist, st%group%iblock(ist, 1)
!!$      end do
!!$      print*, "==============="
!!$
!!$      do ib = 1, st%group%nblocks
!!$        print*, ib, bstart(ib), bend(ib)
!!$      end do
!!$
!!$    end if
!!$!!!!ENDOFDEBUG

    SAFE_DEALLOCATE_A(bstart)
    SAFE_DEALLOCATE_A(bend)
    POP_SUB(states_init_block)
  end subroutine states_init_block


  ! ---------------------------------------------------------
  !> Deallocates the KS wavefunctions defined within a states_t structure.
  subroutine states_deallocate_wfns(st)
    type(states_t), intent(inout) :: st

    integer :: il, ib, iq

    PUSH_SUB(states_deallocate_wfns)

    if (st%group%block_initialized) then
       do ib = 1, st%group%nblocks
          do iq = st%d%kpt%start, st%d%kpt%end
            if(st%group%block_is_local(ib, iq)) then
              call batch_end(st%group%psib(ib, iq))
              if(st%have_left_states) call batch_end(st%psibL(ib, iq)) !cmplxscl
            end if
          end do
       end do

       SAFE_DEALLOCATE_P(st%group%psib)

       if(st%have_left_states) then !cmplxscl
         SAFE_DEALLOCATE_P(st%psibL)
       else  
         nullify(st%psibL)
       end if      
       SAFE_DEALLOCATE_P(st%group%iblock)
       SAFE_DEALLOCATE_P(st%group%block_range)
       SAFE_DEALLOCATE_P(st%group%block_size)
       SAFE_DEALLOCATE_P(st%group%block_is_local)
       st%group%block_initialized = .false.
    end if

    if (states_are_real(st)) then
      SAFE_DEALLOCATE_P(st%dpsi)
    else
      nullify(st%zpsi)
      if(associated(st%psi%zL,target=st%psi%zR )) then
        nullify(st%psi%zL)
      else          
        SAFE_DEALLOCATE_P(st%psi%zL) ! cmplxscl
      end if
      SAFE_DEALLOCATE_P(st%psi%zR) ! cmplxscl      
    end if

    if(st%open_boundaries) then
      do il = 1, NLEADS
        SAFE_DEALLOCATE_P(st%ob_lead(il)%intf_psi)
      end do
    end if

    POP_SUB(states_deallocate_wfns)
  end subroutine states_deallocate_wfns


  ! ---------------------------------------------------------
  subroutine states_densities_init(st, gr, geo)
    type(states_t), target, intent(inout) :: st
    type(grid_t),           intent(in)    :: gr
    type(geometry_t),       intent(in)    :: geo

    FLOAT :: size

    PUSH_SUB(states_densities_init)


    SAFE_ALLOCATE(st%zrho%Re(1:gr%fine%mesh%np_part, 1:st%d%nspin))
    st%zrho%Re = M_ZERO    
    st%rho => st%zrho%Re 
    if( st%cmplxscl%space) then
      SAFE_ALLOCATE(st%zrho%Im(1:gr%fine%mesh%np_part, 1:st%d%nspin))
      st%zrho%Im = M_ZERO
    end if

    if(st%d%cdft) then
      SAFE_ALLOCATE(st%current(1:gr%mesh%np_part, 1:gr%mesh%sb%dim, 1:st%d%nspin))
      st%current = M_ZERO
    end if

    if(geo%nlcc) then
      SAFE_ALLOCATE(st%rho_core(1:gr%fine%mesh%np))
      st%rho_core(:) = M_ZERO
      if(st%cmplxscl%space) then
        SAFE_ALLOCATE(st%Imrho_core(1:gr%fine%mesh%np))
        st%Imrho_core(:) = M_ZERO
      end if
    end if

    size = gr%mesh%np_part*CNST(8.0)*st%d%block_size

    call messages_write('Info: states-block size = ')
    call messages_write(size, fmt = '(f10.1)', align_left = .true., units = unit_megabytes, print_units = .true.)
    call messages_info()

    POP_SUB(states_densities_init)
  end subroutine states_densities_init

  subroutine states_allocate_current(st, gr)
    type(states_t), target, intent(inout) :: st
    type(grid_t),           intent(in)    :: gr

    PUSH_SUB(states_allocate_current)
    
    if(.not. associated(st%current)) then
      SAFE_ALLOCATE(st%current(1:gr%mesh%np_part, 1:gr%mesh%sb%dim, 1:st%d%nspin))
      st%current = M_ZERO
    end if

    POP_SUB(states_allocate_current)
  end subroutine states_allocate_current

  !---------------------------------------------------------------------
  !> This subroutine: (i) Fills in the block size (st\%d\%block_size);
  !! (ii) Finds out whether or not to pack the states (st\%d\%pack_states);
  !! (iii) Finds out the orthogonalization method (st\%d\%orth_method).
  subroutine states_exec_init(st, mc)
    type(states_t),    intent(inout) :: st
    type(multicomm_t), intent(in)    :: mc

    integer :: default

    PUSH_SUB(states_exec_init)

    !%Variable StatesPack
    !%Type logical
    !%Default no
    !%Section Execution::Optimization
    !%Description
    !% (Experimental) When set to yes, states are stored in packed
    !% mode, which improves performance considerably. However this
    !% is not fully implemented and it might give wrong results.
    !%
    !% If OpenCL is used and this variable is set to yes, Octopus
    !% will store the wave-functions in device (GPU) memory. If
    !% there is not enough memory to store all the wave-functions,
    !% execution will stop with an error.
    !%End

    call parse_logical(datasets_check('StatesPack'), .false., st%d%pack_states)
    if(st%d%pack_states) call messages_experimental('StatesPack')

    !%Variable StatesOrthogonalization
    !%Type integer
    !%Section SCF::Eigensolver
    !%Description
    !% The full orthogonalization method used by some
    !% eigensolvers. The default is <tt>cholesky_serial</tt>, except with state
    !% parallelization, the default is <tt>cholesky_parallel</tt>.
    !%Option gram_schmidt 1
    !%Option cholesky_serial 1
    !% Cholesky decomposition implemented using
    !% BLAS/LAPACK. Can be used with domain parallelization but not
    !% state parallelization. (Obsolete synonym: <tt>gram_schmidt</tt>)
    !%Option par_gram_schmidt 1
    !%Option cholesky_parallel 2
    !% Cholesky decomposition implemented using
    !% ScaLAPACK. Compatible with states parallelization. (Obsolete synonym: <tt>par_gram_schmidt</tt>)
    !%Option mgs 3
    !% Modified Gram-Schmidt orthogonalization.
    !% Can be used with domain parallelization but not state parallelization.
    !%Option qr 4
    !% (Experimental) Orthogonalization is performed based on a QR
    !% decomposition with LAPACK or ScaLAPACK.
    !% Compatible with states parallelization.
    !%End

    if(multicomm_strategy_is_parallel(mc, P_STRATEGY_STATES)) then
      default = ORTH_CHOLESKY_PARALLEL
    else
      default = ORTH_CHOLESKY_SERIAL
    end if

    call parse_integer(datasets_check('StatesOrthogonalization'), default, st%d%orth_method)

    if(.not.varinfo_valid_option('StatesOrthogonalization', st%d%orth_method)) call input_error('StatesOrthogonalization')
    call messages_print_var_option(stdout, 'StatesOrthogonalization', st%d%orth_method)

    if(st%d%orth_method == ORTH_QR) call messages_experimental("QR Orthogonalization")


    !%Variable StatesCLDeviceMemory
    !%Type float
    !%Section Execution::Optimization
    !%Default -512
    !%Description
    !% This variable selects the amount of OpenCL device memory that
    !% will be used by Octopus to store the states. 
    !%
    !% A positive number smaller than 1 indicates a fraction of the total
    !% device memory. A number larger than one indicates an absolute
    !% amount of memory in megabytes. A negative number indicates an
    !% amount of memory in megabytes that would be subtracted from
    !% the total device memory.
    !%End
    call parse_float(datasets_check('StatesCLDeviceMemory'), CNST(-512.0), st%d%cl_states_mem)

    POP_SUB(states_exec_init)
  end subroutine states_exec_init


  ! ---------------------------------------------------------
  subroutine states_copy(stout, stin, exclude)
    type(states_t), target, intent(inout) :: stout
    type(states_t),         intent(in)    :: stin
    logical, optional,      intent(in)    :: exclude !< do not copy wavefunctions, etc.

    logical :: exclude_

    PUSH_SUB(states_copy)

    exclude_ = optional_default(exclude, .false.)

    call states_null(stout)

    call states_dim_copy(stout%d, stin%d)
    call modelmb_particles_copy(stout%modelmbparticles, stin%modelmbparticles)
    stout%priv%wfs_type = stin%priv%wfs_type
    stout%nst           = stin%nst

    stout%only_userdef_istates = stin%only_userdef_istates

    if(.not. exclude_) then
      call loct_pointer_copy(stout%dpsi, stin%dpsi)

      !cmplxscl
      call loct_pointer_copy(stout%psi%zR, stin%psi%zR)
      stout%zpsi => stout%psi%zR
      call loct_pointer_copy(stout%zrho%Re, stin%zrho%Re)
      stout%rho => stout%zrho%Re
      call loct_pointer_copy(stout%zeigenval%Re, stin%zeigenval%Re)
      stout%eigenval => stout%zeigenval%Re
      if(stin%cmplxscl%space) then
        call loct_pointer_copy(stout%psi%zL, stin%psi%zL)         
        call loct_pointer_copy(stout%zrho%Im, stin%zrho%Im)           
        call loct_pointer_copy(stout%zeigenval%Im, stin%zeigenval%Im) 
        call loct_pointer_copy(stout%Imrho_core, stin%Imrho_core)
        call loct_pointer_copy(stout%Imfrozen_rho, stin%Imfrozen_rho)
      end if

      call loct_pointer_copy(stout%occ, stin%occ)
      call loct_pointer_copy(stout%spin, stin%spin)
      call loct_pointer_copy(stout%node, stin%node)
    endif

    stout%have_left_states = stin%have_left_states

    
    ! the call to init_block is done at the end of this subroutine
    ! it allocates iblock, psib, block_is_local
    stout%group%nblocks = stin%group%nblocks

    stout%open_boundaries = stin%open_boundaries
    ! Warning: some of the "open boundaries" variables are not copied.

    call loct_pointer_copy(stout%user_def_states, stin%user_def_states)

    call loct_pointer_copy(stout%current, stin%current)

    call loct_pointer_copy(stout%rho_core, stin%rho_core)

    call loct_pointer_copy(stout%frozen_rho, stin%frozen_rho)

    stout%fixed_occ = stin%fixed_occ
    stout%restart_fixed_occ = stin%restart_fixed_occ

    stout%fixed_spins = stin%fixed_spins

    stout%qtot       = stin%qtot
    stout%val_charge = stin%val_charge

    call smear_copy(stout%smear, stin%smear)

    stout%parallel_in_states = stin%parallel_in_states
    call mpi_grp_copy(stout%mpi_grp, stin%mpi_grp)
    stout%dom_st_kpt_mpi_grp = stin%dom_st_kpt_mpi_grp
    stout%st_kpt_mpi_grp     = stin%st_kpt_mpi_grp

#ifdef HAVE_SCALAPACK
    call blacs_proc_grid_copy(stin%dom_st_proc_grid, stout%dom_st_proc_grid)
#endif

    stout%lnst       = stin%lnst
    stout%st_start   = stin%st_start
    stout%st_end     = stin%st_end
    call loct_pointer_copy(stout%st_range, stin%st_range)
    call loct_pointer_copy(stout%st_num, stin%st_num)

    if(stin%parallel_in_states) call multicomm_all_pairs_copy(stout%ap, stin%ap)

    stout%symmetrize_density = stin%symmetrize_density

    if(.not. exclude_) then
      stout%group%block_initialized = .false.
      if(stin%group%block_initialized) then
        call states_init_block(stout, verbose = .false.)
      end if
    endif

    stout%packed = stin%packed

    POP_SUB(states_copy)
  end subroutine states_copy


  ! ---------------------------------------------------------
  subroutine states_end(st)
    type(states_t), intent(inout) :: st

    integer :: il

    PUSH_SUB(states_end)

    call states_dim_end(st%d)
    call modelmb_particles_end(st%modelmbparticles)

    ! this deallocates dpsi, zpsi, psib, iblock, iblock, st%ob_lead(:)%intf_psi
    call states_deallocate_wfns(st)

    SAFE_DEALLOCATE_P(st%zphi)
    SAFE_DEALLOCATE_P(st%ob_eigenval)
    call states_dim_end(st%ob_d)
    SAFE_DEALLOCATE_P(st%ob_occ)
    do il = 1, 2*MAX_DIM
      SAFE_DEALLOCATE_P(st%ob_lead(il)%self_energy)
    end do

    SAFE_DEALLOCATE_P(st%user_def_states)

    !cmplxscl
    !NOTE: sometimes these objects are allocated outside this module
    ! and therefore the correspondence with val => val%Re is broken.
    ! In this case we check if the pointer val is associated with zval%Re.
    if(associated(st%rho, target=st%zrho%Re)) then 
      nullify(st%rho)
      SAFE_DEALLOCATE_P(st%zrho%Re)       
    else
      SAFE_DEALLOCATE_P(st%rho)
    end if
    if(associated(st%eigenval, target=st%zeigenval%Re)) then 
      nullify(st%eigenval)
      SAFE_DEALLOCATE_P(st%zeigenval%Re)
    else
      SAFE_DEALLOCATE_P(st%eigenval)
    end if
    if(st%cmplxscl%space) then
      SAFE_DEALLOCATE_P(st%zrho%Im)
      SAFE_DEALLOCATE_P(st%zeigenval%Im)
      SAFE_DEALLOCATE_P(st%Imrho_core)
      SAFE_DEALLOCATE_P(st%Imfrozen_rho)
    end if
    

    SAFE_DEALLOCATE_P(st%current)
    SAFE_DEALLOCATE_P(st%rho_core)
    SAFE_DEALLOCATE_P(st%frozen_rho)

    SAFE_DEALLOCATE_P(st%occ)
    SAFE_DEALLOCATE_P(st%spin)

#ifdef HAVE_SCALAPACK
    call blacs_proc_grid_end(st%dom_st_proc_grid)
#endif
    SAFE_DEALLOCATE_P(st%node)
    SAFE_DEALLOCATE_P(st%st_range)
    SAFE_DEALLOCATE_P(st%st_num)

    if(st%parallel_in_states) then
      SAFE_DEALLOCATE_P(st%ap%schedule)
    end if

    POP_SUB(states_end)
  end subroutine states_end

  ! ---------------------------------------------------------
  !> generate a hydrogen s-wavefunction around a random point
  subroutine states_generate_random(st, mesh, ist_start_, ist_end_)
    type(states_t),    intent(inout) :: st
    type(mesh_t),      intent(in)    :: mesh
    integer, optional, intent(in)    :: ist_start_, ist_end_

    integer :: ist, ik, id, ist_start, ist_end, jst
    CMPLX   :: alpha, beta
    FLOAT, allocatable :: dpsi(:,  :)
    CMPLX, allocatable :: zpsi(:,  :), zpsi2(:)

    PUSH_SUB(states_generate_random)

    ist_start = optional_default(ist_start_, 1)
    ist_end   = optional_default(ist_end_,   st%nst)

    if (states_are_real(st)) then
      SAFE_ALLOCATE(dpsi(1:mesh%np, 1:st%d%dim))
    else
      SAFE_ALLOCATE(zpsi(1:mesh%np, 1:st%d%dim))
    end if

    select case(st%d%ispin)
    case(UNPOLARIZED, SPIN_POLARIZED)

      do ik = 1, st%d%nik
        do ist = ist_start, ist_end
          if (states_are_real(st)) then
            call dmf_random(mesh, dpsi(:, 1))
            if(.not. state_kpt_is_local(st, ist, ik)) cycle
            call states_set_state(st, mesh, ist,  ik, dpsi)
          else
            call zmf_random(mesh, zpsi(:, 1))
            if(.not. state_kpt_is_local(st, ist, ik)) cycle
            call states_set_state(st, mesh, ist,  ik, zpsi)
            if(st%have_left_states) then
              call zmf_random(mesh, zpsi(:, 1))
              if(.not. state_kpt_is_local(st, ist, ik)) cycle
              call states_set_state(st, mesh, ist,  ik, zpsi, left = .true.)
            end if
          end if
        end do
      end do

    case(SPINORS)

      ASSERT(states_are_complex(st))

      if(st%fixed_spins) then

        do ik = 1, st%d%nik
          do ist = ist_start, ist_end
            call zmf_random(mesh, zpsi(:, 1))
            if(.not. state_kpt_is_local(st, ist, ik)) cycle
            ! In this case, the spinors are made of a spatial part times a vector [alpha beta]^T in
            ! spin space (i.e., same spatial part for each spin component). So (alpha, beta)
            ! determines the spin values. The values of (alpha, beta) can be be obtained
            ! with simple formulae from <Sx>, <Sy>, <Sz>.
            !
            ! Note that here we orthonormalize the orbital part. This ensures that the spinors
            ! are untouched later in the general orthonormalization, and therefore the spin values
            ! of each spinor remain the same.
            SAFE_ALLOCATE(zpsi2(1:mesh%np))
            do jst = ist_start, ist - 1
              call states_get_state(st, mesh, 1, jst, ik, zpsi2)
              zpsi(1:mesh%np, 1) = zpsi(1:mesh%np, 1) - zmf_dotp(mesh, zpsi(:, 1), zpsi2)*zpsi2(1:mesh%np)
            end do
            SAFE_DEALLOCATE_A(zpsi2)

            zpsi(1:mesh%np, 1) = zpsi(1:mesh%np, 1)/zmf_nrm2(mesh, zpsi(:, 1))
            zpsi(1:mesh%np, 2) = zpsi(1:mesh%np, 1)

            alpha = TOCMPLX(sqrt(M_HALF + st%spin(3, ist, ik)), M_ZERO)
            beta  = TOCMPLX(sqrt(M_ONE - abs(alpha)**2), M_ZERO)
            if(abs(alpha) > M_ZERO) then
              beta = TOCMPLX(st%spin(1, ist, ik) / abs(alpha), st%spin(2, ist, ik) / abs(alpha))
            end if
            zpsi(1:mesh%np, 1) = alpha*zpsi(1:mesh%np, 1)
            zpsi(1:mesh%np, 2) = beta*zpsi(1:mesh%np, 2)
            call states_set_state(st, mesh, ist,  ik, zpsi)
          end do
        end do
      else
        do ik = 1, st%d%nik
          do ist = ist_start, ist_end
            do id = 1, st%d%dim
              call zmf_random(mesh, zpsi(:, id))
            end do
            if(.not. state_kpt_is_local(st, ist, ik)) cycle
            call states_set_state(st, mesh, ist,  ik, zpsi)
          end do
        end do
      end if

    end select

    SAFE_DEALLOCATE_A(dpsi)
    SAFE_DEALLOCATE_A(zpsi)

    POP_SUB(states_generate_random)
  end subroutine states_generate_random

  ! ---------------------------------------------------------
  subroutine states_fermi(st, mesh)
    type(states_t), intent(inout) :: st
    type(mesh_t),   intent(in)    :: mesh

    !> Local variables.
    integer            :: ist, ik
    FLOAT              :: charge
    CMPLX, allocatable :: zpsi(:, :)
#if defined(HAVE_MPI)
    integer            :: idir, tmp
    FLOAT, allocatable :: lspin(:), lspin2(:) !< To exchange spin.
#endif

    PUSH_SUB(states_fermi)

    if(st%cmplxscl%space) then
      call smear_occupy_states_by_ordering(st%smear, st%zeigenval%Re, st%zeigenval%Im, st%occ, st%qtot, &
        st%d%nik, st%nst, st%cmplxscl%penalizationfactor)
    else
      
      call smear_find_fermi_energy(st%smear, st%eigenval, st%occ, st%qtot, &
        st%d%nik, st%nst, st%d%kweights)

      call smear_fill_occupations(st%smear, st%eigenval, st%occ, &
        st%d%nik, st%nst)
        
    end if
    
    ! check if everything is OK
    charge = M_ZERO
    do ist = 1, st%nst
      charge = charge + sum(st%occ(ist, 1:st%d%nik) * st%d%kweights(1:st%d%nik))
    end do
    if(abs(charge-st%qtot) > CNST(1e-6)) then
      message(1) = 'Occupations do not integrate to total charge.'
      write(message(2), '(6x,f12.8,a,f12.8)') charge, ' != ', st%qtot
      call messages_warning(2)
      if(charge < M_EPSILON) then
        message(1) = "There don't seem to be any electrons at all!"
        call messages_fatal(1)
      endif
    end if

    if(st%d%ispin == SPINORS) then
      ASSERT(states_are_complex(st))
      
      SAFE_ALLOCATE(zpsi(1:mesh%np, st%d%dim))
      do ik = st%d%kpt%start, st%d%kpt%end
        do ist = st%st_start, st%st_end
          call states_get_state(st, mesh, ist, ik, zpsi)
          st%spin(1:3, ist, ik) = state_spin(mesh, zpsi)
        end do
#if defined(HAVE_MPI)
        if(st%parallel_in_states) then
          SAFE_ALLOCATE(lspin (1:st%lnst))
          SAFE_ALLOCATE(lspin2(1:st%nst))
          do idir = 1, 3
            lspin = st%spin(idir, st%st_start:st%st_end, ik)
            call lmpi_gen_allgatherv(st%lnst, lspin, tmp, lspin2, st%mpi_grp)
            do ist = 1, st%nst
              st%spin(idir, ist, ik) = lspin2(ist)
            enddo
          end do
          SAFE_DEALLOCATE_A(lspin)
          SAFE_DEALLOCATE_A(lspin2)
        end if
#endif
      end do
      SAFE_DEALLOCATE_A(zpsi)
    end if

    POP_SUB(states_fermi)
  end subroutine states_fermi


  ! ---------------------------------------------------------
  !> function to calculate the eigenvalues sum using occupations as weights
  function states_eigenvalues_sum(st, alt_eig) result(tot)
    type(states_t),  intent(in) :: st
    FLOAT, optional, intent(in) :: alt_eig(st%st_start:, st%d%kpt%start:) !< (:st%st_end, :st%d%kpt%end)
    FLOAT                       :: tot

    integer :: ik

    PUSH_SUB(states_eigenvalues_sum)

    tot = M_ZERO
    do ik = st%d%kpt%start, st%d%kpt%end
      if(present(alt_eig)) then
        tot = tot + st%d%kweights(ik) * sum(st%occ(st%st_start:st%st_end, ik) * &
          alt_eig(st%st_start:st%st_end, ik))
      else
        tot = tot + st%d%kweights(ik) * sum(st%occ(st%st_start:st%st_end, ik) * &
          st%eigenval(st%st_start:st%st_end, ik))
      end if
    end do

    if(st%parallel_in_states .or. st%d%kpt%parallel) call comm_allreduce(st%st_kpt_mpi_grp%comm, tot)

    POP_SUB(states_eigenvalues_sum)
  end function states_eigenvalues_sum

  ! ---------------------------------------------------------
  !> Same as states_eigenvalues_sum but suitable for cmplxscl
  function zstates_eigenvalues_sum(st, alt_eig) result(tot)
    type(states_t),  intent(in) :: st
    CMPLX, optional, intent(in) :: alt_eig(st%st_start:, st%d%kpt%start:) !< (:st%st_end, :st%d%kpt%end)
    CMPLX                       :: tot

    integer :: ik

    PUSH_SUB(zstates_eigenvalues_sum)

    tot = M_ZERO
    do ik = st%d%kpt%start, st%d%kpt%end
      if(present(alt_eig)) then
        tot = tot + st%d%kweights(ik) * sum(st%occ(st%st_start:st%st_end, ik) * &
          (alt_eig(st%st_start:st%st_end, ik)))
      else
        tot = tot + st%d%kweights(ik) * sum(st%occ(st%st_start:st%st_end, ik) * &
          (st%zeigenval%Re(st%st_start:st%st_end, ik) + M_zI * st%zeigenval%Im(st%st_start:st%st_end, ik)))
      end if
    end do

    if(st%parallel_in_states .or. st%d%kpt%parallel) call comm_allreduce(st%st_kpt_mpi_grp%comm, tot)

    POP_SUB(zstates_eigenvalues_sum)
  end function zstates_eigenvalues_sum

  ! -------------------------------------------------------
  integer pure function states_spin_channel(ispin, ik, dim)
    integer, intent(in) :: ispin, ik, dim

    select case(ispin)
    case(1); states_spin_channel = 1
    case(2); states_spin_channel = mod(ik+1, 2)+1
    case(3); states_spin_channel = dim
    case default; states_spin_channel = -1
    end select

  end function states_spin_channel


  ! ---------------------------------------------------------
  subroutine states_distribute_nodes(st, mc)
    type(states_t),    intent(inout) :: st
    type(multicomm_t), intent(in)    :: mc

#ifdef HAVE_MPI
    integer :: inode, ist
#endif

    PUSH_SUB(states_distribute_nodes)

    ! Defaults.
    st%node(:)            = 0
    st%st_start           = 1
    st%st_end             = st%nst
    st%lnst               = st%nst
    st%parallel_in_states = .false.
    call mpi_grp_init(st%mpi_grp, mc%group_comm(P_STRATEGY_STATES))
    call mpi_grp_init(st%dom_st_kpt_mpi_grp, mc%dom_st_kpt_comm)
    call mpi_grp_init(st%dom_st_mpi_grp, mc%dom_st_comm)
    call mpi_grp_init(st%st_kpt_mpi_grp, mc%st_kpt_comm)

#ifdef HAVE_SCALAPACK
    !%Variable ScaLAPACKCompatible
    !%Type logical
    !%Section Execution::Parallelization
    !%Description
    !% Whether to use a layout for states parallelization which is compatible with ScaLAPACK.
    !% The default is yes for <tt>CalculationMode = gs, unocc, go</tt> without k-point parallelization,
    !% and no otherwise. (Setting to other than default is experimental.)
    !% The value must be yes if any ScaLAPACK routines are called in the course of the run;
    !% it must be set by hand for <tt>td</tt> with <tt>TDDynamics = bo</tt>.
    !% This variable has no effect unless you are using states parallelization and have linked ScaLAPACK.
    !% Note: currently, use of ScaLAPACK is not compatible with task parallelization (<i>i.e.</i> slaves).
    !%End
    call parse_logical('ScaLAPACKCompatible', calc_mode_scalapack_compat() .and. .not. st%d%kpt%parallel, st%scalapack_compatible)
    if((calc_mode_scalapack_compat() .and. .not. st%d%kpt%parallel) .neqv. st%scalapack_compatible) &
      call messages_experimental('Setting ScaLAPACKCompatible to other than default')
    
    if(st%scalapack_compatible) then
      if(mc%have_slaves) &
        call messages_not_implemented("ScaLAPACK usage with task parallelization (slaves)")
      call blacs_proc_grid_init(st%dom_st_proc_grid, st%dom_st_mpi_grp)
    else
      call blacs_proc_grid_nullify(st%dom_st_proc_grid)
    end if
#else
    st%scalapack_compatible = .false.
#endif

#if defined(HAVE_MPI)
    if(multicomm_strategy_is_parallel(mc, P_STRATEGY_STATES)) then
      st%parallel_in_states = .true.

      call multicomm_create_all_pairs(st%mpi_grp, st%ap)

     if(st%nst < st%mpi_grp%size) then
       message(1) = "Have more processors than necessary"
       write(message(2),'(i4,a,i4,a)') st%mpi_grp%size, " processors and ", st%nst, " states."
       call messages_fatal(2)
     end if

     SAFE_ALLOCATE(st%st_range(1:2, 0:st%mpi_grp%size-1))
     SAFE_ALLOCATE(st%st_num(0:st%mpi_grp%size-1))

     call multicomm_divide_range(st%nst, st%mpi_grp%size, st%st_range(1, :), st%st_range(2, :), &
       lsize = st%st_num, scalapack_compat = st%scalapack_compatible)

     message(1) = "Info: Parallelization in states"
     call messages_info(1)

     do inode = 0, st%mpi_grp%size - 1
       write(message(1),'(a,i4,a,i5,a)') &
            'Info: Nodes in states-group ', inode, ' will manage ', st%st_num(inode), ' states'
       if(st%st_num(inode) > 0) then
         write(message(1),'(a,a,i6,a,i6)') trim(message(1)), ':', &
           st%st_range(1, inode), " - ", st%st_range(2, inode)
       endif
       call messages_info(1)

       do ist = st%st_range(1, inode), st%st_range(2, inode)
         st%node(ist) = inode
       end do
     end do

     if(any(st%st_num(:) == 0)) then
       message(1) = "Cannot run with empty states-groups. Select a smaller number of processors so none are idle."
       call messages_fatal(1, only_root_writes = .true.)
     endif

     st%st_start = st%st_range(1, st%mpi_grp%rank)
     st%st_end   = st%st_range(2, st%mpi_grp%rank)
     st%lnst     = st%st_num(st%mpi_grp%rank)

   end if
#endif

    POP_SUB(states_distribute_nodes)
  end subroutine states_distribute_nodes


  ! ---------------------------------------------------------
  subroutine states_set_complex(st)
    type(states_t),    intent(inout) :: st

    PUSH_SUB(states_set_complex)
    st%priv%wfs_type = TYPE_CMPLX

    POP_SUB(states_set_complex)
  end subroutine states_set_complex

  ! ---------------------------------------------------------
  pure logical function states_are_complex(st) result (wac)
    type(states_t),    intent(in) :: st

    wac = (st%priv%wfs_type == TYPE_CMPLX)

  end function states_are_complex


  ! ---------------------------------------------------------
  pure logical function states_are_real(st) result (war)
    type(states_t),    intent(in) :: st

    war = (st%priv%wfs_type == TYPE_FLOAT)

  end function states_are_real

  ! ---------------------------------------------------------
  !
  !> This function can calculate several quantities that depend on
  !! derivatives of the orbitals from the states and the density.
  !! The quantities to be calculated depend on the arguments passed.
  subroutine states_calc_quantities(der, st, &
    kinetic_energy_density, paramagnetic_current, density_gradient, density_laplacian, gi_kinetic_energy_density)
    type(derivatives_t),     intent(in)    :: der
    type(states_t),          intent(in)    :: st
    FLOAT, optional, target, intent(out)   :: kinetic_energy_density(:,:)       !< The kinetic energy density.
    FLOAT, optional, target, intent(out)   :: paramagnetic_current(:,:,:)       !< The paramagnetic current.
    FLOAT, optional,         intent(out)   :: density_gradient(:,:,:)           !< The gradient of the density.
    FLOAT, optional,         intent(out)   :: density_laplacian(:,:)            !< The Laplacian of the density.
    FLOAT, optional,         intent(out)   :: gi_kinetic_energy_density(:,:)    !< The gauge-invariant kinetic energy density.

    FLOAT, pointer :: jp(:, :, :)
    FLOAT, pointer :: tau(:, :)
    CMPLX, allocatable :: wf_psi(:,:), gwf_psi(:,:,:), lwf_psi(:,:)
    CMPLX   :: c_tmp
    integer :: is, ik, ist, i_dim, st_dim, ii
    FLOAT   :: ww, kpoint(1:MAX_DIM)
    logical :: something_to_do

    PUSH_SUB(states_calc_quantities)

    something_to_do = present(kinetic_energy_density) .or. present(gi_kinetic_energy_density) .or. &
      present(paramagnetic_current) .or. present(density_gradient) .or. present(density_laplacian)
    ASSERT(something_to_do)

    SAFE_ALLOCATE( wf_psi(1:der%mesh%np_part, 1:st%d%dim))
    SAFE_ALLOCATE(gwf_psi(1:der%mesh%np, 1:der%mesh%sb%dim, 1:st%d%dim))
    if(present(density_laplacian)) then
      SAFE_ALLOCATE(lwf_psi(1:der%mesh%np, 1:st%d%dim))
    endif

    nullify(tau)
    if(present(kinetic_energy_density)) tau => kinetic_energy_density

    nullify(jp)
    if(present(paramagnetic_current)) jp => paramagnetic_current

    ! for the gauge-invariant kinetic energy density we need the
    ! current and the kinetic energy density
    if(present(gi_kinetic_energy_density)) then
      if(.not. present(paramagnetic_current) .and. states_are_complex(st)) then
        SAFE_ALLOCATE(jp(1:der%mesh%np, 1:der%mesh%sb%dim, 1:st%d%nspin))
      end if
      if(.not. present(kinetic_energy_density)) then
        SAFE_ALLOCATE(tau(1:der%mesh%np, 1:st%d%nspin))
      end if
    end if

    if(associated(tau)) tau = M_ZERO
    if(associated(jp)) jp = M_ZERO
    if(present(density_gradient)) density_gradient(:,:,:) = M_ZERO
    if(present(density_laplacian)) density_laplacian(:,:) = M_ZERO
    if(present(gi_kinetic_energy_density)) gi_kinetic_energy_density = M_ZERO

    do ik = st%d%kpt%start, st%d%kpt%end

      kpoint(1:der%mesh%sb%dim) = kpoints_get_point(der%mesh%sb%kpoints, states_dim_get_kpoint_index(st%d, ik))
      is = states_dim_get_spin_index(st%d, ik)

      do ist = st%st_start, st%st_end

        ! all calculations will be done with complex wavefunctions
        call states_get_state(st, der%mesh, ist, ik, wf_psi)

        ! calculate gradient of the wavefunction
        do st_dim = 1, st%d%dim
          call zderivatives_grad(der, wf_psi(:,st_dim), gwf_psi(:,:,st_dim))
        end do

        ! calculate the Laplacian of the wavefunction
        if (present(density_laplacian)) then
          do st_dim = 1, st%d%dim
            call zderivatives_lapl(der, wf_psi(:,st_dim), lwf_psi(:,st_dim))
          end do
        end if

        ww = st%d%kweights(ik)*st%occ(ist, ik)

        if(present(density_laplacian)) then
          density_laplacian(1:der%mesh%np, is) = density_laplacian(1:der%mesh%np, is) + &
               ww*M_TWO*real(conjg(wf_psi(1:der%mesh%np, 1))*lwf_psi(1:der%mesh%np, 1))
          if(st%d%ispin == SPINORS) then
            density_laplacian(1:der%mesh%np, 2) = density_laplacian(1:der%mesh%np, 2) + &
                 ww*M_TWO*real(conjg(wf_psi(1:der%mesh%np, 2))*lwf_psi(1:der%mesh%np, 2))
            density_laplacian(1:der%mesh%np, 3) = density_laplacian(1:der%mesh%np, 3) + &
                 ww*real (lwf_psi(1:der%mesh%np, 1)*conjg(wf_psi(1:der%mesh%np, 2)) + &
                 wf_psi(1:der%mesh%np, 1)*conjg(lwf_psi(1:der%mesh%np, 2)))
            density_laplacian(1:der%mesh%np, 4) = density_laplacian(1:der%mesh%np, 4) + &
                 ww*aimag(lwf_psi(1:der%mesh%np, 1)*conjg(wf_psi(1:der%mesh%np, 2)) + &
                 wf_psi(1:der%mesh%np, 1)*conjg(lwf_psi(1:der%mesh%np, 2)))
          end if
        end if
        
        do i_dim = 1, der%mesh%sb%dim
          if(present(density_gradient)) &
               density_gradient(1:der%mesh%np, i_dim, is) = density_gradient(1:der%mesh%np, i_dim, is) + &
               ww*M_TWO*real(conjg(wf_psi(1:der%mesh%np, 1))*gwf_psi(1:der%mesh%np, i_dim, 1))
          if(present(density_laplacian)) &
               density_laplacian(1:der%mesh%np, is) = density_laplacian(1:der%mesh%np, is)         + &
               ww*M_TWO*real(conjg(gwf_psi(1:der%mesh%np, i_dim, 1))*gwf_psi(1:der%mesh%np, i_dim, 1))

          if(associated(jp)) then
            if (.not.(states_are_real(st))) then
              jp(1:der%mesh%np, i_dim, is) = jp(1:der%mesh%np, i_dim, is) + &
                   ww*aimag(conjg(wf_psi(1:der%mesh%np, 1))*gwf_psi(1:der%mesh%np, i_dim, 1) - &
                   M_zI*(wf_psi(1:der%mesh%np, 1))**2*kpoint(i_dim ) )
            else
              jp(1:der%mesh%np, i_dim, is) = M_ZERO
            end if
          end if

          if (associated(tau)) then
            tau (1:der%mesh%np, is)   = tau (1:der%mesh%np, is)        + &
                 ww*abs(gwf_psi(1:der%mesh%np, i_dim, 1))**2  &
                 + ww*abs(kpoint(i_dim))**2*abs(wf_psi(1:der%mesh%np, 1))**2  &
                 - ww*M_TWO*aimag(conjg(wf_psi(1:der%mesh%np, 1))*kpoint(i_dim)*gwf_psi(1:der%mesh%np, i_dim, 1) )
          end if

          if(present(gi_kinetic_energy_density)) then
            ASSERT(associated(tau))
            if(states_are_complex(st)) then
              ASSERT(associated(jp))
              gi_kinetic_energy_density(1:der%mesh%np, is) = tau(1:der%mesh%np, is) - &
                   jp(1:der%mesh%np, i_dim, 1)**2/st%rho(1:der%mesh%np, 1)
            else
              gi_kinetic_energy_density(1:der%mesh%np, is) = tau(1:der%mesh%np, is)
            end if
          end if

          if(st%d%ispin == SPINORS) then
            if(present(density_gradient)) then
              density_gradient(1:der%mesh%np, i_dim, 2) = density_gradient(1:der%mesh%np, i_dim, 2) + &
                   ww*M_TWO*real(conjg(wf_psi(1:der%mesh%np, 2))*gwf_psi(1:der%mesh%np, i_dim, 2))
              density_gradient(1:der%mesh%np, i_dim, 3) = density_gradient(1:der%mesh%np, i_dim, 3) + ww* &
                   real (gwf_psi(1:der%mesh%np, i_dim, 1)*conjg(wf_psi(1:der%mesh%np, 2)) + &
                   wf_psi(1:der%mesh%np, 1)*conjg(gwf_psi(1:der%mesh%np, i_dim, 2)))
              density_gradient(1:der%mesh%np, i_dim, 4) = density_gradient(1:der%mesh%np, i_dim, 4) + ww* &
                   aimag(gwf_psi(1:der%mesh%np, i_dim, 1)*conjg(wf_psi(1:der%mesh%np, 2)) + &
                   wf_psi(1:der%mesh%np, 1)*conjg(gwf_psi(1:der%mesh%np, i_dim, 2)))
            end if

            if(present(density_laplacian)) then
              density_laplacian(1:der%mesh%np, 2) = density_laplacian(1:der%mesh%np, 2)         + &
                   ww*M_TWO*real(conjg(gwf_psi(1:der%mesh%np, i_dim, 2))*gwf_psi(1:der%mesh%np, i_dim, 2))
              density_laplacian(1:der%mesh%np, 3) = density_laplacian(1:der%mesh%np, 3)         + &
                   ww*M_TWO*real (gwf_psi(1:der%mesh%np, i_dim, 1)*conjg(gwf_psi(1:der%mesh%np, i_dim, 2)))
              density_laplacian(1:der%mesh%np, 4) = density_laplacian(1:der%mesh%np, 4)         + &
                   ww*M_TWO*aimag(gwf_psi(1:der%mesh%np, i_dim, 1)*conjg(gwf_psi(1:der%mesh%np, i_dim, 2)))
            end if

            ! the expression for the paramagnetic current with spinors is
            !     j = ( jp(1)             jp(3) + i jp(4) )
            !         (-jp(3) + i jp(4)   jp(2)           )
            if(associated(jp)) then
              jp(1:der%mesh%np, i_dim, 2) = jp(1:der%mesh%np, i_dim, 2) + &
                   ww*aimag(conjg(wf_psi(1:der%mesh%np, 2))*gwf_psi(1:der%mesh%np, i_dim, 2))
              do ii = 1, der%mesh%np
                c_tmp = conjg(wf_psi(ii, 1))*gwf_psi(ii, i_dim, 2) - wf_psi(ii, 2)*conjg(gwf_psi(ii, i_dim, 1))
                jp(ii, i_dim, 3) = jp(ii, i_dim, 3) + ww* real(c_tmp)
                jp(ii, i_dim, 4) = jp(ii, i_dim, 4) + ww*aimag(c_tmp)
              end do
            end if

            ! the expression for the paramagnetic current with spinors is
            !     t = ( tau(1)              tau(3) + i tau(4) )
            !         ( tau(3) - i tau(4)   tau(2)            )
            if(associated(tau)) then
              tau (1:der%mesh%np, 2) = tau (1:der%mesh%np, 2) + ww*abs(gwf_psi(1:der%mesh%np, i_dim, 2))**2
              do ii = 1, der%mesh%np
                c_tmp = conjg(gwf_psi(ii, i_dim, 1))*gwf_psi(ii, i_dim, 2)
                tau(ii, 3) = tau(ii, 3) + ww* real(c_tmp)
                tau(ii, 4) = tau(ii, 4) + ww*aimag(c_tmp)
              end do
            end if

            ASSERT(.not. present(gi_kinetic_energy_density))

          end if !SPINORS

        end do

      end do
    end do

    SAFE_DEALLOCATE_A(wf_psi)
    SAFE_DEALLOCATE_A(gwf_psi)
    SAFE_DEALLOCATE_A(lwf_psi)

    if(.not. present(paramagnetic_current)) then
      SAFE_DEALLOCATE_P(jp)
    end if

    if(.not. present(kinetic_energy_density)) then
      SAFE_DEALLOCATE_P(tau)
    end if

    if(st%parallel_in_states .or. st%d%kpt%parallel) call reduce_all(st%st_kpt_mpi_grp)

    POP_SUB(states_calc_quantities)

  contains

    subroutine reduce_all(grp)
      type(mpi_grp_t), intent(in)  :: grp

      PUSH_SUB(states_calc_quantities.reduce_all)

      if(associated(tau)) call comm_allreduce(grp%comm, tau, dim = (/der%mesh%np, st%d%nspin/))

      if(present(gi_kinetic_energy_density)) &
        call comm_allreduce(grp%comm, gi_kinetic_energy_density, dim = (/der%mesh%np, st%d%nspin/))

      if (present(density_laplacian)) call comm_allreduce(grp%comm, density_laplacian, dim = (/der%mesh%np, st%d%nspin/))

      do is = 1, st%d%nspin
        if(associated(jp)) call comm_allreduce(grp%comm, jp(:, :, is), dim = (/der%mesh%np, der%mesh%sb%dim/))

        if(present(density_gradient)) &
          call comm_allreduce(grp%comm, density_gradient(:, :, is), dim = (/der%mesh%np, der%mesh%sb%dim/))
      end do

      POP_SUB(states_calc_quantities.reduce_all)
    end subroutine reduce_all

  end subroutine states_calc_quantities


  ! ---------------------------------------------------------
  function state_spin(mesh, f1) result(spin)
    type(mesh_t), intent(in) :: mesh
    CMPLX,        intent(in) :: f1(:, :)
    FLOAT                    :: spin(1:3)

    CMPLX :: z

    PUSH_SUB(state_spin)

    z = zmf_dotp(mesh, f1(:, 1) , f1(:, 2))

    spin(1) = M_TWO*dble(z)
    spin(2) = M_TWO*aimag(z)
    spin(3) = zmf_nrm2(mesh, f1(:, 1))**2 - zmf_nrm2(mesh, f1(:, 2))**2
    spin = M_HALF*spin ! spin is half the sigma matrix.

    POP_SUB(state_spin)
  end function state_spin

  ! ---------------------------------------------------------
  logical function state_is_local(st, ist)
    type(states_t), intent(in) :: st
    integer,        intent(in) :: ist

    PUSH_SUB(state_is_local)

    state_is_local = ist >= st%st_start.and.ist <= st%st_end

    POP_SUB(state_is_local)
  end function state_is_local

  ! ---------------------------------------------------------
  logical function state_kpt_is_local(st, ist, ik)
    type(states_t), intent(in) :: st
    integer,        intent(in) :: ist
    integer,        intent(in) :: ik

    PUSH_SUB(state_kpt_is_local)

    state_kpt_is_local = ist >= st%st_start .and. ist <= st%st_end .and. &
      ik >= st%d%kpt%start .and. ik <= st%d%kpt%end

    POP_SUB(state_kpt_is_local)
  end function state_kpt_is_local


  ! ---------------------------------------------------------

  real(8) function states_wfns_memory(st, mesh) result(memory)
    type(states_t), intent(in) :: st
    type(mesh_t),   intent(in) :: mesh

    PUSH_SUB(states_wfns_memory)
    memory = 0.0_8

    ! orbitals
    memory = memory + REAL_PRECISION*dble(mesh%np_part_global)*st%d%dim*dble(st%nst)*st%d%kpt%nglobal

    POP_SUB(states_wfns_memory)
  end function states_wfns_memory

  ! ---------------------------------------------------------

  subroutine states_blacs_blocksize(st, mesh, blocksize, total_np)
    type(states_t),  intent(in)    :: st
    type(mesh_t),    intent(in)    :: mesh
    integer,         intent(out)   :: blocksize(2)
    integer,         intent(out)   :: total_np

    PUSH_SUB(states_blacs_blocksize)

#ifdef HAVE_SCALAPACK
    ! We need to select the block size of the decomposition. This is
    ! tricky, since not all processors have the same number of
    ! points.
    !
    ! What we do for now is to use the maximum of the number of
    ! points and we set to zero the remaining points.

    if(.not. st%scalapack_compatible) then
      message(1) = "Attempt to use ScaLAPACK when processes have not been distributed in compatible layout."
      message(2) = "You need to set ScaLAPACKCompatible = yes in the input file and re-run."
      call messages_fatal(2, only_root_writes = .true.)
    endif
    
    if (mesh%parallel_in_domains) then
      blocksize(1) = maxval(mesh%vp%np_local_vec) + (st%d%dim - 1) * &
       maxval(mesh%vp%np_local_vec + mesh%vp%np_bndry + mesh%vp%np_ghost)
    else
      blocksize(1) = mesh%np + (st%d%dim - 1)*mesh%np_part
    end if

    if (st%parallel_in_states) then
      blocksize(2) = maxval(st%st_num)
    else
      blocksize(2) = st%nst
    end if

    total_np = blocksize(1)*st%dom_st_proc_grid%nprow


    ASSERT(st%d%dim*mesh%np_part >= blocksize(1))
#else
    blocksize(1) = 0
    blocksize(2) = 0
    total_np = 0
#endif

    POP_SUB(states_blacs_blocksize)
  end subroutine states_blacs_blocksize

  ! ------------------------------------------------------------

  subroutine states_pack(st, copy)
    type(states_t),    intent(inout) :: st
    logical, optional, intent(in)    :: copy

    integer :: iqn, ib
    integer(8) :: max_mem, mem
#ifdef HAVE_OPENCL
    FLOAT, parameter :: mem_frac = 0.75
#endif

    PUSH_SUB(states_pack)

    ASSERT(.not. st%packed)

    st%packed = .true.

    if(opencl_is_enabled()) then
#ifdef HAVE_OPENCL
      call clGetDeviceInfo(opencl%device, CL_DEVICE_GLOBAL_MEM_SIZE, max_mem, cl_status)
#endif
      if(st%d%cl_states_mem > CNST(1.0)) then
        max_mem = int(st%d%cl_states_mem, 8)*(1024_8)**2
      else if(st%d%cl_states_mem < CNST(0.0)) then
        max_mem = max_mem + int(st%d%cl_states_mem, 8)*(1024_8)**2
      else
        max_mem = int(st%d%cl_states_mem*real(max_mem, REAL_PRECISION), 8)
      end if
    else
      max_mem = HUGE(max_mem)
    end if

    mem = 0
    qnloop: do iqn = st%d%kpt%start, st%d%kpt%end
      do ib = st%group%block_start, st%group%block_end

        mem = mem + batch_pack_size(st%group%psib(ib, iqn))

        if(mem > max_mem) then
          call messages_write('Not enough CL device memory to store all states simultaneously.', new_line = .true.)
          call messages_write('Only ')
          call messages_write(ib - st%group%block_start)
          call messages_write(' of ')
          call messages_write(st%group%block_end - st%group%block_start + 1)
          call messages_write(' blocks will be stored in device memory.', new_line = .true.)
          call messages_warning()
          exit qnloop
        end if
        
        call batch_pack(st%group%psib(ib, iqn), copy)
        if(st%have_left_states)  call batch_pack(st%psibL(ib, iqn), copy)
      end do
    end do qnloop

    POP_SUB(states_pack)
  end subroutine states_pack

  ! ------------------------------------------------------------

  subroutine states_unpack(st, copy)
    type(states_t),    intent(inout) :: st
    logical, optional, intent(in)    :: copy

    integer :: iqn, ib

    PUSH_SUB(states_unpack)

    ASSERT(st%packed)

    st%packed = .false.

    do iqn = st%d%kpt%start, st%d%kpt%end
      do ib = st%group%block_start, st%group%block_end
        if(batch_is_packed(st%group%psib(ib, iqn))) call batch_unpack(st%group%psib(ib, iqn), copy)
        if(batch_is_packed(st%group%psib(ib, iqn)) .and. st%have_left_states) call batch_unpack(st%psibL(ib, iqn), copy)        
      end do
    end do

    POP_SUB(states_unpack)
  end subroutine states_unpack

  ! ------------------------------------------------------------

  subroutine states_sync(st)
    type(states_t),    intent(inout) :: st

    integer :: iqn, ib

    PUSH_SUB(states_sync)

    if(states_are_packed(st)) then

      do iqn = st%d%kpt%start, st%d%kpt%end
        do ib = st%group%block_start, st%group%block_end
          call batch_sync(st%group%psib(ib, iqn))
          if(st%have_left_states) call batch_sync(st%psibL(ib, iqn))
        end do
      end do

    end if

    POP_SUB(states_sync)
  end subroutine states_sync

  ! -----------------------------------------------------------

  subroutine states_write_info(st)
    type(states_t),    intent(in) :: st

    PUSH_SUB(states_write_info)

    call messages_print_stress(stdout, "States")

    write(message(1), '(a,f12.3)') 'Total electronic charge  = ', st%qtot
    write(message(2), '(a,i8)')    'Number of states         = ', st%nst
    write(message(3), '(a,i8)')    'States block-size        = ', st%d%block_size
    call messages_info(3)
    if(st%have_left_states) then
      write(message(1), '(a)')    'States have left states'
      call messages_info(1)
    end if

    call messages_print_stress(stdout)

    POP_SUB(states_write_info)
  end subroutine states_write_info
 
  ! -----------------------------------------------------------

  logical pure function states_are_packed(st) result(packed)
    type(states_t),    intent(in) :: st

    packed = st%packed
  end function states_are_packed

  ! ------------------------------------------------------------

  subroutine states_set_zero(st)
    type(states_t),    intent(inout) :: st

    integer :: iqn, ib

    PUSH_SUB(states_set_zero)

    do iqn = st%d%kpt%start, st%d%kpt%end
      do ib = st%group%block_start, st%group%block_end
        call batch_set_zero(st%group%psib(ib, iqn))
        if(st%have_left_states) call batch_set_zero(st%psibL(ib, iqn)) 
      end do
    end do
    
    POP_SUB(states_set_zero)
  end subroutine states_set_zero

  ! ------------------------------------------------------------

  integer pure function states_block_min(st, ib) result(range)
    type(states_t),    intent(in) :: st
    integer,           intent(in) :: ib
    
    range = st%group%block_range(ib, 1)
  end function states_block_min

  ! ------------------------------------------------------------

  integer pure function states_block_max(st, ib) result(range)
    type(states_t),    intent(in) :: st
    integer,           intent(in) :: ib
    
    range = st%group%block_range(ib, 2)
  end function states_block_max

  ! ------------------------------------------------------------

  integer pure function states_block_size(st, ib) result(size)
    type(states_t),    intent(in) :: st
    integer,           intent(in) :: ib
    
    size = st%group%block_size(ib)
  end function states_block_size

  ! ---------------------------------------------------------
  !> number of occupied-unoccipied pairs for Casida
  subroutine states_count_pairs(st, n_pairs, n_occ, n_unocc, is_included, is_frac_occ)
    type(states_t),    intent(in)  :: st
    integer,           intent(out) :: n_pairs
    integer,           intent(out) :: n_occ(:)   !< nik
    integer,           intent(out) :: n_unocc(:) !< nik
    logical, pointer,  intent(out) :: is_included(:,:,:) !< (max(n_occ), max(n_unocc), st%d%nik)
    logical,           intent(out) :: is_frac_occ !< are there fractional occupations?

    integer :: ik, ist, ast, n_filled, n_partially_filled, n_half_filled
    character(len=80) :: nst_string, default, wfn_list
    FLOAT :: energy_window

    PUSH_SUB(states_count_pairs)

    is_frac_occ = .false.
    do ik = 1, st%d%nik
      call occupied_states(st, ik, n_filled, n_partially_filled, n_half_filled)
      if(n_partially_filled > 0 .or. n_half_filled > 0) is_frac_occ = .true.
      n_occ(ik) = n_filled + n_partially_filled + n_half_filled
      n_unocc(ik) = st%nst - n_filled
      ! when we implement occupations, partially occupied levels need to be counted as both occ and unocc.
    end do

    !%Variable CasidaKSEnergyWindow
    !%Type float
    !%Section Linear Response::Casida
    !%Description
    !% An alternative to <tt>CasidaKohnShamStates</tt> for specifying which occupied-unoccupied
    !% transitions will be used: all those whose eigenvalue differences are less than this
    !% number will be included. If a value less than 0 is supplied, this criterion will not be used.
    !%End

    call parse_float(datasets_check('CasidaKSEnergyWindow'), -M_ONE, energy_window, units_inp%energy)

    !%Variable CasidaKohnShamStates
    !%Type string
    !%Section Linear Response::Casida
    !%Default all states
    !%Description
    !% The calculation of the excitation spectrum of a system in the Casida frequency-domain
    !% formulation of linear-response time-dependent density functional theory (TDDFT)
    !% implies the use of a basis set of occupied/unoccupied Kohn-Sham orbitals. This
    !% basis set should, in principle, include all pairs formed by all occupied states,
    !% and an infinite number of unoccupied states. In practice, one has to truncate this
    !% basis set, selecting a number of occupied and unoccupied states that will form the
    !% pairs. These states are specified with this variable. If there are, say, 15 occupied
    !% states, and one sets this variable to the value "10-18", this means that occupied
    !% states from 10 to 15, and unoccupied states from 16 to 18 will be considered.
    !%
    !% This variable is a string in list form, <i>i.e.</i> expressions such as "1,2-5,8-15" are
    !% valid. You should include a non-zero number of unoccupied states and a non-zero number
    !% of occupied states.
    !%End

    n_pairs = 0
    SAFE_ALLOCATE(is_included(maxval(n_occ), minval(n_occ) + 1:st%nst , st%d%nik))
    is_included(:,:,:) = .false.

    if(energy_window < M_ZERO) then
      write(nst_string,'(i6)') st%nst
      write(default,'(a,a)') "1-", trim(adjustl(nst_string))
      call parse_string(datasets_check('CasidaKohnShamStates'), default, wfn_list)

      write(message(1),'(a,a)') "Info: States that form the basis: ", trim(wfn_list)
      call messages_info(1)

      ! count pairs
      n_pairs = 0
      do ik = 1, st%d%nik
        do ast = n_occ(ik) + 1, st%nst
          if(loct_isinstringlist(ast, wfn_list)) then
            do ist = 1, n_occ(ik)
              if(loct_isinstringlist(ist, wfn_list)) then
                n_pairs = n_pairs + 1
                is_included(ist, ast, ik) = .true.
              end if
            end do
          end if
        end do
      end do

    else ! using CasidaKSEnergyWindow

      write(message(1),'(a,f12.6,a)') "Info: including transitions with energy < ", &
        units_from_atomic(units_out%energy, energy_window), trim(units_abbrev(units_out%energy))
      call messages_info(1)

      ! count pairs
      n_pairs = 0
      do ik = 1, st%d%nik
        do ast = n_occ(ik) + 1, st%nst
          do ist = 1, n_occ(ik)
            if(st%eigenval(ast, ik) - st%eigenval(ist, ik) < energy_window) then
              n_pairs = n_pairs + 1
              is_included(ist, ast, ik) = .true.
            endif
          end do
        end do
      end do

    endif

    POP_SUB(states_count_pairs)
  end subroutine states_count_pairs

  ! ---------------------------------------------------------
  !> Returns information about which single-particle orbitals are
  !! occupied or not in a _many-particle_ state st:
  !!   n_filled are the number of orbitals that are totally filled
  !!            (the occupation number is two, if ispin = UNPOLARIZED,
  !!            or it is one in the other cases).
  !!   n_half_filled is only meaningful if ispin = UNPOLARIZED. It 
  !!            is the number of orbitals where there is only one 
  !!            electron in the orbital.
  !!   n_partially_filled is the number of orbitals that are neither filled,
  !!            half-filled, nor empty.
  !! The integer arrays filled, partially_filled and half_filled point
  !!   to the indices where the filled, partially filled and half_filled
  !!   orbitals are, respectively.
  subroutine occupied_states(st, ik, n_filled, n_partially_filled, n_half_filled, &
                             filled, partially_filled, half_filled)
    type(states_t),    intent(in)  :: st
    integer,           intent(in)  :: ik
    integer,           intent(out) :: n_filled, n_partially_filled, n_half_filled
    integer, optional, intent(out) :: filled(:), partially_filled(:), half_filled(:)

    integer :: ist
    FLOAT, parameter :: M_THRESHOLD = CNST(1.0e-6)

    PUSH_SUB(occupied_states)

    if(present(filled))           filled(:) = 0
    if(present(partially_filled)) partially_filled(:) = 0
    if(present(half_filled))      half_filled(:) = 0
    n_filled = 0
    n_partially_filled = 0
    n_half_filled = 0

    select case(st%d%ispin)
    case(UNPOLARIZED)
      do ist = 1, st%nst
        if(abs(st%occ(ist, ik) - M_TWO) < M_THRESHOLD) then
          n_filled = n_filled + 1
          if(present(filled)) filled(n_filled) = ist
        elseif(abs(st%occ(ist, ik) - M_ONE) < M_THRESHOLD) then
          n_half_filled = n_half_filled + 1
          if(present(half_filled)) half_filled(n_half_filled) = ist
        elseif(st%occ(ist, ik) > M_THRESHOLD ) then
          n_partially_filled = n_partially_filled + 1
          if(present(partially_filled)) partially_filled(n_partially_filled) = ist
        elseif(abs(st%occ(ist, ik)) > M_THRESHOLD ) then
          write(message(1),*) 'Internal error in occupied_states: Illegal occupation value ', st%occ(ist, ik)
          call messages_fatal(1)
         end if
      end do
    case(SPIN_POLARIZED, SPINORS)
      do ist = 1, st%nst
        if(abs(st%occ(ist, ik)-M_ONE) < M_THRESHOLD) then
          n_filled = n_filled + 1
          if(present(filled)) filled(n_filled) = ist
        elseif(st%occ(ist, ik) > M_THRESHOLD ) then
          n_partially_filled = n_partially_filled + 1
          if(present(partially_filled)) partially_filled(n_partially_filled) = ist
        elseif(abs(st%occ(ist, ik)) > M_THRESHOLD ) then
          write(message(1),*) 'Internal error in occupied_states: Illegal occupation value ', st%occ(ist, ik)
          call messages_fatal(1)
         end if
      end do
    end select

    POP_SUB(occupied_states)
  end subroutine occupied_states


  ! ---------------------------------------------------------
  !> Reads the interface regions of the wavefunctions
  subroutine states_get_ob_intf(st, gr)
    type(states_t),   intent(inout) :: st
    type(grid_t),     intent(in)    :: gr

    integer            :: ik, ist, idim, il
    CMPLX, allocatable :: zpsi(:)

    PUSH_SUB(states_get_ob_intf)

    write(message(1), '(a,i5)') 'Info: Reading ground-state interface wavefunctions.'
    call messages_info(1)

    ! Sanity check.
    do il = 1, NLEADS
      ASSERT(associated(st%ob_lead(il)%intf_psi))
      ASSERT(il <= 2) ! FIXME: wrong if non-transport calculation
    end do

    SAFE_ALLOCATE(zpsi(1:gr%mesh%np))

    do ik = st%d%kpt%start, st%d%kpt%end
      do ist = st%st_start, st%st_end
        do idim = 1, st%d%dim

          call states_get_state(st, gr%mesh, idim, ist, ik, zpsi)

          do il = 1, NLEADS
            call get_intf_wf(gr%intf(il), zpsi, st%ob_lead(il)%intf_psi(:, idim, ist, ik))
          end do

        end do
      end do
    end do

    SAFE_DEALLOCATE_A(zpsi)

    POP_SUB(states_get_ob_intf)
  end subroutine states_get_ob_intf


#include "undef.F90"
#include "real.F90"
#include "states_inc.F90"

#include "undef.F90"
#include "complex.F90"
#include "states_inc.F90"
#include "undef.F90"

end module states_m


!! Local Variables:
!! mode: f90
!! coding: utf-8
!! End:
