# -*- coding: utf-8 -*-
# HORTON: Helpful Open-source Research TOol for N-fermion systems.
# Copyright (C) 2011-2016 The HORTON Development Team
#
# This file is part of HORTON.
#
# HORTON is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 3
# of the License, or (at your option) any later version.
#
# HORTON is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, see <http://www.gnu.org/licenses/>
#
# --


import json

import numpy as np
from nose.tools import assert_raises

from horton import *  # pylint: disable=wildcard-import,unused-wildcard-import


def test_gpt_coeff():
    def py_gpt_coeff(k, n0, n1, pa, pb):
        result = 0
        for q in xrange(max(-k, k-2*n0), min(k, 2*n1-k)+1, 2):
            i0 = (k+q)/2
            i1 = (k-q)/2
            assert (k+q)%2 == 0
            assert (k-q)%2 == 0
            result += binom(n0, i0)*binom(n1, i1)*pa**(n0-i0)*pb**(n1-i1)
        return result

    pa = 0.8769
    pb = 0.123
    for k in xrange(5):
        check = py_gpt_coeff(k, 2, 2, pa, pb)
        result = gpt_coeff(k, 2, 2, pa, pb)
        assert abs(check - result) < 1e-10

    for k in xrange(7):
        check = py_gpt_coeff(k, 3, 3, pa, pb)
        result = gpt_coeff(k, 3, 3, pa, pb)
        assert abs(check - result) < 1e-10


def test_gb_overlap_int1d():
    assert abs(gb_overlap_int1d(0, 0, 0.0, 0.0, 1/4.0) - 0.886227) < 1e-5
    assert abs(gb_overlap_int1d(2, 2, 0.0, 0.0, 1/5.0) - 0.023780) < 1e-5
    assert abs(gb_overlap_int1d(2, 2, 0.0, 0.0, 1/5.0) - 0.023780) < 1e-5
    assert abs(gb_overlap_int1d(0, 0, 0.0, 0.0, 1/5.0) - 0.792665) < 1e-5
    assert abs(gb_overlap_int1d(2, 2, 0.0, 0.0, 1/1.0) - 1.329340) < 1e-5
    assert abs(gb_overlap_int1d(2, 2, 0.0, 0.0, 1/1.0) - 1.329340) < 1e-5
    assert abs(gb_overlap_int1d(1, 1, 0.0, 0.0, 1/1.0) - 0.886227) < 1e-5
    assert abs(gb_overlap_int1d(2, 2, 0.0, 0.0, 1/2.0) - 0.234996) < 1e-5
    assert abs(gb_overlap_int1d(2, 2, 0.0, 0.0, 1/2.0) - 0.234996) < 1e-5
    assert abs(gb_overlap_int1d(1, 1, 0.0, 0.0, 1/2.0) - 0.313329) < 1e-5
    assert abs(gb_overlap_int1d(2, 2, 0.0, 0.0, 1/3.0) - 0.085277) < 1e-5
    assert abs(gb_overlap_int1d(2, 2, 0.0, 0.0, 1/3.0) - 0.085277) < 1e-5
    assert abs(gb_overlap_int1d(1, 1, 0.0, 0.0, 1/3.0) - 0.170554) < 1e-5
    assert abs(gb_overlap_int1d(2, 2, 0.0, 0.0, 1/4.0) - 0.041542) < 1e-5
    assert abs(gb_overlap_int1d(2, 2, 0.0, 0.0, 1/4.0) - 0.041542) < 1e-5
    assert abs(gb_overlap_int1d(1, 1, 0.0, 0.0, 1/4.0) - 0.110778) < 1e-5
    assert abs(gb_overlap_int1d(2, 2, 0.0, 0.0, 1/5.0) - 0.023780) < 1e-5

    assert abs(gb_overlap_int1d(0, 0, 0.000000, -0.377945, 1/211400.020633) - 0.003855) < 1e-5
    assert abs(gb_overlap_int1d(0, 0, 0.000000, -0.377945, 1/31660.020633) - 0.009961) < 1e-5
    assert abs(gb_overlap_int1d(0, 0, 0.000001, -0.377944, 1/7202.020633) - 0.020886) < 1e-5
    assert abs(gb_overlap_int1d(0, 0, 0.000004, -0.377941, 1/2040.020633) - 0.039243) < 1e-5
    assert abs(gb_overlap_int1d(0, 0, 0.000012, -0.377934, 1/666.420633) - 0.068660) < 1e-5
    assert abs(gb_overlap_int1d(0, 0, 0.000032, -0.377913, 1/242.020633) - 0.113933) < 1e-5
    assert abs(gb_overlap_int1d(0, 0, 0.000082, -0.377864, 1/95.550633) - 0.181325) < 1e-5
    assert abs(gb_overlap_int1d(0, 0, 0.000194, -0.377751, 1/40.250633) - 0.279376) < 1e-5
    assert abs(gb_overlap_int1d(0, 0, 0.000440, -0.377506, 1/17.740633) - 0.420814) < 1e-5
    assert abs(gb_overlap_int1d(0, 0, 0.000972, -0.376974, 1/8.025633) - 0.625656) < 1e-5

    assert abs(gb_overlap_int1d(0, 0, -0.014528, 0.363418, 1/8.325) - 0.614303) < 1e-5
    assert abs(gb_overlap_int1d(0, 3, 0.014528, -0.363418, 1/8.325) - -0.069710) < 1e-5
    assert abs(gb_overlap_int1d(0, 4, -0.014528, 0.363418, 1/8.325) - 0.046600) < 1e-5
    assert abs(gb_overlap_int1d(0, 3, -0.014528, 0.363418, 1/8.325) - 0.069710) < 1e-5
    assert abs(gb_overlap_int1d(0, 1, 0.014528, -0.363418, 1/8.325) - -0.223249) < 1e-5
    assert abs(gb_overlap_int1d(0, 0, -0.101693, 2.543923, 1/8.325) - 0.614303) < 1e-5
    assert abs(gb_overlap_int1d(0, 2, -0.014528, 0.363418, 1/8.325) - 0.118028) < 1e-5
    assert abs(gb_overlap_int1d(0, 1, -0.014528, 0.363418, 1/8.325) - 0.223249) < 1e-5
    assert abs(gb_overlap_int1d(0, 3, 0.014528, -0.363418, 1/8.325) - -0.069710) < 1e-5


def test_gb2integral_exceptions():
    gb2i = GB2OverlapIntegral(2)
    r = np.random.uniform(-1, 1, 3)
    for st0, st1 in (-3, 0), (3, 0), (0, -3), (0, 3):
        with assert_raises(ValueError):
            gb2i.reset(st0, st1, r, r)


def test_overlap_norm():
    max_shell_type = 3
    gb2i = GB2OverlapIntegral(max_shell_type)
    for shell_type in 0, 1, 2, 3:
        for alpha in np.arange(0.5, 2.51, 0.5):
            scales = np.ones(get_shell_nbasis(shell_type), float)
            r = np.random.uniform(-1, 1, 3)
            gb2i.reset(shell_type, shell_type, r, r)
            gb2i.add(1.0, alpha, alpha, scales, scales)
            size = get_shell_nbasis(shell_type)
            work = gb2i.get_work(size, size)
            diag = np.diag(work)
            indexes = np.array([shell_type, 0, 0])
            counter = 0
            while True:
                check = diag[counter]*gob_cart_normalization(alpha, indexes)**2
                assert abs(check - 1) < 1e-10
                if not iter_pow1_inc(indexes):
                    break
                counter += 1


def test_gb2_overlap_integral_class():
    max_shell_type = 4
    max_nbasis = get_shell_nbasis(max_shell_type)
    r0 = np.array([2.645617, 0.377945, -0.188973])
    r1 = np.array([0.000000, 0.000000, 0.188973])

    gb2i = GB2OverlapIntegral(max_shell_type)
    assert gb2i.max_shell_type == max_shell_type
    assert gb2i.max_nbasis == max_nbasis
    assert gb2i.nwork == max_nbasis**2

    # case 1
    scales0 = np.ones(10, float)
    scales1 = np.ones(15, float)
    gb2i.reset(3, 4, r0, r1)
    work = gb2i.get_work(max_nbasis, max_nbasis)
    assert work.shape == (max_nbasis, max_nbasis)
    assert (work == 0.0).all()
    assert not (work is gb2i.get_work(max_nbasis, max_nbasis))
    gb2i.add(1.0, 5.398, 0.320, scales0, scales1)
    work = gb2i.get_work(10,15)
    #assert abs(work[0, 3] - -0.0135976488493) < 1e-5
    assert abs(work[0, 3] - -0.001256) < 1e-5
    gb2i.add(0.5, 5.398, 0.320, scales0, scales1)
    work = gb2i.get_work(10,15)
    #assert abs(work[0, 3] - -1.5*0.0135976488493) < 1e-5
    assert abs(work[0, 3] - -1.5*0.001256) < 1e-5

    # case 2
    scales0 = np.ones(6, float)
    scales1 = np.ones(1, float)
    gb2i.reset(2, 0, r0, r0)
    gb2i.add(1.0, 0.463, 8.005, scales0, scales1)
    work = gb2i.get_work(6, 1)
    #assert abs(work[3,0] - 0.0193575315113) < 1e-5
    assert abs(work[3,0] - 0.013343) < 1e-5
    # 200 110 101 020 011 022


def test_nuclear_attraction_helper():
    def my_nah(index, n0, n1, pa, pb, cp, gamma_inv):
        # functions taken from the THO paper
        f = [
            gpt_coeff(0, n0, n1, pa, pb),
            gpt_coeff(1, n0, n1, pa, pb),
            gpt_coeff(2, n0, n1, pa, pb),
        ]
        if n0+n1 == 0:
            if index == 0: return 1.0
            else: raise ValueError
        elif n0+n1 == 1:
            if index == 0: return f[0]
            if index == 1: return -cp
            else: raise ValueError
        elif n0+n1 == 2:
            if index == 0: return f[0] + 0.5*f[2]*gamma_inv
            if index == 1: return -f[1]*cp - f[2]*0.5*gamma_inv
            if index == 2: return cp**2
            else: raise ValueError
        else:
            raise NotImplementedError

    for counter in xrange(100):
        pa, pb, cp, gamma_inv = np.random.uniform(0.5, 2.0, 4)
        #
        work_g = np.zeros(1, float)
        nuclear_attraction_helper(work_g, 0, 0, pa, pb, cp, gamma_inv)
        assert abs(work_g[0] - my_nah(0, 0, 0, pa, pb, cp, gamma_inv)) < 1e-10
        #
        work_g = np.zeros(2, float)
        nuclear_attraction_helper(work_g, 0, 1, pa, pb, cp, gamma_inv)
        assert abs(work_g[0] - my_nah(0, 0, 1, pa, pb, cp, gamma_inv)) < 1e-10
        assert abs(work_g[1] - my_nah(1, 0, 1, pa, pb, cp, gamma_inv)) < 1e-10
        #
        work_g = np.zeros(3, float)
        nuclear_attraction_helper(work_g, 1, 1, pa, pb, cp, gamma_inv)
        assert abs(work_g[0] - my_nah(0, 1, 1, pa, pb, cp, gamma_inv)) < 1e-10
        assert abs(work_g[1] - my_nah(1, 1, 1, pa, pb, cp, gamma_inv)) < 1e-10
        assert abs(work_g[2] - my_nah(2, 1, 1, pa, pb, cp, gamma_inv)) < 1e-10


def check_overlap(alphas0, alphas1, r0, r1, scales0, scales1, shell_type0, shell_type1, result0):
    # This test compares output from HORTON with reference data computed with
    # PyQuante.
    max_shell_type = 4
    max_nbasis = get_shell_nbasis(max_shell_type)
    gb2i = GB2OverlapIntegral(max_shell_type)
    assert gb2i.max_nbasis == max_nbasis
    nbasis0 = get_shell_nbasis(shell_type0)
    nbasis1 = get_shell_nbasis(shell_type1)
    assert result0.shape == (nbasis0, nbasis1)
    # Clear the working memory
    gb2i.reset(shell_type0, shell_type1, r0, r1)
    # Add a few cobtributions:
    for alpha0, alpha1 in zip(alphas0, alphas1):
        gb2i.add(1.0, alpha0, alpha1, scales0, scales1)
    result1 = gb2i.get_work(nbasis0, nbasis1)
    assert abs(result1 - result0).max() < 1e-8


def test_overlap_0_0():
    check_overlap(
        np.array([ 0.38155,  1.55654,  1.03605]), np.array([ 0.38563,  1.9267 ,  1.20415]),
        np.array([ 0.18522,  0.43272,  0.10418]), np.array([-0.56975, -0.53701,  0.27943]),
        np.array([ 0.57244]),
        np.array([ 1.81523]),
        0, 0,
        np.array([[ 7.37505911]]))


def test_overlap_1_0():
    check_overlap(
        np.array([ 1.3641 ,  0.79165,  0.96826]), np.array([ 1.47887,  1.69581,  1.47239]),
        np.array([ 0.69486, -0.56009, -0.98666]), np.array([ 0.50462, -0.28959, -0.6542 ]),
        np.array([ 1.61394,  1.94395,  0.70231]),
        np.array([ 0.70378]),
        1, 0,
        np.array([[-0.46483016],
                  [ 0.79608132],
                  [ 0.3534869 ]]))


def test_overlap_1_1():
    check_overlap(
        np.array([ 0.35734,  1.59801,  0.65735]), np.array([ 1.51192,  1.57289,  1.5284 ]),
        np.array([-0.24525, -0.26472,  0.23287]), np.array([ 0.88163, -0.36455, -0.23466]),
        np.array([ 1.8095 ,  1.97681,  1.06664]),
        np.array([ 1.36562,  1.29883,  1.52082]),
        1, 1,
        np.array([[ 0.04779076,  0.12575296,  0.68959185],
                  [ 0.14444486,  1.58823333, -0.06673933],
                  [ 0.3650089 , -0.03075454,  0.8424836 ]]))


def test_overlap_2_0():
    check_overlap(
        np.array([ 0.91255,  1.69396,  0.53409]), np.array([ 1.94475,  0.60338,  1.06963]),
        np.array([ 0.97877, -0.22518,  0.41346]), np.array([ 0.55014, -0.68058,  0.67553]),
        np.array([ 1.86953,  0.89686,  0.9576 ,  0.75579,  1.9237 ,  1.69169]),
        np.array([ 1.48799]),
        2, 0,
        np.array([[ 4.00988365],
                  [ 0.39770136],
                  [-0.24436603],
                  [ 1.66169765],
                  [-0.52156028],
                  [ 3.18632383]]))


def test_overlap_2_1():
    check_overlap(
        np.array([ 0.16658,  0.52428,  0.48323]), np.array([ 1.84574,  1.94547,  0.32043]),
        np.array([ 0.15899, -0.44708,  0.27887]), np.array([-0.2667 ,  0.52481, -0.67718]),
        np.array([ 1.48149,  1.7914 ,  0.74833,  1.8111 ,  1.64518,  0.96205]),
        np.array([ 0.55303,  0.98614,  1.20578]),
        2, 1,
        np.array([[-0.45537915, -3.02546143,  3.63902144],
                  [ 1.53581866, -0.80941594, -0.58059756],
                  [-0.63110848, -0.19835644, -0.42140112],
                  [ 1.12966856,  1.36593798,  5.53168681],
                  [-0.24455507, -1.66947439,  2.1151424 ],
                  [ 0.59537584, -2.42384453, -0.89574395]]))


def test_overlap_2_2():
    check_overlap(
        np.array([ 1.8772 ,  0.66306,  1.62698]), np.array([ 1.48297,  0.80119,  1.89849]),
        np.array([ 0.58757, -0.44683,  0.92865]), np.array([-0.53945, -0.60169,  0.20352]),
        np.array([ 0.96814,  1.78969,  1.96564,  1.39243,  1.78382,  1.56938]),
        np.array([ 0.93523,  1.76753,  0.64758,  1.64208,  1.31197,  0.99962]),
        2, 2,
        np.array([[ 0.37405284, -0.0053374 , -0.00915657,  0.68544647,  0.04045487,  0.55501502],
                  [ 0.00958391,  0.01848274,  0.00061483, -0.06141207, -0.28364134,  0.07593485],
                  [ 0.0492885 ,  0.00184313,  0.01045512,  0.47947731, -0.03888278, -0.09362596],
                  [ 0.47923418, -0.07319961,  0.09660233,  1.33667794, -0.03495828,  0.37640877],
                  [ 0.06262875, -0.38087778, -0.01741702, -0.03938322,  0.27117485, -0.01167484],
                  [ 0.78637264,  0.09320441, -0.09051771,  0.76244565, -0.02519833,  0.73705081]]))


def test_overlap_3_0():
    check_overlap(
        np.array([ 0.20951,  0.28247,  1.21724]), np.array([ 1.90283,  0.19887,  1.34656]),
        np.array([ 0.17818, -0.86045,  0.10812]), np.array([-0.25553, -0.10358, -0.61739]),
        np.array([ 0.71955,  0.89878,  1.32627,  0.58463,  1.06506,  1.1792 ,  1.13933,  0.7315 ,  1.63897,  1.07301]),
        np.array([ 1.09825]),
        3, 0,
        np.array([[ -6.84412275],
                  [  5.1732252 ],
                  [ -7.3174884 ],
                  [ -2.15734506],
                  [  0.59561345],
                  [ -4.29554216],
                  [ 19.69066116],
                  [ -4.51540609],
                  [ 10.41893612],
                  [-17.69112812]]))


def test_overlap_3_1():
    check_overlap(
        np.array([ 0.15121,  1.63537,  1.85481]), np.array([ 1.2291 ,  1.5473 ,  0.51975]),
        np.array([-0.05191, -0.84862,  0.44892]), np.array([ 0.12233,  0.5476 ,  0.05299]),
        np.array([ 1.50846,  1.88884,  1.30416,  1.77135,  0.92562,  0.55497,  0.99776,  1.43107,  1.57878,  1.7836 ]),
        np.array([ 1.19057,  0.83325,  0.83198]),
        3, 1,
        np.array([[ 2.10014988, -0.11304073,  0.03200642],
                  [ 0.75386786,  0.22795289,  0.11082966],
                  [-0.1476031 ,  0.07663976,  0.40530456],
                  [ 3.84197477,  0.34663155,  0.05492151],
                  [-0.45461677, -0.01775139,  0.13471654],
                  [ 0.31978933, -0.01740981, -0.04302698],
                  [-0.22607852,  3.31678897,  0.35899451],
                  [ 0.06349521, -0.63634866,  2.08774836],
                  [-0.07076617,  0.24600963, -0.98084179],
                  [ 0.05813693,  0.3260457 ,  2.08322125]]))


def test_overlap_3_2():
    check_overlap(
        np.array([ 1.20011,  0.33878,  1.51811]), np.array([ 0.3459 ,  1.96178,  1.7336 ]),
        np.array([-0.29858, -0.7967 ,  0.39854]), np.array([-0.90944, -0.64625, -0.73663]),
        np.array([ 0.54737,  1.43055,  0.65339,  0.74062,  0.80494,  1.02926,  0.68909,  1.70993,  1.87537,  1.07788]),
        np.array([ 1.06983,  1.0181 ,  1.308  ,  0.59744,  0.97754,  1.73034]),
        3, 2,
        np.array([[ 0.04997506, -0.03609568,  0.34989787, -0.06339275,  0.0155446 , -0.38756342],
                  [ 0.06280297,  0.0060115 , -0.00013651, -0.02756347,  0.27915505,  0.10056387],
                  [-0.21643043, -0.00004853,  0.00398969, -0.06289051, -0.00458798,  0.11069749],
                  [ 0.09594925,  0.00073215,  0.13957776, -0.06967677, -0.00094373, -0.15983133],
                  [ 0.00495924, -0.06123233,  0.00272798,  0.00217767, -0.0095553 ,  0.01249752],
                  [ 0.08633778, -0.02767254,  0.0644089 , -0.10327526,  0.00316029, -0.38169709],
                  [ 0.03991552,  0.20348681,  0.02507555, -0.01121691,  0.3630777 ,  0.10835293],
                  [-0.25403295, -0.00226927,  0.06159929, -0.29894421, -0.0012535 ,  0.39095044],
                  [ 0.09695665,  0.25868317,  0.00770482, -0.02943601,  0.10966853,  0.17128969],
                  [-0.73142258,  0.03540743,  0.23754038, -0.32921368, -0.04372348, -0.01565985]]))


def test_overlap_3_3():
    check_overlap(
        np.array([ 0.61505,  1.89932,  0.88091]), np.array([ 1.73721,  0.84781,  0.2702 ]),
        np.array([ 0.13589, -0.64805, -0.00985]), np.array([-0.87779, -0.3078 ,  0.4045 ]),
        np.array([ 0.87562,  0.94676,  1.00714,  1.64365,  0.50092,  1.70516,  1.04766,  0.92249,  0.70682,  0.5254 ]),
        np.array([ 0.97169,  1.39497,  1.30833,  0.76837,  1.11379,  1.57636,  1.54807,  0.73938,  0.69   ,  1.18188]),
        3, 3,
        np.array([[ 4.15859105, -0.40861149, -0.46669399,  0.48447216,  0.11002233,  1.0556848 ,  0.61893723,  0.13137188,  0.1069884 ,  0.58861062],
                  [ 0.23832739,  1.10191452, -0.06694614, -0.08428443, -0.08285553,  0.03197369,  1.69516343,  0.07072293,  0.26634209, -0.08342378],
                  [ 0.30874014, -0.07593169,  1.07114098,  0.01867899, -0.07057921, -0.22090277, -0.11375454,  0.27791106,  0.06910082,  1.4302854 ],
                  [ 2.45406222, -0.16904268, -0.4388101 ,  0.60911125,  0.02232811,  0.52328746,  0.69429659,  0.19405029,  0.00957155,  0.34738557],
                  [ 0.0147559 ,  0.03819706,  0.02877506, -0.00588988,  0.07985295, -0.01191429, -0.05963298,  0.02256946,  0.02584141, -0.03809331],
                  [ 2.56582207, -0.4009092 , -0.20296195,  0.25032315,  0.02357056,  1.29331229,  0.38446027,  0.01251804,  0.15373757,  0.66630371],
                  [ 0.59350333,  2.95315879, -0.17682239, -0.32601776, -0.5561459 ,  0.27100409,  7.24017008,  0.12823454,  0.75868102, -0.21839837],
                  [ 0.21620374, -0.01288184,  0.79126583,  0.1131799 , -0.04714846, -0.38743476, -0.15928094,  0.55088509,  0.0233582 ,  1.12461804],
                  [ 0.13788668,  0.66969011, -0.00975483, -0.12147174, -0.04603221,  0.14558949,  1.09409377,  0.01973007,  0.40312736, -0.09270714],
                  [ 0.36414052, -0.09494291,  1.36719578,  0.076182  , -0.22519195, -0.40464075, -0.14088458,  0.37745611,  0.05945119,  2.79182182]]))


def test_overlap_4_0():
    check_overlap(
        np.array([ 1.33264,  1.93806,  1.93727]), np.array([ 1.95318,  0.55839,  0.92698]),
        np.array([ 0.29297, -0.59716,  0.32057]), np.array([ 0.62773,  0.09809, -0.56828]),
        np.array([ 1.74264,  1.93814,  0.5884 ,  0.78401,  1.5155 ,  0.85113,  0.60545,  0.60674,  1.12304,  1.02276,  1.27453,  1.2872 ,  1.51933,  1.10252,  1.68958]),
        np.array([ 1.1598]),
        4, 0,
        np.array([[ 0.36764944],
                  [ 0.05504003],
                  [-0.02136258],
                  [ 0.06849119],
                  [-0.04183808],
                  [ 0.08601565],
                  [ 0.01999582],
                  [-0.01165504],
                  [ 0.02116684],
                  [-0.04818207],
                  [ 0.39378848],
                  [-0.11287608],
                  [ 0.19922923],
                  [-0.10787124],
                  [ 0.6775918 ]]))


