/**
 * nds_ext.cpp
 * Python wrapper for John Zweizig's NDS1/NDS2 library
 *
 * Copyright (C) 2011  Leo Singer
 *
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.
 *
 */

#include <boost/python.hpp>
#include <boost/python/raw_function.hpp>
#include <ostream>
#include <stdexcept>
#include <sstream>

extern "C" {
#include <daqc.h>
#include <daqc_response.h>
}

#include <numpy/arrayobject.h>

using namespace boost::python;

// From http://wiki.python.org/moin/boost.python/HowTo#MultithreadingSupportformyfunction
class ScopedGILRelease {
public:
    inline ScopedGILRelease()
    {
        m_thread_state = PyEval_SaveThread();
    }
    inline ~ScopedGILRelease()
    {
        PyEval_RestoreThread(m_thread_state);
        m_thread_state = NULL;
    }
private:
    PyThreadState *m_thread_state;
};

struct DaqError {
public:
    DaqError(int retval) : retval(retval) {}
    int retval;
};

struct NoMemoryError {
};

struct StopIterationException {
};

static void DaqErrorTranslator(const DaqError& exc)
{
    char* s;
    asprintf(&s, "error %d: %s", exc.retval, daq_strerror(exc.retval));
    PyErr_SetString(PyExc_RuntimeError, s);
    free(s);
}

static void NoMemoryErrorTranslator(const NoMemoryError& exc)
{
    PyErr_NoMemory();
}

static void StopIterationExceptionTranslator(const StopIterationException& exc)
{
    PyErr_SetNone(PyExc_StopIteration);
}

static int numpy_typenum_for_daq_data_t(daq_data_t type)
{
    switch (type)
    {
        case _16bit_integer:
            return NPY_INT16;
        case _32bit_integer:
            return NPY_INT32;
        case _64bit_integer:
            return NPY_INT64;
        case _32bit_float:
            return NPY_FLOAT32;
        case _64bit_double:
            return NPY_FLOAT64;
        case _32bit_complex:
            return NPY_COMPLEX64;
        case _undefined:
        default:
            return NPY_VOID;
    }
}

struct _daq_t;

class daq_iterator {
private:
    _daq_t *daq;
    time_t start, end;
    bool has_next;
public:
    daq_iterator(_daq_t *daq, time_t start_, time_t end_)
        : daq(daq), start(start_), end(end_), has_next(true) {}
    list next();
};

struct _daq_t : daq_t {
public:

    _daq_t(const std::string& host, int port, nds_version version) : daq_t()
    {
        {
            ScopedGILRelease scoped;
            daq_startup();
        }
        int retval = daq_connect(this, host.c_str(), port, version);
        if (retval) throw DaqError(retval);
    }

    ~_daq_t()
    {
        ScopedGILRelease scoped;
        daq_disconnect(this);
    }

    void disconnect()
    {
        int retval;
        {
            ScopedGILRelease scoped;
            retval = daq_disconnect(this);
        }
        if (retval) throw DaqError(retval);
    }

    void clear_channel_list()
    {
        int retval = daq_clear_channel_list(this);
        if (retval) throw DaqError(retval);
    }

    daq_iterator* request_data(time_t start, time_t end, time_t dt)
    {
        int retval;
        {
            ScopedGILRelease scoped;
            retval = daq_request_data(this, start, end, dt);
        }
        if (retval) throw DaqError(retval);

        return new daq_iterator(this, start, end);
    }

    void request_channel_from_chanlist(daq_channel_t* channel)
    {
        int retval = daq_request_channel_from_chanlist(this, channel);
        if (retval) throw DaqError(retval);
    }

    void request_channel(const std::string& name, chantype type = cUnknown, double rate = 0.0)
    {
        int retval = daq_request_channel(this, name.c_str(), type, rate);
        if (retval) throw DaqError(retval);
    }

    list get_requested_channels()
    {
        list l;
        for (chan_req_t* channel = this->chan_req_list; channel < &this->chan_req_list[this->num_chan_request]; channel++)
            l.append(channel);
        return l;
    }

    tuple get_timestamp()
    {
        if (!this->tb)
            throw DaqError(DAQD_ERROR);
        return make_tuple(daq_get_block_gps(this), daq_get_block_gpsn(this));
    }

