#! /usr/bin/env python3
# pylint: disable=wrong-import-position

import argparse
# Deprecated since Python3.5 use math module.
from fractions import gcd
import os
import re
import sys


def parse_args():
    parser = argparse.ArgumentParser(
        description='''Compare benches generated by 'vcsn score'.''',
        epilog='''Install the `colorama` Python module to get colored output.''')
    opt = parser.add_argument
    opt('file', nargs='+',
        type=str, default=None,
        help='''Bench file (from vcsn score) to compare.
        Files whose base name are generated by `git describe`
        (e.g., `v2.2-110-g406fef6`, or `v2.2-110-g406fef6.2`)
        will be annotated by the corresponding `git summary`.
        Directories will be traversed (shallowly: inner directories
        are skipped).''')
    opt('-a', '--all', action='store_true',
        help='Report also benches with no differences')
    opt('-c', '--color', dest='color', action='store',
        default='auto',
        choices=['auto', 'always', 'never'],
        help='Whether to use colors in the output')
    opt('-O', '--only', metavar='RE',
        type=re.compile, default='.*',
        help='Report only benches whose title is matched by RE')
    opt('-t', '--threshold', metavar='PERCENT',
        type=float, default=10,
        help='''Highlight good and bad scores with associated
        threshold.  Defaults to 10%%.''')
    opt('-n', '--lines', type=int, metavar='NUM', action='store',
        help="process only the last NUM files")
    opt('--no-git', action='store_true',
        help="don't try to use git")
    opt('--output-format', type=str,
        choices=['auto', 'csv', 'latex', 'text'],
        default='auto',
        help='''Select the output format.  If `auto`, guess from
        the output file name, or default to `text`.''')
    opt('-o', '--output', metavar='FILE',
        type=argparse.FileType('w'), default='-',
        help='The output file')
    return parser.parse_args()

args = parse_args()

# Decide the output format.
if args.output_format == 'auto':
    ext = args.output.name.split(".")[-1]
    exts = {
        'csv': 'csv',
        'ltx': 'latex',
        'tex': 'latex',
    }
    args.output_format = exts.get(ext, 'text')


# Access to the git repository.
if not args.no_git:
    try:
        import git
        import gitdb
        repo = git.Repo()
    except ImportError:
        import warnings
        warnings.warn('you should install gitpython for Python')
        repo = None
    except git.exc.InvalidGitRepositoryError:
        repo = None
else:
    repo = None


def git_summary(desc):
    'From a git describe string, recover the commit title.'
    m = re.match(r'^v(?:.*)-\d+-g([\da-f]+)(?:\.\d+)?(?:-dirty)?$', desc)
    if repo:
        try:
            return repo.commit(m.group(1)).summary
        except gitdb.exc.BadName:
            return 'error'
        except AttributeError:
            return '???'
    else:
        return ''


class Format:

    def header(self, _):
        return ''

    def test(self, test):
        '''The test's information.'''
        return test

    def score(self, score, avg):  # pylint: disable=unused-argument
        '''Print a score, in color.'''
        return score

    def color(self, score, avg):
        # Given a score and the average score, the color to use.
        res = 'black'
        if score <= avg - args.threshold / 100 * avg:
            res = 'green'
        elif args.threshold / 100 * avg + avg <= score:
            res = 'red'
        return res

    def line(self, test, repetitions, avg, scores):
        '''Print a whole benchmark line.'''
        print('{}, {}, {}'.format(self.test(test),
                                  repetitions,
                                  ', '.join([self.score(s, avg) for s in scores])),
              file=args.output)

    def footer(self, _):
        return ''


class Csv(Format):

    def header(self, files):
        return 'Command,Setup,Repetitions,'\
            + ','.join([os.path.basename(file) for file in files]) + '\n'

    def score(self, score, avg):  # pylint: disable=unused-argument
        '''Print a score, in color.'''
        if score in ['N/A', 'FAIL']:
            return score
        else:
            return '{:.2f}'.format(score)

    def test(self, test):
        return "\"" + get_command(test).replace("\"", "\"\"")\
               + '\",\"' + get_setup(test).replace("\"", "\"\"") + "\""