def test_overlap_4_1():
    check_overlap(
        np.array([ 1.76653,  0.22375,  0.12324]), np.array([ 1.85803,  1.4349 ,  0.93092]),
        np.array([ 0.59731, -0.13831, -0.78078]), np.array([-0.48602, -0.93713, -0.80125]),
        np.array([ 1.22477,  1.9374 ,  1.4286 ,  1.43287,  1.80204,  1.16563,  1.09655,  0.92286,  0.99676,  1.19032,  1.90032,  1.7857 ,  0.85302,  0.77905,  0.64177]),
        np.array([ 0.64838,  1.68203,  1.74876]),
        4, 1,
        np.array([[-14.8625165 ,   4.46965981,   0.11908028],
                  [ -7.6491756 , -14.83952385,   0.07473575],
                  [ -0.14453562,   0.05300578, -13.52562103],
                  [ -3.3080767 ,  -9.91756325,   0.04641387],
                  [ -0.05627861,  -0.15768523,  -7.6219426 ],
                  [ -1.27143428,   0.65549744,  -0.23523715],
                  [ -1.63774007, -11.27900516,   0.0342184 ],
                  [ -0.01812762,  -0.08131027,  -3.66390153],
                  [ -0.3625845 ,  -1.51048321,  -0.10005337],
                  [ -0.03327069,   0.02567911,  -6.66779844],
                  [  2.01530868, -37.6707845 ,   0.10270681],
                  [  0.02066042,  -0.34706269, -10.14125911],
                  [  0.17329651,  -1.92805616,  -0.11916524],
                  [  0.00647855,  -0.06688925,  -3.21789763],
                  [  0.18290074,   0.34987084,  -0.26228858]]))


def test_overlap_4_2():
    check_overlap(
        np.array([ 1.81625,  0.42484,  1.78998]), np.array([ 0.87881,  1.6645 ,  0.62636]),
        np.array([-0.09577,  0.43432,  0.30419]), np.array([-0.39053, -0.41847,  0.0288 ]),
        np.array([ 1.3704 ,  0.72173,  0.68438,  0.94581,  0.91228,  1.87348,  0.52649,  0.65946,  0.61436,  1.88314,  0.64125,  1.91603,  1.94277,  1.61052,  1.59068]),
        np.array([ 1.85355,  0.65632,  1.9032 ,  1.96359,  1.10161,  1.86939]),
        4, 2,
        np.array([[ 1.32805811, -0.04613372, -0.04320096,  0.58551373,  0.04457861,  0.34759665],
                  [ 0.11276588,  0.02042907, -0.02682566,  0.02262235, -0.00200546,  0.06185421],
                  [ 0.03453078, -0.00877211,  0.12673123,  0.02491492, -0.01460948,  0.00181613],
                  [ 0.37844919,  0.01703868, -0.00756702,  0.21371166, -0.00032273,  0.15894243],
                  [ 0.06573236,  0.00201802,  0.03335586,  0.01076307,  0.01448151,  0.00366559],
                  [ 0.37521478, -0.0052174 ,  0.0417082 ,  0.266232  , -0.00860362,  0.3851452 ],
                  [ 0.01354328,  0.04480448, -0.02351642,  0.08535673, -0.00471426,  0.06669629],
                  [ 0.00460124,  0.00530731,  0.07977768,  0.01892326,  0.00960536,  0.00459219],
                  [ 0.00251714,  0.00564864,  0.02093618,  0.00750272,  0.00229216,  0.04777277],
                  [ 0.00499544, -0.02203157,  0.34241384,  0.06804191, -0.04310194,  0.09592031],
                  [ 0.50827674, -0.02143765,  0.02259588,  1.02481778, -0.03361778,  0.50960077],
                  [ 0.22567413, -0.01022147, -0.09250465,  0.29022148,  0.27601592,  0.04561346],
                  [ 0.31981046, -0.00029871, -0.01568837,  0.4313964 ,  0.05507843,  0.78424487],
                  [ 0.12735302, -0.00264882, -0.06368496,  0.04653848,  0.07559615,  0.23733819],
                  [ 0.39226618,  0.03258814, -0.05009992,  0.66393196, -0.08389847,  1.5393678 ]]))


def test_overlap_4_3():
    check_overlap(
        np.array([ 0.36707,  0.71034,  0.35369]), np.array([ 1.52363,  1.55468,  0.2221 ]),
        np.array([-0.93555,  0.50321, -0.46688]), np.array([ 0.62248,  0.98146,  0.5302 ]),
        np.array([ 0.86788,  1.64811,  0.69014,  1.94977,  1.56664,  0.64777,  1.84159,  1.27696,  1.02837,  0.86408,  1.381  ,  1.71595,  1.03498,  1.98586,  1.04105]),
        np.array([ 0.95256,  1.06661,  1.8854 ,  1.77207,  1.23133,  0.60978,  0.89928,  0.77363,  1.8449 ,  1.90911]),
        4, 3,
        np.array([[ -23.35054619,  -14.64710204,  -53.97896897,   24.40319684,    3.15490521,   10.90521115,  -20.7156497 ,  -13.17357764,  -19.59728824, -101.45798541],
                  [  16.46167958,   -2.33042958,    0.48218883,  -11.247298  ,  -11.7691257 ,    2.71959188,   48.61433368,    3.56956768,   43.04797077,  -15.18179023],
                  [  14.37143634,    0.1142274 ,   -1.40086936,    5.31377017,   -1.835243  ,   -2.8330638 ,   -2.70950356,    4.48851543,    2.97968989,   46.51000352],
                  [ -37.43418335,    1.63051563,  -26.93045037,   -7.73052765,    0.08158491,   -1.36935888,  -18.53947922,  -19.3001896 ,    2.63693345,  -58.83127508],
                  [  -2.34551703,    7.00884266,    4.5874329 ,    0.37583369,   -1.02014207,    0.10896179,   12.41212122,   -2.43030574,  -12.98139482,   13.52097277],
                  [ -13.99347083,   -2.74474081,    0.82822427,   -1.13526391,    0.01010153,   -0.78990337,   -4.52615655,    0.25690086,   -6.57786251,  -24.86036948],
                  [  17.30703576,  -19.66444697,    4.65805441,   -0.68608558,   -6.94224197,    1.52608824,   74.09284612,    0.2698057 ,   43.75198404,  -14.7933046 ],
                  [   8.54902249,   -0.08772057,   -6.25776106,    4.2614252 ,    0.1224609 ,   -0.93972887,   -2.22444903,    6.36918665,   -0.29969637,   26.09690375],
                  [   3.70778095,   -4.08870657,   -0.03135636,   -1.41442273,    0.09893209,    0.50598813,   10.41081975,   -0.04587527,   14.15889481,   -3.4831098 ],
                  [  17.640422  ,    1.28087764,  -14.68186   ,    3.45051844,   -1.41201897,   -0.24804164,   -3.10743482,    6.03176982,    0.33992194,   67.25730374],
                  [ -83.96861477,   19.99273697,  -53.88334134, -172.65772694,    7.5577065 ,  -18.9823001 ,  -12.94719451,  -48.23837147,   24.1335551 ,  -91.17059692],
                  [  -8.11565395,   21.6191248 ,   13.80912443,    0.57585134,  -21.89784237,    2.19626236,   44.18167207,   -0.51103685,  -31.5871051 ,   31.45204663],
                  [ -23.69834124,    0.90945428,    1.16638359,  -30.24182003,   -0.03673576,   -8.26571321,   -8.17772371,    0.64930327,    1.66131007,  -29.37678972],
                  [  -9.86645423,   25.41633554,   19.30186526,    9.17270375,  -29.57497867,    0.25821069,   34.61936665,   -7.91484336,   -4.0299597 ,   47.44749525],
                  [ -79.50480614,  -13.8365909 ,   49.56559163,  -40.12379873,    5.07719789,  -40.08498025,  -17.62392674,   10.92515487,  -37.22721022,  -40.4914588 ]]))


def test_overlap_4_4():
    check_overlap(
        np.array([ 1.49743,  0.28887,  1.30672]), np.array([ 1.5195 ,  0.12034,  1.33652]),
        np.array([-0.76864,  0.21477, -0.87085]), np.array([ 0.30839, -0.47598, -0.31725]),
        np.array([ 1.3229 ,  1.25303,  1.00305,  1.48236,  1.84543,  1.90289,  0.8467 ,  1.11986,  1.73924,  0.53723,  0.87138,  1.44704,  1.64225,  1.97351,  1.37167]),
        np.array([ 1.90328,  1.52245,  1.36394,  1.0657 ,  0.85058,  0.70004,  0.70973,  1.64782,  1.61763,  1.54071,  0.56068,  1.86628,  1.4608 ,  1.38572,  0.56277]),
        4, 4,
        np.array([[ 10860.93874874,   -452.32599228,    324.77228431,    899.09223991,    -93.68752283,    556.18615513,     59.87469175,    -41.66304979,     48.0582979 ,   -101.90066974,    438.82233443,   -173.07090079,    365.36199859,   -125.70452051,    394.38259926],
                  [   475.2771819 ,   1035.09439807,     65.56562056,    -51.42318671,     20.11984948,     10.53552285,    313.64297985,    -51.31692065,    222.27364818,     41.57845445,    121.32573759,    -97.59579562,     48.44847787,    -62.58436268,    -17.3352106 ],
                  [  -304.91814246,     58.58481486,    765.74752471,    -10.92616379,    -20.72984408,     22.08926511,     15.67361047,    198.5377844 ,    -41.10492632,    526.58416418,     12.37591298,     88.78432215,    -33.64414605,     69.93075464,    -77.45449685],
                  [  2199.59949491,    -33.5476083 ,    124.14423181,    518.58703461,     -4.44827034,    136.11229179,    -20.17647609,     17.30106581,     -1.64238931,     17.95648562,    600.53073365,   -118.54871265,    308.43629497,     -8.7327609 ,    141.26801918],
                  [   -71.70703077,    -63.88214282,     73.66368941,     21.65014537,    154.20905777,     14.49632897,     -4.41512691,    -13.14804504,     10.22154124,     11.55457128,     32.24873739,    472.15886478,    -42.07293154,    328.34912518,     32.09031457],
                  [  2790.60394171,   -219.35762785,     29.05788034,    279.14867697,     -5.36600661,    441.86367078,    -13.38540454,      1.71720346,    -27.48837382,     45.52045725,    199.41748731,    -14.50149471,    424.83281568,   -114.14422254,    747.43305291],
                  [   282.36623922,    552.58118862,     34.17329897,    -33.46378077,     27.86578869,     14.2478907 ,    288.39135652,    -15.40385753,    141.99723498,     25.93005542,     97.44726363,    -95.68804905,     15.50381178,    -42.63710347,    -11.52847906],
                  [  -101.99288532,      3.33568027,    224.70852502,    -19.60892488,     -3.06253165,     16.81702703,     10.52690609,    164.32555779,     -2.80099854,    184.91666273,     17.69689654,     63.58385024,    -29.69265992,      5.07943112,    -29.00478529],
                  [   195.33619167,    373.22132967,      4.36166091,    -48.10367836,      3.47286376,     25.22196873,    135.33126588,     -4.08568295,    245.4130332 ,     35.8525322 ,     55.82388002,     -8.2817618 ,     57.04013653,    -57.54634631,    -33.2509476 ],
                  [  -143.01608249,     24.106398  ,    319.86768096,    -11.66557203,    -22.4655688 ,     11.28594366,      7.71688018,     99.24362198,     -9.68722427,    398.45761849,      6.49725333,     47.32706642,     -8.452195  ,     56.42345531,    -50.11075747],
                  [  1294.93125402,    104.07861972,    129.28774282,    784.41910294,     12.67812788,    130.19664906,   -145.83644757,    200.22393936,     35.81998833,    131.40252247,   2082.44694422,   -197.11423855,    821.14664281,     43.81010174,    237.82033605],
                  [   -55.1610876 ,    -88.97845702,    100.17157688,     13.00930011,    197.27217335,     18.10664837,    -62.19064165,    -60.46253073,    100.04826349,    110.40604806,     85.60298441,   1530.13398283,    -44.50790212,    739.31491352,     70.53772374],
                  [   804.56395952,    -21.70147325,     14.81385235,    308.19358507,     -0.48807095,    206.9843575 ,    -91.74326344,     14.51240206,    -19.11372082,    163.14773703,    628.15471113,    -22.87790567,    825.53145046,    -18.25969718,    616.23858118],
                  [   -74.92617811,   -117.99902178,    138.13925207,     36.73285767,    265.5961531 ,     11.7668321 ,    -57.30764656,   -173.26469261,     63.49957629,    230.45395691,     96.29166324,   1431.32187793,    -60.12636755,   1506.73280576,    118.06009694],
                  [  1991.81957825,   -276.96615132,   -121.00405373,    323.73524702,     20.52904741,    825.80781981,   -118.7596791 ,    -50.28343635,   -393.05518418,    404.29891122,    407.02620486,     97.66902465,   1397.44765615,   -233.21165378,   3279.21648727]]))


def check_kinetic(alphas0, alphas1, r0, r1, scales0, scales1, shell_type0, shell_type1, result0):
    # This test compares output from HORTON with reference data computed with
    # PyQuante.
    max_shell_type = 4
    max_nbasis = get_shell_nbasis(max_shell_type)
    gb2i = GB2KineticIntegral(max_shell_type)
    assert gb2i.max_nbasis == max_nbasis

    nbasis0 = get_shell_nbasis(shell_type0)
    nbasis1 = get_shell_nbasis(shell_type1)
    assert result0.shape == (nbasis0, nbasis1)
    # Clear the working memory
    gb2i.reset(shell_type0, shell_type1, r0, r1)
    # Add a few cobtributions:
    for alpha0, alpha1 in zip(alphas0, alphas1):
        gb2i.add(1.0, alpha0, alpha1, scales0, scales1)
    result1 = gb2i.get_work(nbasis0, nbasis1)
    assert abs(result1 - result0).max() < 1e-8


def test_kinetic_0_0():
    check_kinetic(
        np.array([ 1.06724,  0.56887,  0.93402]), np.array([ 1.08863,  0.56363,  1.28424]),
        np.array([ 0.31645,  0.88108,  0.52757]), np.array([-0.44259,  0.84872, -0.01053]),
        np.array([ 0.83689]),
        np.array([ 0.55898]),
        0, 0,
        np.array([[ 2.32594819]]))


def test_kinetic_1_0():
    check_kinetic(
        np.array([ 1.85746,  1.90999,  0.25927]), np.array([ 0.63002,  1.96671,  0.24987]),
        np.array([ 0.9007 , -0.59203, -0.53519]), np.array([ 0.95855, -0.8384 , -0.09042]),
        np.array([ 1.17263,  0.77808,  1.26641]),
        np.array([ 0.59591]),
        1, 0,
        np.array([[ 0.26388818],
                  [-0.74570602],
                  [ 2.19111611]]))


def test_kinetic_1_1():
    check_kinetic(
        np.array([ 1.04297,  1.23914,  1.07897]), np.array([ 1.5801 ,  0.49409,  0.50955]),
        np.array([ 0.17893,  0.37431,  0.38481]), np.array([-0.82787,  0.94762,  0.18045]),
        np.array([ 1.30039,  1.47598,  1.24465]),
        np.array([ 0.70985,  1.77583,  1.44967]),
        1, 1,
        np.array([[-0.28397831,  2.4907056 , -0.72476387],
                  [ 1.13004164,  2.54841539,  0.46843539],
                  [-0.33967814,  0.48389229,  2.7216765 ]]))


def test_kinetic_2_0():
    check_kinetic(
        np.array([ 1.17521,  0.65934,  1.65479]), np.array([ 0.41698,  0.11244,  0.85642]),
        np.array([ 0.72241,  0.89753, -0.00851]), np.array([ 0.59045,  0.12826,  0.3632 ]),
        np.array([ 1.49595,  0.99344,  1.12762,  1.36581,  1.17942,  1.88005]),
        np.array([ 0.9147]),
        2, 0,
        np.array([[ 2.3471677 ],
                  [ 0.07211026],
                  [-0.03954974],
                  [ 2.7039091 ],
                  [-0.24114921],
                  [ 3.11215989]]))


def test_kinetic_2_1():
    check_kinetic(
        np.array([ 1.36516,  1.17638,  1.55071]), np.array([ 1.93595,  0.44769,  0.90763]),
        np.array([ 0.18404, -0.68828,  0.98654]), np.array([-0.93064, -0.72908,  0.55764]),
        np.array([ 1.16665,  0.57554,  1.10899,  1.73997,  1.46346,  0.71421]),
        np.array([ 0.93698,  0.56054,  0.9824 ]),
        2, 1,
        np.array([[ 0.38930604,  0.03676936,  0.67742892],
                  [ 0.00753004, -0.18992091,  0.00773599],
                  [ 0.15252653,  0.00850524, -0.48608683],
                  [ 0.92267705, -0.02192089,  0.3722319 ],
                  [ 0.01876132, -0.18581636, -0.02347886],
                  [ 0.47411344,  0.0103817 , -0.12729666]]))


def test_kinetic_2_2():
    check_kinetic(
        np.array([ 1.89979,  1.76424,  1.9046 ]), np.array([ 0.62318,  1.06094,  1.08178]),
        np.array([-0.04399, -0.55893, -0.27125]), np.array([ 0.75855, -0.69409,  0.11446]),
        np.array([ 0.84952,  1.68385,  1.38139,  0.51791,  1.84538,  1.15837]),
        np.array([ 1.20277,  1.33187,  0.90398,  1.40662,  1.00062,  1.51356]),
        2, 2,
        np.array([[-0.03119284, -0.01743525,  0.03377053, -0.11678781, -0.04301438,  0.03720988],
                  [ 0.08141835, -0.00770716,  0.00068113,  0.1964952 , -0.21007486, -0.04035534],
                  [-0.19061098,  0.00082328, -0.00569026,  0.03835642,  0.04806574, -0.44178996],
                  [ 0.13897445,  0.01886571,  0.0835409 ,  0.26167285,  0.006812  , -0.04892507],
                  [-0.05299278, -0.30644261,  0.05800892,  0.10349719,  0.36203394,  0.09939553],
                  [ 0.39410443, -0.11022501, -0.05508249, -0.21071154,  0.01026846,  0.41588123]]))


def test_kinetic_3_0():
    check_kinetic(
        np.array([ 0.24487,  1.2973 ,  0.78529]), np.array([ 1.08555,  0.20016,  1.09331]),
        np.array([ 0.81851, -0.4641 ,  0.45828]), np.array([ 0.34426,  0.32795, -0.58748]),
        np.array([ 1.60856,  1.83875,  1.01534,  1.39612,  1.13447,  1.29609,  1.69782,  1.94584,  1.91287,  1.80643]),
        np.array([ 1.60581]),
        3, 0,
        np.array([[ 1.45712499],
                  [ 0.04786773],
                  [-0.03489885],
                  [-1.21160003],
                  [ 1.98999625],
                  [-2.40460791],
                  [-0.15201511],
                  [-3.72364559],
                  [ 5.92706686],
                  [-3.71976609]]))


def test_kinetic_3_1():
    check_kinetic(
        np.array([ 1.3859 ,  1.45007,  0.72838]), np.array([ 1.92734,  0.62468,  1.02123]),
        np.array([-0.20822,  0.25475,  0.35607]), np.array([ 0.04812,  0.59736, -0.93359]),
        np.array([ 1.68254,  1.26832,  1.46192,  0.61266,  0.72222,  0.78278,  0.92126,  0.78023,  0.9645 ,  1.00688]),
        np.array([ 0.99839,  1.29352,  0.55317]),
        3, 1,
        np.array([[ 0.46027255, -0.10779032,  0.17351635],
                  [ 0.07607686,  0.12934742,  0.06561057],
                  [-0.33008252,  0.17684106, -0.20082226],
                  [ 0.06794769,  0.04566878,  0.02684077],
                  [-0.13137715, -0.11873031, -0.01581792],
                  [ 0.58498744, -0.09558014,  0.02563804],
                  [-0.04780901,  0.37061864,  0.13326834],
                  [ 0.06169353, -0.29260474, -0.11723323],
                  [-0.09089861,  0.86457147,  0.04222125],
                  [ 0.4742713 ,  0.82126521,  0.04670602]]))


def test_kinetic_3_2():
    check_kinetic(
        np.array([ 1.48542,  0.79169,  1.34211]), np.array([ 1.24285,  0.28268,  1.87086]),
        np.array([ 0.6894 ,  0.55164, -0.28557]), np.array([-0.72617, -0.16973, -0.32865]),
        np.array([ 1.41548,  0.79948,  1.6947 ,  1.68438,  1.52736,  1.60386,  1.43968,  0.98679,  1.2037 ,  1.76271]),
        np.array([ 0.977  ,  0.6379 ,  1.33194,  1.71525,  0.71392,  1.81895]),
        3, 2,
        np.array([[ 2.52788123,  0.1282514 ,  0.01599233, -1.46868389, -0.05494013,  0.77808606],
                  [-0.16895599,  0.12266631, -0.01106863,  1.47903372,  0.01414347,  0.02905072],
                  [-0.02138831, -0.01123691,  0.9344087 , -0.0359525 ,  0.77676525,  0.24035566],
                  [ 1.22661584, -0.01659763, -0.01606463, -1.38557896, -0.00103593,  0.32945928],
                  [ 0.00300844,  0.00069503,  0.10595709,  0.03025624, -0.31248034,  0.04134263],
                  [ 1.11526834, -0.08160655,  0.00315003, -0.27383354,  0.00282892, -2.08153282],
                  [-1.80465494,  1.71328638, -0.0859725 ,  2.91256472,  0.05835408,  0.49447633],
                  [-0.02930605, -0.00054228,  1.14489772, -0.0247036 ,  0.00904122,  0.10115098],
                  [-0.45563978,  0.3871701 ,  0.00396102,  1.44404904, -0.00181716, -0.79608931],
                  [-0.11945322, -0.04666628,  5.17828171, -0.02742448,  1.41441641,  0.22148981]]))


def test_kinetic_3_3():
    check_kinetic(
        np.array([ 0.24381,  0.23783,  0.99507]), np.array([ 1.50572,  1.85677,  1.2733 ]),
        np.array([ 0.72564,  0.84368, -0.36619]), np.array([ 0.94312, -0.40269, -0.38761]),
        np.array([ 1.92006,  0.69307,  1.83838,  1.95554,  1.86517,  1.87671,  0.52899,  1.96087,  1.9599 ,  0.59324]),
        np.array([ 0.92372,  0.73361,  1.5476 ,  1.33584,  1.70224,  0.67454,  1.32857,  1.99426,  1.54813,  0.75661]),
        3, 3,
        np.array([[-0.36336191,  0.1451802 ,  0.00526348, -0.6004123 ,  0.0039414 , -0.39403405, -0.09995606,  0.00059715, -0.0716949 , -0.00180703],
                  [-0.23771798, -0.04700782, -0.00429495, -0.12666777, -0.00021514, -0.08073866, -0.17731161,  0.00041064, -0.07101734,  0.00099062],
                  [-0.01083659, -0.00540036,  0.39965712, -0.00969768,  0.14540191, -0.0016994 ,  0.00167653, -0.28417365,  0.00426847, -0.42898238],
                  [ 1.91328778, -0.09791611, -0.00008584,  0.63998663, -0.00420059,  0.45720829, -0.70166359,  0.00345994, -0.28442249,  0.00551906],
                  [ 0.04159644, -0.00046273, -0.09611402,  0.01590596, -0.04461169,  0.00395212, -0.0022684 , -0.24876613, -0.00100919, -0.35194648],
                  [-0.59847537, -0.02102334, -0.00408665, -0.24605568, -0.01439029,  0.17174458, -0.03994016, -0.01504648,  0.1220779 , -0.00909649],
                  [ 0.08135957,  0.29846857, -0.00461913,  0.04747975, -0.00040337,  0.01949473,  1.33956911, -0.00698129,  0.62616674, -0.006459  ],
                  [ 0.00681928,  0.01332067,  1.10970391,  0.00232393,  0.04299804, -0.00008343,  0.06929648,  0.99291596,  0.02038613,  1.59835482],
                  [-0.04745016, -0.09741801,  0.05285409, -0.0132973 , -0.00064616,  0.05554872, -0.52537749,  0.05346241, -0.29674912,  0.05444251],
                  [-0.00074187,  0.00109813, -0.32111421,  0.00006704, -0.01021031,  0.00072821,  0.00374424, -0.34732652, -0.00957819, -0.12308774]]))


def test_kinetic_4_0():
    check_kinetic(
        np.array([ 1.13597,  1.35446,  1.28955]), np.array([ 1.57931,  0.42536,  0.24026]),
        np.array([ 0.69848,  0.18202,  0.08076]), np.array([-0.12811,  0.45661, -0.68514]),
        np.array([ 1.12509,  0.79513,  1.79369,  1.61313,  1.02925,  1.00163,  1.65755,  0.99287,  1.09147,  1.64997,  0.9976 ,  1.87067,  1.6571 ,  0.60683,  1.79732]),
        np.array([ 1.51279]),
        4, 0,
        np.array([[ 1.45735285],
                  [-0.16598107],
                  [ 1.04437188],
                  [ 0.42741402],
                  [-0.1224337 ],
                  [ 0.55500839],
                  [-0.21595088],
                  [ 0.13823637],
                  [-0.12650628],
                  [ 0.90327565],
                  [ 0.38644339],
                  [-0.22582262],
                  [ 0.40384259],
                  [-0.11035852],
                  [ 2.01813692]]))