    list recv_channel_list(chantype channeltype)
    {
        int nchannels_received;
        int retval;

        {
            ScopedGILRelease scoped;
            retval = daq_recv_channel_list(this, NULL, 0, &nchannels_received, 0, channeltype);
        }

        if (retval)
            throw DaqError(retval);

        daq_channel_t* channels = (daq_channel_t*) calloc(nchannels_received, sizeof(daq_channel_t));

        if (!channels)
            throw NoMemoryError();

        int old_nchannels_received = nchannels_received;
        {
            ScopedGILRelease scoped;
            retval = daq_recv_channel_list(this, channels, old_nchannels_received, &nchannels_received, 0, channeltype);
        }

        if (retval)
        {
            free(channels);
            throw DaqError(retval);
        }

        list l;
        for (daq_channel_t* channel = channels ; channel < &channels[nchannels_received]; channel++)
            l.append(*channel);

        free(channels);
        return l;
    }

};

list daq_iterator::next() {
    if (!has_next)
        throw StopIterationException();

    int retval;
    {
        ScopedGILRelease scoped;
        retval = daq_recv_next(daq);
    }
    if (retval) throw DaqError(retval);

    bool keep_going;
    time_t block_gps = daq_get_block_gps(daq);
    time_t remaining_time = 0;
    if (end == 0)
        keep_going = true;
    else
        keep_going = false;
    if (block_gps < end)
        remaining_time = end - block_gps;

    list l;
    for (chan_req_t *channel = daq->chan_req_list; channel < &daq->chan_req_list[daq->num_chan_request]; channel++)
    {
        if (channel->status < 0)
            throw DaqError(- channel->status);
        int data_length = channel->status;
        npy_intp dim = data_length / data_type_size(channel->data_type);
        PyObject *array_obj = PyArray_SimpleNew(1, &dim, numpy_typenum_for_daq_data_t(channel->data_type));
        if (!array_obj) throw NoMemoryError();
        memcpy(PyArray_DATA(array_obj), daq_get_block_data(daq) + channel->offset, data_length);
        l.append(handle<>(array_obj));
        if (!keep_going && data_length < remaining_time * channel->rate)
            keep_going = true;
    }

    if (!keep_going)
    {
        has_next = false;

        /* NDS1 transfers end with a 'termination block', an empty block that
         * is indistinguisable from a 'data not found' condition. If this is
         * an NDS1 connection, we must digest the termination block. */
        if (daq->nds_versn == nds_v1)
        {
            {
                ScopedGILRelease scoped;
                retval = daq_recv_next(daq);
            }
            if (retval != DAQD_NOT_FOUND)
                throw DaqError(retval);
        }
    }

    return l;
}