class Latex(Format):

    def escape(self, text):
        # Escape special latex characters.
        return text.replace('->', r'\rightarrow')\
                   .replace('^', r'\hat{}')\
                   .replace('}', r'\}')\
                   .replace('{', r'\{')\
                   .replace('&', r'\&')\
                   .replace('_', r'\_')\
                   .replace('\\e', r'\backslash e')

    def header(self, _):
        title_rule = (r'\newcommand{\titlerot}[1]'
                      r'{\multicolumn{1}{c}{\rlap{\rotatebox{60}{#1}~}}}')
        tabular = r'\begin{{tabular}}{{llr *{{{}}}{{r}}}}'.format(len(args.file))
        filenames = ''.join([r'& \titlerot{{{}}} '.format(os.path.basename(file))
                             for file in args.file])
        titles = r'Command & Setup & \titles{Repetitions} ' + filenames + r'\\ \midrule'
        return '\n'.join([title_rule, tabular, titles]) + '\n'

    def test(self, test):
        command = '$' + self.escape(get_command(test)) + '$ & '
        setup = '$' + self.escape(get_setup(test)) + '$'
        return command + setup

    def score(self, score, avg):
        if score in ['N/A', 'FAIL']:
            return '{:>5}'.format(score)
        else:
            c = self.color(score, avg)
            if c == 'black':
                return r'\num{{{:5.2f}}}'.format(score)
            else:
                return (r'\textcolor{{{}}}{{\num{{{:5.2f}}}}}'
                        .format(c, score))

    def line(self, test, repetitions, avg, scores):
        '''Print a whole benchmark line.'''
        print(r'{} & {} & {} \\'.format(self.test(test),
                                        repetitions,
                                        ' & '.join([self.score(s, avg) for s in scores])),
              file=args.output)

    def footer(self, _):
        return '\\bottomrule\n\\end{tabular}\n'


class Text(Format):

    def header(self, files):
        return ' '.join(["{:^5}".format(i) for i in range(len(files))]).rstrip() + '\n'

    def score(self, score, avg):
        if score in ['N/A', 'FAIL']:
            return '{}{:>5}{}'.format(color['blue'], score, color['std'])
        else:
            return '{}{:5.2f}{}'.format(color[self.color(score, avg)],
                                        score,
                                        color['std'])

    def line(self, test, repetitions, avg, scores):
        '''Print a whole benchmark line.'''
        r = ', {}x'.format(repetitions) if repetitions != 1 else ''
        print('{} {}{}'.format(' '.join([self.score(s, avg) for s in scores]),
                               self.test(test),
                               r),
              file=args.output)

    def footer(self, files):
        res = self.header(files)
        for i, file in enumerate(files):
            f = os.path.basename(file)
            s = git_summary(f)
            res += "{:3}. {} {}".format(i, f, s).rstrip() + '\n'
        return res

# A printer, from the output format.
if args.output_format == 'csv':
    printer = Csv()
elif args.output_format == 'latex':
    printer = Latex()
else:
    printer = Text()


# Colors support.
color = {'green': '', 'red': '', 'std': '', 'black': '', 'blue': ''}
if args.color == 'always' or args.color == 'auto' and sys.stdout.isatty():
    try:
        from colorama import Fore, Style
        color['blue'] = Fore.BLUE
        color['green'] = Fore.GREEN + Style.BRIGHT
        color['red'] = Fore.RED + Style.BRIGHT
        color['std'] = Style.RESET_ALL
    except ImportError:
        import warnings
        warnings.warn('you should install colorama for Python')

bench = dict()
# bench-id => number of iterations (e.g., `20` for 20x).
number = {}
benc_csv = dict()


def lcm(numbers):
    res = 1
    for num in numbers:
        res = (num * res) // gcd(num, res)  # pylint: disable=deprecated-method
    return res


