import ctypes
import numpy
from . import _lib

def sdarray_ndarray(array, nrows, ncols):
    """ Convert Lenstool square_double array to numpy.ndarray. The memory content of the 
        input sd_array is freed upon the call to this function.

        Parameters
        ----------
            - array         Object of ctypes type square_double_t
            - nrows, ncols  Size of the first and second dimensions
        
        Returns
        -------
            Array of type numpy.ndarray with shape (nrows, ncols)
    """
    arr = numpy.empty((ncols, nrows), dtype=numpy.float64)
    for i in range(ncols):
        arr[i,] = numpy.ctypeslib.as_array(array[i].contents, shape=nrows)
    _lib.free_square_double(array, ncols)
    return arr.T

class Lenstool():
    """
    A lenstool instance initialized with a paramater file, and providing
    some high-level functions of the C version.

    The simplest way to initialize a Lenstool object is to load a standard
    lenstool parameter file

        >>> lenstool = lenstool.Lenstool('test_galaxies.par')

    Make sure to have all the file dependencies mentioned in the parameter
    file available in your current directory.

    Note that a Lenstool instance is not thread-safe. The global variables
    initialized in memory by the underlying C library will be shared between
    multiple Lenstool instances.

    Parameters:
        infile:     The path to the parameter file 
    """

    def __init__(self, infile):
        self._constraints_read = False
        self._chains = None
        self._fburnin = None
        self._fbayes = None
        self._userCommon = None
        self._modelset = False

        _lib.init_cosmoratio()
        _lib.init_grille(bytes(infile, "ascii"), 1)
        _lib.checkpar()
        _lib.grid()
        self._nparams = self.get_nparams()

    def __del__(self):
        if self._chains:
            _lib.free_square_double(self._chains, self._nparams)

        if self._fburnin or self._fbayes:
            self.close_bayes()

    def readBayesModels(self):
        """ Read the content of the bayes.dat found in the current directory """
        n_params = ctypes.c_int(0)
        n_vals = ctypes.c_long(0)
        self._chains = _lib.readBayesModels(n_params, n_vals)
        if n_params.value != self._nparams:
            raise RuntimeError("Model and bayes.dat number of free parameters mismatch ({0} != {1})".format(self._nparams, n_params.value))
        self._nvals = n_vals.value

    def get_chains(return_ndarray=True):
        """ Return the content of the bayes.dat. If return_ndarray is True, the
            returned numpy.ndarray contains a copy of the internal memory content.

        Parameters
        ----------
            return_ndarray    Return a numpy.ndarray if True, otherwise return the 
                              content in the internal ctype square_double_t type.
        """

        if not self._chains:
            raise RuntimeError("bayes.dat file not read. Use readBayesModel first")

        if return_ndarray:
            np_arr = sdarray_ndarray(self._chains, self._nparams, self._nvals)
            return np_arr.T
        else:
            return self._chains, self._nparams, self._nvals

    def get_nparams(self):
        """ Return the number of free parameters """
        return _lib.getNParameters()

    def o_chi_lhood0(self):
        """ Return the chi2 and likelihood corresponding to the model and the constraints
        previously stored in memory.
        
        Returns
        -------
            error, chi, lhood0   The error value (0 is good, 1 is bad model), the chi2 and the normalization likelihood term
        """
        chi = ctypes.c_double(0)
        lhood0 = ctypes.c_double(0)
        np_b0 = ctypes.POINTER(ctypes.c_double)() # NULL pointer
        error = _lib.o_chi_lhood0(chi, lhood0, np_b0)
        return error, chi.value, lhood0.value

    def readConstraints(self):
        """ Read in memory the constraints files (multiple images, arclets,
            etc.) mentioned in the input parameter file. 
        """
        _lib.readConstraints()
        self._constraints_read = True

    def setBayesModel(self, method=-4):
        """ Select a model from the bayes.dat samples read with readBayesModels() function.

        Parameters:
            method     Specify the row from the bayes.dat from which a chires.dat file should be calculated. Special values are -4 to use the maximum likelihood row (default).
         """
        if not self._constraints_read:
            readConstraints()

        if not self._chains:
            readBayesModels()

        _lib.setBayesModel(method, self._nvals, self._chains)
        self._modelset = True
        
    def o_chires(self, filename='chires.dat'):
        """ Write a chires.dat file for the model stored in memory """
        if not self._modelset:
            raise RuntimeError("Model not set in memory. Use setBayesModel() before calling this function")

        _lib.o_chires(bytes(filename,"ascii"))

    def rescaleCube(self, cube, return_rescaled=True):
        """ Rescale the cube values to their corresponding parameter values in
            Lenstool.

            Parameters
                cube[nparams]    The array containing the parameters for all the potentials. Cube values must 
                                 follow uniform distributions in the range [0,1].

                return_rescaled  If True, cube values are replaced in-place by their rescaled values.

            Returns
                Return 1 if the rescaled parameters are meaningfull in the
                parameter space, 0 otherwise.
        """
        if len(cube) != self._nparams:
            raise ValueError(f"Cube size mismatch ({0} != {1})".format(len(cube), self._nparams))

        if not all([(v>=0) & (v<=1) for v in cube]):
            raise ValueError("Cube value not in range [0, 1]", cube)

        array = (ctypes.c_double * len(cube))(*cube)
        valid = _lib.rescaleCube_1Atom(array, len(cube))

        if return_rescaled:
            for i in range(len(cube)): cube[i] = array[i]

        self._modelset = True
        return valid

    def write_bayes_header(self):
        """ Write the header of the bayes.dat file. Open the burnin.dat and bayes.dat files on the disk.

            Return the number of free parameters/dimensions
        """
        if self._fburnin or self._fbayes:
            raise RuntimeError("bayes.dat or burnin.dat files already open. Close them first with close_bayes().")

        nparams = _lib.bayesHeader()
        _suffix = b' '*13  # allocate enough memory

        # Need to cast here because by default ctypes converts returned c_void_p into class <int>
        # which is then truncated to 32 bits when passed to close_bayes(), and results in segfault
        self._fburnin = ctypes.cast(_lib.open_mcmc_files(ctypes.c_double(0), _suffix), ctypes.c_void_p)
        self._fbayes = ctypes.cast(_lib.open_mcmc_files(ctypes.c_double(1), _suffix), ctypes.c_void_p)
        _suffix = _suffix.split(b'\x00')[0].decode()
        self._burnin_name = "burnin." + _suffix
        self._bayes_name = "bayes." + _suffix
        self._userCommon = ctypes.cast(_lib.userCommon_init(self._nparams), ctypes.c_void_p)
        return nparams

    def write_burnin_line(self, iteration, likelihood=None, coord=None):
        """ Write a line in the burnin.dat file with the parameter currently set in memory. 

            Call o_chi_lhood() function is likelihood is not set.
        """
        if self._fburnin is None:
            raise RuntimeError(f"{self._burnin_name} file not open. Use write_bayes_header() function first.")

        if not coord is None:
            if self.rescaleCube(coord) == 0:
                raise RuntimeError("Invalid model with rescaled cube", cube)

        if likelihood is None:
            error, chi2, lhood0 = self.o_chi_lhood0()
            likelihood = -0.5 * (chi2 + lhood0)
            
        _lib.write_bayes_line(self._fburnin, ctypes.c_double(likelihood), ctypes.c_double(0), 
                              iteration, self._userCommon)
        
    def write_bayes_line(self, iteration, likelihood=None, coord=None):
        """ Write a line in the bayes.dat file with the parameter currently set in memory. 

            Call o_chi_lhood() function is likelihood is not set.
        """
        if self._fbayes is None:
            raise RuntimeError(f"{self._bayes_name} file not open. Use write_bayes_header() function first.")

        if not coord is None:
            if self.rescaleCube(coord) == 0:
                raise RuntimeError("Invalid model with rescaled cube", cube)

        if likelihood is None:
            error, chi2, lhood0 = self.o_chi_lhood0()
            likelihood = -0.5 * (chi2 + lhood0)
        
        _lib.write_bayes_line(self._fbayes, ctypes.c_double(likelihood), ctypes.c_double(0), 
                              iteration, self._userCommon)
    
    def close_bayes(self):
        """ Close the burnin.dat and bayes.dat files """
        if self._fburnin:
            _lib.close_mcmc_files(self._fbayes, self._fburnin)
            self._fburnin = None
            self._fbayes = None

        if self._userCommon:
            _lib.userCommon_free(self._userCommon)
            self._userCommon = None
    
