/*
 * A little client for Mercurial command server using socket
 *
 * Copyright (c) 2011 Yuya Nishihara <yuya@tcha.org>
 *
 * This software may be used and distributed according to the terms of the
 * GNU General Public License version 2 or any later version.
 *
 * About Mercurial Command Server:
 * http://mercurial.selenic.com/wiki/CommandServer
 */

#include <arpa/inet.h>  // for ntohl(), htonl()
#include <assert.h>
#include <ctype.h>
#include <errno.h>
#include <fcntl.h>
#include <signal.h>
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <sys/socket.h>
#include <sys/stat.h>
#include <sys/un.h>
#include <termios.h>
#include <unistd.h>

#include "hgclient.h"
#include "util.h"

enum {
    STATE_INITIALIZING = 0,
    STATE_RUNNING,
    STATE_KILLING
};

enum {
    UIATTR_FORMATTED = 1,
    UIATTR_INTERACTIVE = 2,
    UIATTR_PROGRESSTTY = 4,
};

enum {
    CAP_GETENCODING  = 0x0001,
    CAP_RUNCOMMAND   = 0x0002,
    // cHg extension:
    CAP_CHDIR        = 0x0100,
    CAP_GETPAGER     = 0x0200,
    CAP_GETPID       = 0x0400,
    CAP_SETCHGUIATTR = 0x0800,
    CAP_SETENV       = 0x1000,
};

typedef struct {
    const char *name;
    unsigned int flag;
} cappair_t;

static const cappair_t captable[] = {
    { "getencoding",  CAP_GETENCODING },
    { "runcommand",   CAP_RUNCOMMAND },
    { "chdir",        CAP_CHDIR },
    { "getpager",     CAP_GETPAGER },
    { "getpid",       CAP_GETPID },
    { "setchguiattr", CAP_SETCHGUIATTR },
    { "setenv",       CAP_SETENV },
    { NULL,           0 },  // terminator
};

typedef struct {
    char ch;
    char *data;
    size_t maxdatasize;
    size_t datasize;
} context_t;

struct hgclient_tag_ {
    int sockfd;
    pid_t workerpid;
    unsigned int state;
    context_t ctx;
    unsigned int uiattrs;
    unsigned int capflags;
};

static const size_t defaultdatasize = 4096;

static void initcontext(context_t *ctx)
{
    ctx->ch = '\0';
    ctx->data = malloc(defaultdatasize);
    ctx->maxdatasize = (ctx->data) ? defaultdatasize : 0;
    ctx->datasize = 0;
    debugmsg("initialize context buffer with size %zu", ctx->maxdatasize);
}

static void enlargecontext(context_t *ctx, size_t newsize)
{
    if (newsize <= ctx->maxdatasize) return;

    newsize = defaultdatasize * (
            (newsize + defaultdatasize - 1) / defaultdatasize);
    char *p = realloc(ctx->data, newsize);
    if (!p) abortmsg("failed to allocate buffer");
    ctx->data = p;
    ctx->maxdatasize = newsize;
    debugmsg("enlarge context buffer to %zu", ctx->maxdatasize);
}

static void freecontext(context_t *ctx)
{
    debugmsg("free context buffer");
    free(ctx->data);
    ctx->data = NULL;
    ctx->maxdatasize = 0;
    ctx->datasize = 0;
}

// Read channeled response from cmdserver
static void readchannel(hgclient_t *hgc)
{
    assert(hgc);

    ssize_t rsize = recv(hgc->sockfd, &hgc->ctx.ch, sizeof(hgc->ctx.ch), 0);
    if (rsize != sizeof(hgc->ctx.ch)) abortmsg("failed to read channel");

    uint32_t datasize_n;
    rsize = recv(hgc->sockfd, &datasize_n, sizeof(datasize_n), 0);
    if (rsize != sizeof(datasize_n)) abortmsg("failed to read data size");

    // datasize denotes the maximum size to write if input request
    hgc->ctx.datasize = ntohl(datasize_n);
    enlargecontext(&hgc->ctx, hgc->ctx.datasize);

    if (isupper(hgc->ctx.ch)) return;  // input request

    size_t cursize = 0;
    while (cursize < hgc->ctx.datasize) {
        rsize = recv(hgc->sockfd, hgc->ctx.data + cursize,
                hgc->ctx.datasize - cursize, 0);
        if (rsize < 0) abortmsg("failed to read data block");
        cursize += rsize;
    }
}