def normalize(k):
    '''Fix a bench fix, i.e., fix errors, update APIs etc.
    '''
    # Separate with ' # ' only.
    k = ' # '.join([str.strip(x) for x in k.split('#', 2)])

    # The right symbol for repeated &.
    k = re.sub(r'a\*\*(\d+) ', r'a & \1', k)
    # Fix: spello.
    k = re.sub(r'de_buijn', 'de_bruijn', k)
    # Fix: extraneous paren.
    k = re.sub(r'ladybird\(21\)\)', 'ladybird(21)', k)
    # Fix: Incorrect use of .format.
    k = re.sub(r'(a.(?:product|shuffle)\(a\) # a = std\(\{\}\).format\(r\))',
               lambda m: m.group(1).replace('{}', '[a-e]?{50})'),
               k)
    # Fix: now use 's' to denote a string, instead of 'a'.
    k = re.sub(r'read\(a\) # a =',
               r'read(s) # s =', k)
    # Now we display the number of repetitions.
    k = re.sub(r'(# a = de_bruijn\(150\))$', r'\1, 1000x', k)
    k = re.sub(r'(# e = "\(\\e\+a\)" \* 500)$', r'\1, 100x', k)
    k = re.sub(
        r'(# r = b\.expression\("\(\\e\+a\)" \* 500\))$', r'\1, 1000x', k)
    # Now, instead of "   on [a-z]  -> Z", ", c = [a-z] -> Z".
    k = re.sub(r' +on (\[.*?\][?*]?) *-> *([BQZ])',
               r', c = \1 -> \2', k)
    k = re.sub(r'a = lal\(a-zA-Z0-9\).ladybird\(18\)',
               r'a = ladybird(18), c = [a-zA-Z0-9] -> B', k)
    # We never worked on Q in score, it was a typo.  And working with
    # B is good enough anyway and more relevant.
    k = re.sub(r'(determinize.*de_bruijn\(\d+\)), c = \[abc\] -> [BQ]',
               r'\1', k)

    # derived_term.
    k = k.replace('derived_term()', 'derived_term("derivation")')
    k = k.replace('linear()', 'derived_term("expansion")')

    # For a while we displayed 'a.sort() # a = std([a-e]?{600})' but
    # were actually running 'a.shortest(5)'.
    k = re.sub(r'a.sort\(\) (# a = std\(\[a-e\]\?\{600\}\))',
               r'a.shortest(5) \1', k)
    # and we were not reporting the context, although it's not B.
    k = re.sub(r'(a\.shortest\(5\) # a = std\(\[a-e\]\?\{600\}\))$',
               r'\1, c = [a-e] -> Z',
               k)

    # The syntax of contexts has changed.
    k = re.sub(r'lal_char\(abc\)(_|, )b', '[abc] -> B', k)

    k = k.replace('ratexp', 'expression')

    k = k.replace('a.num_sccs', 'a.scc')

    k = k.replace('a.accessible ', 'a.accessible() ')

    # Now we display the context.
    k = re.sub(r'(a.minimize\("(moore|signature)"\) # a = std\(.*?\))$',
               r'\1, c = [a-k] -> B',
               k)

    k = k.replace('product', 'conjunction')

    k = k.replace('a.expression()', 'a.expression("associative")')
    k = re.sub(r'(a\.expression\("\w+")\)',
               r'\1, "naive")',
               k)

    # has_twins_property is benched on an expression using the
    # associative identities.
    k = re.sub(r'(a.has_twins_property.* # a = std\([^,]*?)\)',
               r'\1, "associative")', k)

    # has_twins_property was run on Zmin, but with Q displayed.
    k = re.sub(r'(a.has_twins_property.* # .*)Q,',
               r'\1Zmin,', k)

    # is_ambiguous and is_cycle_ambiguous run on Z, but with B
    # displayed.
    k = re.sub(r'(a.is_(cycle_)?ambiguous.* # .*)B,',
               r'\1Z,', k)

    # ZMIN -> Zmin.
    k = re.sub(r'([NRZ])MIN', r'\1min', k)

    # Nicer notation for tuples.
    k = k.replace("'(a, x)'{2000}'(b, y)'", "(a|x){2000}(b|y)")

    # Useless parens.
    k = k.replace("(['(a,x)'-'(b,y)']*){600}", "['(a,x)'-'(b,y)']*{600}")
    k = k.replace("(['(a,x)'-'(b,y)']{1000})*", "['(a,x)'-'(b,y)']{1000}*")

    # Use of `;` instead of `,`.
    k = re.sub(r'(a\.compose\(a2\).*);', r'\1,', k)

    # It makes more sense to run these algos on Nmin, because we can
    # use the fastest implementations without having to check for
    # preconditions.  In the past, we used the fastest implementations
    # blindly.
    k = re.sub(r'(lightest_automaton.* ->) Zmin', r'\1 Nmin', k)

    # For a while, we thought we were working on a different
    # expression, but it was reusing the previous one.  Easy to see
    # since the context is wrong: it does not go up to `z`.  And use
    # Nmin now.
    k = k.replace('a.lightest() # a = std([a-z]?{300}), c = [a-e] -> Z',
                  'a.lightest() # a = std([a-e]?{150}), c = [a-e] -> Nmin')

    # Now we display the context.
    k = re.sub(r'thompson(\([^\)])', r'thm\1', k)
    k = re.sub(r"std\(\['\(a,a\)'-'\(i,z\)'\]\{4\}\)$",
               "std(['(a,a)'-'(i,z)']{4}), c = [a-z]x[a-z] -> B",
               k)
    k = re.sub(r"thm\(\['\(a,a\)'-'\(i,z\)'\]\{4\}\)$",
               "thm(['(a,a)'-'(i,z)']{4}), c = [a-z]?x[a-z]? -> B",
               k)

    # a.lightest_automaton("a-star") -> a.lightest_automaton(1, "a-star")
    k = re.sub(r'(a.lightest_automaton)\("', r'\1(1, ', k)
    k = k.replace('a.lightest(5)', 'a.lightest(5, "auto")')

    k = k.replace('infiltration', 'infiltrate')

    k = k.replace('proper(False)', 'proper(prune=False)')

    k = re.sub(r'\beval\b', 'evaluate', k)

    # Final format.
    k = '{:20s} # {}'.format(*[str.strip(x) for x in k.split('#', 2)])
    return k


