/* -*- Mode: C; c-basic-offset:4 ; indent-tabs-mode:nil ; -*- */
/*
 *  (C) 2014 Mellanox Technologies, Inc.
 *
 */


#ifdef USE_PMI2_API
#include "pmi2.h"
#else
#include "pmi.h"
#endif

#include "mpid_nem_impl.h"
#include "mxm_impl.h"

MPID_nem_netmod_funcs_t MPIDI_nem_mxm_funcs = {
    MPID_nem_mxm_init,
    MPID_nem_mxm_finalize,
#ifdef ENABLE_CHECKPOINTING
    NULL,
    NULL,
    NULL,
#endif
    MPID_nem_mxm_poll,
    MPID_nem_mxm_get_business_card,
    MPID_nem_mxm_connect_to_root,
    MPID_nem_mxm_vc_init,
    MPID_nem_mxm_vc_destroy,
    MPID_nem_mxm_vc_terminate,
    MPID_nem_mxm_anysource_iprobe,
    MPID_nem_mxm_anysource_improbe
};

static MPIDI_Comm_ops_t comm_ops = {
    MPID_nem_mxm_recv,  /* recv_posted */

    MPID_nem_mxm_send,  /* send */
    MPID_nem_mxm_send,  /* rsend */
    MPID_nem_mxm_ssend, /* ssend */
    MPID_nem_mxm_isend, /* isend */
    MPID_nem_mxm_isend, /* irsend */
    MPID_nem_mxm_issend,        /* issend */

    NULL,       /* send_init */
    NULL,       /* bsend_init */
    NULL,       /* rsend_init */
    NULL,       /* ssend_init */
    NULL,       /* startall */

    MPID_nem_mxm_cancel_send,   /* cancel_send */
    MPID_nem_mxm_cancel_recv,   /* cancel_recv */

    MPID_nem_mxm_probe, /* probe */
    MPID_nem_mxm_iprobe,        /* iprobe */
    MPID_nem_mxm_improbe        /* improbe */
};


static MPID_nem_mxm_module_t _mxm_obj;
MPID_nem_mxm_module_t *mxm_obj;

static int _mxm_init(int rank, int size);
static int _mxm_fini(void);
static int _mxm_connect(MPID_nem_mxm_ep_t * ep, const char *business_card,
                        MPID_nem_mxm_vc_area * vc_area);
static int _mxm_disconnect(MPID_nem_mxm_ep_t * ep);
static int _mxm_add_comm(MPID_Comm * comm, void *param);
static int _mxm_del_comm(MPID_Comm * comm, void *param);


#undef FUNCNAME
#define FUNCNAME MPID_nem_mxm_post_init
#undef FCNAME
#define FCNAME MPIDI_QUOTE(FUNCNAME)
int MPID_nem_mxm_post_init(void)
{
    int mpi_errno = MPI_SUCCESS;

#if MXM_API >= MXM_VERSION(3,1)
    if (_mxm_obj.conf.bulk_connect) {
        mxm_ep_wireup(_mxm_obj.mxm_ep);
    }
#endif

  fn_exit:
    return mpi_errno;
  fn_fail:
    goto fn_exit;
}