def test_kinetic_4_1():
    check_kinetic(
        np.array([ 0.48846,  0.11109,  0.96246]), np.array([ 0.21453,  0.75976,  1.40583]),
        np.array([-0.56211,  0.92747,  0.95125]), np.array([-0.08782,  0.22151,  0.63397]),
        np.array([ 1.64348,  1.84322,  1.66204,  1.08815,  1.82327,  1.81657,  1.82976,  1.56129,  0.76587,  1.13837,  0.7917 ,  1.30602,  1.56535,  1.01148,  1.08262]),
        np.array([ 1.82595,  0.78174,  0.60821]),
        4, 1,
        np.array([[-0.45681214,  2.44328001,  0.85433277],
                  [-3.95061972, -0.05074994, -0.1595038 ],
                  [-1.60100293, -0.1848605 ,  0.21977392],
                  [ 1.77173617, -0.12845114,  0.23002238],
                  [ 1.94377444, -0.16132923, -0.46540169],
                  [-0.48094063,  0.85985151, -0.15832253],
                  [-2.50714443,  2.48185656, -0.19096334],
                  [-1.16791942,  0.67683007,  0.63044617],
                  [-0.1675691 ,  0.01152346,  0.2781729 ],
                  [ 0.03859379, -0.11475725,  0.23176439],
                  [-2.40086128, -0.9627868 ,  0.53497129],
                  [-0.40920508, -1.1850348 , -0.70864675],
                  [-1.33780046,  0.37566839, -0.60071929],
                  [-0.23816643,  0.09794847, -0.3065185 ],
                  [-2.22815282,  1.41988952,  0.31720508]]))


def test_kinetic_4_2():
    check_kinetic(
        np.array([ 0.48502,  1.44324,  0.82621]), np.array([ 1.82098,  0.2687 ,  1.8275 ]),
        np.array([ 0.11711,  0.79594,  0.18567]), np.array([-0.74717,  0.94565, -0.51744]),
        np.array([ 1.7328 ,  1.58101,  0.69871,  0.83793,  0.93788,  1.64038,  1.38135,  1.47861,  1.66126,  1.66495,  0.58413,  1.01951,  1.47233,  1.53884,  1.32106]),
        np.array([ 0.68468,  1.88328,  1.63165,  0.62252,  1.39735,  0.56771]),
        4, 2,
        np.array([[ 0.92435227,  0.18051621, -0.73451521, -0.38295241, -0.18070661, -0.02006706],
                  [-0.05546955,  1.07017098,  0.095939  ,  0.04001782, -0.37306699, -0.04476061],
                  [ 0.11513025,  0.04893791,  0.21965911,  0.06714136,  0.01850592, -0.04407295],
                  [-0.0449795 , -0.08174348,  0.05480114,  0.10160832,  0.02500517, -0.04470255],
                  [-0.01627476,  0.17513955, -0.01556046,  0.0044475 ,  0.12926124, -0.0010177 ],
                  [ 0.03957011,  0.00744553,  0.51356254, -0.03306082,  0.02710952,  0.10021321],
                  [ 0.03230652,  0.19609489,  0.01170556, -0.07620247, -0.18182412, -0.00014996],
                  [-0.05221476, -0.01762434,  0.04123325,  0.18096224, -0.03790677, -0.04944872],
                  [-0.00203066,  0.13428713, -0.01217184,  0.0118815 ,  0.29470096, -0.03000974],
                  [-0.11452422,  0.02204987,  0.26638479,  0.0906101 ,  0.09268121,  0.27170194],
                  [-0.07886027,  0.02689077,  0.23741267,  0.09347838,  0.01623164, -0.09972681],
                  [-0.00260263, -0.18086267,  0.01792558, -0.04575363,  0.16317086,  0.01815165],
                  [-0.05632008,  0.04348975,  0.07867847,  0.10362926, -0.10438378, -0.05899034],
                  [-0.03429256, -0.40883837,  0.10002436,  0.03334327,  0.69053916, -0.04349921],
                  [-0.0312356 , -0.18704025, -0.48207014, -0.37461544,  0.07151296,  0.49623294]]))


def test_kinetic_4_3():
    check_kinetic(
        np.array([ 0.3345 ,  1.64669,  0.929  ]), np.array([ 1.27381,  1.33581,  0.31581]),
        np.array([-0.94447, -0.74489, -0.69538]), np.array([ 0.03678,  0.00172,  0.69426]),
        np.array([ 1.77234,  1.81103,  0.90533,  0.5529 ,  0.52613,  0.67168,  1.05098,  1.25991,  1.87356,  1.74584,  1.24985,  1.504  ,  1.75061,  0.71086,  1.91046]),
        np.array([ 0.95997,  0.90913,  1.13314,  1.84706,  1.01587,  1.08516,  1.95369,  1.03864,  1.03212,  0.68402]),
        4, 3,
        np.array([[-1.19705745, -0.92059647, -2.13567612,  1.8671679 ,  0.70103653,  2.08845246,  0.62605675, -0.58451014, -1.8731403 , -1.51764175],
                  [ 1.26053951, -0.22136545,  0.13103956, -0.90811448, -0.56484751,  0.90920166,  1.46135761,  0.81635229,  0.68860832, -0.74576666],
                  [ 1.17286016,  0.05255646, -0.0511965 ,  0.74546035,  0.05350109, -0.08841558, -0.60215749, -0.13567086,  0.16111641,  0.43549521],
                  [-0.13405354,  0.08340719, -0.10349697,  0.15284299, -0.05146173,  0.03952825, -0.20011853, -0.17165176,  0.10426819, -0.17394836],
                  [ 0.12300317,  0.06564591, -0.03479819,  0.12575332, -0.00159245,  0.08581165,  0.48263837,  0.00533259, -0.04604211,  0.07454911],
                  [ 0.04505641, -0.13486632,  0.04792085,  0.44479063, -0.04825856,  0.16571381, -0.34919987,  0.06894654, -0.09110029, -0.03886915],
                  [ 0.36264841, -0.40314165,  0.43254344, -0.15302165, -0.21911208,  0.14072085,  1.57217926,  0.0671955 ,  0.5995979 , -0.34105766],
                  [ 0.42408011,  0.16262529,  0.01244903,  0.17785609, -0.00298866, -0.10755156,  0.65571298, -0.09570932,  0.22025453,  0.21403187],
                  [ 0.79468035, -0.23459974,  0.25532909, -0.34046512, -0.01924828,  0.07691668,  2.28404592,  0.246688  ,  0.21853562,  0.18714778],
                  [ 1.91189383,  0.54201798,  0.1419181 ,  0.36654603, -0.09625734, -0.26749951, -2.05415371,  0.7106361 ,  0.04759638,  1.55590396],
                  [ 0.22568194,  0.54925744, -0.74042218, -2.23115416,  0.5813364 , -1.38419999, -2.20833857, -1.77679063,  1.04295856, -0.56096164],
                  [-0.43706544,  0.62666275, -0.19264964,  0.17100529,  0.12868126,  0.31254849,  3.18623705, -0.05485677, -0.42382346,  0.53177137],
                  [-0.44865373,  0.55479109, -0.01361964, -1.0934655 , -0.11421405, -0.2078602 ,  0.01332442,  0.05808793,  0.3995725 , -0.1980759 ],
                  [-0.47230149,  0.25829064,  0.26920432,  0.56638082, -0.08569426,  0.02037592,  1.38752232,  0.01370847, -0.1143199 ,  0.48203276],
                  [-2.07067988, -1.51621235,  2.40090937, -2.33028665,  0.35488133, -0.79926827, -1.82358308,  1.99988938, -0.57841976,  0.63749859]]))


def test_kinetic_4_4():
    check_kinetic(
        np.array([ 0.91296,  0.5249 ,  1.00904]), np.array([ 1.59961,  0.41264,  1.40902]),
        np.array([-0.34722,  0.32224, -0.94351]), np.array([ 0.97377, -0.68191, -0.56333]),
        np.array([ 0.59055,  1.91352,  1.21212,  1.62118,  0.89282,  0.73801,  1.57987,  1.13435,  1.35365,  1.74871,  1.55493,  1.97258,  0.90919,  1.16725,  1.27353]),
        np.array([ 1.60633,  1.91703,  0.67265,  1.33285,  1.99531,  1.68639,  1.46872,  1.82796,  1.64229,  1.5316 ,  0.77546,  0.50652,  0.99187,  1.53397,  1.99627]),
        4, 4,
        np.array([[  0.01454609,   0.5628389 ,  -0.07477131,  -0.20245295,  -0.65543507,  -1.5095628 ,   4.44528595,  -1.10259724,   1.38341534,  -1.31971274,   1.57338212,  -0.50250635,   0.01800541,  -1.0529002 ,  -3.53099477],
                  [ 12.28684952,   0.95086653,   0.58620124,   0.20795757,   0.06316635,   0.38157282,   4.6286143 ,  -0.85553374,   0.26922119,   0.85637737,  11.54504863,  -0.98819814,   1.70996786,  -0.64981072,  -1.59173342],
                  [ -2.94675361,   1.0582764 ,   1.05153003,  -0.04909832,  -0.01383473,  -0.09249454,   0.72270554,   1.82051812,  -0.71332802,   1.89121388,   1.13232305,   1.52071579,  -1.10879825,   3.77618512,  -6.84861474],
                  [  4.00873606,   1.77958066,   0.32538762,  -1.98394966,   0.33697799,  -1.85731414,   0.67619364,  -0.22427856,  -0.02090287,  -0.32647462,   4.40663123,  -0.40907794,  -0.53644056,   0.68515156,  -2.75965274],
                  [ -0.93984135,  -0.16501577,   0.37194692,   0.15085322,   0.04502695,   0.28224893,   0.09421378,   0.07545097,  -0.02415806,  -0.25063644,   0.82942073,   0.71154496,  -0.27173879,   0.47787684,   2.11609153],
                  [  0.0671019 ,  -0.59905831,  -0.17616495,  -0.6179649 ,   0.24277629,  -0.32146652,   0.36128626,   0.01391828,   0.27035113,  -0.13377041,  -0.02069301,   0.2187369 ,   1.00281186,  -0.92662908,   4.84697588],
                  [ 16.60483486,   6.61479606,   0.71815902,   0.28046146,   0.74858766,   1.28054353,  -1.59168411,  -0.03132672,  -0.47796563,  -0.27247876,   6.30588413,  -0.67129568,  -0.07852528,  -1.20479602,  -0.72877036],
                  [ -2.0628045 ,  -0.30796516,   0.9217338 ,  -0.20926715,   0.23838212,   0.35595217,   0.00362702,  -0.22957898,   0.02188366,  -0.46202668,   0.83281142,   0.53128046,  -0.37432891,  -0.85037981,  -2.65000941],
                  [  4.24103241,   0.73504773,  -0.22074491,  -0.44468517,  -0.12024642,   1.24472031,  -0.32937241,   0.03973551,  -0.12352405,   0.04142307,   2.10347079,   0.24144563,   1.25154729,  -0.55537516,  -6.68672359],
                  [ -5.85289468,   1.92749853,   2.63465519,  -0.66198436,  -2.45920422,  -0.32193778,  -0.25599472,  -0.78397558,  -0.00479019,  -1.78721639,   0.99643182,   1.64649657,   0.45525451,   9.23485156,  -9.50161491],
                  [ 11.18796955,  22.00859905,   1.71789918,   6.3004833 ,   4.05760061,  -0.20650275,   1.42991677,   3.24759191,   4.28597035,   1.9428923 ,   6.8119998 ,   0.14192473,  -0.92713865,   3.10109865, -10.41389098],
                  [ -4.72239259,  -3.25799217,   3.52168212,  -0.20370955,   6.16760228,   2.87715576,  -2.430354  ,   0.65415303,   1.94987453,   5.21306081,   2.26594308,   1.98571242,   0.19837589,   3.8136846 ,   7.29979204],
                  [  0.68013726,   1.3546239 ,  -0.43230333,  -0.16201482,  -0.16548443,   2.3748352 ,  -1.08680331,  -0.44867829,   0.9107388 ,   1.16359037,   0.34409332,   0.15918897,   0.34564642,   0.38115263,   3.4769    ],
                  [ -2.17924671,  -0.68362942,   1.83846862,   0.77900439,   1.39901579,  -0.54328461,  -1.26737831,  -2.5337324 ,  -0.18866358,   6.15466614,   1.46447806,   1.20306891,   0.18462164,   2.22867503,   4.82106044],
                  [ -3.11502094,  -3.259664  ,  -2.66168097,  -0.32464083,   3.74427532,  11.68910013,  -1.61932582,  -3.18178915,  -8.93757788,  -0.96609349,  -2.71093157,   1.24326   ,   3.87719066,   0.73551194,  37.54293115]]))


def check_nuclear_attraction(alphas0, alphas1, r0, r1, scales0, scales1, charges, centers, shell_type0, shell_type1, result0):
    # This test compares output from HORTON with reference data computed with
    # PyQuante.
    max_shell_type = 4
    max_nbasis = get_shell_nbasis(max_shell_type)
    gb2i = GB2NuclearAttractionIntegral(max_shell_type, charges, centers)
    assert gb2i.max_nbasis == max_nbasis

    assert gb2i.max_nbasis == max_nbasis
    nbasis0 = get_shell_nbasis(shell_type0)
    nbasis1 = get_shell_nbasis(shell_type1)
    assert result0.shape == (nbasis0, nbasis1)
    # Clear the working memory
    gb2i.reset(shell_type0, shell_type1, r0, r1)
    # Add a few cobtributions:
    for alpha0, alpha1 in zip(alphas0, alphas1):
        gb2i.add(1.0, alpha0, alpha1, scales0, scales1)
    result1 = gb2i.get_work(nbasis0, nbasis1)
    error = abs(result1 - -result0).max()
    assert error < 1e-6


def test_nuclear_attraction_0_0():
    check_nuclear_attraction(
        np.array([ 1.1499 ,  0.80173,  1.75613]), np.array([ 0.11444,  0.43894,  0.89604]),
        np.array([-0.83928,  0.02661, -0.15484]), np.array([-0.63628,  0.64107, -0.41656]),
        np.array([ 0.83566]),
        np.array([ 1.41376]),
        np.array([ 0.03953, -0.94018, -0.90911,  0.88274]),
        np.array([[-0.7696 , -0.8702 , -0.49915],
                  [ 0.69398, -0.79432, -0.04651],
                  [ 0.4377 , -0.56457,  0.81482],
                  [ 0.10819, -0.84377, -0.50199]]),
        0, 0,
        np.array([[-3.720398]]))


def test_nuclear_attraction_1_0():
    check_nuclear_attraction(
        np.array([ 0.95595,  1.93405,  0.46025]), np.array([ 1.46799,  0.26041,  1.86753]),
        np.array([-0.4592 , -0.32629,  0.97071]), np.array([ 0.63585, -0.45094, -0.4064 ]),
        np.array([ 1.92994,  1.72237,  0.95997]),
        np.array([ 0.60535]),
        np.array([ 0.95945, -0.60563, -0.6996 ,  0.57144]),
        np.array([[-0.44917, -0.27333,  0.724  ],
                  [-0.03626,  0.42363,  0.43008],
                  [ 0.22669, -0.50554, -0.49913],
                  [-0.88467,  0.08702,  0.85269]]),
        1, 0,
        np.array([[-0.40084655],
                  [ 0.01259754],
                  [ 0.2063208 ]]))


def test_nuclear_attraction_1_1():
    check_nuclear_attraction(
        np.array([ 0.59912,  1.5902 ,  1.47962]), np.array([ 0.11924,  0.36095,  0.88271]),
        np.array([-0.0489 , -0.33057, -0.58192]), np.array([-0.50679, -0.74802, -0.50857]),
        np.array([ 1.95   ,  1.79499,  1.00779]),
        np.array([ 0.81129,  1.31353,  0.92627]),
        np.array([ 0.55335,  0.42741,  0.31526, -0.23886]),
        np.array([[-0.207  , -0.48222, -0.09863],
                  [ 0.53182,  0.20242,  0.4351 ],
                  [ 0.38424,  0.6983 ,  0.85954],
                  [-0.40051, -0.21384,  0.94795]]),
        1, 1,
        np.array([[  6.58632082,   0.11715378,   0.36320453],
                  [  0.08483758,  10.1088519 ,   0.30922763],
                  [  0.80222346,   1.19031454,   4.29384884]]))


def test_nuclear_attraction_2_0():
    check_nuclear_attraction(
        np.array([ 0.56361,  1.01941,  0.46388]), np.array([ 1.92887,  1.32672,  0.54163]),
        np.array([-0.27482, -0.64939, -0.85553]), np.array([-0.98864, -0.09978,  0.62885]),
        np.array([ 1.7099 ,  0.75985,  0.53305,  1.47259,  1.69215,  1.2908 ]),
        np.array([ 0.62823]),
        np.array([-0.07271,  0.84005, -0.90672, -0.98748]),
        np.array([[-0.22538,  0.48992,  0.24625],
                  [-0.2713 , -0.56729, -0.12587],
                  [ 0.62638,  0.65665, -0.45222],
                  [-0.69105, -0.93802,  0.38737]]),
        2, 0,
        np.array([[-1.38007038],
                  [ 0.04513791],
                  [ 0.32720976],
                  [-1.14911877],
                  [-0.51420671],
                  [-2.5054637 ]]))


def test_nuclear_attraction_2_1():
    check_nuclear_attraction(
        np.array([ 1.96842,  0.30347,  1.48238]), np.array([ 1.36835,  0.26125,  1.79489]),
        np.array([-0.93567, -0.77696, -0.48014]), np.array([ 0.87862,  0.66487,  0.31839]),
        np.array([ 1.81556,  1.43739,  1.3201 ,  1.77607,  1.47075,  1.94402]),
        np.array([ 1.01597,  0.61246,  1.10889]),
        np.array([ 0.31244, -0.44517, -0.7148 ,  0.75846]),
        np.array([[-0.54616,  0.57517, -0.24002],
                  [-0.10344, -0.36722,  0.16586],
                  [ 0.18978,  0.8406 , -0.3029 ],
                  [-0.10373, -0.36703, -0.43995]]),
        2, 1,
        np.array([[-0.23990603, -0.35763947, -0.06800012],
                  [ 0.4427801 , -0.25459448,  0.58787446],
                  [ 0.29081026,  0.23687337, -0.21718961],
                  [ 0.97949251, -0.85103387,  0.82432459],
                  [ 0.41384973, -0.01860388, -0.33831452],
                  [ 0.49364606,  0.13828694, -0.96183716]]))


def test_nuclear_attraction_2_2():
    check_nuclear_attraction(
        np.array([ 1.4928 ,  0.8887 ,  0.48535]), np.array([ 0.84159,  1.22852,  1.6358 ]),
        np.array([-0.91863, -0.40145, -0.02778]), np.array([ 0.70613,  0.30642,  0.9048 ]),
        np.array([ 1.77404,  0.99202,  0.78898,  1.51593,  0.90145,  1.26582]),
        np.array([ 0.89204,  1.01132,  1.51105,  0.69261,  0.60815,  0.74161]),
        np.array([ 0.74195, -0.18199, -0.73651, -0.91895]),
        np.array([[-0.77067, -0.84769,  0.42186],
                  [ 0.04751,  0.02714, -0.38344],
                  [-0.63515, -0.54802,  0.49843],
                  [ 0.34683, -0.27743,  0.65987]]),
        2, 2,
        np.array([[-0.59197241, -0.27862092, -0.39411938, -0.60244061, -0.18778631, -0.62336674],
                  [-0.12690367,  0.02104568, -0.08872218, -0.01046148,  0.01307119, -0.11131231],
                  [-0.16158146, -0.08137274, -0.01477491, -0.1313657 , -0.00777608, -0.0422372 ],
                  [-0.28136967,  0.03925771, -0.1820123 , -0.10210815,  0.01337943, -0.14417909],
                  [-0.10741863,  0.01786701, -0.00014094, -0.00613456, -0.00237334, -0.01623228],
                  [-0.38190797, -0.15863741,  0.01618758, -0.1757732 ,  0.0012421 , -0.11230951]]))


def test_nuclear_attraction_3_0():
    check_nuclear_attraction(
        np.array([ 0.59247,  1.96546,  1.07904]), np.array([ 1.0072 ,  0.70814,  0.92547]),
        np.array([ 0.4857 , -0.39667, -0.3998 ]), np.array([ 0.19794, -0.82631, -0.50484]),
        np.array([ 1.58555,  0.79364,  1.16501,  0.95828,  0.82792,  0.95889,  1.54545,  1.9622 ,  0.93092,  1.44428]),
        np.array([ 0.75126]),
        np.array([ 0.16358,  0.36607, -0.85077, -0.55414]),
        np.array([[-0.31967,  0.32547,  0.04623],
                  [ 0.24747, -0.49427, -0.99233],
                  [-0.45277,  0.84931, -0.9229 ],
                  [-0.29599, -0.37722,  0.96455]]),
        3, 0,
        np.array([[ 0.44282923],
                  [ 0.02652339],
                  [-0.05402371],
                  [ 0.09709459],
                  [-0.02390434],
                  [ 0.09157822],
                  [ 0.16242631],
                  [-0.05719273],
                  [ 0.02373123],
                  [-0.2632455 ]]))


def test_nuclear_attraction_3_1():
    check_nuclear_attraction(
        np.array([ 0.79517,  1.4921 ,  1.77045]), np.array([ 1.42462,  0.72626,  1.80002]),
        np.array([ 0.4837 , -0.80385, -0.21282]), np.array([ 0.17651,  0.90621,  0.92473]),
        np.array([ 1.9472 ,  1.3943 ,  1.63229,  0.80412,  0.59167,  1.29869,  1.93069,  1.03061,  0.83169,  1.21334]),
        np.array([ 0.6603 ,  1.95907,  1.19918]),
        np.array([-0.5566 , -0.17837,  0.92798,  0.24499]),
        np.array([[-0.06608,  0.04099,  0.83998],
                  [-0.2522 ,  0.33108, -0.31661],
                  [ 0.65812,  0.54883,  0.57297],
                  [-0.70701,  0.672  , -0.47598]]),
        3, 1,
        np.array([[ 0.01609283, -0.04743015, -0.00613079],
                  [ 0.0138956 , -0.01017928, -0.02448584],
                  [ 0.01371342,  0.00589694, -0.00403712],
                  [ 0.02185936, -0.03681533, -0.00703703],
                  [ 0.00693758, -0.02482844, -0.00371066],
                  [ 0.01060229, -0.06769765, -0.0051031 ],
                  [ 0.20061655, -0.32609629, -0.46074057],
                  [ 0.05799593, -0.09038838, -0.06421044],
                  [ 0.03380386, -0.04695774, -0.02344935],
                  [ 0.04404127, -0.05060601, -0.00608187]]))


def test_nuclear_attraction_3_2():
    check_nuclear_attraction(
        np.array([ 0.64589,  1.56386,  1.0447 ]), np.array([ 0.60992,  1.09548,  1.86252]),
        np.array([ 0.7588 , -0.93344,  0.43236]), np.array([ 0.40788, -0.96689, -0.96929]),
        np.array([ 0.53937,  1.8209 ,  1.65581,  1.05339,  0.57303,  1.0209 ,  1.36056,  1.43388,  1.64072,  1.32519]),
        np.array([ 1.0943 ,  0.69348,  1.88678,  1.71102,  1.31969,  0.54817]),
        np.array([-0.88781, -0.83111,  0.8418 , -0.36735]),
        np.array([[ 0.67679, -0.54284, -0.02183],
                  [ 0.80029,  0.61345,  0.74594],
                  [ 0.91053, -0.97766, -0.92354],
                  [-0.45779, -0.82766, -0.65828]]),
        3, 2,
        np.array([[ 0.20717507, -0.03736966, -0.62966998,  0.16015066,  0.0335251 ,  0.13834668],
                  [-0.14169283,  0.08782599,  0.0061612 , -0.32772481, -0.57739303, -0.10280709],
                  [ 0.74871014, -0.01707286, -0.19536232,  0.47922252,  0.0598764 ,  0.06846831],
                  [-0.01089697, -0.05782069, -0.37887496,  0.29392114,  0.05335422,  0.08873772],
                  [ 0.0068878 ,  0.04999432,  0.01707595, -0.03012055, -0.0251939 , -0.00327156],
                  [ 0.17490718, -0.04713846, -0.23490833,  0.22855108,  0.02308318,  0.08322237],
                  [-0.15080511, -0.13155257, -0.21051351, -1.2959771 , -1.39438012, -0.30405306],
                  [ 0.20978329,  0.03712039,  0.04568512,  1.05880842,  0.11378082,  0.03580113],
                  [-0.13448719, -0.02647456, -0.07171657, -0.59156898, -0.33817882, -0.09889612],
                  [ 0.40353597,  0.04431891,  0.05353972,  1.04828721,  0.06798425,  0.09869288]]))