def read_file(fn):
    '''Read one `vcsn score` generated file named `fn`.  Store in `bench`.

    Each line looks like:

        0.12s: a.is_proper()        # a = "", 200000x

    So split in `v` (0.12s) and `k` for the rest, normalized.
    '''
    with open(fn) as f:
        for line in f:
            # Skip empty lines and comments.
            if not line or line.startswith('#'):
                continue
            v, k = [str.strip(x) for x in line.split(':', 1)]
            # Get rid of "s", we know the unit.  And make it a float.
            if v[-1] == 's':
                v = float(v[:-1])
            # Fix errors in algo descriptions.
            k = normalize(k)
            # Number of iterations of the test.
            num = re.search(', ([0-9]+)x', k)
            num = int(num.group(1)) if num else 1
            k = re.sub(', ([0-9]+)x', '', k)
            if k not in bench:
                bench[k] = dict()
            bench[k][fn] = {'value': v, 'num': num}


def read_files(files):
    # Read the score files.
    for fn in files:
        read_file(fn)
    # Normalize the number of iterations: if some test was run 2x
    # and then 5x, display both results in 10x.
    for k in bench:
        # For each bench, the number of times it was run.
        nums = [bench[k][f]['num']
                for f in bench[k] if bench[k][f]['value'] not in ['N/A', 'FAIL']]
        number[k] = lcm(nums)
        # Normalize each bench.
        for f in bench[k]:
            if bench[k][f]['value'] not in ['N/A', 'FAIL']:
                bench[k][f]['value'] *= number[k] // bench[k][f]['num']
                bench[k][f]['num'] = number[k]


def repetitions(nb):
    # The number of times the test was executed in order to give these scores.
    if args.output_format == 'csv':
        return str(nb)
    elif args.output_format == 'latex':
        return r'\num{' + str(nb) + '}'
    else:
        return '' if nb == 1 else ', {}x'.format(nb)


def get_command(test):
    # The command executed in the test (ex: a.compose(b)).
    return test.split('#')[0].rstrip()


def get_setup(test):
    # The setup of the test, object initialization (ex: a = std([a-z]*)).
    return test.split('#')[1][1:].lstrip()


def text(keys):
    # For each bench-case, compare all the recorded scores.
    args.output.write(printer.header(args.file))
    for k in keys:
        # All the benches.
        bs = [bench[k][f] if f in bench[k] else {'value': 'N/A', 'num': 0}
              for f in args.file]
        # All the valid benches.
        bfs = [b['value']
               for b in bs if b['value'] not in ['N/A', 'FAIL']]

        # If all the same (not even N/A or FAIL), there is nothing to
        # say.
        if not args.all and len(set(bfs)) == 1:
            continue

        # Compute average of benchs.
        avg = sum(bfs) / len(bfs) if bfs else 0
        scores = []
        for b in bs:
            scores.append(b['value'])
        printer.line(k, number[k], avg, scores)
    args.output.write(printer.footer(args.file))

# Main.
# If some arguments are directories, then read the files in it.
files = []
for f in args.file:
    if os.path.isdir(f):
        fs = [os.path.join(f, e.name) for e in os.scandir(f)
              if not e.name.startswith('.') and e.is_file()]
        fs.sort()
        files.extend(fs)
    else:
        files.append(f)
if args.lines:
    files = files[-args.lines:]
args.file = files
read_files(args.file)

# The keys we are interested in.
keys = [k for k in sorted(bench.keys()) if args.only.search(k)]

# Print the score table.
text(keys)
