/* -*- mode: c; c-basic-offset: 4; -*- */
/*******************************************************
  Matlab mex wrapper for NDS.
  
  This allows access to full sample rate data (i.e., not
  trend data).
******************************************************/

#if HAVE_CONFIG_H
#include "daq_config.h"
#endif /* HAVE_CONFIG_H */

#include <string.h>
#include <math.h>
#include <stdio.h>
#include "daqc.h"
#include "trench.h"
#include "nds_log_matlab.h"
#include "nds_logging.h"
#include "nds_mex_utils.h"

typedef void* mxDataReal_type;
typedef void* mxDataImag_type;


/* For labelling parts of the "indices" array */
#define MAX_TOTAL_REQUEST_CHANNELS 256
/* K. Thorne 2011-11-20 Increase from 128MB to 1GB for commissioner queries */
/* #define MAX_BYTES_PER_QUERY 134217728 */
#define MAX_BYTES_PER_QUERY 1073741824

/*  Internal functions */

mxArray* mlwrapper_get_channel_data(const mxArray*, const mxArray*, 
				    double, double, const char*, int);
int ml_get_channel_data(mxArray*, mwSize, time_t, time_t, const char*, int,
			enum nds_version*);

#define NDS_MSG_LEN 256
static const size_t ndsmsglen = NDS_MSG_LEN;

#if SIZEOF_MWSIZE == 0
#undef SIZEOF_MWSIZE
#define SIZEOF_MWSIZE SIZEOF_INT
#endif
#if SIZEOF_MWSIZE == SIZEOF_INT
#define FMT_MWSIZE "%d"
#define CAST_MWSIZE int
#elif SIZEOF_MWSIZE == SIZEOF_LONG
#define FMT_MWSIZE "%ld"
#define CAST_MWSIZE long
#elif SIZEOF_MWSIZE == SIZEOF_LONG_LONG
#define FMT_MWSIZE "%lld"
#define CAST_MWSIZE long long
#else
#error "Unalbe to determine size of wxSize"
#endif /* SIZEOF_MWSIZE */

void 
mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
{
    /* Matlab arguments */
    /* func(request_chans, start_time, duration, host, listener, server_chans)*/
    /*                  0           1         2     3         4             5 */
    char hpbuf[MAX_HOST_LENGTH];
    char msg[NDS_MSG_LEN];
    double start,duration;
    mxArray* retval;
	
    /*** Input Validation ***/
    if(nlhs != 1) {
	mexErrMsgTxt("This function only returns one paramter.");
    }

    if(nrhs != 5) {
	mexErrMsgTxt("This function requires five arguments.\n \
                      channel list to query, \n \
                      start time (GPS seconds)\n \
                      duration (seconds)\n \
                      hostport_string\n \
                      chan list from NDS_GetChannels()\n");
    }

    /* to make input string an actual C string */
    if (mxGetString(prhs[3], hpbuf, (mwSize)(sizeof(hpbuf))))
        mexErrMsgTxt("Hostname should be of the form \"host\" or \"host:port\""
                     " and shorter than " STRINGIFY(MAX_HOST_LENGTH) " characters.");

    {
	short port=0;
	parse_host_name(hpbuf, &port);

	/* start and duration */
	/* start */
	if((mxGetM(prhs[1]) * mxGetN(prhs[1])) != 1) {
	    mexErrMsgTxt("Start time must be a scalar.");
	}
	start = mxGetScalar(prhs[1]);

	/* duration */
	if((mxGetM(prhs[2]) * mxGetN(prhs[2])) != 1) {
	    mexErrMsgTxt("Duration must be a scalar.");
	}
	duration = mxGetScalar(prhs[2]);

	/* channel list must be a Cell type */
	if( ! mxIsCell(prhs[0]) ) {
	    mexErrMsgTxt("Channel list must be a cell type, i.e., use {'Chan1';'Chan2';Chan3} etc.");
	}

	if(mxGetN(prhs[0]) > MAX_TOTAL_REQUEST_CHANNELS) {
	    SNPRINTF(msg, ndsmsglen, 
		     "The maximum number of channels per request is %d",
		     MAX_TOTAL_REQUEST_CHANNELS);
	    msg[ndsmsglen-1] = '\0';
	    mexErrMsgTxt(msg);
	}
	/*** END Input Validation ***/

	/*  Loop over channels */
	retval = mlwrapper_get_channel_data(prhs[0], prhs[4], start, duration, 
					    hpbuf, port);
    }
    if(retval == NULL) {
	mexErrMsgTxt("Fatal Error getting channel data.");
    }
    plhs[0] = retval;
    return;
}

