/*
 * Copyright (c) 2014, 2015  Machine Zone, Inc.
 * 
 * Original author: Lev Walkin <lwalkin@machinezone.com>
 * 
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions
 * are met:
 * 1. Redistributions of source code must retain the above copyright
 *    notice, this list of conditions and the following disclaimer.
 * 2. Redistributions in binary form must reproduce the above copyright
 *    notice, this list of conditions and the following disclaimer in the
 *    documentation and/or other materials provided with the distribution.

 * THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND
 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
 * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
 * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
 * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
 * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
 * SUCH DAMAGE.
 */
#include <stdint.h>
#include <stdlib.h>
#include <string.h>
#include <stdio.h>
#include <math.h>
#include <assert.h>

#include "tcpkali_data.h"
#include "tcpkali_expr.h"
#include "tcpkali_websocket.h"
#include "tcpkali_transport.h"


/*
 * Helper function to sort headers first, messages last.
 */
static int snippet_compare_cb(const void *ap, const void *bp) {
    const struct message_collection_snippet *a = ap;
    const struct message_collection_snippet *b = bp;
    int ka = MSK_PURPOSE(a);
    int kb = MSK_PURPOSE(b);

    if(ka < kb) return -1;
    if(ka > kb) return 1;

    if(a->sort_index < b->sort_index)
        return -1;
    if(a->sort_index > b->sort_index)
        return 1;

    return 0;
}

void
message_collection_finalize(struct message_collection *mc, int as_websocket, const char *hostport, const char *path) {
    const char ws_http_headers_fmt[] =
        "GET /%s HTTP/1.1\r\n"
        "Host: %s\r\n"
        "Upgrade: websocket\r\n"
        "Connection: Upgrade\r\n"
        "Sec-WebSocket-Key: x3JJHMbDL1EzLkh9GBhXDw==\r\n"
        "Sec-WebSocket-Version: 13\r\n"
        "\r\n";

    assert(mc->state == MC_EMBRYONIC);

    if(as_websocket) {
        ssize_t estimated_size = snprintf("", 0,
                                          ws_http_headers_fmt,
                                          path, hostport);
        assert(estimated_size >= (ssize_t)sizeof(ws_http_headers_fmt));
        char http_headers[estimated_size + 1];
        ssize_t h_size = snprintf(http_headers, estimated_size + 1,
                                  ws_http_headers_fmt, path, hostport);
        assert(h_size == estimated_size);

        const int DISABLE_UNESCAPE = 0;
        message_collection_add(mc, MSK_PURPOSE_HTTP_HEADER,
                               http_headers, h_size, DISABLE_UNESCAPE);

        mc->state = MC_FINALIZED_WEBSOCKET;
    } else {
        mc->state = MC_FINALIZED_PLAIN_TCP;
    }

    /* Order hdr > first_msg > msg. */
    qsort(mc->snippets, mc->snippets_count, sizeof(mc->snippets[0]),
          snippet_compare_cb);
}


/*
 * If the payload is less then target_size,
 * replicate it several times so the total buffer exceeds target_size.
 */
void replicate_payload(struct transport_data_spec *data, size_t target_size) {
    size_t payload_size = data->total_size - data->once_size;

    assert(!(data->flags & TDS_FLAG_REPLICATED));

    if(!payload_size) {
        /* Can't blow up an empty buffer. */
    } else if(payload_size >= target_size) {
        /* Data is large enough to avoid blowing up. */
    } else {
        /* The optimum target_size is size(L2)/k */
        size_t n = ceil(((double)target_size)/payload_size);
        size_t new_payload_size = n * payload_size;
        size_t once_offset = data->once_size;
        char *p = realloc(data->ptr, once_offset + new_payload_size + 1);
        void *msg_data = p + once_offset;
        assert(p);
        for(size_t i = 1; i < n; i++) {
            memcpy(&p[once_offset + i * payload_size], msg_data, payload_size);
        }
        p[once_offset + new_payload_size] = '\0';
        data->ptr = p;
        data->total_size = once_offset + new_payload_size;
    }

    /*
     * Always mark as replicated, even if we have not increased the size
     * At least, replication procedure was applied.
     */
    data->flags |= TDS_FLAG_REPLICATED;
}