list fetch(tuple args, dict kwargs)
{
    list l;
    int retval;
    _daq_t &daq = extract<_daq_t&>(args[0]);

    int start = extract<int>(args[1]);
    int stop = extract<int>(args[2]);
    int nchannels = len(args) - 3;

    if (start >= stop)
        throw std::invalid_argument("stop must be greater than start");

    daq.clear_channel_list();

    try {
        if (daq.nds_versn == nds_v2)
        {
            for (int i = 0; i < nchannels; i ++)
                daq.request_channel(extract<std::string>(args[i + 3]));
        } else /* daq.nds_versn == nds_v1 */ {
            int nchannels_received;

            {
                ScopedGILRelease scoped;
                retval = daq_recv_channel_list(&daq, NULL, 0, &nchannels_received, 0, cUnknown);
            }
            if (retval)
                throw DaqError(retval);

            daq_channel_t* channels = (daq_channel_t*) calloc(nchannels_received, sizeof(daq_channel_t));

            if (!channels)
                throw NoMemoryError();

            try {
                int old_nchannels_received = nchannels_received;
                {
                    ScopedGILRelease scoped;
                    retval = daq_recv_channel_list(&daq, channels, old_nchannels_received, &nchannels_received, 0, cUnknown);
                }

                if (retval)
                    throw DaqError(retval);

                for (int i = 0; i < nchannels; i ++)
                {
                    std::string desired_channel_name = extract<std::string>(args[i + 3]);
                    bool found = false;
                    for (daq_channel_t* channel = channels ; channel < &channels[nchannels_received]; channel++)
                    {
                        if (desired_channel_name.compare(channel->name) == 0)
                        {
                            if (found)
                            {
                                std::ostringstream stream;
                                stream << "Channel '" << desired_channel_name << "': has more than one matching entry";
                                throw std::runtime_error(stream.str());
                            }
                            daq.request_channel_from_chanlist(channel);
                            found = true;
                        }
                    }
                    if (!found)
                    {
                        std::ostringstream stream;
                        stream << "Channel '" << desired_channel_name << "': not found";
                        throw std::runtime_error(stream.str());
                    }
                }
            } catch (...) {
                free(channels);
                throw;
            }
            free(channels);
        }

        {
            ScopedGILRelease scoped;
            retval = daq_request_data(&daq, start, stop, 0);
        }
        if (retval) throw DaqError(retval);

        {
            ScopedGILRelease scoped;
            retval = daq_recv_next(&daq);
        }
        if (retval) throw DaqError(retval);

        char *data[nchannels];
        char *data_end[nchannels];

        for (int i = 0; i < nchannels; i++)
        {
            chan_req_t &channel = daq.chan_req_list[i];
            if (channel.status < 0)
                throw DaqError(- channel.status);
            npy_intp dim = (stop - start) * channel.rate;
            PyObject *array_obj = PyArray_SimpleNew(1, &dim, numpy_typenum_for_daq_data_t(channel.data_type));
            if (!array_obj) throw NoMemoryError();
            l.append(handle<>(array_obj));
            data[i] = (char *) PyArray_DATA(array_obj);
            data_end[i] = data[i] + PyArray_NBYTES(array_obj);
            memcpy(data[i], daq_get_block_data(&daq) + channel.offset, channel.status);
            data[i] += channel.status;
        }

        while (true)
        {
            bool keep_going = false;
            for (int i = 0; i < nchannels; i ++)
                if (data[i] < data_end[i])
                    keep_going = true;
            if (!keep_going)
                break;

            {
                ScopedGILRelease scoped;
                retval = daq_recv_next(&daq);
            }
            if (retval) throw DaqError(retval);

            for (int i = 0; i < nchannels; i ++)
            {
                chan_req_t &channel = daq.chan_req_list[i];
                if (channel.status < 0)
                    throw DaqError(- channel.status);
                memcpy(data[i], daq_get_block_data(&daq) + channel.offset, channel.status);
                data[i] += channel.status;
            }
        }

        /* NDS1 transfers end with a 'termination block', an empty block that
         * is indistinguisable from a 'data not found' condition. If this is
         * an NDS1 connection, we must digest the termination block. */
        if (daq.nds_versn == nds_v1)
        {
            {
                ScopedGILRelease scoped;
                retval = daq_recv_next(&daq);
            }
            if (retval != DAQD_NOT_FOUND)
                throw DaqError(retval);
        }
    } catch (...) {
        daq.clear_channel_list();
        throw;
    }
    daq.clear_channel_list();

    return l;
}

static str _signal_conv_get_units(const signal_conv_t* self)
{
    return str(self->signal_units);
}

static str _daq_channel_get_name(const daq_channel_t* self)
{
    return str(self->name);
}

static str _chan_req_get_name(const chan_req_t* self)
{
    return str((const char*)self->name);
}

template<class CHANNEL_T>
static str _channel_as_str(const CHANNEL_T* c)
{
    return str("%s (%dHz, %s, %s)" % make_tuple(str((const char*)c->name), int(c->rate), c->data_type, c->type));
}

template<class CHANNEL_T>
static str _channel_as_repr(const CHANNEL_T* c)
{
    return str("<%s>" % _channel_as_str(c));
}

inline object identity(object const& o) { return o; }