/*  Get channel data function */
mxArray* 
mlwrapper_get_channel_data(const mxArray* chan_query_list,
			   const mxArray* server_chans, 
			   double start, double duration, 
			   const char* hostname, int host_port)
{
    /* Go through each element of indices and
       attempts to get channel data for that channel */
    char msg[NDS_MSG_LEN];

    /* Get Server Field numbers */
    int srv_name = mxGetFieldNumber(server_chans,"name");
    int srv_group_num = mxGetFieldNumber(server_chans,"group_num");
    int srv_rate = mxGetFieldNumber(server_chans,"rate");
    int srv_tpnum = mxGetFieldNumber(server_chans,"tpnum");
    /* int srv_bps = mxGetFieldNumber(server_chans,"bps"); */
    int srv_data_type = mxGetFieldNumber(server_chans,"data_type");
    int srv_signal_gain = mxGetFieldNumber(server_chans,"signal_gain");
    int srv_signal_offset = mxGetFieldNumber(server_chans,"signal_offset");
    int srv_signal_slope = mxGetFieldNumber(server_chans,"signal_slope");
    int srv_signal_units = mxGetFieldNumber(server_chans,"signal_units");
    /* fputs("OK 1.5\n",stderr); */

    /*  Construct the array_root matlab array
     */
    const char* fields[] = {
	"name", "group_num", "rate", "tpnum", "bps", "data_type", 
	"signal_gain", "signal_offset", "signal_slope", "signal_units", 
	"start_gps_sec", "duration_sec", "data", "exists"};

    mwSize dims = mxGetM(chan_query_list) * mxGetN(chan_query_list);
    mxArray* array_root = mxCreateStructArray(1, &dims, 14, fields);
    if(array_root == NULL) {
	mexWarnMsgTxt("mlwrapper_get_channel_data(): Could not allocate \
                       initial array.");
	return NULL;
    }


    {
	enum nds_version protocol = nds_try;
	mxArray* cur_list = (mxArray*)NULL;
	char tempstring[MAX_LONG_CHANNEL_NAME_LENGTH];
	struct trench_struct tst;
	int inx = -1;
	mwSize i;
	mwSize j;
	double cur_rate;
	int cur_bps;
	int err;

	/* Get destination array field numbers */
	int field_name = mxGetFieldNumber(array_root,"name");
	int field_group_num = mxGetFieldNumber(array_root,"group_num");
	int field_rate = mxGetFieldNumber(array_root,"rate");
	int field_tpnum = mxGetFieldNumber(array_root,"tpnum");
	int field_bps = mxGetFieldNumber(array_root,"bps");
	int field_data_type = mxGetFieldNumber(array_root,"data_type");
	int field_signal_gain = mxGetFieldNumber(array_root,"signal_gain");
	int field_signal_offset = mxGetFieldNumber(array_root,"signal_offset");
	int field_signal_slope = mxGetFieldNumber(array_root,"signal_slope");
	int field_signal_units = mxGetFieldNumber(array_root,"signal_units");
	int field_exists = mxGetFieldNumber(array_root,"exists");
	int field_start_gps_sec = mxGetFieldNumber(array_root,"start_gps_sec");
	int field_duration_sec = mxGetFieldNumber(array_root,"duration_sec");
	/* fputs("OK 1.6\n",stderr); */

	/* Get integer second start and duration */
	double tStart    = floor(start);
	double tDuration = ceil(duration);

	/* Make sure i/f is ready, guess about protocol */
	daq_startup();

	/* Now, try to fill in data */
	for(i=0; i < dims; i++) {

	    /*---------------  Get the next channel query. Validate it */
	    cur_list = mxGetCell(chan_query_list, (int)i);
	    if( ! mxIsChar(cur_list)) {
		SNPRINTF(msg, ndsmsglen, 
			 "Channel name index " FMT_MWSIZE " is not a string.", (CAST_MWSIZE)(i + i));
		msg[ndsmsglen-1] = '\0';
		mexWarnMsgTxt(msg);
		put_mxarray_bool(array_root, field_exists, i, 0);
		continue;
	    }

	    mxGetString(cur_list, tempstring, MAX_LONG_CHANNEL_NAME_LENGTH);

	    trench_init(&tst);
	    trench_parse(&tst, tempstring);

	    for (j=0; j<mxGetM(server_chans); ++j) {
		int cmpstat = 0;
		mxGetString(mxGetFieldByNumber(server_chans, j, srv_name),
			    tempstring, MAX_LONG_CHANNEL_NAME_LENGTH);

		if (!strcmp(tst.str, tempstring)) cmpstat=1;
		else if (inx >= 0 || tst.styp == trch_base) continue; 
		else if (!trench_cmp_base(&tst, tempstring)) cmpstat=2;
		if (cmpstat != 0) {
		    int dti, cti;
		    double ratei;
		    inx = j;
		    ratei = get_mxarray_float(server_chans, srv_rate,    j);
		    dti   = get_mxarray_int(server_chans, srv_data_type, j);
		    cti   = get_mxarray_int(server_chans, srv_group_num, j);
		    trench_infer_chan_info(&tst, cti, ratei, dti);
		    if (cmpstat == 1) break;
		}
	    }

	    /* See if channel actually exists, and mark whether or not it exists. 
	     * If it does not exist, "continue" on to next index, else, start 
	     * copy and perform data query 
	     */
	    if (inx < 0) {
		put_mxarray_bool(array_root, field_exists, i, 0);
		trench_destroy(&tst);
		continue;
	    }

	    j = inx;

	    /* Copy over the necessary info */
	    /* --- name */
	    put_mxarray_str(array_root, field_name, i, tst.str);

	    put_mxarray_int(array_root, field_group_num, i, (int)tst.ctype);

	    cur_rate = tst.rate;
	    put_mxarray_double(array_root, field_rate, i, cur_rate);

	    copy_mxarray(server_chans, srv_tpnum, j, 
			 array_root, field_tpnum, i);

	    cur_bps = data_type_size(tst.dtype);
	    put_mxarray_int(array_root, field_bps, i, cur_bps);

	    put_mxarray_int(array_root, field_data_type, i, (int)tst.dtype);

	    copy_mxarray(server_chans, srv_signal_gain, j, 
			 array_root, field_signal_gain, i);

	    copy_mxarray(server_chans, srv_signal_offset, j, 
			 array_root, field_signal_offset, i);

	    copy_mxarray(server_chans, srv_signal_slope, j, 
			 array_root, field_signal_slope, i);

	    copy_mxarray(server_chans, srv_signal_units, j, 
			 array_root, field_signal_units, i);
	
	    /* Set start time and duration fields, even if channel does not "exist"
	     */
	    if (put_mxarray_double(array_root, field_start_gps_sec, i, tStart)) {
		mexWarnMsgTxt("Could not create numeric array for gps start time.");
		trench_destroy(&tst);
		mxDestroyArray(array_root);
		return NULL;
	    }

	    if (put_mxarray_double(array_root, field_duration_sec, i, tDuration)) {
		mexWarnMsgTxt("Could not create numeric array for gps duration.");
		trench_destroy(&tst);
		mxDestroyArray(array_root);
		return NULL;
	    }

	    /* now see if data query will be too big */
	    if((double)(cur_rate*cur_bps*ceil(duration)) > MAX_BYTES_PER_QUERY ) {
		mxArray* temp;

		SNPRINTF(msg, ndsmsglen,
			 "Data query for channel %s exceeds maximum of %0.lf MB", 
			 tst.str, (double)MAX_BYTES_PER_QUERY/(double)(1<<20));
		mexWarnMsgTxt(msg);
		/* replace "exists" with "not exists" */
		temp=mxGetFieldByNumber(array_root, (int)i, field_exists);
		mxDestroyArray(temp);
		put_mxarray_bool(array_root, field_exists, i, 0);
		trench_destroy(&tst);
		continue;
	    }
	
	    /* Ok, now perform the query */
	    err = ml_get_channel_data(array_root, i, (time_t)floor(start), 
					  (time_t)ceil(duration), hostname, 
					  host_port, &protocol);
	    if (err < 0) {
		mxArray* temp;

		/* get string again, just in case tempstring was not treated 
		 * as const 
		 */
		SNPRINTF(msg, ndsmsglen, 
			 "Error in query for %s: %s", 
			 tst.str, daq_strerror(err));
		mexWarnMsgTxt(msg);
		/* replace "exists" with "not exists" */
		temp=mxGetFieldByNumber(array_root, (int)i, field_exists);
		mxDestroyArray(temp);
		put_mxarray_bool(array_root, field_exists, i, 0);
	    } else {
		put_mxarray_bool(array_root, field_exists, i, 1);
	    }
	    trench_destroy(&tst);
	}
    }
    return array_root;
}