def test_nuclear_attraction_3_3():
    check_nuclear_attraction(
        np.array([ 0.43043,  1.7863 ,  1.60762]), np.array([ 1.91989,  1.29592,  1.30772]),
        np.array([-0.8922 ,  0.2513 ,  0.53418]), np.array([-0.19884,  0.58376, -0.45971]),
        np.array([ 0.86817,  1.22086,  1.9112 ,  1.84934,  0.90329,  0.84013,  1.87491,  0.69813,  0.52947,  1.32453]),
        np.array([ 0.59783,  1.93594,  0.83289,  1.81521,  1.72562,  1.58553,  1.39675,  1.09098,  0.75444,  1.03229]),
        np.array([ 0.10038,  0.91252,  0.70627,  0.93819]),
        np.array([[ 0.78751,  0.69459,  0.18841],
                  [ 0.9495 ,  0.21052, -0.06736],
                  [ 0.85328,  0.40186, -0.70412],
                  [ 0.4515 , -0.42787, -0.16876]]),
        3, 3,
        np.array([[ 0.34088228, -0.19031243,  0.09917368,  0.30519868, -0.05949284,  0.29265389, -0.19585497,  0.09012035, -0.04668562,  0.239288  ],
                  [ 0.04039421,  0.18270908,  0.01711149, -0.02779487,  0.04134614,  0.04421631,  0.296214  , -0.0194554 ,  0.06373448,  0.05558897],
                  [-0.31069953,  0.23058826,  0.0062535 , -0.34084047,  0.01262938, -0.18488268,  0.28350987, -0.00635457,  0.03135252,  0.09956498],
                  [ 0.07404025,  0.03017311,  0.01524066,  0.14886828,  0.00617503,  0.06921038, -0.01551016,  0.08742889,  0.02112685,  0.12280132],
                  [-0.02316562, -0.05868407,  0.00251599,  0.01064597,  0.00047643, -0.01226654, -0.16651828,  0.005737  , -0.01791275,  0.01261166],
                  [ 0.11997389, -0.05664825, -0.02419945,  0.11360714,  0.00494627,  0.07752199, -0.11726056, -0.02774312, -0.01661499, -0.06332052],
                  [-0.02684942,  0.26749532,  0.0316273 , -0.03519875, -0.05146405, -0.04336663,  0.73051257,  0.02721388,  0.13187146,  0.11394667],
                  [ 0.01023628, -0.02587968, -0.00016514,  0.02034569,  0.0002602 ,  0.00395819, -0.0058442 ,  0.0008789 , -0.00677989,  0.01791367],
                  [-0.0059965 ,  0.07409827, -0.00575962, -0.00194471,  0.00052533, -0.0051086 ,  0.15500333, -0.00463763,  0.02327128, -0.01554138],
                  [ 0.0620966 ,  0.16383841,  0.13408279,  0.0635047 ,  0.00154304,  0.04312172,  0.29506236,  0.1667039 ,  0.04523992,  0.44739997]]))


def test_nuclear_attraction_4_0():
    check_nuclear_attraction(
        np.array([ 0.43286,  0.78487,  1.62566]), np.array([ 1.41406,  1.66238,  1.97132]),
        np.array([ 0.38072,  0.29672, -0.72526]), np.array([ 0.71378, -0.99468, -0.95029]),
        np.array([ 0.9257 ,  1.66412,  0.66727,  0.99949,  1.52209,  1.72865,  0.51615,  0.75286,  1.76842,  0.94215,  1.23075,  1.33678,  0.5485 ,  1.91084,  1.6127 ]),
        np.array([ 1.96168]),
        np.array([ 0.7194 ,  0.88196, -0.55991, -0.52554]),
        np.array([[ 0.80592,  0.4921 ,  0.15675],
                  [ 0.46617, -0.09085,  0.26806],
                  [-0.33277, -0.40421, -0.53184],
                  [-0.72137, -0.84458,  0.1552 ]]),
        4, 0,
        np.array([[ 0.32760253],
                  [-0.62591288],
                  [ 0.00547763],
                  [ 0.30512678],
                  [ 0.00807307],
                  [ 0.17494689],
                  [-0.36949683],
                  [ 0.00844852],
                  [-0.2183174 ],
                  [ 0.01205235],
                  [ 1.58868446],
                  [-0.02598786],
                  [ 0.13286064],
                  [-0.06617688],
                  [ 0.43677537]]))


def test_nuclear_attraction_4_1():
    check_nuclear_attraction(
        np.array([ 1.85192,  1.96127,  0.87917]), np.array([ 0.96456,  1.34294,  1.6066 ]),
        np.array([-0.07608,  0.72871, -0.72547]), np.array([ 0.24925, -0.65673,  0.8172 ]),
        np.array([ 0.52959,  1.30494,  0.9852 ,  1.35485,  0.80144,  1.52196,  0.77957,  0.84893,  1.53579,  0.85214,  1.188  ,  1.59161,  0.97467,  1.15059,  0.53265]),
        np.array([ 1.59586,  1.75814,  0.85554]),
        np.array([ 0.91525, -0.34068, -0.65457, -0.25339]),
        np.array([[-0.37816, -0.15993,  0.73053],
                  [-0.73099, -0.60595,  0.03621],
                  [ 0.70889,  0.47144,  0.0716 ],
                  [ 0.2287 ,  0.39313,  0.89614]]),
        4, 1,
        np.array([[-0.0104936 , -0.01524375,  0.00690854],
                  [ 0.01506684,  0.0163139 , -0.00888917],
                  [-0.01453817, -0.02700533,  0.00827173],
                  [-0.01261253, -0.01097711,  0.0120404 ],
                  [ 0.00967868,  0.00793637, -0.00475146],
                  [-0.03365313, -0.04266877,  0.01291758],
                  [ 0.0081477 ,  0.00789043, -0.00594975],
                  [-0.00294355, -0.0124472 ,  0.00473904],
                  [-0.00009283,  0.03830578, -0.00957808],
                  [-0.0009783 , -0.05605144,  0.00873553],
                  [-0.01875715,  0.00585854,  0.03942251],
                  [ 0.04779419, -0.01433297, -0.01918349],
                  [-0.04464388,  0.01226269,  0.00721527],
                  [ 0.08273982, -0.02207834, -0.00697022],
                  [-0.06627396, -0.01414908,  0.0066148 ]]))


def test_nuclear_attraction_4_2():
    check_nuclear_attraction(
        np.array([ 0.36021,  1.79442,  1.57187]), np.array([ 1.12257,  1.03095,  1.95063]),
        np.array([-0.29695,  0.44777, -0.77844]), np.array([-0.33327,  0.45354, -0.8477 ]),
        np.array([ 0.82225,  0.95734,  1.80011,  1.91371,  0.60327,  1.80693,  1.08726,  1.75153,  1.11105,  1.09023,  1.23719,  1.12201,  1.13849,  0.76953,  1.4393 ]),
        np.array([ 1.4637 ,  0.75032,  1.92972,  1.62525,  0.93952,  0.95921]),
        np.array([-0.78467,  0.11556, -0.33453, -0.09984]),
        np.array([[-0.42784, -0.17077,  0.63764],
                  [-0.81157, -0.36819, -0.43921],
                  [ 0.23838,  0.13993, -0.97486],
                  [-0.62996,  0.40268, -0.38323]]),
        4, 2,
        np.array([[-1.34854189,  0.01000777,  0.02490566, -0.29629721,  0.00347858, -0.18826372],
                  [ 0.02387177, -0.15928792,  0.00756381,  0.01428623,  0.00291333,  0.00118517],
                  [ 0.04864587,  0.00283832, -0.81396476,  0.0107821 ,  0.00232458,  0.02365635],
                  [-0.62227364,  0.01443932,  0.00907031, -0.68042615,  0.00893708, -0.14770675],
                  [ 0.00173986,  0.00160994,  0.00222336,  0.00318978, -0.04481623,  0.00519097],
                  [-0.61976439,  0.00293062,  0.04157873, -0.23185576,  0.01397476, -0.4443802 ],
                  [ 0.01726219, -0.17860168,  0.01010713,  0.0264142 ,  0.00284197,  0.00077092],
                  [ 0.00881871,  0.00414304, -0.26673401,  0.00961992,  0.00129273,  0.01046642],
                  [ 0.00426588, -0.06569658,  0.01669052,  0.00168988,  0.00639143, -0.00482759],
                  [ 0.01873512,  0.00567195, -0.53212704,  0.01072746, -0.00432915,  0.03786206],
                  [-0.39860556,  0.01774606,  0.0002031 , -2.19444803,  0.0377573 , -0.2946714 ],
                  [ 0.00506549,  0.00220976,  0.00534481,  0.04478312, -0.25662034,  0.03962101],
                  [-0.13174811,  0.00223928,  0.00846308, -0.44831812,  0.03634891, -0.30770588],
                  [ 0.00751688,  0.00260639, -0.00111099,  0.03880093, -0.20101936,  0.07689931],
                  [-0.53150382, -0.00038199,  0.06567489, -0.64496486,  0.13460586, -2.07339697]]))


def test_nuclear_attraction_4_3():
    check_nuclear_attraction(
        np.array([ 0.56339,  0.71124,  1.62157]), np.array([ 0.19348,  1.5277 ,  1.80133]),
        np.array([ 0.32121, -0.5229 , -0.06592]), np.array([ 0.51652, -0.59954,  0.17978]),
        np.array([ 0.65806,  1.68055,  0.64869,  0.52775,  1.97388,  0.64465,  0.51532,  1.18736,  0.74923,  1.17013,  0.99703,  1.51968,  1.86413,  1.0914 ,  1.99955]),
        np.array([ 1.11945,  1.45372,  1.80845,  1.28122,  1.73378,  1.4425 ,  1.61124,  1.45099,  1.24217,  1.09274]),
        np.array([-0.44678,  0.36088,  0.01751, -0.20022]),
        np.array([[-0.11109, -0.88754, -0.96656],
                  [ 0.86625, -0.5558 ,  0.32802],
                  [ 0.1276 , -0.02469,  0.83612],
                  [-0.20433,  0.86768, -0.41741]]),
        4, 3,
        np.array([[  4.63346952,  -0.31957392,   2.11976097,   0.68539713,  -0.02364843,   0.89209133,  -0.24641214,   0.34008233,  -0.04652934,   0.80068539],
                  [ -0.12272846,   2.32444056,   0.05259011,  -0.4168462 ,   1.07764423,   0.03887916,   1.24680883,  -0.0517621 ,   0.35246465,   0.0391223 ],
                  [  0.68143974,   0.00928398,   1.16169446,   0.15259096,  -0.05804658,   0.86247367,  -0.00242904,   0.14075504,   0.00384304,   0.37238912],
                  [  0.65862541,  -0.1145501 ,   0.37926657,   0.41782883,  -0.01748812,   0.18136703,  -0.32191415,   0.30273147,  -0.02205313,   0.23064524],
                  [  0.05394195,   0.54253929,   0.05052189,  -0.00983586,   0.69933114,   0.07196346,   0.6055983 ,  -0.14386988,   0.76274784,   0.10011738],
                  [  0.80619623,  -0.05057314,   0.88267745,   0.16385214,   0.00795996,   0.62326958,  -0.07003733,   0.2256297 ,  -0.02901964,   1.06579757],
                  [ -0.05523076,   0.5916707 ,  -0.00932837,  -0.23279132,   0.37383133,  -0.00578396,   0.77704975,  -0.057942  ,   0.13056843,   0.00224111],
                  [  0.27027926,  -0.00216729,   0.57310649,   0.29694259,  -0.06274311,   0.56401437,  -0.06054708,   0.29926497,  -0.00122482,   0.25907335],
                  [  0.00826805,   0.27298395,   0.03959985,  -0.04987286,   0.32352798,   0.04830981,   0.20325324,  -0.00145921,   0.16513271,   0.04749938],
                  [  0.7791537 ,   0.0214819 ,   1.69648992,   0.27160254,  -0.07037385,   2.24238179,   0.00519538,   0.28347181,   0.04155366,   1.25056837],
                  [  1.1038507 ,  -0.49057262,   0.85557925,   2.37455635,  -0.17319217,   0.62235339,  -4.35254225,   3.11695875,  -0.26162699,   1.38961117],
                  [  0.01032487,   0.48075443,  -0.01833936,  -0.07625031,   0.95305786,   0.038467  ,   2.39559618,  -0.53933465,   1.85442663,   0.10959516],
                  [  0.65615085,  -0.0935971 ,   0.94893467,   0.80035678,   0.02056837,   1.0332704 ,  -0.89648608,   2.00975884,  -0.06843286,   3.17634643],
                  [  0.04177737,   0.3161773 ,   0.09007162,   0.00557593,   0.64922996,   0.14546375,   0.93407817,  -0.16510007,   1.74295377,   0.37350129],
                  [  2.11029844,  -0.10712227,   4.43985594,   0.82276822,   0.10698314,   5.44697235,  -0.53582233,   3.08460043,  -0.3088369 ,  20.48498288]]))


def test_nuclear_attraction_4_4():
    check_nuclear_attraction(
        np.array([ 0.97292,  0.65951,  1.94103]), np.array([ 0.89762,  1.52249,  0.41737]),
        np.array([-0.17657, -0.96842, -0.57669]), np.array([ 0.98269, -0.57914, -0.19589]),
        np.array([ 1.49108,  1.47196,  1.08852,  1.72005,  1.32745,  0.9925 ,  1.321  ,  1.29231,  1.85347,  1.95019,  0.83019,  1.60779,  0.5771 ,  0.77665,  1.15768]),
        np.array([ 1.94145,  1.31199,  1.40738,  1.36325,  1.84822,  0.97724,  0.84666,  1.50865,  1.54769,  1.02591,  1.31714,  1.52794,  0.6883 ,  0.74998,  1.25922]),
        np.array([-0.11183, -0.02877,  0.29553, -0.50534]),
        np.array([[-0.9955 , -0.74737,  0.10906],
                  [-0.92434,  0.03378,  0.1684 ],
                  [-0.67273, -0.94425,  0.46248],
                  [ 0.07585, -0.24336,  0.4268 ]]),
        4, 4,
        np.array([[-0.53118279, -0.00401394, -0.01612239, -0.06201167, -0.01038275, -0.04576963,  0.00878683,  0.00422869,  0.00555994,  0.00897352, -0.17828355, -0.01616837, -0.03183558, -0.00804222, -0.17603212],
                  [ 0.05920394, -0.02854264,  0.00580099, -0.00372926, -0.00058634, -0.00472332, -0.0082669 ,  0.000186  , -0.00473875,  0.00244661,  0.01854354,  0.0138344 , -0.00142447,  0.00601089, -0.0373056 ],
                  [ 0.01053246, -0.00060774, -0.01934745, -0.00632062,  0.00174176, -0.0051292 ,  0.00084696, -0.00313679, -0.0001665 , -0.00641701, -0.02695024,  0.00999622, -0.0012583 ,  0.00533218,  0.00889274],
                  [-0.18115365,  0.01334967, -0.00889782, -0.03681077, -0.00122229, -0.01425267,  0.00131571, -0.00193796,  0.0030125 , -0.00128046, -0.14866989, -0.00450307, -0.01753647,  0.00075277, -0.04797476],
                  [-0.04098825,  0.00523866,  0.00992962, -0.00335816, -0.01093042, -0.00222025,  0.00177538,  0.00097706,  0.00052922,  0.00397553, -0.00070547, -0.02415759, -0.00182   , -0.01232901, -0.00238753],
                  [-0.07696838,  0.00159787, -0.00175031, -0.00909376, -0.0036222 , -0.01509686, -0.00018166,  0.00069433,  0.00023313,  0.00052543, -0.02705356,  0.00082157, -0.01035816, -0.0016726 , -0.09028418],
                  [ 0.07257664, -0.01977239,  0.00677115, -0.00931629, -0.00387574, -0.00343337,  0.01247653,  0.0014329 ,  0.00672918,  0.0012964 , -0.03848553,  0.01223381, -0.00747548,  0.0038234 , -0.02474452],
                  [ 0.01488805, -0.0008199 , -0.00348592, -0.00533798, -0.00424805, -0.00402551,  0.00213187,  0.00505007,  0.00275063,  0.00580116, -0.03051874, -0.00627293, -0.00162043, -0.00590333, -0.00051926],
                  [ 0.0349914 , -0.00339764,  0.00552735, -0.00303821, -0.00936948, -0.00768619,  0.00555021,  0.00311722,  0.0088985 ,  0.00817745,  0.00302179, -0.01569427, -0.00254734, -0.00542211, -0.05215689],
                  [-0.04726208, -0.01136831,  0.01176993, -0.01412471,  0.00388779, -0.03080514, -0.00082725,  0.01313391,  0.00043224,  0.03734277, -0.03558048,  0.01312788, -0.01221191,  0.00881638, -0.07888353],
                  [-0.31982004,  0.08770909, -0.01041743, -0.11144077, -0.00107576, -0.02795677,  0.03224786, -0.01049296,  0.02183301, -0.005061  , -0.22685331, -0.00258242, -0.02174002,  0.00156351, -0.03800167],
                  [-0.18939784,  0.06522039,  0.0685561 , -0.03931349, -0.06235835, -0.01363501,  0.02908588,  0.02257433,  0.00802004,  0.02889253, -0.04072822, -0.06735181, -0.00662313, -0.02683495, -0.0068066 ],
                  [-0.09125682,  0.01768906,  0.01778724, -0.01864111, -0.01150887, -0.016599  ,  0.00342623,  0.00599179,  0.00845479,  0.00503977, -0.02805109, -0.006316  , -0.00742313, -0.0035279 , -0.03256503],
                  [-0.10644479,  0.02422955,  0.04867129, -0.01054117, -0.03035538, -0.0191032 ,  0.0079627 ,  0.00613225,  0.01118281,  0.02655279, -0.00320426, -0.024138  , -0.00399992, -0.01758776, -0.02703443],
                  [-0.3222066 ,  0.02533274,  0.09355949, -0.0401891 , -0.03152081, -0.11587615, -0.00167675,  0.02221827,  0.00175327,  0.056431  , -0.04715164, -0.00075652, -0.03022785, -0.00856802, -0.34575507]]))


def test_nuclear_attraction_simple_0_0():
    check_nuclear_attraction(
        np.array([ 1.]), np.array([ 1.]),
        np.array([ 1.,  1.,  1.]), np.array([ 1.,  1., -1.]),
        np.array([ 1.]),
        np.array([ 1.]),
        np.array([ 1.]),
        np.array([[ 0.,  0.,  0.]]),
        0, 0,
        np.array([[ 0.18751654]]))


def test_nuclear_attraction_simple_1_0():
    check_nuclear_attraction(
        np.array([ 1.]), np.array([ 1.]),
        np.array([ 1.,  1.,  1.]), np.array([ 1.,  1., -1.]),
        np.array([ 1.,  1.,  1.]),
        np.array([ 1.]),
        np.array([ 1.]),
        np.array([[ 0.,  0.,  0.]]),
        1, 0,
        np.array([[-0.02246616],
                  [-0.02246616],
                  [-0.18751654]]))


def test_nuclear_attraction_simple_1_1():
    check_nuclear_attraction(
        np.array([ 1.]), np.array([ 1.]),
        np.array([ 1.,  1.,  1.]), np.array([ 1.,  1., -1.]),
        np.array([ 1.,  1.,  1.]),
        np.array([ 1.,  1.,  1.]),
        np.array([ 1.]),
        np.array([[ 0.,  0.,  0.]]),
        1, 1,
        np.array([[ 0.048714  ,  0.00745141, -0.02246616],
                  [ 0.00745141,  0.048714  , -0.02246616],
                  [ 0.02246616,  0.02246616, -0.14625394]]))


def test_gb4_erilibint_class():
    max_shell_type = 4
    max_nbasis = get_shell_nbasis(max_shell_type)
    r0 = np.array([2.645617, 0.377945, -0.188973])
    r1 = np.array([0.000000, 0.000000, 0.188973])
    r2 = np.array([1.456687, 0.132147, -0.13572])
    r3 = np.array([0.798754, 0.456465, 0.465736])

    gb4i = GB4ElectronRepulsionIntegralLibInt(max_shell_type)
    assert gb4i.max_shell_type == max_shell_type
    assert gb4i.max_nbasis == max_nbasis
    assert gb4i.nwork == max_nbasis**4


def check_electron_repulsion(alphas0, alphas1, alphas2, alphas3, r0, r1, r2, r3, scales0, scales1, scales2, scales3, shell_type0, shell_type1, shell_type2, shell_type3, result0):
    # This test compares output from HORTON with reference data computed with
    # PyQuante.
    max_shell_type = 4
    max_nbasis = get_shell_nbasis(max_shell_type)
    gb4i = GB4ElectronRepulsionIntegralLibInt(max_shell_type)
    assert gb4i.max_nbasis == max_nbasis
    assert gb4i.nwork == max_nbasis**4

    nbasis0 = get_shell_nbasis(shell_type0)
    nbasis1 = get_shell_nbasis(shell_type1)
    nbasis2 = get_shell_nbasis(shell_type2)
    nbasis3 = get_shell_nbasis(shell_type3)
    assert result0.shape == (nbasis0, nbasis1, nbasis2, nbasis3)
    # Clear the working memory
    gb4i.reset(shell_type0, shell_type1, shell_type2, shell_type3, r0, r1, r2, r3)
    # Add a few cobtributions:
    for alpha0, alpha1, alpha2, alpha3 in zip(alphas0, alphas1, alphas2, alphas3):
        gb4i.add(1.0, alpha0, alpha1, alpha2, alpha3, scales0, scales1, scales2, scales3)
    result1 = gb4i.get_work(nbasis0, nbasis1, nbasis2, nbasis3)
    assert abs(result1 - result0).max() < 3e-7


def test_electron_repulsion_0_0_0_0_simple0():
    check_electron_repulsion(
        np.array([1.]), np.array([1.]),
        np.array([1.]), np.array([1.]),
        np.array([0., 0., 0.]), np.array([0., 0., 0.]),
        np.array([0., 0., 0.]), np.array([0., 0., 0.]),
        np.array([1.]),
        np.array([1.]),
        np.array([1.]),
        np.array([1.]),
        0, 0, 0, 0,
        np.array([[[[4.37335457]]]]))


def test_electron_repulsion_0_0_0_0_simple1():
    check_electron_repulsion(
        np.array([1.]), np.array([1.]),
        np.array([1.]), np.array([1.]),
        np.array([0., 0., 0.]), np.array([1., 1., 1.]),
        np.array([0., 0., 0.]), np.array([1., 1., 1.]),
        np.array([1.]),
        np.array([1.]),
        np.array([1.]),
        np.array([1.]),
        0, 0, 0, 0,
        np.array([[[[2.20567322]]]]))


def test_electron_repulsion_0_0_0_0_simple2():
    check_electron_repulsion(
        np.array([1.]), np.array([1.]),
        np.array([1.]), np.array([1.]),
        np.array([0.57092, 0.29608, -0.758]), np.array([-0.70841, 0.22864, 0.79589]),
        np.array([0.83984, 0.65053, 0.36087]), np.array([-0.62267, -0.83676, -0.75233]),
        np.array([1.]),
        np.array([1.]),
        np.array([1.]),
        np.array([1.]),
        0, 0, 0, 0,
        np.array([[[[0.19609589]]]]))


def test_electron_repulsion_0_0_0_0_simple3():
    check_electron_repulsion(
        np.array([0.57283]), np.array([1.74713]),
        np.array([0.21032]), np.array([1.60538]),
        np.array([0.82197, 0.73226, -0.98154]), np.array([0.57466, 0.17815, -0.25519]),
        np.array([0.00425, -0.33757, 0.08556]), np.array([-0.38717, 0.66721, 0.40838]),
        np.array([1.]),
        np.array([1.]),
        np.array([1.]),
        np.array([1.]),
        0, 0, 0, 0,
        np.array([[[[0.92553047]]]]))