static void sendall(int sockfd, const void *data, size_t datasize)
{
    const char *p = data;
    const char * const endp = p + datasize;
    while (p < endp) {
        ssize_t r = send(sockfd, p, endp - p, 0);
        if (r < 0) abortmsg("cannot communicate (errno = %d)", errno);
        p += r;
    }
}

// Write lengh-data block to cmdserver
static void writeblock(const hgclient_t *hgc)
{
    assert(hgc);

    const uint32_t datasize_n = htonl(hgc->ctx.datasize);
    sendall(hgc->sockfd, &datasize_n, sizeof(datasize_n));

    sendall(hgc->sockfd, hgc->ctx.data, hgc->ctx.datasize);
}

static void writeblockrequest(const hgclient_t *hgc, const char *chcmd)
{
    debugmsg("request %s, block size %zu", chcmd, hgc->ctx.datasize);

    char buf[strlen(chcmd) + 1];
    memcpy(buf, chcmd, sizeof(buf) - 1);
    buf[sizeof(buf) - 1] = '\n';
    sendall(hgc->sockfd, buf, sizeof(buf));

    writeblock(hgc);
}

// Build '\0'-separated list of args. argsize < 0 denotes that args are
// terminated by NULL.
static void packcmdargs(context_t *ctx, const char * const args[],
        ssize_t argsize)
{
    ctx->datasize = 0;
    const char * const * const end = (argsize >= 0) ? args + argsize : NULL;
    for (const char * const *it = args; it != end && *it; ++it) {
        const size_t n = strlen(*it) + 1;  // include '\0'
        enlargecontext(ctx, ctx->datasize + n);
        memcpy(ctx->data + ctx->datasize, *it, n);
        ctx->datasize += n;
    }

    if (ctx->datasize > 0) --ctx->datasize;  // strip last '\0'
}

// Execute the requested command and write exit code
static void handlesystemrequest(hgclient_t *hgc)
{
    context_t *ctx = &hgc->ctx;
    enlargecontext(ctx, ctx->datasize + 1);
    ctx->data[ctx->datasize] = '\0';
    int32_t r = system(ctx->data);
    if (r >= 0) r = WEXITSTATUS(r);

    uint32_t r_n = htonl(r);
    memcpy(ctx->data, &r_n, sizeof(r_n));
    ctx->datasize = sizeof(r_n);
    writeblock(hgc);
}

static void handlereadrequest(hgclient_t *hgc)
{
    context_t *ctx = &hgc->ctx;
    size_t r = fread(ctx->data, sizeof(ctx->data[0]), ctx->datasize, stdin);
    if (hgc->state == STATE_KILLING) return;
    ctx->datasize = r;
    writeblock(hgc);
}

// Read single-line
static void handlereadlinerequest(hgclient_t *hgc)
{
    context_t *ctx = &hgc->ctx;
    if (!fgets(ctx->data, ctx->datasize, stdin)) ctx->data[0] = '\0';
    if (hgc->state == STATE_KILLING) return;
    ctx->datasize = strlen(ctx->data);
    writeblock(hgc);
}

// Read single-line without echo
static void handlereadpassrequest(hgclient_t *hgc)
{
    struct termios origattr;
    if (tcgetattr(fileno(stdin), &origattr) < 0)
        abortmsg("failed to get terminal attribute (errno = %d)", errno);
    struct termios newattr = origattr;
    newattr.c_lflag &= ~(ECHO | ISIG);
    if (tcsetattr(fileno(stdin), TCSAFLUSH, &newattr) < 0)
        abortmsg("failed to change terminal attribute (errno = %d)", errno);

    handlereadlinerequest(hgc);

    if (tcsetattr(fileno(stdin), TCSAFLUSH, &origattr) < 0)
        abortmsg("failed to rollback terminal attribute (errno = %d)", errno);

    fputc('\n', stdout);
}

// Read response of command execution until receiving 'r'-esult
static void handleresponse(hgclient_t *hgc)
{
    for (;;) {
        readchannel(hgc);
        context_t *ctx = &hgc->ctx;
        debugmsg("response read from channel %c, size %zu",
                ctx->ch, ctx->datasize);
        switch (ctx->ch) {
        case 'o':
            if (ctx->datasize > 0)
                fwrite(ctx->data, sizeof(ctx->data[0]), ctx->datasize, stdout);
            else
                fflush(stdout);
            break;
        case 'e':
            if (ctx->datasize > 0)
                fwrite(ctx->data, sizeof(ctx->data[0]), ctx->datasize, stderr);
            else
                fflush(stderr);
            break;
        case 'r':
            return;
        case 's':
            handlesystemrequest(hgc);
            break;
        case 'I':
            handlereadrequest(hgc);
            break;
        case 'L':
            handlereadlinerequest(hgc);
            break;
        case 'P':
            handlereadpassrequest(hgc);
            break;
        }
    }
}

