#!/usr/bin/python
import sys, os, time, signal, random, traceback, errno
import multiprocessing, multiprocessing.queues
from optparse import OptionParser

sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', 'drivers', 'python')))
import rethinkdb as r

def call_ignore_interrupt(fun):
    while True:
        try:
            return fun()
        except IOError as ex:
            if ex.errno != errno.EINTR:
                raise

def stress_client_proc(options, start_event, exit_event, stat_queue, host_offset, random_seed) :
    host_offset = host_offset % len(options["hosts"])
    host = options["hosts"][host_offset][0]
    port = options["hosts"][host_offset][1]
    ops_per_conn = options["ops_per_conn"]

    random.seed(random_seed)

    # Stagger ops_left so not all clients reconnect at once
    ops_done = int(random.random() * ops_per_conn)

    if ops_per_conn == 0:
        loop_cond = lambda: True
    else:
        loop_cond = lambda: ops_done < ops_per_conn

    runner = QueryThrottler(options, stat_queue)
    stat_queue.put("ready")

    start_event.wait()
    while not exit_event.is_set():
        with r.connect(host, port) as conn:
            while loop_cond():
                if exit_event.is_set():
                    break
                runner.send_query(conn)
                ops_done += 1
            ops_done = 0

def spawn_clients(options, start_event, exit_event, stat_queue):
    num_clients = options["clients"]
    client_procs = [ ]
    host_offset = 0

    random_seed = options["seed"]
    if random_seed is None:
        random_seed = random.random()

    print >> sys.stderr, "Random seed used: %f" % random_seed
    random.seed(random_seed)

    for i in xrange(num_clients):
        client_procs.append(multiprocessing.Process(target=stress_client_proc,
                                                    args=(options,
                                                          start_event,
                                                          exit_event,
                                                          stat_queue,
                                                          host_offset,
                                                          random.random())))
        client_procs[-1].start()
        host_offset += 1

    # Wait for ready responses
    for i in xrange(num_clients):
        response = call_ignore_interrupt(stat_queue.get)
        if response != "ready":
            raise RuntimeError("Unexpected response from client: %s" % str(response))

    return client_procs

def stop_clients(exit_event, child_procs, num_clients, timeout):
    exit_event.set()
    end_time = time.time()

    # Check that all processes have exited, allow some time to shut down
    kill_time = end_time + timeout

    for proc in child_procs:
        while time.time() < kill_time and proc.is_alive():
            time.sleep(0.1)
        if proc.is_alive():
            proc.terminate()
    if time.time() > kill_time:
        print "Timed out waiting for processes to shut down"

    return end_time

# Write stats to stderr, so they don't interfere with parsers
def print_stats(stats, start_time, end_time, num_clients):
    duration = end_time - start_time

    print >> sys.stderr, "Duration: %0.3f seconds" % duration

    # Print stats table
    print >> sys.stderr, ""
    print >> sys.stderr, "Operations data: "

    table = [["total", "per sec", "per sec client avg", "avg latency"]]

    if duration != 0.0:
        per_sec = "%0.3f" % (stats["count"] / duration)
        per_sec_per_client = "%0.3f" % (stats["count"] / duration / num_clients)
    else:
        per_sec = "inf"
        per_sec_per_client = "inf"

    if stats["count"] != 0:
        latency = "%0.6f" % (stats["latency"] / stats["count"])
    else:
        latency = "%0.6f" % (0.0)

    table.append([str(stats["count"]), per_sec, per_sec_per_client, latency])

    column_widths = []
    for i in range(len(table[0])):
        column_widths.append(max([len(row[i]) + 2 for row in table]))

    format_str = ("{:<%d}" + ("{:>%d}" * (len(column_widths) - 1))) % tuple(column_widths)

    rql_time_spent = stats["latency"]
    total_client_time = duration * num_clients
    print >> sys.stderr, "Percent time clients spent in ReQL space: %.2f" % (100 * rql_time_spent / total_client_time)

    for row in table:
        print >> sys.stderr, format_str.format(*row)

    # Print errors
    if len(stats["errors"]) != 0:
        print >> sys.stderr, ""
        print >> sys.stderr, "Errors encountered:"
        for error, count in stats["errors"].iteritems():
            print >> sys.stderr, "%s: %s" % (error, count)

def interrupt_handler(signal, frame, exit_event, parent_pid):
    if os.getpid() == parent_pid:
        exit_event.set()