def test_electron_repulsion_0_0_0_0_simple4():
    check_electron_repulsion(
        np.array([1.35491]), np.array([0.9714]),
        np.array([1.95585]), np.array([1.77853]),
        np.array([0.37263, -0.87382, 0.28078]), np.array([-0.08946, -0.52616, 0.69184]),
        np.array([-0.35128, 0.07017, 0.08193]), np.array([0.14543, -0.29499, -0.09769]),
        np.array([1.61086]),
        np.array([1.19397]),
        np.array([1.8119]),
        np.array([1.55646]),
        0, 0, 0, 0,
        np.array([[[[1.65373353]]]]))


def test_electron_repulsion_0_0_0_0():
    check_electron_repulsion(
        np.array([1.63216, 1.25493, 1.46134, 0.48024]),
        np.array([1.72365, 1.59905, 0.10447, 1.28324]),
        np.array([1.4105, 0.27134, 1.51238, 0.7518]),
        np.array([1.38488, 0.97611, 0.34149, 0.4326]),
        np.array([0.61356, -0.85284, -0.37151]), np.array([-0.63238, -0.81396, 0.40314]),
        np.array([0.29559, 0.60342, 0.18878]), np.array([-0.6893, 0.09175, -0.97283]),
        np.array([0.84965]),
        np.array([1.49169]),
        np.array([1.11046]),
        np.array([0.6665]),
        0, 0, 0, 0,
        np.array([[[[2.97326773]]]]))


def test_electron_repulsion_0_0_0_1():
    check_electron_repulsion(
        np.array([0.74579, 0.93686, 0.39742]), np.array([1.01349, 1.46072, 0.22295]),
        np.array([1.90756, 0.52423, 1.35586]), np.array([0.9655, 0.73539, 0.51017]),
        np.array([0.55177, 0.11232, -0.95152]), np.array([0.79941, 0.80782, 0.02287]),
        np.array([-0.52471, 0.59124, 0.434]), np.array([0.40758, 0.96818, 0.59852]),
        np.array([1.38989]),
        np.array([1.20619]),
        np.array([1.25917]),
        np.array([0.70246, 1.69253, 1.5632]),
        0, 0, 0, 1,
        np.array([[[[0.03331693, -3.27978915, -5.8871596]]]]))


def test_electron_repulsion_0_0_1_0():
    check_electron_repulsion(
        np.array([0.88609, 0.76883, 0.56082]), np.array([1.29216, 0.28671, 1.25389]),
        np.array([1.36987, 0.90792, 0.30511]), np.array([0.57079, 1.98163, 0.66835]),
        np.array([0.7706, 0.99091, -0.21592]), np.array([-0.00566, -0.37522, -0.3936]),
        np.array([0.1527, -0.95347, 0.16682]), np.array([-0.75347, -0.6388, -0.81567]),
        np.array([0.72776]),
        np.array([1.08088]),
        np.array([0.53874, 1.37722, 1.16945]),
        np.array([1.4472]),
        0, 0, 1, 0,
        np.array([[[[0.5110845], [5.86518324], [-1.47878266]]]]))


def test_electron_repulsion_0_0_1_1():
    check_electron_repulsion(
        np.array([0.94138, 0.23708, 1.33464]), np.array([1.89753, 0.54214, 0.80346]),
        np.array([1.04131, 1.6925, 0.81454]), np.array([1.06467, 0.55116, 1.21121]),
        np.array([0.6941, 0.3354, -0.49162]), np.array([0.68756, 0.49975, -0.69756]),
        np.array([0.60432, -0.01449, -0.26057]), np.array([0.35763, -0.04674, -0.78137]),
        np.array([0.75847]),
        np.array([0.57683]),
        np.array([1.61747, 0.59289, 0.93361]),
        np.array([1.38523, 1.77715, 0.8249]),
        0, 0, 1, 1,
        np.array([[[[0.57999607, 0.04732015, 0.00079488],
                    [0.09549513, 0.42707461, 0.03630467],
                    [-0.15902635, -0.25704193, 0.12295133]]]]))


def test_electron_repulsion_0_1_0_0():
    check_electron_repulsion(
        np.array([0.11308, 0.49861, 1.12215]), np.array([0.6186, 1.93501, 1.72751]),
        np.array([0.4644, 0.61371, 1.99408]), np.array([1.98686, 0.49338, 0.88466]),
        np.array([0.31794, 0.18412, 0.89808]), np.array([0.35463, 0.17042, 0.0682]),
        np.array([0.51676, -0.86674, -0.32785]), np.array([-0.03453, -0.05741, -0.86135]),
        np.array([1.84487]),
        np.array([1.17293, 1.02836, 0.50605]),
        np.array([0.54734]),
        np.array([1.55774]),
        0, 1, 0, 0,
        np.array([[[[-2.98984233]], [[-2.16665085]], [[-3.19087757]]]]))


def test_electron_repulsion_0_1_0_1():
    check_electron_repulsion(
        np.array([0.95345, 1.7616, 0.62144]), np.array([0.60537, 0.78954, 0.17662]),
        np.array([1.39946, 1.03161, 1.42837]), np.array([1.05228, 1.80211, 1.37614]),
        np.array([0.18086, -0.0927, -0.36495]), np.array([0.48062, -0.97782, -0.05878]),
        np.array([-0.55927, -0.95238, 0.33122]), np.array([0.17856, 0.06077, 0.62697]),
        np.array([0.9876]),
        np.array([1.39633, 1.30787, 1.80682]),
        np.array([0.93201]),
        np.array([1.21516, 1.84023, 1.59345]),
        0, 1, 0, 1,
        np.array([[[[1.11620596, 0.60061237, 0.36843148]],
                   [[-0.05340867, 0.33119515, -0.70418275]],
                   [[-0.04504112, -1.01394262, 1.17313632]]]]))


def test_electron_repulsion_0_1_1_1():
    check_electron_repulsion(
        np.array([1.60961, 1.48434, 1.09022]), np.array([1.49016, 0.78972, 1.01383]),
        np.array([1.357, 1.6929, 1.46297]), np.array([1.3126, 1.39773, 0.3295]),
        np.array([-0.74441, 0.13168, 0.17287]), np.array([-0.73242, 0.73598, -0.07688]),
        np.array([0.06303, 0.61361, 0.92689]), np.array([0.31395, 0.00081, -0.13425]),
        np.array([1.92653]),
        np.array([0.84324, 1.68215, 0.64055]),
        np.array([1.62317, 1.94784, 1.54325]),
        np.array([0.67873, 0.76053, 0.57816]),
        0, 1, 1, 1,
        np.array([[[[-0.06633908, -0.13761956, -0.03005655],
                    [-0.023407, -0.07813472, -0.03489736],
                    [-0.02263273, -0.20143856, -0.03550443]],
                   [[-0.40044718, -0.35436776, 0.07827812],
                    [-0.39382673, -0.18295174, 0.10845718],
                    [-0.37310311, -0.34400264, 0.05152883]],
                   [[0.07743294, -0.04648822, -0.2043075],
                    [0.03540926, -0.00400861, -0.13446393],
                    [0.02364929, -0.01807209, -0.18079094]]]]))


def test_electron_repulsion_1_0_0_1():
    check_electron_repulsion(
        np.array([0.39834, 1.4798, 1.80662]), np.array([1.9623, 0.88607, 0.93517]),
        np.array([0.46864, 1.1317, 0.67625]), np.array([1.52214, 0.93879, 0.71425]),
        np.array([-0.04796, 0.70504, 0.36481]), np.array([0.40599, 0.97607, 0.64758]),
        np.array([0.66271, -0.64123, -0.17474]), np.array([-0.60087, 0.25093, 0.32664]),
        np.array([0.68301, 1.18047, 1.44482]),
        np.array([0.97181]),
        np.array([1.18315]),
        np.array([0.79184, 1.41932, 1.32812]),
        1, 0, 0, 1,
        np.array([[[[0.16173756, 0.14265052, 0.05405344]]],
                  [[[-0.431925, -0.37295006, -0.1782411]]],
                  [[[-0.17915755, -0.20235955, 0.03526912]]]]))


def test_electron_repulsion_1_1_1_1():
    check_electron_repulsion(
        np.array([0.13992, 0.37329, 0.33259]), np.array([0.64139, 1.73019, 0.13917]),
        np.array([0.44337, 1.28161, 0.3277]), np.array([1.24252, 1.27924, 1.45445]),
        np.array([0.02582, 0.94923, -0.17438]), np.array([-0.81301, 0.086, -0.77236]),
        np.array([-0.67901, 0.6566, -0.45438]), np.array([-0.02669, -0.13942, -0.98892]),
        np.array([1.01729, 0.83942, 1.15976]),
        np.array([1.92943, 1.10829, 0.87557]),
        np.array([0.58667, 0.97031, 1.31261]),
        np.array([1.57111, 0.74218, 0.68171]),
        1, 1, 1, 1,
        np.array([[[[5.38092832, 0.67101024, 0.50643354],
                    [-0.36637823, -0.17128347, 0.00749151],
                    [-0.47015285, -0.00846274, -0.23514519]],
                   [[0.31412053, 1.85552661, -0.05096966],
                    [0.5668773, 0.04019152, -0.05803149],
                    [-0.02195855, 0.00256108, 0.03373068]],
                   [[0.26139911, -0.05908764, 1.34729127],
                    [0.03563575, 0.02599451, 0.0669569],
                    [0.6249628, -0.09012696, -0.02559206]]],
                  [[[-1.30079959, 0.06525516, -0.24130176],
                    [7.90805546, 0.5029288, 1.03164863],
                    [0.22531828, -0.01518479, -0.63472654]],
                   [[0.07758755, -0.30344079, 0.03679751],
                    [0.88274549, 3.43263474, -0.20761467],
                    [0.09249023, 0.10854722, 0.15741632]],
                   [[0.0082139, -0.00382022, -0.24202072],
                    [0.44155444, -0.06437548, 2.40552259],
                    [0.29276089, 0.01725224, 0.05956368]]],
                  [[[-1.45339037, -0.37266055, 0.25844897],
                    [0.41152374, -0.40525461, -0.16607501],
                    [14.23224926, 2.34068558, 0.65653732]],
                   [[-0.00776144, -0.38261119, -0.0073076],
                    [0.28311943, 0.14089539, 0.08426703],
                    [0.91304633, 5.92042353, -0.12886949]],
                   [[0.09807363, 0.06281554, -0.25920407],
                    [0.15636252, 0.10752926, 0.14182457],
                    [1.2142302, -0.38098265, 4.57694241]]]]))


def test_electron_repulsion_0_2_1_0():
    check_electron_repulsion(
        np.array([1.36794, 1.14001, 1.97798]), np.array([1.68538, 0.75019, 0.72741]),
        np.array([1.55248, 0.78842, 1.84644]), np.array([1.73266, 0.46153, 0.63621]),
        np.array([0.05517, 0.27196, -0.98928]), np.array([-0.20526, 0.27314, -0.16208]),
        np.array([-0.00876, -0.47585, 0.88613]), np.array([0.75034, 0.54371, -0.1464]),
        np.array([0.50974]),
        np.array([1.36246, 0.58913, 0.73488, 0.53568, 1.11864, 1.80388]),
        np.array([1.62815, 0.58942, 1.52452]),
        np.array([0.66094]),
        0, 2, 1, 0,
        np.array([[[[0.03940319], [0.05597157], [-0.32990373]],
                   [[-0.00066587], [0.00221213], [-0.00319745]],
                   [[-0.00035194], [-0.00011777], [0.0063613]],
                   [[0.00478058], [0.01592957], [-0.09687372]],
                   [[0.00002574], [-0.00009517], [-0.00166564]],
                   [[0.01578456], [0.05420504], [-0.32175899]]]]))


def test_electron_repulsion_0_2_2_3():
    with open(context.get_fn('test/electron_repulsion_0_2_2_3.json')) as f:
        result0 = np.array(json.load(f))
    check_electron_repulsion(
        np.array([0.96867, 0.41743, 1.03509]), np.array([1.84594, 0.83035, 1.20242]),
        np.array([0.94861, 0.47292, 0.38655]), np.array([1.3009, 1.10486, 1.4979]),
        np.array([0.10017, 0.21708, 0.08942]), np.array([-0.03049, 0.99486, -0.37959]),
        np.array([-0.7765, 0.53988, 0.25643]), np.array([0.60758, 0.85146, 0.15088]),
        np.array([1.14284]),
        np.array([1.39723, 1.77896, 0.72525, 0.99877, 1.5953, 0.69473]),
        np.array([0.56774, 1.69348, 1.8146, 0.85426, 1.35434, 1.87402]),
        np.array([0.99964, 1.45499, 1.35143, 1.9758, 0.58887, 1.40713, 0.55226, 1.44979,
                  0.57156, 0.71009]),
        0, 2, 2, 3,
        result0)


def test_electron_repulsion_4_3_2_1():
    with open(context.get_fn('test/electron_repulsion_4_3_2_1.json')) as f:
        result0 = np.array(json.load(f))
    check_electron_repulsion(
        np.array([0.94212, 1.71823, 0.3309]), np.array([0.94854, 0.12816, 0.42016]),
        np.array([0.46046, 0.43321, 1.0587]), np.array([1.0089, 0.52286, 1.83539]),
        np.array([-0.48859, 0.6043, -0.57858]), np.array([0.74567, -0.82555, -0.30631]),
        np.array([-0.5679, -0.08725, 0.7623]), np.array([0.10338, 0.65407, -0.20172]),
        np.array([1.10904, 1.40637, 1.8707, 0.68295, 1.29692, 0.99892, 1.13936, 0.81258,
                  0.50325, 1.27698, 1.81192, 1.43415, 1.1686, 1.38063, 0.61592]),
        np.array([1.19368, 0.75291, 0.63535, 1.22654, 1.32848, 1.17482, 1.74897, 0.93964,
                  1.90303, 1.44528]),
        np.array([1.63343, 1.80498, 1.61313, 0.99992, 1.04505, 1.42297]),
        np.array([1.4825, 1.69421, 1.8635]),
        4, 3, 2, 1,
        result0)


def get_erf_repulsion(alphas0, alphas1, alphas2, alphas3, r0, r1, r2, r3,
                      scales0, scales1, scales2, scales3,
                      shell_type0, shell_type1, shell_type2, shell_type3, mu):
    """Get the short-range damped Erf integrals for a primitive shell.

    Parameters
    ----------
    alpha0, alpha1, alpha2, alpha3 : float
        Exponents of the four primitive shells.
    r0, r1, r2, r3 : np.ndarray, shape=(3,), dtype=float
        Cartesian coordinates of the centers of the four primitive shells.
    scales0, scales1, scales2, scales3 : float
        Normalization prefactors for the Gaussian shells.
    shell_type0, shell_type1, shell_type2, shell_type3 : int
        Shell types of the four primitive shells.
    mu : float
        The range-separation parameters.
    """
    max_shell_type = 4
    gb4i = GB4ErfIntegralLibInt(max_shell_type, mu)

    nbasis0 = get_shell_nbasis(shell_type0)
    nbasis1 = get_shell_nbasis(shell_type1)
    nbasis2 = get_shell_nbasis(shell_type2)
    nbasis3 = get_shell_nbasis(shell_type3)
    # Clear the working memory
    gb4i.reset(shell_type0, shell_type1, shell_type2, shell_type3, r0, r1, r2, r3)
    # Add a few cobtributions:
    for alpha0, alpha1, alpha2, alpha3 in zip(alphas0, alphas1, alphas2, alphas3):
        gb4i.add(1.0, alpha0, alpha1, alpha2, alpha3, scales0, scales1, scales2, scales3)
    return gb4i.get_work(nbasis0, nbasis1, nbasis2, nbasis3)


def check_erf_repulsion(alphas0, alphas1, alphas2, alphas3, r0, r1, r2, r3, scales0,
                        scales1, scales2, scales3, shell_type0, shell_type1, shell_type2,
                        shell_type3, result0, mu):
    """Compare output from HORTON Erf integrals with reference data.

    The reference data was generated with a Mathematica script of Julien Toulouse and
    Andreas Savin.

    Parameters
    ----------
    alpha0, alpha1, alpha2, alpha3 : float
        Exponents of the four primitive shells.
    r0, r1, r2, r3 : np.ndarray, shape=(3,), dtype=float
        Cartesian coordinates of the centers of the four primitive shells.
    scales0, scales1, scales2, scales3 : float
        Normalization prefactors for the Gaussian shells.
    shell_type0, shell_type1, shell_type2, shell_type3 : int
        Shell types of the four primitive shells.
    result0 : np.ndarray, shape=(nbasis, nbasis, nbasis, nbasis), dtype=float
        The expected result.
    mu : float
        The range-separation parameters.
    """
    result1 = get_erf_repulsion(alphas0, alphas1, alphas2, alphas3, r0, r1, r2, r3,
                                scales0, scales1, scales2, scales3, shell_type0,
                                shell_type1, shell_type2, shell_type3, mu)
    assert abs(result1 - result0).max() < 3e-7


def test_erf_repulsion_0_0_0_0_simple0():
    check_erf_repulsion(
        np.array([1.]), np.array([1.]),
        np.array([1.]), np.array([1.]),
        np.array([0., 0., 0.]), np.array([0., 0., 0.]),
        np.array([0., 0., 0.]), np.array([0., 0., 0.]),
        np.array([1.]),
        np.array([1.]),
        np.array([1.]),
        np.array([1.]),
        0, 0, 0, 0,
        np.array([[[[1.25667419]]]]), 0.3)


def test_erf_repulsion_0_0_0_0_simple1():
    check_erf_repulsion(
        np.array([1.]), np.array([1.]),
        np.array([1.]), np.array([1.]),
        np.array([0., 0., 0.]), np.array([1., 1., 1.]),
        np.array([0., 0., 0.]), np.array([1., 1., 1.]),
        np.array([1.]),
        np.array([1.]),
        np.array([1.]),
        np.array([1.]),
        0, 0, 0, 0,
        np.array([[[[1.16018914]]]]), 0.3)


def test_erf_repulsion_0_0_0_0_simple2():
    check_erf_repulsion(
        np.array([1.]), np.array([1.]),
        np.array([1.]), np.array([1.]),
        np.array([0.57092, 0.29608, -0.758]), np.array([-0.70841, 0.22864, 0.79589]),
        np.array([0.83984, 0.65053, 0.36087]), np.array([-0.62267, -0.83676, -0.75233]),
        np.array([1.]),
        np.array([1.]),
        np.array([1.]),
        np.array([1.]),
        0, 0, 0, 0,
        np.array([[[[0.09691428]]]]), 0.3)


def test_erf_repulsion_0_0_0_0_simple3():
    check_erf_repulsion(
        np.array([0.57283]), np.array([1.74713]),
        np.array([0.21032]), np.array([1.60538]),
        np.array([0.82197, 0.73226, -0.98154]), np.array([0.57466, 0.17815, -0.25519]),
        np.array([0.00425, -0.33757, 0.08556]), np.array([-0.38717, 0.66721, 0.40838]),
        np.array([1.]),
        np.array([1.]),
        np.array([1.]),
        np.array([1.]),
        0, 0, 0, 0,
        np.array([[[[0.807980035]]]]), 1.2)


def test_erf_repulsion_0_0_0_0_simple4():
    check_erf_repulsion(
        np.array([1.35491]), np.array([0.9714]),
        np.array([1.95585]), np.array([1.77853]),
        np.array([0.37263, -0.87382, 0.28078]), np.array([-0.08946, -0.52616, 0.69184]),
        np.array([-0.35128, 0.07017, 0.08193]), np.array([0.14543, -0.29499, -0.09769]),
        np.array([1.61086]),
        np.array([1.19397]),
        np.array([1.8119]),
        np.array([1.55646]),
        0, 0, 0, 0,
        np.array([[[[1.16218348]]]]), 1.2)


def test_erf_repulsion_0_0_0_0():
    check_erf_repulsion(
        np.array([1.63216, 1.25493, 1.46134, 0.48024]),
        np.array([1.72365, 1.59905, 0.10447, 1.28324]),
        np.array([1.4105, 0.27134, 1.51238, 0.7518]),
        np.array([1.38488, 0.97611, 0.34149, 0.4326]),
        np.array([0.61356, -0.85284, -0.37151]), np.array([-0.63238, -0.81396, 0.40314]),
        np.array([0.29559, 0.60342, 0.18878]), np.array([-0.6893, 0.09175, -0.97283]),
        np.array([0.84965]),
        np.array([1.49169]),
        np.array([1.11046]),
        np.array([0.6665]),
        0, 0, 0, 0,
        np.array([[[[2.97326773]]]]), 10000)


def test_erf_repulsion_0_0_0_1():
    check_erf_repulsion(
        np.array([0.74579, 0.93686, 0.39742]), np.array([1.01349, 1.46072, 0.22295]),
        np.array([1.90756, 0.52423, 1.35586]), np.array([0.9655, 0.73539, 0.51017]),
        np.array([0.55177, 0.11232, -0.95152]), np.array([0.79941, 0.80782, 0.02287]),
        np.array([-0.52471, 0.59124, 0.434]), np.array([0.40758, 0.96818, 0.59852]),
        np.array([1.38989]),
        np.array([1.20619]),
        np.array([1.25917]),
        np.array([0.70246, 1.69253, 1.5632]),
        0, 0, 0, 1,
        np.array([[[[0.03331693, -3.27978915, -5.8871596]]]]), 10000)


def test_erf_repulsion_0_0_1_0():
    check_erf_repulsion(
        np.array([0.88609, 0.76883, 0.56082]), np.array([1.29216, 0.28671, 1.25389]),
        np.array([1.36987, 0.90792, 0.30511]), np.array([0.57079, 1.98163, 0.66835]),
        np.array([0.7706, 0.99091, -0.21592]), np.array([-0.00566, -0.37522, -0.3936]),
        np.array([0.1527, -0.95347, 0.16682]), np.array([-0.75347, -0.6388, -0.81567]),
        np.array([0.72776]),
        np.array([1.08088]),
        np.array([0.53874, 1.37722, 1.16945]),
        np.array([1.4472]),
        0, 0, 1, 0,
        np.array([[[[0.5110845], [5.86518324], [-1.47878266]]]]), 10000)


def test_erf_repulsion_0_0_1_1():
    check_erf_repulsion(
        np.array([0.94138, 0.23708, 1.33464]), np.array([1.89753, 0.54214, 0.80346]),
        np.array([1.04131, 1.6925, 0.81454]), np.array([1.06467, 0.55116, 1.21121]),
        np.array([0.6941, 0.3354, -0.49162]), np.array([0.68756, 0.49975, -0.69756]),
        np.array([0.60432, -0.01449, -0.26057]), np.array([0.35763, -0.04674, -0.78137]),
        np.array([0.75847]),
        np.array([0.57683]),
        np.array([1.61747, 0.59289, 0.93361]),
        np.array([1.38523, 1.77715, 0.8249]),
        0, 0, 1, 1,
        np.array([[[[0.57999607, 0.04732015, 0.00079488],
                    [0.09549513, 0.42707461, 0.03630467],
                    [-0.15902635, -0.25704193, 0.12295133]]]]),
        10000)


def test_erf_repulsion_0_1_0_0():
    check_erf_repulsion(
        np.array([0.11308, 0.49861, 1.12215]), np.array([0.6186, 1.93501, 1.72751]),
        np.array([0.4644, 0.61371, 1.99408]), np.array([1.98686, 0.49338, 0.88466]),
        np.array([0.31794, 0.18412, 0.89808]), np.array([0.35463, 0.17042, 0.0682]),
        np.array([0.51676, -0.86674, -0.32785]), np.array([-0.03453, -0.05741, -0.86135]),
        np.array([1.84487]),
        np.array([1.17293, 1.02836, 0.50605]),
        np.array([0.54734]),
        np.array([1.55774]),
        0, 1, 0, 0,
        np.array([[[[-2.98984233]], [[-2.16665085]], [[-3.19087757]]]]), 10000)