void
message_collection_add(struct message_collection *mc,
                            enum mc_snippet_kind kind,
                            void *data, size_t size,
                            int unescape) {

    assert(mc->state == MC_EMBRYONIC);

    /* Verify that messages are properly kinded. */
    switch(kind) {
    case MSK_PURPOSE_HTTP_HEADER:
        break;
    case MSK_PURPOSE_FIRST_MSG:
    case MSK_PURPOSE_MESSAGE:
        kind |= MSK_FRAMING_ALLOWED;
        break;
    default:
        assert(!"Cannot add message with non-MSK_PURPOSE_ kind");
        return; /* Unreachable */
    }

    /* Reallocate snippets array, if needed. */
    if(mc->snippets_count >= mc->snippets_size) {
        mc->snippets_size = 2 * (mc->snippets_size ? mc->snippets_size : 8);
        struct message_collection_snippet *ptr = realloc(mc->snippets,
                        mc->snippets_size * sizeof(mc->snippets[0]));
        if(!ptr) {
            /* TODO: make snippets[] dynamic. */
            fprintf(stderr, "Too many --message "
                            "or --first-message arguments\n");
            exit(1);
        }
        memset(&ptr[mc->snippets_count], 0,
               (mc->snippets_size - mc->snippets_count) * sizeof(ptr[0]));
        mc->snippets = ptr;
    }

    char *p = malloc(size + 1);
    assert(p);
    memcpy(p, data, size);
    p[size] = 0;

    if(unescape) unescape_data(p, &size);

    struct message_collection_snippet *snip;
    snip = &mc->snippets[mc->snippets_count++];
    snip->data = p;
    snip->size = size;
    snip->expr = 0;
    snip->flags = kind;
    snip->sort_index = mc->snippets_count;

    const int ENABLE_DEBUG = 1;
    tk_expr_t *expr = 0;
    switch(parse_expression(&expr, p, size, ENABLE_DEBUG)) {
    case 0:
        /* Trivial expression, does not change wrt. environment. */
        free_expression(expr);
        /* Just use the data instead. */
        break;
    case 1:
        snip->expr = expr;
        snip->flags |= MSK_EXPRESSION_FOUND;
        mc->expressions_found++;
        break;
    case -1:
        /* parse_expression() would have already printed the failure reason */
        exit(1);
    }
}

/*
 * Give the largest size the message can possibly occupy.
 */
size_t
message_collection_estimate_size(struct message_collection *mc,
                                 enum mc_snippet_kind kind_and,
                                 enum mc_snippet_kind kind_equal) {
    size_t total_size = 0;
    size_t i;

    assert(mc->state != MC_EMBRYONIC);

    for(i = 0; i < mc->snippets_count; i++) {
        struct message_collection_snippet *snip = &mc->snippets[i];

        /* Match pattern */
        if((snip->flags & kind_and) != kind_equal)
            continue;

        if(snip->flags & MSK_EXPRESSION_FOUND) {
            total_size += snip->expr->estimate_size;
        } else {
            total_size += snip->size;
        }
        total_size +=
                (mc->state == MC_FINALIZED_WEBSOCKET
                && (snip->flags & MSK_FRAMING_ALLOWED))
                        ? WEBSOCKET_MAX_FRAME_HDR_SIZE : 0;
    }
    return total_size;
}