class QueryThrottler:
    def __init__(self, options, stat_queue):
        self.workload = options["workload"]
        self.stat_queue = stat_queue
        self.secs_per_op = float(options["clients"]) / options["ops_per_sec"]

        # Time at which to send the next query
        self.next_query_time = None

    def send_query(self, conn):
        if self.next_query_time is None:
            # Desync from other clients by waiting a random amount of time less than one op's duration
            time.sleep(random.random() * self.secs_per_op)
            self.next_query_time = time.time()

        # Sync up with our schedule
        now = time.time()
        time_overdue = now - self.next_query_time

        if time_overdue > 10 * self.secs_per_op:
            # Don't allow us to get more than 10 ops behind or we'll overload if/when the system recovers
            self.next_query_time = (10 * self.secs_per_op)
        elif time_overdue >= 0.0:
            self.next_query_time += self.secs_per_op
        else:
            self.next_query_time += self.secs_per_op
            time.sleep(-time_overdue)

        start_time = time.time()
        result = { "timestamp": start_time }
        try:
            result.update(self.workload.run(conn))
        except (r.RqlError, r.RqlDriverError) as ex:
            result["errors"] = [ ex.message ]
        except (IOError, OSError) as ex:
            if ex.errno != errno.EINTR:
                raise
            result["errors"] = [ "Interrupted system call" ]
        result["latency"] = time.time() - start_time

        self.stat_queue.put(result)

# Main loop for the main stress process
def stress_controller(options):
    stat_queue = multiprocessing.queues.SimpleQueue()
    start_event = multiprocessing.Event()
    exit_event = multiprocessing.Event()
    child_procs = [ ]

    # Register interrupt, now that we're spawning client processes
    parent_pid = os.getpid()
    signal.signal(signal.SIGINT, lambda a,b: interrupt_handler(a, b,
                                                               exit_event,
                                                               parent_pid))

    child_procs.extend(spawn_clients(options,
                                     start_event,
                                     exit_event,
                                     stat_queue))

    stats = { "count": 0, "latency": 0.0, "errors": { } }
    start_time = time.time()

    try:
        start_event.set()

        # Collect stats as they come in
        while not exit_event.is_set():
            if not stat_queue.empty():
                stat = call_ignore_interrupt(stat_queue.get)

                # Print new stat
                if not options["quiet"]:
                    print "%0.3f: %0.6f" % (stat["timestamp"], stat["latency"])
                    sys.stdout.flush()

                stats["count"] += 1
                stats["latency"] += stat["latency"]
                for error in stat.get("errors", [ ]):
                    print "  - %s" % error
                    stats["errors"][error] = stats["errors"].get(error, 0) + 1
            else:
                time.sleep(0.1)

    except:
        traceback.print_exc()
    finally:
        # Allow some time for operations to complete
        stop_timeout = max(2.0, 3.0 / options["ops_per_sec"])
        end_time = stop_clients(exit_event, child_procs, options["clients"], stop_timeout)

        while not stat_queue.empty():
            stat = call_ignore_interrupt(stat_queue.get)
            stats["count"] += 1
            stats["latency"] += stat["latency"]
            for error in stat.get("errors", [ ]):
                stats["errors"][error] = stats["errors"].get(error, 0) + 1

        print_stats(stats, start_time, end_time, options["clients"])

if __name__ == "__main__":
    parser = OptionParser()
    parser.add_option("--seed", dest="seed", metavar="FLOAT", default=None, type="float")
    parser.add_option("--table", dest="db_table", metavar="DB.TABLE", default=None, type="string")
    parser.add_option("--ops-per-sec", dest="ops_per_sec", metavar="NUMBER", default=100, type="float")
    parser.add_option("--ops-per-conn", dest="ops_per_conn", metavar="NUMBER", default=0, type="int")
    parser.add_option("--clients", dest="clients", metavar="CLIENTS", default=64, type="int")
    parser.add_option("--workload", "-w", dest="workload", metavar="WORKLOAD", default=None, type="string")
    parser.add_option("--host", dest="hosts", metavar="HOST:PORT", action="append", default=[], type="string")
    parser.add_option("--quiet", dest="quiet", action="store_true", default=False)
    (parsed_options, args) = parser.parse_args()
    options = { "clients": parsed_options.clients,
                "ops_per_sec": parsed_options.ops_per_sec,
                "ops_per_conn": parsed_options.ops_per_conn,
                "quiet": parsed_options.quiet,
                "hosts": [ ] }

    if len(args) != 0:
        print "no positional arguments supported"
        exit(1)

    # Parse out host/port pairs
    for host_port in parsed_options.hosts:
        (host, port) = host_port.split(":")
        options["hosts"].append((host, int(port)))
    if len(options["hosts"]) == 0:
        print "no host specified"
        exit(1)

    if parsed_options.workload is None:
        print "no workload specified"
        exit(1)

    # Get table name, and make sure it exists on the server
    if parsed_options.db_table is None:
        print "no table specified"
        exit(1)

    (options["db"], options["table"]) = parsed_options.db_table.split(".")

    with r.connect(options["hosts"][0][0], options["hosts"][0][1]) as connection:
        if options["db"] not in r.db_list().run(connection):
            r.db_create(options["db"]).run(connection)

        if options["table"] not in r.db(options["db"]).table_list().run(connection):
            r.db(options["db"]).table_create(options["table"]).run(connection)
        else:
            # Using existing table
            pass

    # Parse out workload info
    # Add the stress_workload subdirectory to the import search path
    sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), 'stress_workloads')))
    options["workload"] = __import__(parsed_options.workload).Workload(options)
    options["seed"] = parsed_options.seed

    stress_controller(options)