/*  Get the channel data...
 */
int
ml_get_channel_data(mxArray* array_root, mwSize idx, time_t start_time,
		    time_t duration, const char* hostname, int port,
		    enum nds_version *protocol) {
    int retval = 0;
    mwSize dims[1];
    mxComplexity complexity;
    mxClassID    dataclass;
    char channelname[MAX_LONG_CHANNEL_NAME_LENGTH + 1];
    char msg[NDS_MSG_LEN];
    time_t end_time = start_time + duration;
    mxArray* dest_array;
    mxDataReal_type dest_real = (mxDataReal_type)NULL;
    mxDataImag_type dest_imag = (mxDataImag_type)NULL;
    daq_t daq;

    /* int field_name      = mxGetFieldNumber(array_root, "name"); */
    int field_group_num = mxGetFieldNumber(array_root, "group_num");
    int field_rate      = mxGetFieldNumber(array_root, "rate");
    int field_data_type = mxGetFieldNumber(array_root, "data_type");

    mxGetString(mxGetField(array_root, idx, "name"),
		channelname, MAX_LONG_CHANNEL_NAME_LENGTH);
    channelname[MAX_LONG_CHANNEL_NAME_LENGTH] = '\0';
    {
	daq_channel_t chan;
	int rc;
	time_t time_now = start_time;

	int    data_type   = get_mxarray_int(array_root, field_data_type, idx);
	double sample_rate = get_mxarray_float(array_root, field_rate, idx);
	int    chan_type   = get_mxarray_int(array_root, field_group_num, idx);

	daq_init_channel(&chan, channelname, chan_type, sample_rate, data_type);

	switch(data_type) {
	case _16bit_integer:
	    dataclass = mxINT16_CLASS;
	    complexity = mxREAL;
	    break;
	case _32bit_integer:
	    dataclass = mxINT32_CLASS;
	    complexity = mxREAL;
	    break;
	case _32bit_float:
	    dataclass = mxSINGLE_CLASS;
	    complexity = mxREAL;
	    break;
	case _64bit_integer:
	    dataclass = mxINT64_CLASS;
	    complexity = mxREAL;
	    break;
	case _64bit_double:
	    dataclass = mxDOUBLE_CLASS;
	    complexity = mxREAL;
	    break;
	case _32bit_complex:
	    dataclass = mxSINGLE_CLASS;
	    complexity = mxCOMPLEX;
	    break;
#if WORKING
	case _64bit_complex:
	    dataclass = mxDOUBLE_CLASS;
	    complexity = mxCOMPLEX;
	    break;
#endif /* WORKING */
	default:
	    SNPRINTF(msg, ndsmsglen, 
		     "%s has an unknown data class.",
		     channelname);
	    mexWarnMsgTxt(msg);
	    retval = -1;
	    goto no_connect_error;
	}
	if (sample_rate < 1.0) dims[0] = duration / (int)(1./sample_rate + 0.5);
	else                   dims[0] = sample_rate * duration;
	dest_array = mxCreateNumericArray(1, dims, dataclass, complexity);
	if(dest_array == NULL) {
	    SNPRINTF(msg, ndsmsglen,
		     "%s could not get data array pointer.",
		     channelname);
	    mexWarnMsgTxt(msg);
	    retval = -2;
	    goto no_connect_error;
	}

	dest_real = mxGetData(dest_array);
	if(dest_real == NULL) {
	    SNPRINTF(msg, ndsmsglen,
		     "%s could not get real data array pointer.",
		     channelname);
	    mexWarnMsgTxt(msg);
	    retval = -3;
	    goto no_connect_error;
	}

	if(complexity == mxCOMPLEX) {
	    /* also get imaginary part */
	    dest_imag = mxGetImagData(dest_array);
	    if (!dest_imag) {
		SNPRINTF(msg, ndsmsglen,
			 "%s could not get imaginary data array pointer.",
			 channelname);
		mexWarnMsgTxt(msg);
		retval = -4;
		goto no_connect_error;
	    }
	    retval = -12;
	    goto no_connect_error;
	}

	/* Connect to the nds server... */
	rc = daq_connect (&daq, hostname, port, *protocol);
	if (rc) {
	    SNPRINTF(msg, ndsmsglen,
		     "Error trying to connect to %s: %s", 
		     hostname, daq_strerror(rc));
	    mexWarnMsgTxt(msg);
	    retval = -5;
	    goto no_connect_error;
	}
	*protocol = daq.nds_versn;

	rc = daq_request_channel_from_chanlist(&daq, &chan);
	if (rc) {
	    SNPRINTF(msg, ndsmsglen,
		     "Error adding channel %s to request list: %s",
		     channelname, daq_strerror(rc));
	    mexWarnMsgTxt(msg);
	    retval = -6;
	    goto post_connect_error;
	}

	/*----------------------------  Loop over potential partial buffers    */
	rc = daq_request_data(&daq, start_time, end_time, duration);
	if (rc) {
	    SNPRINTF(msg, ndsmsglen,
		     "Error requesting data for channel %s: %s",
		     channelname, daq_strerror(rc));
	    mexWarnMsgTxt(msg);
	    retval = -7;
	    goto post_connect_error;
	}

	/* time_now = start_time; */
	while (time_now < end_time) {
	    chan_req_t* req;

	    /*------- Receive data buffer, check for success                    */
	    rc = daq_recv_next(&daq);
	    if (rc) {
		SNPRINTF(msg, ndsmsglen,
			 "Error receiving next %s data block: %s",
			 channelname, daq_strerror(rc));
		mexWarnMsgTxt(msg);
		retval = -8;
		goto post_connect_error;
	    }

	    /*------- Find channel request info, check that it was found        */
	    req = daq_get_channel_status(&daq, channelname);
	    if (!req) {
		SNPRINTF(msg, ndsmsglen,
			 "Can't get channel %s status.",
			 channelname);
		mexWarnMsgTxt(msg);
		retval = -9;
		goto post_connect_error;
	    }

	    /*-------  Check channel status, OK?                                */
	    if (req->status < 0) {
		SNPRINTF(msg, ndsmsglen,
			 "Channel %s error status  %s.",
			 channelname, daq_strerror(- req->status));
		mexWarnMsgTxt(msg);
		retval = -10;
		goto post_connect_error;
	    }

	    /*-------  Check data                                               */
	    if (time_now != daq_get_block_gps(&daq)) {
		SNPRINTF(msg, ndsmsglen,
			 "Channel %s incorrect time %d (expected %d).",
			 channelname, daq_get_block_gps(&daq), (int)time_now);
		mexWarnMsgTxt(msg);
		retval = -11;
		goto post_connect_error;
	    }

	    /*-------  Copy data to the right place in the buffer               */
	    {
		char* p;
		size_t nByt;

		p = daq_get_block_data(&daq) + req->offset;
		nByt = req->status;
		if ( complexity == mxCOMPLEX )
		{
		    switch( data_type_word( req->data_type ) )
		    {
		    case 4:
			{
			    static const mwSize word_size = 4;
			    mwSize elements = ( nByt / ( word_size * 2 ) );

			    stride_copy( dest_real, p, elements, word_size );
			    stride_copy( dest_imag, p + word_size, elements, word_size );
			}
			break;
		    default:
			retval = -13;
			goto post_connect_error;
		    }
		}
		else
		{
		    memcpy(dest_real, p, nByt);
		    dest_real = ((char*)dest_real + nByt);
		}

	    }

	    /*-------  Bump the current gps time                                */
	    time_now += daq_get_block_secs(&daq);
	}

	mxSetField(array_root, (int)idx, "data", dest_array);
    }

    /* the interpreter thread will handle the rest */

    /*-------------------------------------------------------------------
     * Fall through to do the proper cleanup
     *-------------------------------------------------------------------*/

 post_connect_error:
    /*-------------------------------------------------------------------
     * Connection was established so need to cleanup before reporting
     * the error
     *-------------------------------------------------------------------*/
    daq_disconnect(&daq);
 no_connect_error:
    /*-------------------------------------------------------------------
     * Connection was not established
     *-------------------------------------------------------------------*/
    daq_recv_shutdown (&daq);
    return retval;
}