#undef FUNCNAME
#define FUNCNAME MPID_nem_mxm_init
#undef FCNAME
#define FCNAME MPIDI_QUOTE(FUNCNAME)
int MPID_nem_mxm_init(MPIDI_PG_t * pg_p, int pg_rank, char **bc_val_p, int *val_max_sz_p)
{
    int mpi_errno = MPI_SUCCESS;

    MPIDI_STATE_DECL(MPID_STATE_MXM_INIT);
    MPIDI_FUNC_ENTER(MPID_STATE_MXM_INIT);

    /* first make sure that our private fields in the vc and req fit into the area provided  */
    MPIU_Assert(sizeof(MPID_nem_mxm_vc_area) <= MPID_NEM_VC_NETMOD_AREA_LEN);
    MPIU_Assert(sizeof(MPID_nem_mxm_req_area) <= MPID_NEM_REQ_NETMOD_AREA_LEN);

    mpi_errno = _mxm_init(pg_rank, pg_p->size);
    if (mpi_errno)
        MPIU_ERR_POP(mpi_errno);

    mpi_errno = MPID_nem_mxm_get_business_card(pg_rank, bc_val_p, val_max_sz_p);
    if (mpi_errno)
        MPIU_ERR_POP(mpi_errno);

    mpi_errno =
        MPIDI_CH3I_Register_anysource_notification(MPID_nem_mxm_anysource_posted,
                                                   MPID_nem_mxm_anysource_matched);
    if (mpi_errno)
        MPIU_ERR_POP(mpi_errno);

    mpi_errno = MPID_nem_register_initcomp_cb(MPID_nem_mxm_post_init);
    if (mpi_errno)
        MPIU_ERR_POP(mpi_errno);

    mpi_errno = MPIDI_CH3U_Comm_register_create_hook(_mxm_add_comm, NULL);
    if (mpi_errno)
        MPIU_ERR_POP(mpi_errno);

    mpi_errno = MPIDI_CH3U_Comm_register_destroy_hook(_mxm_del_comm, NULL);
    if (mpi_errno)
        MPIU_ERR_POP(mpi_errno);

  fn_exit:
    MPIDI_FUNC_EXIT(MPID_STATE_MXM_INIT);
    return mpi_errno;
  fn_fail:
    goto fn_exit;
}

#undef FUNCNAME
#define FUNCNAME MPID_nem_mxm_finalize
#undef FCNAME
#define FCNAME MPIDI_QUOTE(FUNCNAME)
int MPID_nem_mxm_finalize(void)
{
    int mpi_errno = MPI_SUCCESS;

    MPIDI_STATE_DECL(MPID_STATE_MXM_FINALIZE);
    MPIDI_FUNC_ENTER(MPID_STATE_MXM_FINALIZE);

    mpi_errno = _mxm_fini();
    if (mpi_errno)
        MPIU_ERR_POP(mpi_errno);

  fn_exit:
    MPIDI_FUNC_EXIT(MPID_STATE_MXM_FINALIZE);
    return mpi_errno;
  fn_fail:
    goto fn_exit;
}

#undef FUNCNAME
#define FUNCNAME MPID_nem_mxm_get_business_card
#undef FCNAME
#define FCNAME MPIDI_QUOTE(FUNCNAME)
int MPID_nem_mxm_get_business_card(int my_rank, char **bc_val_p, int *val_max_sz_p)
{
    int mpi_errno = MPI_SUCCESS;
    int str_errno = MPIU_STR_SUCCESS;

    MPIDI_STATE_DECL(MPID_STATE_MXM_GET_BUSINESS_CARD);
    MPIDI_FUNC_ENTER(MPID_STATE_MXM_GET_BUSINESS_CARD);

    str_errno = MPIU_Str_add_binary_arg(bc_val_p, val_max_sz_p, MXM_MPICH_ENDPOINT_KEY,
                                        _mxm_obj.mxm_ep_addr, _mxm_obj.mxm_ep_addr_size);
    if (str_errno) {
        MPIU_ERR_CHKANDJUMP(str_errno == MPIU_STR_NOMEM, mpi_errno, MPI_ERR_OTHER, "**buscard_len");
        MPIU_ERR_SETANDJUMP(mpi_errno, MPI_ERR_OTHER, "**buscard");
    }

  fn_exit:
    MPIDI_FUNC_EXIT(MPID_STATE_MXM_GET_BUSINESS_CARD);
    return mpi_errno;
  fn_fail:
    goto fn_exit;
}