def test_erf_repulsion_0_1_0_1():
    check_erf_repulsion(
        np.array([0.95345, 1.7616, 0.62144]), np.array([0.60537, 0.78954, 0.17662]),
        np.array([1.39946, 1.03161, 1.42837]), np.array([1.05228, 1.80211, 1.37614]),
        np.array([0.18086, -0.0927, -0.36495]), np.array([0.48062, -0.97782, -0.05878]),
        np.array([-0.55927, -0.95238, 0.33122]), np.array([0.17856, 0.06077, 0.62697]),
        np.array([0.9876]),
        np.array([1.39633, 1.30787, 1.80682]),
        np.array([0.93201]),
        np.array([1.21516, 1.84023, 1.59345]),
        0, 1, 0, 1,
        np.array([[[[1.11620596, 0.60061237, 0.36843148]],
                   [[-0.05340867, 0.33119515, -0.70418275]],
                   [[-0.04504112, -1.01394262, 1.17313632]]]]),
        10000)


def test_erf_repulsion_0_1_1_1():
    check_erf_repulsion(
        np.array([1.60961, 1.48434, 1.09022]), np.array([1.49016, 0.78972, 1.01383]),
        np.array([1.357, 1.6929, 1.46297]), np.array([1.3126, 1.39773, 0.3295]),
        np.array([-0.74441, 0.13168, 0.17287]), np.array([-0.73242, 0.73598, -0.07688]),
        np.array([0.06303, 0.61361, 0.92689]), np.array([0.31395, 0.00081, -0.13425]),
        np.array([1.92653]),
        np.array([0.84324, 1.68215, 0.64055]),
        np.array([1.62317, 1.94784, 1.54325]),
        np.array([0.67873, 0.76053, 0.57816]),
        0, 1, 1, 1,
        np.array([[[[-0.06633908, -0.13761956, -0.03005655],
                    [-0.023407, -0.07813472, -0.03489736],
                    [-0.02263273, -0.20143856, -0.03550443]],
                   [[-0.40044718, -0.35436776, 0.07827812],
                    [-0.39382673, -0.18295174, 0.10845718],
                    [-0.37310311, -0.34400264, 0.05152883]],
                   [[0.07743294, -0.04648822, -0.2043075],
                    [0.03540926, -0.00400861, -0.13446393],
                    [0.02364929, -0.01807209, -0.18079094]]]]),
        10000)


def test_erf_repulsion_1_0_0_1():
    check_erf_repulsion(
        np.array([0.39834, 1.4798, 1.80662]), np.array([1.9623, 0.88607, 0.93517]),
        np.array([0.46864, 1.1317, 0.67625]), np.array([1.52214, 0.93879, 0.71425]),
        np.array([-0.04796, 0.70504, 0.36481]), np.array([0.40599, 0.97607, 0.64758]),
        np.array([0.66271, -0.64123, -0.17474]), np.array([-0.60087, 0.25093, 0.32664]),
        np.array([0.68301, 1.18047, 1.44482]),
        np.array([0.97181]),
        np.array([1.18315]),
        np.array([0.79184, 1.41932, 1.32812]),
        1, 0, 0, 1,
        np.array([[[[0.16173756, 0.14265052, 0.05405344]]],
                  [[[-0.431925, -0.37295006, -0.1782411]]],
                  [[[-0.17915755, -0.20235955, 0.03526912]]]]),
        10000)


def test_erf_repulsion_1_1_1_1():
    check_erf_repulsion(
        np.array([0.13992, 0.37329, 0.33259]), np.array([0.64139, 1.73019, 0.13917]),
        np.array([0.44337, 1.28161, 0.3277]), np.array([1.24252, 1.27924, 1.45445]),
        np.array([0.02582, 0.94923, -0.17438]), np.array([-0.81301, 0.086, -0.77236]),
        np.array([-0.67901, 0.6566, -0.45438]), np.array([-0.02669, -0.13942, -0.98892]),
        np.array([1.01729, 0.83942, 1.15976]),
        np.array([1.92943, 1.10829, 0.87557]),
        np.array([0.58667, 0.97031, 1.31261]),
        np.array([1.57111, 0.74218, 0.68171]),
        1, 1, 1, 1,
        np.array([[[[5.38092832, 0.67101024, 0.50643354],
                    [-0.36637823, -0.17128347, 0.00749151],
                    [-0.47015285, -0.00846274, -0.23514519]],
                   [[0.31412053, 1.85552661, -0.05096966],
                    [0.5668773, 0.04019152, -0.05803149],
                    [-0.02195855, 0.00256108, 0.03373068]],
                   [[0.26139911, -0.05908764, 1.34729127],
                    [0.03563575, 0.02599451, 0.0669569],
                    [0.6249628, -0.09012696, -0.02559206]]],
                  [[[-1.30079959, 0.06525516, -0.24130176],
                    [7.90805546, 0.5029288, 1.03164863],
                    [0.22531828, -0.01518479, -0.63472654]],
                   [[0.07758755, -0.30344079, 0.03679751],
                    [0.88274549, 3.43263474, -0.20761467],
                    [0.09249023, 0.10854722, 0.15741632]],
                   [[0.0082139, -0.00382022, -0.24202072],
                    [0.44155444, -0.06437548, 2.40552259],
                    [0.29276089, 0.01725224, 0.05956368]]],
                  [[[-1.45339037, -0.37266055, 0.25844897],
                    [0.41152374, -0.40525461, -0.16607501],
                    [14.23224926, 2.34068558, 0.65653732]],
                   [[-0.00776144, -0.38261119, -0.0073076],
                    [0.28311943, 0.14089539, 0.08426703],
                    [0.91304633, 5.92042353, -0.12886949]],
                   [[0.09807363, 0.06281554, -0.25920407],
                    [0.15636252, 0.10752926, 0.14182457],
                    [1.2142302, -0.38098265, 4.57694241]]]]),
        10000)


def test_erf_repulsion_0_2_1_0():
    check_erf_repulsion(
        np.array([1.36794, 1.14001, 1.97798]), np.array([1.68538, 0.75019, 0.72741]),
        np.array([1.55248, 0.78842, 1.84644]), np.array([1.73266, 0.46153, 0.63621]),
        np.array([0.05517, 0.27196, -0.98928]), np.array([-0.20526, 0.27314, -0.16208]),
        np.array([-0.00876, -0.47585, 0.88613]), np.array([0.75034, 0.54371, -0.1464]),
        np.array([0.50974]),
        np.array([1.36246, 0.58913, 0.73488, 0.53568, 1.11864, 1.80388]),
        np.array([1.62815, 0.58942, 1.52452]),
        np.array([0.66094]),
        0, 2, 1, 0,
        np.array([[[[0.03940319], [0.05597157], [-0.32990373]],
                   [[-0.00066587], [0.00221213], [-0.00319745]],
                   [[-0.00035194], [-0.00011777], [0.0063613]],
                   [[0.00478058], [0.01592957], [-0.09687372]],
                   [[0.00002574], [-0.00009517], [-0.00166564]],
                   [[0.01578456], [0.05420504], [-0.32175899]]]]),
        10000)


def test_erf_repulsion_0_2_2_3():
    with open(context.get_fn('test/erf_repulsion_0_2_2_3.json')) as f:
        result0 = np.array(json.load(f))
    check_erf_repulsion(
        np.array([0.96867, 0.41743, 1.03509]), np.array([1.84594, 0.83035, 1.20242]),
        np.array([0.94861, 0.47292, 0.38655]), np.array([1.3009, 1.10486, 1.4979]),
        np.array([0.10017, 0.21708, 0.08942]), np.array([-0.03049, 0.99486, -0.37959]),
        np.array([-0.7765, 0.53988, 0.25643]), np.array([0.60758, 0.85146, 0.15088]),
        np.array([1.14284]),
        np.array([1.39723, 1.77896, 0.72525, 0.99877, 1.5953, 0.69473]),
        np.array([0.56774, 1.69348, 1.8146, 0.85426, 1.35434, 1.87402]),
        np.array([0.99964, 1.45499, 1.35143, 1.9758, 0.58887, 1.40713, 0.55226,
                  1.44979, 0.57156, 0.71009]),
        0, 2, 2, 3,
        result0, 10000)


def test_erf_repulsion_4_3_2_1():
    with open(context.get_fn('test/erf_repulsion_4_3_2_1.json')) as f:
        result0 = np.array(json.load(f))
    check_erf_repulsion(
        np.array([0.94212, 1.71823, 0.3309]), np.array([0.94854, 0.12816, 0.42016]),
        np.array([0.46046, 0.43321, 1.0587]), np.array([1.0089, 0.52286, 1.83539]),
        np.array([-0.48859, 0.6043, -0.57858]), np.array([0.74567, -0.82555, -0.30631]),
        np.array([-0.5679, -0.08725, 0.7623]), np.array([0.10338, 0.65407, -0.20172]),
        np.array([1.10904, 1.40637, 1.8707, 0.68295, 1.29692, 0.99892, 1.13936, 0.81258,
                  0.50325, 1.27698, 1.81192, 1.43415, 1.1686, 1.38063, 0.61592]),
        np.array([1.19368, 0.75291, 0.63535, 1.22654, 1.32848, 1.17482, 1.74897, 0.93964,
                  1.90303, 1.44528]),
        np.array([1.63343, 1.80498, 1.61313, 0.99992, 1.04505, 1.42297]),
        np.array([1.4825, 1.69421, 1.8635]),
        4, 3, 2, 1,
        result0, 10000)


def test_erf_repulsion_h2_sto3g():
    mol2 = IOData.from_file(context.get_fn('test/FCIDUMP.molpro.h2-erf'))

    mol = IOData(title='h2')
    mol.coordinates = np.array([[0.0, 0.0, 0.0], [0.0, 0.0, 1.4]])
    mol.numbers = np.array([1, 1])

    # Create a Gaussian basis set
    obasis = get_gobasis(mol.coordinates, mol.numbers, 'sto-3g')

    # Create a linalg factory
    lf = DenseLinalgFactory(obasis.nbasis)

    # Compute Gaussian integrals
    olp = obasis.compute_overlap(lf)
    kin = obasis.compute_kinetic(lf)
    na = obasis.compute_nuclear_attraction(mol.coordinates, mol.pseudo_numbers, lf)
    er = obasis.compute_erf_repulsion(lf, 2.25)

    # Create alpha orbitals
    exp_alpha = lf.create_expansion()

    # Initial guess
    guess_core_hamiltonian(olp, kin, na, exp_alpha)
    mol.exp_alpha = exp_alpha

    # Transform orbitals
    one = kin.copy()
    one.iadd(na)
    two = er
    two_mo = transform_integrals(one, two, 'tensordot', mol.exp_alpha)[1][0]
    assert abs(mol2.two_mo._array - two_mo._array).max() < 1e-10


def check_gauss_repulsion(alphas0, alphas1, alphas2, alphas3, r0, r1, r2, r3, scales0,
                          scales1, scales2, scales3, shell_type0, shell_type1,
                          shell_type2, shell_type3, result0, c, alpha):
    """Compare output from HORTON Gauss 4-center integrals with reference data.

    The reference data was generated with a Mathematica script of Julien Toulouse and
    Andreas Savin.

    Parameters
    ----------
    alpha0, alpha1, alpha2, alpha3 : float
        Exponents of the four primitive shells.
    r0, r1, r2, r3 : np.ndarray, shape=(3,), dtype=float
        Cartesian coordinates of the centers of the four primitive shells.
    scales0, scales1, scales2, scales3 : float
        Normalization prefactors for the Gaussian shells.
    shell_type0, shell_type1, shell_type2, shell_type3 : int
        Shell types of the four primitive shells.
    result0 : np.ndarray, shape=(nbasis, nbasis, nbasis, nbasis), dtype=float
        The expected result.
    c : float
        Coefficient of the gaussian.
    alpha : float
        Exponential parameter of the gaussian.
    """
    max_shell_type = 4
    max_nbasis = get_shell_nbasis(max_shell_type)
    gb4i = GB4GaussIntegralLibInt(max_shell_type, c, alpha)
    assert gb4i.max_nbasis == max_nbasis
    assert gb4i.alpha == alpha
    assert gb4i.c == c
    assert gb4i.nwork == max_nbasis**4

    nbasis0 = get_shell_nbasis(shell_type0)
    nbasis1 = get_shell_nbasis(shell_type1)
    nbasis2 = get_shell_nbasis(shell_type2)
    nbasis3 = get_shell_nbasis(shell_type3)
    assert result0.shape == (nbasis0, nbasis1, nbasis2, nbasis3)
    # Clear the working memory
    gb4i.reset(shell_type0, shell_type1, shell_type2, shell_type3, r0, r1, r2, r3)
    # Add a few cobtributions:
    for alpha0, alpha1, alpha2, alpha3 in zip(alphas0, alphas1, alphas2, alphas3):
        gb4i.add(1.0, alpha0, alpha1, alpha2, alpha3, scales0, scales1, scales2, scales3)
    result1 = gb4i.get_work(nbasis0, nbasis1, nbasis2, nbasis3)
    assert abs(result1 - result0).max() < 3e-7


def test_gauss_repulsion_0_0_0_0_simple0():
    check_gauss_repulsion(
        np.array([1.]), np.array([1.]),
        np.array([1.]), np.array([1.]),
        np.array([0., 0., 0.]), np.array([0., 0., 0.]),
        np.array([0., 0., 0.]), np.array([0., 0., 0.]),
        np.array([1.]),
        np.array([1.]),
        np.array([1.]),
        np.array([1.]),
        0, 0, 0, 0,
        np.array([[[[-2.45402796]]]]), -2.256758334191, 4./3.)


def test_gauss_repulsion_0_0_0_0_simple1():
    check_gauss_repulsion(
        np.array([1.]), np.array([1.]),
        np.array([1.]), np.array([1.]),
        np.array([0., 0., 0.]), np.array([0., 0., 0.]),
        np.array([0., 0., 0.]), np.array([0., 0., 0.]),
        np.array([1.]),
        np.array([1.]),
        np.array([1.]),
        np.array([1.]),
        0, 0, 0, 0,
        np.array([[[[1.1063809195]]]]), -0.5*-2.25675833419102, 1.5)


def test_gauss_repulsion_0_0_0_0_simple2():
    check_gauss_repulsion(
        np.array([1.]), np.array([1.]),
        np.array([1.]), np.array([1.]),
        np.array([0., 0., 0.]), np.array([1., 1., 1.]),
        np.array([0., 0., 0.]), np.array([1., 1., 1.]),
        np.array([1.]),
        np.array([1.]),
        np.array([1.]),
        np.array([1.]),
        0, 0, 0, 0,
        np.array([[[[-0.44195157099]]]]), -2.25675833419102, 4./3.)


def test_gauss_repulsion_0_0_0_0_simple3():
    check_gauss_repulsion(
        np.array([1.]), np.array([1.]),
        np.array([1.]), np.array([1.]),
        np.array([0.57092, 0.29608, -0.758]), np.array([-0.70841, 0.22864, 0.79589]),
        np.array([0.83984, 0.65053, 0.36087]), np.array([-0.62267, -0.83676, -0.75233]),
        np.array([1.]),
        np.array([1.]),
        np.array([1.]),
        np.array([1.]),
        0, 0, 0, 0,
        np.array([[[[-0.0476487494]]]]), -2.25675833419102, 4./3.)


def test_gauss_repulsion_0_0_0_0_simple4():
    check_gauss_repulsion(
        np.array([0.57283]), np.array([1.74713]),
        np.array([0.21032]), np.array([1.60538]),
        np.array([0.82197, 0.73226, -0.98154]), np.array([0.57466, 0.17815, -0.25519]),
        np.array([0.00425, -0.33757, 0.08556]), np.array([-0.38717, 0.66721, 0.40838]),
        np.array([1.]),
        np.array([1.]),
        np.array([1.]),
        np.array([1.]),
        0, 0, 0, 0,
        np.array([[[[-0.352255229]]]]), -2.25675833419102, 4./3.)


def test_gauss_repulsion_0_0_0_0_simple5():
    check_gauss_repulsion(
        np.array([1.35491]), np.array([0.9714]),
        np.array([1.95585]), np.array([1.77853]),
        np.array([0.37263, -0.87382, 0.28078]), np.array([-0.08946, -0.52616, 0.69184]),
        np.array([-0.35128, 0.07017, 0.08193]), np.array([0.14543, -0.29499, -0.09769]),
        np.array([1.61086]),
        np.array([1.19397]),
        np.array([1.8119]),
        np.array([1.55646]),
        0, 0, 0, 0,
        np.array([[[[-1.03673858]]]]), -2.25675833419102, 4./3.)


def test_gauss_repulsion_0_0_0_1():
    epsilon = 0.000001
    e1 = get_erf_repulsion(
        np.array([0.74579, 0.93686, 0.39742]), np.array([1.01349, 1.46072, 0.22295]),
        np.array([1.90756, 0.52423, 1.35586]), np.array([0.9655, 0.73539, 0.51017]),
        np.array([0.55177, 0.11232, -0.95152]), np.array([0.79941, 0.80782, 0.02287]),
        np.array([-0.52471, 0.59124, 0.434]), np.array([0.40758, 0.96818, 0.59852]),
        np.array([1.38989]),
        np.array([1.20619]),
        np.array([1.25917]),
        np.array([0.70246, 1.69253, 1.5632]),
        0, 0, 0, 1, 2.0)
    e2 = get_erf_repulsion(
        np.array([0.74579, 0.93686, 0.39742]), np.array([1.01349, 1.46072, 0.22295]),
        np.array([1.90756, 0.52423, 1.35586]), np.array([0.9655, 0.73539, 0.51017]),
        np.array([0.55177, 0.11232, -0.95152]), np.array([0.79941, 0.80782, 0.02287]),
        np.array([-0.52471, 0.59124, 0.434]), np.array([0.40758, 0.96818, 0.59852]),
        np.array([1.38989]),
        np.array([1.20619]),
        np.array([1.25917]),
        np.array([0.70246, 1.69253, 1.5632]),
        0, 0, 0, 1, 2.0+epsilon)
    egauss = (e2 - e1)/epsilon
    check_gauss_repulsion(
        np.array([0.74579, 0.93686, 0.39742]), np.array([1.01349, 1.46072, 0.22295]),
        np.array([1.90756, 0.52423, 1.35586]), np.array([0.9655, 0.73539, 0.51017]),
        np.array([0.55177, 0.11232, -0.95152]), np.array([0.79941, 0.80782, 0.02287]),
        np.array([-0.52471, 0.59124, 0.434]), np.array([0.40758, 0.96818, 0.59852]),
        np.array([1.38989]),
        np.array([1.20619]),
        np.array([1.25917]),
        np.array([0.70246, 1.69253, 1.5632]),
        0, 0, 0, 1, egauss, 1.1283791670955126, 4.0)


def test_gauss_repulsion_0_0_1_0():
    epsilon = 0.000001
    e1 = get_erf_repulsion(
        np.array([0.88609, 0.76883, 0.56082]), np.array([1.29216, 0.28671, 1.25389]),
        np.array([1.36987, 0.90792, 0.30511]), np.array([0.57079, 1.98163, 0.66835]),
        np.array([0.7706, 0.99091, -0.21592]), np.array([-0.00566, -0.37522, -0.3936]),
        np.array([0.1527, -0.95347, 0.16682]), np.array([-0.75347, -0.6388, -0.81567]),
        np.array([0.72776]),
        np.array([1.08088]),
        np.array([0.53874, 1.37722, 1.16945]),
        np.array([1.4472]),
        0, 0, 1, 0, 2.0)
    e2 = get_erf_repulsion(
        np.array([0.88609, 0.76883, 0.56082]), np.array([1.29216, 0.28671, 1.25389]),
        np.array([1.36987, 0.90792, 0.30511]), np.array([0.57079, 1.98163, 0.66835]),
        np.array([0.7706, 0.99091, -0.21592]), np.array([-0.00566, -0.37522, -0.3936]),
        np.array([0.1527, -0.95347, 0.16682]), np.array([-0.75347, -0.6388, -0.81567]),
        np.array([0.72776]),
        np.array([1.08088]),
        np.array([0.53874, 1.37722, 1.16945]),
        np.array([1.4472]),
        0, 0, 1, 0, 2.0+epsilon)
    egauss = (e2 - e1)/epsilon
    check_gauss_repulsion(
        np.array([0.88609, 0.76883, 0.56082]), np.array([1.29216, 0.28671, 1.25389]),
        np.array([1.36987, 0.90792, 0.30511]), np.array([0.57079, 1.98163, 0.66835]),
        np.array([0.7706, 0.99091, -0.21592]), np.array([-0.00566, -0.37522, -0.3936]),
        np.array([0.1527, -0.95347, 0.16682]), np.array([-0.75347, -0.6388, -0.81567]),
        np.array([0.72776]),
        np.array([1.08088]),
        np.array([0.53874, 1.37722, 1.16945]),
        np.array([1.4472]),
        0, 0, 1, 0, egauss, 1.1283791670955126, 4.0)


def test_gauss_repulsion_0_0_1_1():
    epsilon = 0.000001
    e1 = get_erf_repulsion(
        np.array([0.94138, 0.23708, 1.33464]), np.array([1.89753, 0.54214, 0.80346]),
        np.array([1.04131, 1.6925, 0.81454]), np.array([1.06467, 0.55116, 1.21121]),
        np.array([0.6941, 0.3354, -0.49162]), np.array([0.68756, 0.49975, -0.69756]),
        np.array([0.60432, -0.01449, -0.26057]), np.array([0.35763, -0.04674, -0.78137]),
        np.array([0.75847]),
        np.array([0.57683]),
        np.array([1.61747, 0.59289, 0.93361]),
        np.array([1.38523, 1.77715, 0.8249]),
        0, 0, 1, 1, 2.0)
    e2 = get_erf_repulsion(
        np.array([0.94138, 0.23708, 1.33464]), np.array([1.89753, 0.54214, 0.80346]),
        np.array([1.04131, 1.6925, 0.81454]), np.array([1.06467, 0.55116, 1.21121]),
        np.array([0.6941, 0.3354, -0.49162]), np.array([0.68756, 0.49975, -0.69756]),
        np.array([0.60432, -0.01449, -0.26057]), np.array([0.35763, -0.04674, -0.78137]),
        np.array([0.75847]),
        np.array([0.57683]),
        np.array([1.61747, 0.59289, 0.93361]),
        np.array([1.38523, 1.77715, 0.8249]),
        0, 0, 1, 1, 2.0+epsilon)
    egauss = (e2 - e1)/epsilon
    check_gauss_repulsion(
        np.array([0.94138, 0.23708, 1.33464]), np.array([1.89753, 0.54214, 0.80346]),
        np.array([1.04131, 1.6925, 0.81454]), np.array([1.06467, 0.55116, 1.21121]),
        np.array([0.6941, 0.3354, -0.49162]), np.array([0.68756, 0.49975, -0.69756]),
        np.array([0.60432, -0.01449, -0.26057]), np.array([0.35763, -0.04674, -0.78137]),
        np.array([0.75847]),
        np.array([0.57683]),
        np.array([1.61747, 0.59289, 0.93361]),
        np.array([1.38523, 1.77715, 0.8249]),
        0, 0, 1, 1, egauss, 1.1283791670955126, 4.0)


def test_gauss_repulsion_0_1_0_0():
    epsilon = 0.000001
    e1 = get_erf_repulsion(
        np.array([0.11308, 0.49861, 1.12215]), np.array([0.6186, 1.93501, 1.72751]),
        np.array([0.4644, 0.61371, 1.99408]), np.array([1.98686, 0.49338, 0.88466]),
        np.array([0.31794, 0.18412, 0.89808]), np.array([0.35463, 0.17042, 0.0682]),
        np.array([0.51676, -0.86674, -0.32785]), np.array([-0.03453, -0.05741, -0.86135]),
        np.array([1.84487]),
        np.array([1.17293, 1.02836, 0.50605]),
        np.array([0.54734]),
        np.array([1.55774]),
        0, 1, 0, 0, 2.0)
    e2 = get_erf_repulsion(
        np.array([0.11308, 0.49861, 1.12215]), np.array([0.6186, 1.93501, 1.72751]),
        np.array([0.4644, 0.61371, 1.99408]), np.array([1.98686, 0.49338, 0.88466]),
        np.array([0.31794, 0.18412, 0.89808]), np.array([0.35463, 0.17042, 0.0682]),
        np.array([0.51676, -0.86674, -0.32785]), np.array([-0.03453, -0.05741, -0.86135]),
        np.array([1.84487]),
        np.array([1.17293, 1.02836, 0.50605]),
        np.array([0.54734]),
        np.array([1.55774]),
        0, 1, 0, 0, 2.0+epsilon)
    egauss = (e2 - e1)/epsilon
    check_gauss_repulsion(
        np.array([0.11308, 0.49861, 1.12215]), np.array([0.6186, 1.93501, 1.72751]),
        np.array([0.4644, 0.61371, 1.99408]), np.array([1.98686, 0.49338, 0.88466]),
        np.array([0.31794, 0.18412, 0.89808]), np.array([0.35463, 0.17042, 0.0682]),
        np.array([0.51676, -0.86674, -0.32785]), np.array([-0.03453, -0.05741, -0.86135]),
        np.array([1.84487]),
        np.array([1.17293, 1.02836, 0.50605]),
        np.array([0.54734]),
        np.array([1.55774]),
        0, 1, 0, 0, egauss, 1.1283791670955126, 4.0)