BOOST_PYTHON_MODULE(nds)
{
    docstring_options doc_options(true, true, false);

    PyEval_InitThreads();
    import_array();

    scope().attr("__doc__") = "Wrapper for John Zweizig's NDS1/NDS2 library";
    scope().attr("__author__") = "Leo Singer <leo.singer@ligo.org>";
    scope().attr("__organization__") = make_tuple("LIGO", "California Institute of Technology");
    scope().attr("__copyright__") = "Copyright 2011, Leo Singer";
    scope().attr("__version__") = PACKAGE_VERSION;

    enum_<chantype>("channel_type", "Enumeration of posssible NDS channel types")
        .value("unknown", cUnknown)
        .value("online", cOnline)
        .value("raw", cRaw)
        .value("reduced", cRDS)
        .value("second_trend", cSTrend)
        .value("minute_trend", cMTrend)
        .value("testpoint", cTestPoint)
        .value("static", cStatic);

    enum_<daq_data_t>("data_type", "Enumeration of possible NDS data series types")
        .value("int16", _16bit_integer)
        .value("int32", _32bit_integer)
        .value("int64", _64bit_integer)
        .value("float", _32bit_float)
        .value("double", _64bit_double)
        .value("complex", _32bit_complex)
        .value("undefined", _undefined);

    enum_<nds_version>("nds_version", "Enumeration of possible NDS protocol versions")
        .value("auto", nds_try)
        .value("v1", nds_v1)
        .value("v2", nds_v2);

    class_<signal_conv_t>("conversion", "Unit conversion information")
        .def_readonly("gain", &signal_conv_t::signal_gain)
        .def_readonly("slope", &signal_conv_t::signal_slope)
        .def_readonly("offset", &signal_conv_t::signal_offset)
        .add_property("units", _signal_conv_get_units);

    class_<daq_channel_t>("channel", "Channel descriptor")
        .add_property("name", _daq_channel_get_name)
        .def_readonly("rate", &daq_channel_t::rate)
        .def_readonly("data_type", &daq_channel_t::data_type)
        .def_readonly("type", &daq_channel_t::type)
        .def_readonly("conversion", &daq_channel_t::s)
        .def("__str__", &_channel_as_str<daq_channel_t>)
        .def("__repr__", &_channel_as_repr<daq_channel_t>);

    class_<chan_req_t>("channel_request", "Channel request descriptor")
        .add_property("name", _chan_req_get_name)
        .def_readonly("rate", &chan_req_t::rate)
        .def_readonly("data_type", &chan_req_t::data_type)
        .def_readonly("type", &chan_req_t::type)
        .def_readonly("conversion", &chan_req_t::s)
        .def("__str__", &_channel_as_str<chan_req_t>)
        .def("__repr__", &_channel_as_repr<chan_req_t>);

    register_exception_translator<DaqError>(DaqErrorTranslator);

    register_exception_translator<NoMemoryError>(NoMemoryErrorTranslator);

    register_exception_translator<StopIterationException>(StopIterationExceptionTranslator);

    class_<daq_iterator>("daq_iterator", "An iterator that provides data received from an NDS server block by block", no_init)
        .def("next", &daq_iterator::next, (arg("self")), "Retrieve a block of data from the server, and return a list of Numpy\narrays for each requested channel.")
        .def("__iter__", &identity, (arg("self")), "Support the Python iterator protocol.  Return self.");

    class_<_daq_t>("daq", "Connection to an NDS server", init<const std::string&, int, nds_version>((arg("self"), arg("host"), arg("port") = 31200, arg("nds_version") = nds_try)))
        .def("disconnect", &_daq_t::disconnect, (arg("self")), "close connection to server")
        .def("recv_channel_list", &_daq_t::recv_channel_list, (arg("self"), arg("channeltype") = cUnknown), "get list of all available channels")
        .def("clear_channel_list", &_daq_t::clear_channel_list, (arg("self")), "reset list of requested channels")
        .def("fetch", raw_function(&fetch, 4),
             "fetch(start, stop, channel1[, channel2[, ...]]) -> list of Numpy arrays\n"
             "Convenience routine to fetch data from time start to time stop, for a\n"
             "number of channels named channel1, channel2, ... .")
        .def("request_data", &_daq_t::request_data, return_value_policy<manage_new_object, with_custodian_and_ward_postcall<0, 1> >(), (arg("self"), arg("gps_start") = 0, arg("gps_end") = 0, arg("stride") = 1))
        .def("request_channel", &_daq_t::request_channel_from_chanlist, (arg("self"), arg("channel")), "request a channel")
        .def("request_channel", &_daq_t::request_channel, (arg("self"), arg("name"), arg("type") = cUnknown, arg("rate") = 0.0), "request a channel")
        .add_property("requested_channels", &_daq_t::get_requested_channels, "list of requested channels")
        .add_property("timestamp", &_daq_t::get_timestamp, "tuple containing GPS seconds and nanoseconds of last received block")
        .def_readonly("nds_version", &_daq_t::nds_versn, "version of NDS protocol used for this connection");
}