#undef FUNCNAME
#define FUNCNAME MPID_nem_mxm_connect_to_root
#undef FCNAME
#define FCNAME MPIDI_QUOTE(FUNCNAME)
int MPID_nem_mxm_connect_to_root(const char *business_card, MPIDI_VC_t * new_vc)
{
    int mpi_errno = MPI_SUCCESS;

    MPIDI_STATE_DECL(MPID_STATE_MXM_CONNECT_TO_ROOT);
    MPIDI_FUNC_ENTER(MPID_STATE_MXM_CONNECT_TO_ROOT);

    MPIU_ERR_SETFATAL(mpi_errno, MPI_ERR_OTHER, "**notimpl");

  fn_exit:
    MPIDI_FUNC_EXIT(MPID_STATE_MXM_CONNECT_TO_ROOT);
    return mpi_errno;
  fn_fail:
    goto fn_exit;
}

#undef FUNCNAME
#define FUNCNAME MPID_nem_mxm_vc_init
#undef FCNAME
#define FCNAME MPIDI_QUOTE(FUNCNAME)
int MPID_nem_mxm_vc_init(MPIDI_VC_t * vc)
{
    int mpi_errno = MPI_SUCCESS;
    MPIDI_CH3I_VC *vc_ch = &vc->ch;

    MPIDI_STATE_DECL(MPID_STATE_MXM_VC_INIT);
    MPIDI_FUNC_ENTER(MPID_STATE_MXM_VC_INIT);

    /* local connection is used for any source communication */
    MPIU_Assert(MPID_nem_mem_region.rank != vc->lpid);
    MPIU_DBG_MSG_FMT(CH3_CHANNEL, VERBOSE,
                     (MPIU_DBG_FDEST,
                      "[%i]=== connecting  to  %i  \n", MPID_nem_mem_region.rank, vc->lpid));
    {
        char *business_card;
        int val_max_sz;
#ifdef USE_PMI2_API
        val_max_sz = PMI2_MAX_VALLEN;
#else
        mpi_errno = PMI_KVS_Get_value_length_max(&val_max_sz);
        if (mpi_errno)
            MPIU_ERR_POP(mpi_errno);
#endif

        business_card = (char *) MPIU_Malloc(val_max_sz);
        mpi_errno = vc->pg->getConnInfo(vc->pg_rank, business_card, val_max_sz, vc->pg);
        if (mpi_errno)
            MPIU_ERR_POP(mpi_errno);

        VC_FIELD(vc, ctx) = vc;
        VC_FIELD(vc, mxm_ep) = &_mxm_obj.endpoint[vc->pg_rank];
        mpi_errno = _mxm_connect(&_mxm_obj.endpoint[vc->pg_rank], business_card, VC_BASE(vc));
        if (mpi_errno)
            MPIU_ERR_POP(mpi_errno);

        MPIU_Free(business_card);
    }

    MPIDI_CHANGE_VC_STATE(vc, ACTIVE);

    VC_FIELD(vc, pending_sends) = 0;

    vc->rndvSend_fn = NULL;
    vc->rndvRecv_fn = NULL;
    vc->sendNoncontig_fn = MPID_nem_mxm_SendNoncontig;
    vc->comm_ops = &comm_ops;

    vc_ch->iStartContigMsg = MPID_nem_mxm_iStartContigMsg;
    vc_ch->iSendContig = MPID_nem_mxm_iSendContig;

  fn_exit:
    MPIDI_FUNC_EXIT(MPID_STATE_MXM_VC_INIT);
    return mpi_errno;
  fn_fail:
    goto fn_exit;
}

#undef FUNCNAME
#define FUNCNAME MPID_nem_mxm_vc_destroy
#undef FCNAME
#define FCNAME MPIDI_QUOTE(FUNCNAME)
int MPID_nem_mxm_vc_destroy(MPIDI_VC_t * vc)
{
    int mpi_errno = MPI_SUCCESS;

    MPIDI_STATE_DECL(MPID_STATE_MXM_VC_DESTROY);
    MPIDI_FUNC_ENTER(MPID_STATE_MXM_VC_DESTROY);

    /* Do nothing because
     * finalize is called before vc destroy as result it is not possible
     * to destroy endpoint here
     */
#if 0
    if (VC_FIELD(vc, ctx) == vc) {
        mpi_errno = _mxm_disconnect(VC_FIELD(vc, mxm_ep));
        if (mpi_errno)
            MPIU_ERR_POP(mpi_errno);
    }
#endif

  fn_exit:
    MPIDI_FUNC_EXIT(MPID_STATE_MXM_VC_DESTROY);
    return mpi_errno;
  fn_fail:
    goto fn_exit;
}