def test_gauss_repulsion_0_1_0_1():
    epsilon = 0.000001
    e1 = get_erf_repulsion(
        np.array([0.95345, 1.7616, 0.62144]), np.array([0.60537, 0.78954, 0.17662]),
        np.array([1.39946, 1.03161, 1.42837]), np.array([1.05228, 1.80211, 1.37614]),
        np.array([0.18086, -0.0927, -0.36495]), np.array([0.48062, -0.97782, -0.05878]),
        np.array([-0.55927, -0.95238, 0.33122]), np.array([0.17856, 0.06077, 0.62697]),
        np.array([0.9876]),
        np.array([1.39633, 1.30787, 1.80682]),
        np.array([0.93201]),
        np.array([1.21516, 1.84023, 1.59345]),
        0, 1, 0, 1, 2.0)
    e2 = get_erf_repulsion(
        np.array([0.95345, 1.7616, 0.62144]), np.array([0.60537, 0.78954, 0.17662]),
        np.array([1.39946, 1.03161, 1.42837]), np.array([1.05228, 1.80211, 1.37614]),
        np.array([0.18086, -0.0927, -0.36495]), np.array([0.48062, -0.97782, -0.05878]),
        np.array([-0.55927, -0.95238, 0.33122]), np.array([0.17856, 0.06077, 0.62697]),
        np.array([0.9876]),
        np.array([1.39633, 1.30787, 1.80682]),
        np.array([0.93201]),
        np.array([1.21516, 1.84023, 1.59345]),
        0, 1, 0, 1, 2.0+epsilon)
    egauss = (e2 - e1)/epsilon
    check_gauss_repulsion(
        np.array([0.95345, 1.7616, 0.62144]), np.array([0.60537, 0.78954, 0.17662]),
        np.array([1.39946, 1.03161, 1.42837]), np.array([1.05228, 1.80211, 1.37614]),
        np.array([0.18086, -0.0927, -0.36495]), np.array([0.48062, -0.97782, -0.05878]),
        np.array([-0.55927, -0.95238, 0.33122]), np.array([0.17856, 0.06077, 0.62697]),
        np.array([0.9876]),
        np.array([1.39633, 1.30787, 1.80682]),
        np.array([0.93201]),
        np.array([1.21516, 1.84023, 1.59345]),
        0, 1, 0, 1, egauss, 1.1283791670955126, 4.0)


def test_gauss_repulsion_0_1_1_1():
    epsilon = 0.000001
    e1 = get_erf_repulsion(
        np.array([1.60961, 1.48434, 1.09022]), np.array([1.49016, 0.78972, 1.01383]),
        np.array([1.357, 1.6929, 1.46297]), np.array([1.3126, 1.39773, 0.3295]),
        np.array([-0.74441, 0.13168, 0.17287]), np.array([-0.73242, 0.73598, -0.07688]),
        np.array([0.06303, 0.61361, 0.92689]), np.array([0.31395, 0.00081, -0.13425]),
        np.array([1.92653]),
        np.array([0.84324, 1.68215, 0.64055]),
        np.array([1.62317, 1.94784, 1.54325]),
        np.array([0.67873, 0.76053, 0.57816]),
        0, 1, 1, 1, 2.0)
    e2 = get_erf_repulsion(
        np.array([1.60961, 1.48434, 1.09022]), np.array([1.49016, 0.78972, 1.01383]),
        np.array([1.357, 1.6929, 1.46297]), np.array([1.3126, 1.39773, 0.3295]),
        np.array([-0.74441, 0.13168, 0.17287]), np.array([-0.73242, 0.73598, -0.07688]),
        np.array([0.06303, 0.61361, 0.92689]), np.array([0.31395, 0.00081, -0.13425]),
        np.array([1.92653]),
        np.array([0.84324, 1.68215, 0.64055]),
        np.array([1.62317, 1.94784, 1.54325]),
        np.array([0.67873, 0.76053, 0.57816]),
        0, 1, 1, 1, 2.0+epsilon)
    egauss = (e2 - e1)/epsilon
    check_gauss_repulsion(
        np.array([1.60961, 1.48434, 1.09022]), np.array([1.49016, 0.78972, 1.01383]),
        np.array([1.357, 1.6929, 1.46297]), np.array([1.3126, 1.39773, 0.3295]),
        np.array([-0.74441, 0.13168, 0.17287]), np.array([-0.73242, 0.73598, -0.07688]),
        np.array([0.06303, 0.61361, 0.92689]), np.array([0.31395, 0.00081, -0.13425]),
        np.array([1.92653]),
        np.array([0.84324, 1.68215, 0.64055]),
        np.array([1.62317, 1.94784, 1.54325]),
        np.array([0.67873, 0.76053, 0.57816]),
        0, 1, 1, 1, egauss, 1.1283791670955126, 4.0)


def test_gauss_repulsion_1_0_0_1():
    epsilon = 0.000001
    e1 = get_erf_repulsion(
        np.array([0.39834, 1.4798, 1.80662]), np.array([1.9623, 0.88607, 0.93517]),
        np.array([0.46864, 1.1317, 0.67625]), np.array([1.52214, 0.93879, 0.71425]),
        np.array([-0.04796, 0.70504, 0.36481]), np.array([0.40599, 0.97607, 0.64758]),
        np.array([0.66271, -0.64123, -0.17474]), np.array([-0.60087, 0.25093, 0.32664]),
        np.array([0.68301, 1.18047, 1.44482]),
        np.array([0.97181]),
        np.array([1.18315]),
        np.array([0.79184, 1.41932, 1.32812]),
        1, 0, 0, 1, 2.0)
    e2 = get_erf_repulsion(
        np.array([0.39834, 1.4798, 1.80662]), np.array([1.9623, 0.88607, 0.93517]),
        np.array([0.46864, 1.1317, 0.67625]), np.array([1.52214, 0.93879, 0.71425]),
        np.array([-0.04796, 0.70504, 0.36481]), np.array([0.40599, 0.97607, 0.64758]),
        np.array([0.66271, -0.64123, -0.17474]), np.array([-0.60087, 0.25093, 0.32664]),
        np.array([0.68301, 1.18047, 1.44482]),
        np.array([0.97181]),
        np.array([1.18315]),
        np.array([0.79184, 1.41932, 1.32812]),
        1, 0, 0, 1, 2.0+epsilon)
    egauss = (e2 - e1)/epsilon
    check_gauss_repulsion(
        np.array([0.39834, 1.4798, 1.80662]), np.array([1.9623, 0.88607, 0.93517]),
        np.array([0.46864, 1.1317, 0.67625]), np.array([1.52214, 0.93879, 0.71425]),
        np.array([-0.04796, 0.70504, 0.36481]), np.array([0.40599, 0.97607, 0.64758]),
        np.array([0.66271, -0.64123, -0.17474]), np.array([-0.60087, 0.25093, 0.32664]),
        np.array([0.68301, 1.18047, 1.44482]),
        np.array([0.97181]),
        np.array([1.18315]),
        np.array([0.79184, 1.41932, 1.32812]),
        1, 0, 0, 1, egauss, 1.1283791670955126, 4.0)


def test_gauss_repulsion_1_1_1_1():
    epsilon = 0.000001
    e1 = get_erf_repulsion(
        np.array([0.13992, 0.37329, 0.33259]), np.array([0.64139, 1.73019, 0.13917]),
        np.array([0.44337, 1.28161, 0.3277]), np.array([1.24252, 1.27924, 1.45445]),
        np.array([0.02582, 0.94923, -0.17438]), np.array([-0.81301, 0.086, -0.77236]),
        np.array([-0.67901, 0.6566, -0.45438]), np.array([-0.02669, -0.13942, -0.98892]),
        np.array([1.01729, 0.83942, 1.15976]),
        np.array([1.92943, 1.10829, 0.87557]),
        np.array([0.58667, 0.97031, 1.31261]),
        np.array([1.57111, 0.74218, 0.68171]),
        1, 1, 1, 1, 2.0)
    e2 = get_erf_repulsion(
        np.array([0.13992, 0.37329, 0.33259]), np.array([0.64139, 1.73019, 0.13917]),
        np.array([0.44337, 1.28161, 0.3277]), np.array([1.24252, 1.27924, 1.45445]),
        np.array([0.02582, 0.94923, -0.17438]), np.array([-0.81301, 0.086, -0.77236]),
        np.array([-0.67901, 0.6566, -0.45438]), np.array([-0.02669, -0.13942, -0.98892]),
        np.array([1.01729, 0.83942, 1.15976]),
        np.array([1.92943, 1.10829, 0.87557]),
        np.array([0.58667, 0.97031, 1.31261]),
        np.array([1.57111, 0.74218, 0.68171]),
        1, 1, 1, 1, 2.0+epsilon)
    egauss = (e2 - e1)/epsilon
    check_gauss_repulsion(
        np.array([0.13992, 0.37329, 0.33259]), np.array([0.64139, 1.73019, 0.13917]),
        np.array([0.44337, 1.28161, 0.3277]), np.array([1.24252, 1.27924, 1.45445]),
        np.array([0.02582, 0.94923, -0.17438]), np.array([-0.81301, 0.086, -0.77236]),
        np.array([-0.67901, 0.6566, -0.45438]), np.array([-0.02669, -0.13942, -0.98892]),
        np.array([1.01729, 0.83942, 1.15976]),
        np.array([1.92943, 1.10829, 0.87557]),
        np.array([0.58667, 0.97031, 1.31261]),
        np.array([1.57111, 0.74218, 0.68171]),
        1, 1, 1, 1, egauss, 1.1283791670955126, 4.0)


def test_gauss_repulsion_0_2_1_0():
    epsilon = 0.000001
    e1 = get_erf_repulsion(
        np.array([1.36794, 1.14001, 1.97798]), np.array([1.68538, 0.75019, 0.72741]),
        np.array([1.55248, 0.78842, 1.84644]), np.array([1.73266, 0.46153, 0.63621]),
        np.array([0.05517, 0.27196, -0.98928]), np.array([-0.20526, 0.27314, -0.16208]),
        np.array([-0.00876, -0.47585, 0.88613]), np.array([0.75034, 0.54371, -0.1464]),
        np.array([0.50974]),
        np.array([1.36246, 0.58913, 0.73488, 0.53568, 1.11864, 1.80388]),
        np.array([1.62815, 0.58942, 1.52452]),
        np.array([0.66094]),
        0, 2, 1, 0, 2.0)
    e2 = get_erf_repulsion(
        np.array([1.36794, 1.14001, 1.97798]), np.array([1.68538, 0.75019, 0.72741]),
        np.array([1.55248, 0.78842, 1.84644]), np.array([1.73266, 0.46153, 0.63621]),
        np.array([0.05517, 0.27196, -0.98928]), np.array([-0.20526, 0.27314, -0.16208]),
        np.array([-0.00876, -0.47585, 0.88613]), np.array([0.75034, 0.54371, -0.1464]),
        np.array([0.50974]),
        np.array([1.36246, 0.58913, 0.73488, 0.53568, 1.11864, 1.80388]),
        np.array([1.62815, 0.58942, 1.52452]),
        np.array([0.66094]),
        0, 2, 1, 0, 2.0+epsilon)
    egauss = (e2 - e1)/epsilon
    check_gauss_repulsion(
        np.array([1.36794, 1.14001, 1.97798]), np.array([1.68538, 0.75019, 0.72741]),
        np.array([1.55248, 0.78842, 1.84644]), np.array([1.73266, 0.46153, 0.63621]),
        np.array([0.05517, 0.27196, -0.98928]), np.array([-0.20526, 0.27314, -0.16208]),
        np.array([-0.00876, -0.47585, 0.88613]), np.array([0.75034, 0.54371, -0.1464]),
        np.array([0.50974]),
        np.array([1.36246, 0.58913, 0.73488, 0.53568, 1.11864, 1.80388]),
        np.array([1.62815, 0.58942, 1.52452]),
        np.array([0.66094]),
        0, 2, 1, 0, egauss, 1.1283791670955126, 4.0)


def test_gauss_repulsion_0_2_2_3():
    epsilon = 0.000001
    e1 = get_erf_repulsion(
        np.array([0.96867, 0.41743, 1.03509]), np.array([1.84594, 0.83035, 1.20242]),
        np.array([0.94861, 0.47292, 0.38655]), np.array([1.3009, 1.10486, 1.4979]),
        np.array([0.10017, 0.21708, 0.08942]), np.array([-0.03049, 0.99486, -0.37959]),
        np.array([-0.7765, 0.53988, 0.25643]), np.array([0.60758, 0.85146, 0.15088]),
        np.array([1.14284]),
        np.array([1.39723, 1.77896, 0.72525, 0.99877, 1.5953, 0.69473]),
        np.array([0.56774, 1.69348, 1.8146, 0.85426, 1.35434, 1.87402]),
        np.array([0.99964, 1.45499, 1.35143, 1.9758, 0.58887, 1.40713, 0.55226,
                  1.44979, 0.57156, 0.71009]),
        0, 2, 2, 3, 2.0)
    e2 = get_erf_repulsion(
        np.array([0.96867, 0.41743, 1.03509]), np.array([1.84594, 0.83035, 1.20242]),
        np.array([0.94861, 0.47292, 0.38655]), np.array([1.3009, 1.10486, 1.4979]),
        np.array([0.10017, 0.21708, 0.08942]), np.array([-0.03049, 0.99486, -0.37959]),
        np.array([-0.7765, 0.53988, 0.25643]), np.array([0.60758, 0.85146, 0.15088]),
        np.array([1.14284]),
        np.array([1.39723, 1.77896, 0.72525, 0.99877, 1.5953, 0.69473]),
        np.array([0.56774, 1.69348, 1.8146, 0.85426, 1.35434, 1.87402]),
        np.array([0.99964, 1.45499, 1.35143, 1.9758, 0.58887, 1.40713, 0.55226,
                  1.44979, 0.57156, 0.71009]),
        0, 2, 2, 3, 2.0+epsilon)
    egauss = (e2 - e1)/epsilon
    check_gauss_repulsion(
        np.array([0.96867, 0.41743, 1.03509]), np.array([1.84594, 0.83035, 1.20242]),
        np.array([0.94861, 0.47292, 0.38655]), np.array([1.3009, 1.10486, 1.4979]),
        np.array([0.10017, 0.21708, 0.08942]), np.array([-0.03049, 0.99486, -0.37959]),
        np.array([-0.7765, 0.53988, 0.25643]), np.array([0.60758, 0.85146, 0.15088]),
        np.array([1.14284]),
        np.array([1.39723, 1.77896, 0.72525, 0.99877, 1.5953, 0.69473]),
        np.array([0.56774, 1.69348, 1.8146, 0.85426, 1.35434, 1.87402]),
        np.array([0.99964, 1.45499, 1.35143, 1.9758, 0.58887, 1.40713, 0.55226, 1.44979,
                  0.57156, 0.71009]),
        0, 2, 2, 3, egauss, 1.1283791670955126, 4.0)


def test_gauss_repulsion_4_3_2_1():
    epsilon = 0.0000001
    e1 = get_erf_repulsion(
        np.array([0.94212, 1.71823, 0.3309]), np.array([0.94854, 0.12816, 0.42016]),
        np.array([0.46046, 0.43321, 1.0587]), np.array([1.0089, 0.52286, 1.83539]),
        np.array([-0.48859, 0.6043, -0.57858]), np.array([0.74567, -0.82555, -0.30631]),
        np.array([-0.5679, -0.08725, 0.7623]), np.array([0.10338, 0.65407, -0.20172]),
        np.array([1.10904, 1.40637, 1.8707, 0.68295, 1.29692, 0.99892, 1.13936, 0.81258,
                  0.50325, 1.27698, 1.81192, 1.43415, 1.1686, 1.38063, 0.61592]),
        np.array([1.19368, 0.75291, 0.63535, 1.22654, 1.32848, 1.17482, 1.74897, 0.93964,
                  1.90303, 1.44528]),
        np.array([1.63343, 1.80498, 1.61313, 0.99992, 1.04505, 1.42297]),
        np.array([1.4825, 1.69421, 1.8635]),
        4, 3, 2, 1, 2.0)
    e2 = get_erf_repulsion(
        np.array([0.94212, 1.71823, 0.3309]), np.array([0.94854, 0.12816, 0.42016]),
        np.array([0.46046, 0.43321, 1.0587]), np.array([1.0089, 0.52286, 1.83539]),
        np.array([-0.48859, 0.6043, -0.57858]), np.array([0.74567, -0.82555, -0.30631]),
        np.array([-0.5679, -0.08725, 0.7623]), np.array([0.10338, 0.65407, -0.20172]),
        np.array([1.10904, 1.40637, 1.8707, 0.68295, 1.29692, 0.99892, 1.13936, 0.81258,
                  0.50325, 1.27698, 1.81192, 1.43415, 1.1686, 1.38063, 0.61592]),
        np.array([1.19368, 0.75291, 0.63535, 1.22654, 1.32848, 1.17482, 1.74897, 0.93964,
                  1.90303, 1.44528]),
        np.array([1.63343, 1.80498, 1.61313, 0.99992, 1.04505, 1.42297]),
        np.array([1.4825, 1.69421, 1.8635]),
        4, 3, 2, 1, 2.0+epsilon)
    egauss = (e2 - e1)/epsilon
    check_gauss_repulsion(
        np.array([0.94212, 1.71823, 0.3309]), np.array([0.94854, 0.12816, 0.42016]),
        np.array([0.46046, 0.43321, 1.0587]), np.array([1.0089, 0.52286, 1.83539]),
        np.array([-0.48859, 0.6043, -0.57858]), np.array([0.74567, -0.82555, -0.30631]),
        np.array([-0.5679, -0.08725, 0.7623]), np.array([0.10338, 0.65407, -0.20172]),
        np.array([1.10904, 1.40637, 1.8707, 0.68295, 1.29692, 0.99892, 1.13936, 0.81258,
                  0.50325, 1.27698, 1.81192, 1.43415, 1.1686, 1.38063, 0.61592]),
        np.array([1.19368, 0.75291, 0.63535, 1.22654, 1.32848, 1.17482, 1.74897, 0.93964,
                  1.90303, 1.44528]),
        np.array([1.63343, 1.80498, 1.61313, 0.99992, 1.04505, 1.42297]),
        np.array([1.4825, 1.69421, 1.8635]),
        4, 3, 2, 1, egauss, 1.1283791670955126, 4.0)


def check_ralpha_repulsion(alphas0, alphas1, alphas2, alphas3, r0, r1, r2, r3, scales0,
                           scales1, scales2, scales3, shell_type0, shell_type1,
                           shell_type2, shell_type3, result0, alpha):
    """Compare output from HORTON Erf integrals with reference data.

    The reference data was generated with a Mathematica script of Julien Toulouse and
    Andreas Savin.

    Parameters
    ----------
    alpha0, alpha1, alpha2, alpha3 : float
        Exponents of the four primitive shells.
    r0, r1, r2, r3 : np.ndarray, shape=(3,), dtype=float
        Cartesian coordinates of the centers of the four primitive shells.
    scales0, scales1, scales2, scales3 : float
        Normalization prefactors for the Gaussian shells.
    shell_type0, shell_type1, shell_type2, shell_type3 : int
        Shell types of the four primitive shells.
    result0 : np.ndarray, shape=(nbasis, nbasis, nbasis, nbasis), dtype=float
        The expected result.
    alpha : float
        The interaction is r to the power alpha.
    """
    max_shell_type = 4
    max_nbasis = get_shell_nbasis(max_shell_type)
    gb4i = GB4RAlphaIntegralLibInt(max_shell_type, alpha)
    assert gb4i.max_nbasis == max_nbasis
    assert gb4i.nwork == max_nbasis**4
    assert gb4i.alpha == alpha

    nbasis0 = get_shell_nbasis(shell_type0)
    nbasis1 = get_shell_nbasis(shell_type1)
    nbasis2 = get_shell_nbasis(shell_type2)
    nbasis3 = get_shell_nbasis(shell_type3)
    assert result0.shape == (nbasis0, nbasis1, nbasis2, nbasis3)
    # Clear the working memory
    gb4i.reset(shell_type0, shell_type1, shell_type2, shell_type3, r0, r1, r2, r3)
    # Add a few cobtributions:
    for alpha0, alpha1, alpha2, alpha3 in zip(alphas0, alphas1, alphas2, alphas3):
        gb4i.add(1.0, alpha0, alpha1, alpha2, alpha3, scales0, scales1, scales2, scales3)
    result1 = gb4i.get_work(nbasis0, nbasis1, nbasis2, nbasis3)
    assert abs(result1 - result0).max() < 3e-7


def test_ralpha_simple0():
    check_ralpha_repulsion(
        np.array([1.]), np.array([1.]),
        np.array([1.]), np.array([1.]),
        np.array([0., 0., 0.]), np.array([0., 0., 0.]),
        np.array([0., 0., 0.]), np.array([0., 0., 0.]),
        np.array([1.]),
        np.array([1.]),
        np.array([1.]),
        np.array([1.]),
        0, 0, 0, 0,
        np.array([[[[4.37335457]]]]), -1.)


def test_ralpha_simple1():
    check_ralpha_repulsion(
        np.array([1.]), np.array([1.]),
        np.array([1.]), np.array([1.]),
        np.array([0., 0., 0.]), np.array([0., 0., 0.]),
        np.array([0., 0., 0.]), np.array([0., 0., 0.]),
        np.array([1.]),
        np.array([1.]),
        np.array([1.]),
        np.array([1.]),
        0, 0, 0, 0,
        np.array([[[[5.81367687]]]]), 2.)


def test_ralpha_simple2():
    check_ralpha_repulsion(
        np.array([1.]), np.array([1.]),
        np.array([1.]), np.array([1.]),
        np.array([0., 0., 0.]), np.array([0., 0., 0.]),
        np.array([0., 0., 0.]), np.array([0., 0., 0.]),
        np.array([1.]),
        np.array([1.]),
        np.array([1.]),
        np.array([1.]),
        0, 0, 0, 0,
        np.array([[[[7.75156917]]]]), -2.)


def test_ralpha_repulsion_0_0_0_0_simple1():
    check_ralpha_repulsion(
        np.array([1.]), np.array([1.]),
        np.array([1.]), np.array([1.]),
        np.array([0., 0., 0.]), np.array([1., 1., 1.]),
        np.array([0., 0., 0.]), np.array([1., 1., 1.]),
        np.array([1.]),
        np.array([1.]),
        np.array([1.]),
        np.array([1.]),
        0, 0, 0, 0,
        np.array([[[[2.20567322]]]]), -1.)


def test_ralpha_repulsion_0_0_0_0_simple2():
    check_ralpha_repulsion(
        np.array([1.]), np.array([1.]),
        np.array([1.]), np.array([1.]),
        np.array([0.57092, 0.29608, -0.758]), np.array([-0.70841, 0.22864, 0.79589]),
        np.array([0.83984, 0.65053, 0.36087]), np.array([-0.62267, -0.83676, -0.75233]),
        np.array([1.]),
        np.array([1.]),
        np.array([1.]),
        np.array([1.]),
        0, 0, 0, 0,
        np.array([[[[0.19609589]]]]), -1.)


def test_ralpha_repulsion_0_0_0_0_simple3():
    check_ralpha_repulsion(
        np.array([0.57283]), np.array([1.74713]),
        np.array([0.21032]), np.array([1.60538]),
        np.array([0.82197, 0.73226, -0.98154]), np.array([0.57466, 0.17815, -0.25519]),
        np.array([0.00425, -0.33757, 0.08556]), np.array([-0.38717, 0.66721, 0.40838]),
        np.array([1.]),
        np.array([1.]),
        np.array([1.]),
        np.array([1.]),
        0, 0, 0, 0,
        np.array([[[[0.92553047]]]]), -1.)


def test_ralpha_repulsion_0_0_0_0_simple4():
    check_ralpha_repulsion(
        np.array([1.35491]), np.array([0.9714]),
        np.array([1.95585]), np.array([1.77853]),
        np.array([0.37263, -0.87382, 0.28078]), np.array([-0.08946, -0.52616, 0.69184]),
        np.array([-0.35128, 0.07017, 0.08193]), np.array([0.14543, -0.29499, -0.09769]),
        np.array([1.61086]),
        np.array([1.19397]),
        np.array([1.8119]),
        np.array([1.55646]),
        0, 0, 0, 0,
        np.array([[[[1.65373353]]]]), -1.)


