#!/usr/bin/python

import yaml
import json
from os import walk
from sys import argv, stdout, exit
import types
import re
import os
from subprocess import call
from test_util import RethinkDBTestServers

# A test script src language
# This class attempts to abstract some of the details that
# separate our source languages.
class SrcLang:
    # Returns the approriate line comment string
    def comment(self):
        return "#"

    # Converts input value into the appropriate string
    # representation for this language.
    def langrepr(self, val):
        return repr(str(val))

    # Translates names from canonical name representation
    # (underscores) to the convention for this language
    def nametranslate(self, name):
        return name

    # Translates dictionary (object) representation from cannonical
    # form (e.g. "{'a':1}") to the appropriate representation for this
    # language
    def dicttranslate(self, dic):
        return dic

    # Translates 'null' to langauge equivalents
    def nulltranslate(self, strng):
        return strng

    # Translate a generic code string using the rules defined by `self`
    def translate_query(self, src):
        return self.dicttranslate(self.nametranslate(self.nulltranslate(src)))

    # Translate an expected value
    def translate_expected(self, src):
        return self.dicttranslate(self.nulltranslate(src))


class PyLang(SrcLang):
    def langrepr(self, val):
        if not isinstance(val, types.UnicodeType):
            return repr(str(val))
        else:
            return repr(unicode(val))

    def nulltranslate(self, strng):
        return re.sub("(?<!\"|')null(?!\"|')", 'None', strng)

class JsLang(SrcLang):
    def comment(self):
        return "//"

    def langrepr(self, val):
        if isinstance(val, types.StringType):
            return repr(val)
        elif isinstance(val, types.UnicodeType):
            return repr(val)[1:] # Get rid of the 'u'
        elif isinstance(val, types.BooleanType):
            return ("true" if val else "false")
        else:
            return repr(str(val))

    # Converts cannonical form (underscore separation) to camel case
    def nametranslate(self, name):
        if not re.search('__', name):
            return re.sub('_[a-z]', lambda m: m.group()[1].upper(), name)
        else:
            return re.sub('__', '_', name)

class RbLang(SrcLang):
    def rb_str_repr(self, string):
        return ('"'+re.sub("\"", "\\\"", re.sub("\\\\", "\\\\\\\\", string))+'"')

    def langrepr(self, val):
        if isinstance(val, types.StringType):
            return self.rb_str_repr(val)
        elif isinstance(val, types.UnicodeType):
            return re.sub("\\$","\\$", repr(val)[1:]) # Get rid of the 'u'
        elif isinstance(val, types.BooleanType):
            return ("true" if val else "false")
        else:
            return repr(str(val))

    def dicttranslate(self, dic):
        # print dic
        dicnew = re.sub("\\\\?'([^']*[^'\\\\])\\\\?':",":\\1=>", dic)
        dicnew = re.sub("\\\\n$", "\n", dicnew)
        # print dicnew
        return dicnew

    def nulltranslate(self, strng):
        return re.sub('null', 'nil', strng)

# Really just used as a namespace here
class Langs:
    langs = {
        'py':PyLang(),
        'js':JsLang(),
        'rb':RbLang()
    }

    @staticmethod
    def comment_for(lang):
        return Langs.langs[lang].comment()

    @staticmethod
    def repr_query_for(lang, val):
        return Langs.translate_query_for(lang, Langs.langs[lang].langrepr(val))

    @staticmethod
    def repr_expected_for(lang, val):
        return Langs.translate_expected_for(lang, Langs.langs[lang].langrepr(val))


    @staticmethod
    def translate_query_for(lang, src):
        return Langs.langs[lang].translate_query(src)

    @staticmethod
    def translate_expected_for(lang, src):
        return Langs.langs[lang].translate_expected(src)