#undef FUNCNAME
#define FUNCNAME MPID_nem_mxm_vc_terminate
#undef FCNAME
#define FCNAME MPIDI_QUOTE(FUNCNAME)
int MPID_nem_mxm_vc_terminate(MPIDI_VC_t * vc)
{
    int mpi_errno = MPI_SUCCESS;

    MPIDI_STATE_DECL(MPID_STATE_MXM_VC_TERMINATE);
    MPIDI_FUNC_ENTER(MPID_STATE_MXM_VC_TERMINATE);

    while ((VC_FIELD(vc, pending_sends)) > 0)
        MPID_nem_mxm_poll(FALSE);

    mpi_errno = MPIDI_CH3U_Handle_connection(vc, MPIDI_VC_EVENT_TERMINATED);
    if (mpi_errno)
        MPIU_ERR_POP(mpi_errno);

  fn_exit:
    MPIDI_FUNC_EXIT(MPID_STATE_MXM_VC_TERMINATE);
    return mpi_errno;
  fn_fail:
    goto fn_exit;
}

static int _mxm_init(int rank, int size)
{
    int mpi_errno = MPI_SUCCESS;
    mxm_error_t ret = MXM_OK;
    unsigned long cur_ver;

    cur_ver = mxm_get_version();
    if (cur_ver != MXM_API) {
        MPIU_DBG_MSG_FMT(CH3_CHANNEL, VERBOSE,
                         (MPIU_DBG_FDEST,
                          "WARNING: MPICH was compiled with MXM version %d.%d but version %ld.%ld detected.",
                          MXM_VERNO_MAJOR,
                          MXM_VERNO_MINOR,
                          (cur_ver >> MXM_MAJOR_BIT) & 0xff, (cur_ver >> MXM_MINOR_BIT) & 0xff));
    }

    _mxm_obj.compiletime_version = MXM_VERNO_STRING;
#if MXM_API >= MXM_VERSION(3,0)
    _mxm_obj.runtime_version = MPIU_Strdup(mxm_get_version_string());
#else
    _mxm_obj.runtime_version = MPIU_Malloc(sizeof(MXM_VERNO_STRING) + 10);
    snprintf(_mxm_obj.runtime_version, (sizeof(MXM_VERNO_STRING) + 9),
             "%ld.%ld", (cur_ver >> MXM_MAJOR_BIT) & 0xff, (cur_ver >> MXM_MINOR_BIT) & 0xff);
#endif

    if (cur_ver < MXM_VERSION(3, 2)) {
        _mxm_obj.conf.bulk_connect = 0;
        _mxm_obj.conf.bulk_disconnect = 0;
    }
    else {
        _mxm_obj.conf.bulk_connect = 1;
        _mxm_obj.conf.bulk_disconnect = 1;
    }

    ret = mxm_config_read_opts(&_mxm_obj.mxm_ctx_opts, &_mxm_obj.mxm_ep_opts, "MPICH2", NULL, 0);
    MPIU_ERR_CHKANDJUMP1(ret != MXM_OK,
                         mpi_errno, MPI_ERR_OTHER,
                         "**mxm_config_read_opts",
                         "**mxm_config_read_opts %s", mxm_error_string(ret));

    ret = mxm_init(_mxm_obj.mxm_ctx_opts, &_mxm_obj.mxm_context);
    MPIU_ERR_CHKANDJUMP1(ret != MXM_OK,
                         mpi_errno, MPI_ERR_OTHER,
                         "**mxm_init", "**mxm_init %s", mxm_error_string(ret));

    ret =
        mxm_set_am_handler(_mxm_obj.mxm_context, MXM_MPICH_HID_ADI_MSG, MPID_nem_mxm_get_adi_msg,
                           MXM_AM_FLAG_THREAD_SAFE);
    MPIU_ERR_CHKANDJUMP1(ret != MXM_OK, mpi_errno, MPI_ERR_OTHER, "**mxm_set_am_handler",
                         "**mxm_set_am_handler %s", mxm_error_string(ret));

    ret = mxm_mq_create(_mxm_obj.mxm_context, MXM_MPICH_MQ_ID, &_mxm_obj.mxm_mq);
    MPIU_ERR_CHKANDJUMP1(ret != MXM_OK,
                         mpi_errno, MPI_ERR_OTHER,
                         "**mxm_mq_create", "**mxm_mq_create %s", mxm_error_string(ret));

    ret = mxm_ep_create(_mxm_obj.mxm_context, _mxm_obj.mxm_ep_opts, &_mxm_obj.mxm_ep);
    MPIU_ERR_CHKANDJUMP1(ret != MXM_OK,
                         mpi_errno, MPI_ERR_OTHER,
                         "**mxm_ep_create", "**mxm_ep_create %s", mxm_error_string(ret));

    _mxm_obj.mxm_ep_addr_size = MXM_MPICH_MAX_ADDR_SIZE;
    ret = mxm_ep_get_address(_mxm_obj.mxm_ep, &_mxm_obj.mxm_ep_addr, &_mxm_obj.mxm_ep_addr_size);
    MPIU_ERR_CHKANDJUMP1(ret != MXM_OK,
                         mpi_errno, MPI_ERR_OTHER,
                         "**mxm_ep_get_address", "**mxm_ep_get_address %s", mxm_error_string(ret));

    _mxm_obj.mxm_rank = rank;
    _mxm_obj.mxm_np = size;
    _mxm_obj.endpoint =
        (MPID_nem_mxm_ep_t *) MPIU_Malloc(_mxm_obj.mxm_np * sizeof(MPID_nem_mxm_ep_t));
    memset(_mxm_obj.endpoint, 0, _mxm_obj.mxm_np * sizeof(MPID_nem_mxm_ep_t));

    list_init(&_mxm_obj.free_queue);
    list_grow_mxm_req(&_mxm_obj.free_queue);
    MPIU_Assert(list_length(&_mxm_obj.free_queue) == MXM_MPICH_MAX_REQ);

    mxm_obj = &_mxm_obj;

  fn_exit:
    return mpi_errno;
  fn_fail:
    goto fn_exit;
}