static unsigned int parsecapabilities(const char *s, const char *e)
{
    unsigned int flags = 0;
    while (s < e) {
        const char *t = strchr(s, ' ');
        if (!t || t > e)
            t = e;
        const cappair_t *cap;
        for (cap = captable; cap->flag; ++cap) {
            size_t n = t - s;
            if (strncmp(s, cap->name, n) == 0 && strlen(cap->name) == n) {
                flags |= cap->flag;
                break;
            }
        }
        s = t + 1;
    }
    return flags;
}

static void readhello(hgclient_t *hgc)
{
    readchannel(hgc);
    enlargecontext(&hgc->ctx, hgc->ctx.datasize + 1);
    hgc->ctx.data[hgc->ctx.datasize] = '\0';
    debugmsg("hello received: %s (size = %zu)", hgc->ctx.data,
            hgc->ctx.datasize);

    const char *s = hgc->ctx.data;
    const char * const dataend = hgc->ctx.data + hgc->ctx.datasize;
    while (s < dataend) {
        const char *t = strchr(s, ':');
        if (!t || t[1] != ' ')
            break;
        const char *u = strchr(t + 2, '\n');
        if (!u)
            u = dataend;
        if (strncmp(s, "capabilities:", t - s + 1) == 0)
            hgc->capflags = parsecapabilities(t + 2, u);
        s = u + 1;
    }
    debugmsg("capflags=0x%04x", hgc->capflags);
}

static void getpidofserver(hgclient_t *hgc)
{
    static const char cmd[] = "getpid\n";
    sendall(hgc->sockfd, cmd, sizeof(cmd) - 1);
    readchannel(hgc);

    int32_t pids_n[1];
    if (hgc->ctx.datasize < sizeof(pids_n))
        abortmsg("unexpected size of pid");
    memcpy(pids_n, hgc->ctx.data, sizeof(pids_n));
    hgc->workerpid = ntohl(pids_n[0]);
    debugmsg("pid of server: worker=%d", hgc->workerpid);
}

static void chdirtocwd(hgclient_t *hgc)
{
    if (!getcwd(hgc->ctx.data, hgc->ctx.maxdatasize))
        abortmsg("failed to getcwd (errno = %d)", errno);
    hgc->ctx.datasize = strlen(hgc->ctx.data);
    writeblockrequest(hgc, "chdir");
}

static void setenvofserver(hgclient_t *hgc, const char * const envp[])
{
    packcmdargs(&hgc->ctx, envp, /*argsize*/ -1);
    writeblockrequest(hgc, "setenv");
}

// Set ui attributes according to the current environment
static void setuiattr(hgclient_t *hgc)
{
    debugmsg("set ui attributes");
    assert(hgc);

    const char *plain = getenv("HGPLAIN");
    const char *plainexcept = getenv("HGPLAINEXCEPT");

    context_t *ctx = &hgc->ctx;
    ctx->datasize = snprintf(ctx->data, ctx->maxdatasize - 1,
            "progress.assume-tty: %d\n"
            "ui.formatted: %d\n"
            "ui.interactive: %d\n"
            "ui._plain: %d\n"
            "ui._plainexcept: %s\n"
            "ui._termwidth: %d\n",
            !!(hgc->uiattrs & UIATTR_PROGRESSTTY),
            !!(hgc->uiattrs & UIATTR_FORMATTED),
            !!(hgc->uiattrs & UIATTR_INTERACTIVE),
            !!plain,
            plainexcept ? plainexcept : "",
            gettermwidth());
    ctx->data[ctx->maxdatasize - 1] = '\0';
    writeblockrequest(hgc, "setchguiattr");
}

/*!
 * Open connection to per-user cmdserver
 *
 * If no background server running, this spawns new cmdserver.
 *
 * @param envp  list of environment variables in "NAME=VALUE" format,
 *              terminated by NULL.
 */