# Abstracts a set of tests given in a single file
class TestGroup:
    def __init__(self, path, group, shard):
        self.path = path
        self.group = group
        self.shard = shard

    def write_group(self, out, lang):

        # First, write header
        out.write(Langs.comment_for(lang)+' '+('/'.join(self.path))+' '+self.group['desc']+'\n\n')

        i = 0
        for test in self.group['tests']:
            i = i + 1
            self.write_test(out, test, lang, str(i))

        out.write('\n')

    def write_test(self, out, test, lang, index):

        test_name = '-'.join(self.path) + '-' + str(index)
        if not re.match(test_filter, test_name):
            return

        # Does this test define a variable?
        if 'def' in test:
            self.write_def(out, test['def'], lang, index)

        # See if this test even defines a test for our language
        (code, runopts) = self.get_code(test, lang, index)
        for test_case in code:
            out.write("test("+test_case)
            expected = self.get_expected(test, lang)
            if not expected:
                expected = '""'
            out.write(', '+expected)
            out.write(', "' + '-'.join(self.path) + '-' + str(index) + '"')

            if runopts:
                runoptsrepr = Langs.translate_expected_for(lang, repr(runopts))
                out.write(', '+runoptsrepr)

            out.write(')\n')

            # We want to auto-shard tables that we create. There is stil no
            # way to do this from rql so we have to hack the admin cli.

            if self.shard > 0:
                pattern = 'table_create\\(\'(\\w+)\'\\)'
                mo = re.search(pattern, test_case)
                if mo:
                    table_name = mo.group(1)
                    out.write('shard("'+table_name+'")\n')

    def write_def(self, out, defobj, lang, index):
        (code, runopts) = self.get_code(defobj, lang, index + '-def')
        if len(code) > 0:
            out.write('define('+code[0]+')\n')

    # Tests may specify generic test strings valid for
    # all languages or language specific versions
    def get_code(self, obj, lang, index):

        runopts = None
        if isinstance(obj, dict):
            if 'runopts' in obj:
                runopts = obj['runopts']

            # Try language specific version first
            if lang in obj:

                lang_specific = obj[lang]

                # lang_specific may be a dict giving a code attribute
                if isinstance(lang_specific, dict):
                    assert 'cd' in lang_specific
                    code = lang_specific['cd']
                else:
                    code = lang_specific

            elif 'cd' in obj:
                    code = obj['cd']

            else:
                code = None

        else:

            # obj itself is the code
            code = obj

        # Code may be a string or a list of syntactic variants
        if code and not isinstance(code, list):
            assert isinstance(code, types.StringTypes) or isinstance(code, types.IntType)
            code = [code]
        elif code == None:
            code = []

        # Construct the appropriate representation for each test in this language
        return (map(lambda src: Langs.repr_query_for(lang, src), code), runopts)

    # Get the expected result from the test object. Either a generic result or
    # a language specific result.
    def get_expected(self, obj, lang):
        if lang in obj and isinstance(obj[lang], dict) and 'ot' in obj[lang]:
            return Langs.repr_expected_for(lang, obj[lang]['ot'])
        elif 'ot' in obj:
            expected = obj['ot']
            if isinstance(expected, dict):
                if expected.has_key(lang):
                    expected = expected[lang]
                else:
                    expected = expected['cd']
            return Langs.repr_expected_for(lang, expected)

        return None


# Build test scripts by walking a tree of yaml files
class TestBuilder:
    def __init__(self, root, shard):
        self.root = root
        self.shard = shard
        self.tests = self.get_test_structure()
        self.build_test_scripts()

    # Walk the test tree and get an object representing all
    # the test data from the yaml files. Each yaml file defines
    # a test group. The test groups themselves are nested by
    # their directory structure.
    def get_test_structure(self):

        tests = []
        for root, dirs, files in walk(self.root):

            # Find the nested group location from the root
            path = root.split('/')[1:]

            # Parse and insert the yaml structures
            for name in filter(lambda name: name.endswith('.yaml'), files):
                test_group = yaml.load(open(root+'/'+name))
                leaf = name.partition('.')[0]
                tests.append(TestGroup(path+[leaf], test_group, self.shard))

        return tests

    def build_test_scripts(self):
        self.build_script('py')
        self.build_script('js')
        self.build_script('rb')

    def build_script(self, lang):
        with open('build/test.'+lang, 'w+') as script:
            with open('drivers/driver.'+lang) as driver:

                # The driver is the head of the file
                script.write(driver.read())

                for group in self.tests:
                    group.write_group(script, lang)

                script.write('the_end()')

def run_tests(lang, servers):
    if lang is 'py':
        interpreter = os.getenv('PYTHON') or 'python'
    elif lang is 'js':
        interpreter = os.getenv('NODE') or 'node'
    elif lang is 'rb':
        interpreter = os.getenv('RUBY') or 'ruby'
    else:
        exit("Unkown language: " + lang)

    print "Running "+lang+" tests"
    os.system(interpreter + " --version")
    print ''
    stdout.flush()
    res = call([interpreter, 'build/test.'+lang, "UNUSED", str(servers.driver_port()), str(servers.cluster_port()), servers.executable()])
    if res != 0:
        exit("Failed %s tests" % lang)
    print "Successfully Passed "+lang+" polyglot tests\n"
    stdout.flush()

# Run all language tests
def run(build, lang=None):
    with RethinkDBTestServers(1, server_build_dir=build) as servers:
        if lang:
            run_tests(lang, servers)
        else:
            # Run all
            run_tests('py', servers)
            servers.restart()
            run_tests('js', servers)
            servers.restart()
            run_tests('rb', servers)

def attach(driver_port, cluster_port, rethinkdb_executable, lang='js'):
    server =  RethinkDBRunningServer(driver_port, cluster_port, rethinkdb_executable)
    run_tests(lang, server)

class RethinkDBRunningServer:
    def __init__(self, driver_port, cluster_port, rethinkdb_executable):
        self.driver_port_ = driver_port
        self.cluster_port_ = cluster_port
        self.rethinkdb_executable_ = rethinkdb_executable

    def driver_port(self):
        return self.driver_port_

    def cluster_port(self):
        return self.cluster_port_

    def executable(self):
        return self.rethinkdb_executable_

def build_test_scripts(shard):
    TestBuilder('src', shard)

test_filter = os.getenv("RQL_TEST") or ".*"

# Call 0-arg function given by shell (e.g. "./test-runner run")
if __name__ == '__main__':
    eval(argv[1]+'(%s)' % ', '.join(argv[2:]))