static int _mxm_fini(void)
{
    int mpi_errno = MPI_SUCCESS;

    if (_mxm_obj.mxm_context) {

        while (!list_is_empty(&_mxm_obj.free_queue)) {
            MPIU_Free(list_dequeue(&_mxm_obj.free_queue));
        }

#if MXM_API >= MXM_VERSION(3,1)
        if (_mxm_obj.conf.bulk_disconnect) {
            mxm_ep_powerdown(_mxm_obj.mxm_ep);
        }
#endif

        while (_mxm_obj.mxm_np) {
            _mxm_disconnect(&(_mxm_obj.endpoint[--_mxm_obj.mxm_np]));
        }

        if (_mxm_obj.endpoint)
            MPIU_Free(_mxm_obj.endpoint);

        if (_mxm_obj.mxm_ep)
            mxm_ep_destroy(_mxm_obj.mxm_ep);

        if (_mxm_obj.mxm_mq)
            mxm_mq_destroy(_mxm_obj.mxm_mq);

        mxm_cleanup(_mxm_obj.mxm_context);
        _mxm_obj.mxm_context = NULL;

        mxm_config_free_ep_opts(_mxm_obj.mxm_ep_opts);
        mxm_config_free_context_opts(_mxm_obj.mxm_ctx_opts);

        MPIU_Free(_mxm_obj.runtime_version);
    }

  fn_exit:
    return mpi_errno;
  fn_fail:
    goto fn_exit;
}