def test_ralpha_repulsion_0_0_0_0():
    check_ralpha_repulsion(
        np.array([1.63216, 1.25493, 1.46134, 0.48024]),
        np.array([1.72365, 1.59905, 0.10447, 1.28324]),
        np.array([1.4105, 0.27134, 1.51238, 0.7518]),
        np.array([1.38488, 0.97611, 0.34149, 0.4326]),
        np.array([0.61356, -0.85284, -0.37151]), np.array([-0.63238, -0.81396, 0.40314]),
        np.array([0.29559, 0.60342, 0.18878]), np.array([-0.6893, 0.09175, -0.97283]),
        np.array([0.84965]),
        np.array([1.49169]),
        np.array([1.11046]),
        np.array([0.6665]),
        0, 0, 0, 0,
        np.array([[[[2.97326773]]]]), -1.)


def test_ralpha_repulsion_0_0_0_1():
    check_ralpha_repulsion(
        np.array([0.74579, 0.93686, 0.39742]), np.array([1.01349, 1.46072, 0.22295]),
        np.array([1.90756, 0.52423, 1.35586]), np.array([0.9655, 0.73539, 0.51017]),
        np.array([0.55177, 0.11232, -0.95152]), np.array([0.79941, 0.80782, 0.02287]),
        np.array([-0.52471, 0.59124, 0.434]), np.array([0.40758, 0.96818, 0.59852]),
        np.array([1.38989]),
        np.array([1.20619]),
        np.array([1.25917]),
        np.array([0.70246, 1.69253, 1.5632]),
        0, 0, 0, 1,
        np.array([[[[0.03331693, -3.27978915, -5.8871596]]]]), -1.)


def test_ralpha_repulsion_0_0_1_0():
    check_ralpha_repulsion(
        np.array([0.88609, 0.76883, 0.56082]), np.array([1.29216, 0.28671, 1.25389]),
        np.array([1.36987, 0.90792, 0.30511]), np.array([0.57079, 1.98163, 0.66835]),
        np.array([0.7706, 0.99091, -0.21592]), np.array([-0.00566, -0.37522, -0.3936]),
        np.array([0.1527, -0.95347, 0.16682]), np.array([-0.75347, -0.6388, -0.81567]),
        np.array([0.72776]),
        np.array([1.08088]),
        np.array([0.53874, 1.37722, 1.16945]),
        np.array([1.4472]),
        0, 0, 1, 0,
        np.array([[[[0.5110845], [5.86518324], [-1.47878266]]]]), -1.0)


def test_ralpha_repulsion_0_0_1_1():
    check_ralpha_repulsion(
        np.array([0.94138, 0.23708, 1.33464]), np.array([1.89753, 0.54214, 0.80346]),
        np.array([1.04131, 1.6925, 0.81454]), np.array([1.06467, 0.55116, 1.21121]),
        np.array([0.6941, 0.3354, -0.49162]), np.array([0.68756, 0.49975, -0.69756]),
        np.array([0.60432, -0.01449, -0.26057]), np.array([0.35763, -0.04674, -0.78137]),
        np.array([0.75847]),
        np.array([0.57683]),
        np.array([1.61747, 0.59289, 0.93361]),
        np.array([1.38523, 1.77715, 0.8249]),
        0, 0, 1, 1,
        np.array([[[[0.57999607, 0.04732015, 0.00079488],
                    [0.09549513, 0.42707461, 0.03630467],
                    [-0.15902635, -0.25704193, 0.12295133]]]]),
        -1.0)


def test_ralpha_repulsion_0_1_0_0():
    check_ralpha_repulsion(
        np.array([0.11308, 0.49861, 1.12215]), np.array([0.6186, 1.93501, 1.72751]),
        np.array([0.4644, 0.61371, 1.99408]), np.array([1.98686, 0.49338, 0.88466]),
        np.array([0.31794, 0.18412, 0.89808]), np.array([0.35463, 0.17042, 0.0682]),
        np.array([0.51676, -0.86674, -0.32785]), np.array([-0.03453, -0.05741, -0.86135]),
        np.array([1.84487]),
        np.array([1.17293, 1.02836, 0.50605]),
        np.array([0.54734]),
        np.array([1.55774]),
        0, 1, 0, 0,
        np.array([[[[-2.98984233]], [[-2.16665085]], [[-3.19087757]]]]), -1.0)


def test_ralpha_repulsion_0_1_0_1():
    check_ralpha_repulsion(
        np.array([0.95345, 1.7616, 0.62144]), np.array([0.60537, 0.78954, 0.17662]),
        np.array([1.39946, 1.03161, 1.42837]), np.array([1.05228, 1.80211, 1.37614]),
        np.array([0.18086, -0.0927, -0.36495]), np.array([0.48062, -0.97782, -0.05878]),
        np.array([-0.55927, -0.95238, 0.33122]), np.array([0.17856, 0.06077, 0.62697]),
        np.array([0.9876]),
        np.array([1.39633, 1.30787, 1.80682]),
        np.array([0.93201]),
        np.array([1.21516, 1.84023, 1.59345]),
        0, 1, 0, 1,
        np.array([[[[1.11620596, 0.60061237, 0.36843148]],
                   [[-0.05340867, 0.33119515, -0.70418275]],
                   [[-0.04504112, -1.01394262, 1.17313632]]]]),
        -1.0)


def test_ralpha_repulsion_0_1_1_1():
    check_ralpha_repulsion(
        np.array([1.60961, 1.48434, 1.09022]), np.array([1.49016, 0.78972, 1.01383]),
        np.array([1.357, 1.6929, 1.46297]), np.array([1.3126, 1.39773, 0.3295]),
        np.array([-0.74441, 0.13168, 0.17287]), np.array([-0.73242, 0.73598, -0.07688]),
        np.array([0.06303, 0.61361, 0.92689]), np.array([0.31395, 0.00081, -0.13425]),
        np.array([1.92653]),
        np.array([0.84324, 1.68215, 0.64055]),
        np.array([1.62317, 1.94784, 1.54325]),
        np.array([0.67873, 0.76053, 0.57816]),
        0, 1, 1, 1,
        np.array([[[[-0.06633908, -0.13761956, -0.03005655],
                    [-0.023407, -0.07813472, -0.03489736],
                    [-0.02263273, -0.20143856, -0.03550443]],
                   [[-0.40044718, -0.35436776, 0.07827812],
                    [-0.39382673, -0.18295174, 0.10845718],
                    [-0.37310311, -0.34400264, 0.05152883]],
                   [[0.07743294, -0.04648822, -0.2043075],
                    [0.03540926, -0.00400861, -0.13446393],
                    [0.02364929, -0.01807209, -0.18079094]]]]),
        -1.0)


def test_ralpha_repulsion_1_0_0_1():
    check_ralpha_repulsion(
        np.array([0.39834, 1.4798, 1.80662]), np.array([1.9623, 0.88607, 0.93517]),
        np.array([0.46864, 1.1317, 0.67625]), np.array([1.52214, 0.93879, 0.71425]),
        np.array([-0.04796, 0.70504, 0.36481]), np.array([0.40599, 0.97607, 0.64758]),
        np.array([0.66271, -0.64123, -0.17474]), np.array([-0.60087, 0.25093, 0.32664]),
        np.array([0.68301, 1.18047, 1.44482]),
        np.array([0.97181]),
        np.array([1.18315]),
        np.array([0.79184, 1.41932, 1.32812]),
        1, 0, 0, 1,
        np.array([[[[0.16173756, 0.14265052, 0.05405344]]],
                  [[[-0.431925, -0.37295006, -0.1782411]]],
                  [[[-0.17915755, -0.20235955, 0.03526912]]]]),
        -1.0)


def test_ralpha_repulsion_1_1_1_1():
    check_ralpha_repulsion(
        np.array([0.13992, 0.37329, 0.33259]), np.array([0.64139, 1.73019, 0.13917]),
        np.array([0.44337, 1.28161, 0.3277]), np.array([1.24252, 1.27924, 1.45445]),
        np.array([0.02582, 0.94923, -0.17438]), np.array([-0.81301, 0.086, -0.77236]),
        np.array([-0.67901, 0.6566, -0.45438]), np.array([-0.02669, -0.13942, -0.98892]),
        np.array([1.01729, 0.83942, 1.15976]),
        np.array([1.92943, 1.10829, 0.87557]),
        np.array([0.58667, 0.97031, 1.31261]),
        np.array([1.57111, 0.74218, 0.68171]),
        1, 1, 1, 1,
        np.array([[[[5.38092832, 0.67101024, 0.50643354],
                    [-0.36637823, -0.17128347, 0.00749151],
                    [-0.47015285, -0.00846274, -0.23514519]],
                   [[0.31412053, 1.85552661, -0.05096966],
                    [0.5668773, 0.04019152, -0.05803149],
                    [-0.02195855, 0.00256108, 0.03373068]],
                   [[0.26139911, -0.05908764, 1.34729127],
                    [0.03563575, 0.02599451, 0.0669569],
                    [0.6249628, -0.09012696, -0.02559206]]],
                  [[[-1.30079959, 0.06525516, -0.24130176],
                    [7.90805546, 0.5029288, 1.03164863],
                    [0.22531828, -0.01518479, -0.63472654]],
                   [[0.07758755, -0.30344079, 0.03679751],
                    [0.88274549, 3.43263474, -0.20761467],
                    [0.09249023, 0.10854722, 0.15741632]],
                   [[0.0082139, -0.00382022, -0.24202072],
                    [0.44155444, -0.06437548, 2.40552259],
                    [0.29276089, 0.01725224, 0.05956368]]],
                  [[[-1.45339037, -0.37266055, 0.25844897],
                    [0.41152374, -0.40525461, -0.16607501],
                    [14.23224926, 2.34068558, 0.65653732]],
                   [[-0.00776144, -0.38261119, -0.0073076],
                    [0.28311943, 0.14089539, 0.08426703],
                    [0.91304633, 5.92042353, -0.12886949]],
                   [[0.09807363, 0.06281554, -0.25920407],
                    [0.15636252, 0.10752926, 0.14182457],
                    [1.2142302, -0.38098265, 4.57694241]]]]),
        -1.0)


def test_ralpha_repulsion_0_2_1_0():
    check_ralpha_repulsion(
        np.array([1.36794, 1.14001, 1.97798]), np.array([1.68538, 0.75019, 0.72741]),
        np.array([1.55248, 0.78842, 1.84644]), np.array([1.73266, 0.46153, 0.63621]),
        np.array([0.05517, 0.27196, -0.98928]), np.array([-0.20526, 0.27314, -0.16208]),
        np.array([-0.00876, -0.47585, 0.88613]), np.array([0.75034, 0.54371, -0.1464]),
        np.array([0.50974]),
        np.array([1.36246, 0.58913, 0.73488, 0.53568, 1.11864, 1.80388]),
        np.array([1.62815, 0.58942, 1.52452]),
        np.array([0.66094]),
        0, 2, 1, 0,
        np.array([[[[0.03940319], [0.05597157], [-0.32990373]],
                   [[-0.00066587], [0.00221213], [-0.00319745]],
                   [[-0.00035194], [-0.00011777], [0.0063613]],
                   [[0.00478058], [0.01592957], [-0.09687372]],
                   [[0.00002574], [-0.00009517], [-0.00166564]],
                   [[0.01578456], [0.05420504], [-0.32175899]]]]),
        -1.0)


def test_ralpha_repulsion_0_2_2_3():
    with open(context.get_fn('test/electron_repulsion_0_2_2_3.json')) as f:
        result0 = np.array(json.load(f))
    check_ralpha_repulsion(
        np.array([0.96867, 0.41743, 1.03509]), np.array([1.84594, 0.83035, 1.20242]),
        np.array([0.94861, 0.47292, 0.38655]), np.array([1.3009, 1.10486, 1.4979]),
        np.array([0.10017, 0.21708, 0.08942]), np.array([-0.03049, 0.99486, -0.37959]),
        np.array([-0.7765, 0.53988, 0.25643]), np.array([0.60758, 0.85146, 0.15088]),
        np.array([1.14284]),
        np.array([1.39723, 1.77896, 0.72525, 0.99877, 1.5953, 0.69473]),
        np.array([0.56774, 1.69348, 1.8146, 0.85426, 1.35434, 1.87402]),
        np.array([0.99964, 1.45499, 1.35143, 1.9758, 0.58887, 1.40713, 0.55226, 1.44979,
                  0.57156, 0.71009]),
        0, 2, 2, 3,
        result0, -1.0)


def test_ralpha_repulsion_4_3_2_1():
    with open(context.get_fn('test/electron_repulsion_4_3_2_1.json')) as f:
        result0 = np.array(json.load(f))
    check_ralpha_repulsion(
        np.array([0.94212, 1.71823, 0.3309]), np.array([0.94854, 0.12816, 0.42016]),
        np.array([0.46046, 0.43321, 1.0587]), np.array([1.0089, 0.52286, 1.83539]),
        np.array([-0.48859, 0.6043, -0.57858]), np.array([0.74567, -0.82555, -0.30631]),
        np.array([-0.5679, -0.08725, 0.7623]), np.array([0.10338, 0.65407, -0.20172]),
        np.array([1.10904, 1.40637, 1.8707, 0.68295, 1.29692, 0.99892, 1.13936, 0.81258,
                  0.50325, 1.27698, 1.81192, 1.43415, 1.1686, 1.38063, 0.61592]),
        np.array([1.19368, 0.75291, 0.63535, 1.22654, 1.32848, 1.17482, 1.74897, 0.93964,
                  1.90303, 1.44528]),
        np.array([1.63343, 1.80498, 1.61313, 0.99992, 1.04505, 1.42297]),
        np.array([1.4825, 1.69421, 1.8635]),
        4, 3, 2, 1,
        result0, -1.0)


def check_g09_overlap(fn_fchk):
    fn_log = fn_fchk[:-5] + '.log'
    mol = IOData.from_file(fn_fchk, fn_log)
    olp1 = mol.obasis.compute_overlap(mol.lf)
    olp2 = mol.olp
    mask = abs(olp1._array) > 1e-5
    delta = olp1._array - olp2._array
    expect = olp1._array
    error = (delta[mask]/expect[mask]).max()
    assert error < 1e-5


def test_overlap_water_sto3g_hf():
    check_g09_overlap(context.get_fn('test/water_sto3g_hf_g03.fchk'))


def test_overlap_water_ccpvdz_pure_hf():
    check_g09_overlap(context.get_fn('test/water_ccpvdz_pure_hf_g03.fchk'))


def test_overlap_water_ccpvdz_cart_hf():
    check_g09_overlap(context.get_fn('test/water_ccpvdz_cart_hf_g03.fchk'))


def test_overlap_co_ccpv5z_pure_hf():
    check_g09_overlap(context.get_fn('test/co_ccpv5z_pure_hf_g03.fchk'))


def test_overlap_co_ccpv5z_cart_hf():
    check_g09_overlap(context.get_fn('test/co_ccpv5z_cart_hf_g03.fchk'))


def check_g09_kinetic(fn_fchk):
    fn_log = fn_fchk[:-5] + '.log'
    mol = IOData.from_file(fn_fchk, fn_log)
    kin1 = mol.obasis.compute_kinetic(mol.lf)
    kin2 = mol.kin
    mask = abs(kin1._array) > 1e-5
    delta = kin1._array - kin2._array
    expect = kin1._array
    error = (delta[mask]/expect[mask]).max()
    assert error < 1e-5


def test_kinetic_water_sto3g_hf():
    check_g09_kinetic(context.get_fn('test/water_sto3g_hf_g03.fchk'))


def test_kinetic_water_ccpvdz_pure_hf():
    check_g09_kinetic(context.get_fn('test/water_ccpvdz_pure_hf_g03.fchk'))


def test_kinetic_water_ccpvdz_cart_hf():
    check_g09_kinetic(context.get_fn('test/water_ccpvdz_cart_hf_g03.fchk'))


def test_kinetic_co_ccpv5z_pure_hf():
    check_g09_kinetic(context.get_fn('test/co_ccpv5z_pure_hf_g03.fchk'))


def test_kinetic_co_ccpv5z_cart_hf():
    check_g09_kinetic(context.get_fn('test/co_ccpv5z_cart_hf_g03.fchk'))


def check_g09_nuclear_attraction(fn_fchk):
    fn_log = fn_fchk[:-5] + '.log'
    mol = IOData.from_file(fn_fchk, fn_log)
    na1 = mol.obasis.compute_nuclear_attraction(mol.coordinates, mol.pseudo_numbers, mol.lf)
    na2 = mol.na
    mask = abs(na1._array) > 1e-5
    expect = na1._array
    result = na2._array
    delta = -expect - result
    error = (delta[mask]/expect[mask]).max()
    assert error < 4e-5


def test_nuclear_attraction_water_sto3g_hf():
    check_g09_nuclear_attraction(context.get_fn('test/water_sto3g_hf_g03.fchk'))


def test_nuclear_attraction_water_ccpvdz_pure_hf():
    check_g09_nuclear_attraction(context.get_fn('test/water_ccpvdz_pure_hf_g03.fchk'))


def test_nuclear_attraction_water_ccpvdz_cart_hf():
    check_g09_nuclear_attraction(context.get_fn('test/water_ccpvdz_cart_hf_g03.fchk'))


def test_nuclear_attraction_co_ccpv5z_pure_hf():
    check_g09_nuclear_attraction(context.get_fn('test/co_ccpv5z_pure_hf_g03.fchk'))


def test_nuclear_attraction_co_ccpv5z_cart_hf():
    check_g09_nuclear_attraction(context.get_fn('test/co_ccpv5z_cart_hf_g03.fchk'))


def check_g09_dipole(fn_fchk, dipole_values):
    """Compare dipole moment computed from WFN and nuclei to reference value.

    Parameters
    ----------
    fn_fchk : str
        The FCHK filename.
    dipole_values : array, shape=(3,)
        Three components of the expected dipole moment.
    """
    mol = IOData.from_file(fn_fchk)
    xyz_array = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
    center = np.zeros(3)
    mol.dm_full = mol.get_dm_full()
    dipole = []
    for xyz in xyz_array:
        dipole_ints = mol.obasis.compute_multipole_moment(xyz, center, mol.lf)
        dipole_v = -dipole_ints.contract_two('ab,ab', mol.dm_full)
        for i, n in enumerate(mol.pseudo_numbers):
            dipole_v += n * pow(mol.coordinates[i, 0], xyz[0]) * \
                            pow(mol.coordinates[i, 1], xyz[1]) * \
                            pow(mol.coordinates[i, 2], xyz[2])
        dipole.append(dipole_v)
    np.testing.assert_almost_equal(dipole, dipole_values, decimal=6)


def test_dipole_water_sto3g_hf():
    check_g09_dipole(context.get_fn('test/water_sto3g_hf_g03.fchk'),
                     np.array([5.46423145e-01, -1.25137695e-16, 3.86381228e-01]))


def test_dipole_water_ccpvdz_pure_hf():
    check_g09_dipole(context.get_fn('test/water_ccpvdz_pure_hf_g03.fchk'),
                     np.array([6.46132274e-01, 3.28892045e-03, 3.40563176e-01]))


def test_dipole_ccpvdz_cart_hf():
    check_g09_dipole(context.get_fn('test/water_ccpvdz_cart_hf_g03.fchk'),
                     np.array([6.46475310e-01, 3.32969714e-03, 3.41075744e-01]))


def test_dipole_co_ccpv5z_pure_hf():
    check_g09_dipole(context.get_fn('test/co_ccpv5z_pure_hf_g03.fchk'),
                     np.array([-2.25401400e+00, -3.22002009e-01, 3.22002009e-01]))


def test_dipole_co_ccpv5z_cart_hf():
    check_g09_dipole(context.get_fn('test/co_ccpv5z_cart_hf_g03.fchk'),
                     np.array([-2.25364754e+00, -3.21949654e-01, 3.21949654e-01]))


def check_g09_quadrupole(fn_fchk, quadrupole_values):
    """Compare quadrupole moment computed from WFN and nuclei to reference value.

    Parameters
    ----------
    fn_fchk : str
        The FCHK filename.
    dipole_values : array, shape=(6,)
        Six components of the expected dipole moment: x^2, y^2, z^2, xy, xz, yz
    """
    mol = IOData.from_file(fn_fchk)
    xyz_array = np.array([[2, 0, 0], [0, 2, 0], [0, 0, 2], [1, 1, 0], [1, 0, 1], [0, 1, 1]])
    center = np.zeros(3)
    mol.dm_full = mol.get_dm_full()
    quadrupole = []
    for xyz in xyz_array:
        quadrupole_ints = mol.obasis.compute_multipole_moment(xyz, center, mol.lf)
        quad_v = -quadrupole_ints.contract_two('ab,ab', mol.dm_full)
        for i, n in enumerate(mol.pseudo_numbers):
            quad_v += n * pow(mol.coordinates[i, 0], xyz[0]) * \
                          pow(mol.coordinates[i, 1], xyz[1]) * \
                          pow(mol.coordinates[i, 2], xyz[2])
        quadrupole.append(quad_v)
    # removing trace:
    trace = (quadrupole[0] + quadrupole[1] + quadrupole[2])/3.0
    quadrupole[:3] -= trace
    np.testing.assert_almost_equal(quadrupole, quadrupole_values, decimal=6)


def test_quadrupole_ch3_hf_sto3g():
    check_g09_quadrupole(context.get_fn('test/ch3_hf_sto3g.fchk'),
                         np.array([-3.00591674e-03, 1.50295837e-03, 1.50295837e-03,
                                   -1.32772907e-01, -1.32772907e-01, -1.33146521e-01]))


def test_quadrupole_li_h_321g_hf_g09():
    check_g09_quadrupole(context.get_fn('test/li_h_3-21G_hf_g09.fchk'),
                         np.array([-6.75277790e-01, -6.75277790e-01, 1.35055558e+00,
                                   0.00000000e+00, 0.00000000e+00, 0.00000000e+00]))


def check_g09_electron_repulsion(fn_fchk, check_g09_zeros=False):
    fn_log = fn_fchk[:-5] + '.log'
    mol = IOData.from_file(fn_fchk, fn_log)
    er1 = mol.obasis.compute_electron_repulsion(mol.lf)
    er2 = mol.er
    mask = abs(er1._array) > 1e-6
    expect = er1._array
    got = er2._array
    if check_g09_zeros:
        assert ((expect == 0.0) == (got == 0.0)).all()
    delta = expect - got
    error = (delta[mask]/expect[mask]).max()
    assert error < 1e-5


def test_electron_repulsion_water_sto3g_hf():
    check_g09_electron_repulsion(context.get_fn('test/water_sto3g_hf_g03.fchk'), True)


def test_electron_repulsion_water_ccpvdz_pure_hf():
    check_g09_electron_repulsion(context.get_fn('test/water_ccpvdz_pure_hf_g03.fchk'))


def test_electron_repulsion_water_ccpvdz_cart_hf():
    check_g09_electron_repulsion(context.get_fn('test/water_ccpvdz_cart_hf_g03.fchk'))


def check_g09_grid_fn(fn_fchk):
    mol = IOData.from_file(fn_fchk)
    grid = BeckeMolGrid(mol.coordinates, mol.numbers, mol.pseudo_numbers, 'tv-13.7-4', random_rotate=False)
    dm_full = mol.get_dm_full()
    rhos = mol.obasis.compute_grid_density_dm(dm_full, grid.points)
    pop = grid.integrate(rhos)
    nel = mol.obasis.compute_overlap(mol.lf).contract_two('ab,ab', dm_full)
    assert abs(pop-nel) < 2e-3


def test_grid_fn_h_sto3g():
    check_g09_grid_fn(context.get_fn('test/h_sto3g.fchk'))


def test_grid_fn_lih_321g_hf():
    check_g09_grid_fn(context.get_fn('test/li_h_3-21G_hf_g09.fchk'))


def test_grid_fn_water_sto3g_hf_T():
    check_g09_grid_fn(context.get_fn('test/water_sto3g_hf_g03.fchk'))


def test_grid_fn_co_ccpv5z_pure_hf_T():
    check_g09_grid_fn(context.get_fn('test/co_ccpv5z_pure_hf_g03.fchk'))


def test_grid_fn_co_ccpv5z_cart_hf_T():
    check_g09_grid_fn(context.get_fn('test/co_ccpv5z_cart_hf_g03.fchk'))