hgclient_t* hgc_open(const char *sockname, const char * const envp[])
{
    int fd = socket(AF_UNIX, SOCK_STREAM, 0);
    if (fd < 0) abortmsg("cannot create socket (errno = %d)", errno);

    // don't keep fd on fork(), so that it can be closed when the parent
    // process get terminated.
    int flags = fcntl(fd, F_GETFD);
    if (flags < 0) abortmsg("cannot get flags of socket (errno = %d)", errno);
    if (fcntl(fd, F_SETFD, flags | FD_CLOEXEC) < 0)
        abortmsg("cannot set flags of socket (errno = %d)", errno);

    struct sockaddr_un addr;
    addr.sun_family = AF_UNIX;
    strncpy(addr.sun_path, sockname, sizeof(addr.sun_path));
    addr.sun_path[sizeof(addr.sun_path) - 1] = '\0';

    debugmsg("connect to %s", addr.sun_path);
    int r = connect(fd, (struct sockaddr*) &addr, sizeof(addr));
    if (r < 0) {
        close(fd);
        if (errno == ENOENT || errno == ECONNREFUSED)
            return NULL;
        abortmsg("cannot connect to %s (errno = %d)",
                addr.sun_path, errno);
    }

    hgclient_t *hgc = malloc(sizeof(hgclient_t));
    if (!hgc)
        abortmsg("failed to allocate hgclient object");
    memset(hgc, 0, sizeof(*hgc));
    hgc->sockfd = fd;
    initcontext(&hgc->ctx);

    if (isatty(fileno(stdout))) hgc->uiattrs |= UIATTR_FORMATTED;
    if (isatty(fileno(stdin))) hgc->uiattrs |= UIATTR_INTERACTIVE;
    if (isatty(fileno(stderr))) hgc->uiattrs |= UIATTR_PROGRESSTTY;

    readhello(hgc);
    if (!(hgc->capflags & CAP_RUNCOMMAND))
        abortmsg("insufficient capability: runcommand");
    if (hgc->capflags & CAP_GETPID)
        getpidofserver(hgc);
    if (hgc->capflags & CAP_CHDIR)
        chdirtocwd(hgc);
    if (envp && (hgc->capflags & CAP_SETENV))
        setenvofserver(hgc, envp);
    if (hgc->capflags & CAP_SETCHGUIATTR)
        setuiattr(hgc);

    hgc->state = STATE_RUNNING;
    return hgc;
}

/*!
 * Close connection and free allocated memory
 */
void hgc_close(hgclient_t *hgc)
{
    assert(hgc);
    freecontext(&hgc->ctx);
    close(hgc->sockfd);
    free(hgc);
}

/*!
 * Send signal to the worker server
 */
void hgc_kill(hgclient_t *hgc, int sig)
{
    assert(hgc);
    if (hgc->workerpid <= 0)
        abortmsg("worker pid not known (insufficient capability)");
    if (kill(hgc->workerpid, sig) < 0)
        abortmsg("cannot kill %d (errno = %d)", hgc->workerpid, errno);
    hgc->state = STATE_KILLING;
    fclose(stdin);  // so that blocking read can be ended
}

/*!
 * Execute the specified Mercurial command
 *
 * @return result code
 */
int hgc_runcommand(hgclient_t *hgc, const char * const args[], size_t argsize)
{
    assert(hgc);

    packcmdargs(&hgc->ctx, args, argsize);
    writeblockrequest(hgc, "runcommand");
    handleresponse(hgc);

    int32_t exitcode_n;
    if (hgc->ctx.datasize != sizeof(exitcode_n)) {
        abortmsg("unexpected size of exitcode");
    }
    memcpy(&exitcode_n, hgc->ctx.data, sizeof(exitcode_n));
    return ntohl(exitcode_n);
}

/*!
 * Get pager command for the given Mercurial command args
 *
 * If no pager enabled, returns NULL. The return value becomes invalid
 * once you run another request to hgc.
 */
const char* hgc_getpager(hgclient_t *hgc, const char * const args[],
        size_t argsize)
{
    assert(hgc);

    if (!(hgc->uiattrs & UIATTR_FORMATTED) || !(hgc->capflags & CAP_GETPAGER))
        return NULL;

    packcmdargs(&hgc->ctx, args, argsize);
    writeblockrequest(hgc, "getpager");
    handleresponse(hgc);

    if (hgc->ctx.datasize < 1 || hgc->ctx.data[0] == '\0') return NULL;
    enlargecontext(&hgc->ctx, hgc->ctx.datasize + 1);
    hgc->ctx.data[hgc->ctx.datasize] = '\0';
    return hgc->ctx.data;
}