static int _mxm_connect(MPID_nem_mxm_ep_t * ep, const char *business_card,
                        MPID_nem_mxm_vc_area * vc_area)
{
    int mpi_errno = MPI_SUCCESS;
    int str_errno = MPIU_STR_SUCCESS;
    mxm_error_t ret = MXM_OK;
    char mxm_ep_addr[MXM_MPICH_MAX_ADDR_SIZE];
    int len = 0;

    str_errno =
        MPIU_Str_get_binary_arg(business_card, MXM_MPICH_ENDPOINT_KEY, mxm_ep_addr,
                                sizeof(mxm_ep_addr), &len);
    MPIU_ERR_CHKANDJUMP(str_errno, mpi_errno, MPI_ERR_OTHER, "**buscard");

    ret = mxm_ep_connect(_mxm_obj.mxm_ep, mxm_ep_addr, &ep->mxm_conn);
    MPIU_ERR_CHKANDJUMP1(ret != MXM_OK,
                         mpi_errno, MPI_ERR_OTHER,
                         "**mxm_ep_connect", "**mxm_ep_connect %s", mxm_error_string(ret));

    mxm_conn_ctx_set(ep->mxm_conn, vc_area->ctx);

    list_init(&ep->free_queue);
    list_grow_mxm_req(&ep->free_queue);
    MPIU_Assert(list_length(&ep->free_queue) == MXM_MPICH_MAX_REQ);

  fn_exit:
    return mpi_errno;
  fn_fail:
    goto fn_exit;
}

static int _mxm_disconnect(MPID_nem_mxm_ep_t * ep)
{
    int mpi_errno = MPI_SUCCESS;
    mxm_error_t ret = MXM_OK;

    MPIU_Assert(ep);

    if (ep->mxm_conn) {
        ret = mxm_ep_disconnect(ep->mxm_conn);
        MPIU_ERR_CHKANDJUMP1(ret != MXM_OK,
                             mpi_errno, MPI_ERR_OTHER,
                             "**mxm_ep_disconnect",
                             "**mxm_ep_disconnect %s", mxm_error_string(ret));

        while (!list_is_empty(&ep->free_queue)) {
            MPIU_Free(list_dequeue(&ep->free_queue));
        }
    }

  fn_exit:
    return mpi_errno;
  fn_fail:
    goto fn_exit;
}

static int _mxm_add_comm(MPID_Comm * comm, void *param)
{
    int mpi_errno = MPI_SUCCESS;
    mxm_error_t ret = MXM_OK;
    mxm_mq_h mxm_mq;

    _dbg_mxm_output(6, "Add COMM comm %p (context %d rank %d) \n",
                    comm, comm->context_id, comm->rank);

    ret = mxm_mq_create(_mxm_obj.mxm_context, comm->context_id, &mxm_mq);
    MPIU_ERR_CHKANDJUMP1(ret != MXM_OK,
                         mpi_errno, MPI_ERR_OTHER,
                         "**mxm_mq_create", "**mxm_mq_create %s", mxm_error_string(ret));

    comm->ch.netmod_comm = (void *) mxm_mq;

  fn_exit:
    return mpi_errno;
  fn_fail:
    goto fn_exit;
}

static int _mxm_del_comm(MPID_Comm * comm, void *param)
{
    int mpi_errno = MPI_SUCCESS;
    mxm_mq_h mxm_mq = (mxm_mq_h) comm->ch.netmod_comm;

    _dbg_mxm_output(6, "Del COMM comm %p (context %d rank %d) \n",
                    comm, comm->context_id, comm->rank);

    if (mxm_mq)
        mxm_mq_destroy(mxm_mq);

    comm->ch.netmod_comm = NULL;

  fn_exit:
    return mpi_errno;
  fn_fail:
    goto fn_exit;
}