struct transport_data_spec *
transport_spec_from_message_collection(struct transport_data_spec *out_spec, struct message_collection *mc, expr_callback_f optional_cb, void *expr_cb_key, enum transport_websocket_side tws_side) {

    /*
     * If expressions found we can not create a transport data specification
     * from this collection directly. Need to go through expression evaluator.
     */
    if(mc->expressions_found) {
        if(!optional_cb)
            return NULL;
    }

    size_t estimate_size = message_collection_estimate_size(mc, 0, 0);

    struct transport_data_spec *data_spec;
    /* out_spec is expected to be 0-filled, if given. */
    data_spec = out_spec ? out_spec : calloc(1, sizeof(*data_spec));
    assert(data_spec);
    data_spec->ptr = malloc(estimate_size + 1);
    assert(data_spec->ptr);

    enum websocket_side ws_side =
        (tws_side == TWS_SIDE_CLIENT) ? WS_SIDE_CLIENT : WS_SIDE_SERVER;

    size_t i;
    for(i = 0; i < mc->snippets_count; i++) {
        struct message_collection_snippet *snip = &mc->snippets[i];

        void *data = snip->data;
        size_t size = snip->size;

        if(snip->flags & MSK_EXPRESSION_FOUND) {
            ssize_t reified_size;
            char *tptr = (char *)data_spec->ptr + data_spec->total_size;
            assert(estimate_size >= data_spec->total_size
                                    + snip->expr->estimate_size);
            reified_size = eval_expression(&tptr,
                            estimate_size - data_spec->total_size,
                            snip->expr, optional_cb, expr_cb_key, 0);
            assert(reified_size >= 0);
            data = 0;
            size = reified_size;
        }

        size_t ws_frame_size = 0;
        if(mc->state == MC_FINALIZED_WEBSOCKET) {
            /* Do not construct WebSocket/HTTP header. */
            if((ws_side == WS_SIDE_SERVER)
               && (snip->flags & MSK_PURPOSE_HTTP_HEADER))
                continue;

            if(snip->flags & MSK_FRAMING_ALLOWED) {
                if(snip->flags & MSK_EXPRESSION_FOUND) {
                    uint8_t tmpbuf[WEBSOCKET_MAX_FRAME_HDR_SIZE];
                    /* Save the websocket frame elsewhere temporarily */
                    ws_frame_size = websocket_frame_header(size,
                                            tmpbuf, sizeof(tmpbuf),
                                            ws_side);
                    /* Move the data to the right to make space for framing */
                    memmove((char *)data_spec->ptr + data_spec->total_size
                                                   + ws_frame_size,
                            (char *)data_spec->ptr + data_spec->total_size,
                            size);
                    /* Prepend the websocket frame */
                    memcpy((char *)data_spec->ptr + data_spec->total_size,
                           tmpbuf, ws_frame_size);
                } else {
                    ws_frame_size = websocket_frame_header(size,
                                (uint8_t *)data_spec->ptr+data_spec->total_size,
                                estimate_size - data_spec->total_size, ws_side);
                }
            }
        }

        /*
         * We only add data if it has not already been added.
         */
        size_t framed_snippet_size = ws_frame_size + size;
        if(data) {  /* Data is not there if expression is used. */
            memcpy((char *)data_spec->ptr
                   + data_spec->total_size + ws_frame_size, data, size);
        }
        data_spec->total_size += framed_snippet_size;

        switch(MSK_PURPOSE(snip)) {
        case MSK_PURPOSE_HTTP_HEADER:
            data_spec->ws_hdr_size += framed_snippet_size;
            data_spec->once_size   += framed_snippet_size;
            break;
        case MSK_PURPOSE_FIRST_MSG:
            data_spec->once_size   += framed_snippet_size;
            break;
        case MSK_PURPOSE_MESSAGE:
            data_spec->single_message_size += framed_snippet_size;
            break;
        default:
            assert(!"No recognized snippet purpose");
            return NULL;
        }
    }
    assert(data_spec->total_size <= estimate_size);
    ((char *)data_spec->ptr)[data_spec->total_size] = '\0';

    return data_spec;
}
