#!/usr/bin/env python3
# pylint: disable=C0103,C0114,C0115,C0116,C0123,C0301,R0902,R0913,R0914,R0912,R0915,W0621
######################################################################

import argparse
import glob
import re
import sys
# from pprint import pprint, pformat

Types = []
Classes = {}
Children = {}
ClassRefs = {}
Stages = {}


class Cpt:
    def __init__(self):
        self.did_out_tree = False
        self.in_filename = ""
        self.in_linenum = 1
        self.out_filename = ""
        self.out_linenum = 1
        self.out_lines = []
        self.tree_skip_visit = {}
        self.treeop = {}
        self._exec_nsyms = 0
        self._exec_syms = {}

    def error(self, txt):
        sys.exit("%%Error: %s:%d: %s" %
                 (self.in_filename, self.in_linenum, txt))

    def print(self, txt):
        self.out_lines.append(txt)

    def output_func(self, func):
        self.out_lines.append(func)

    def _output_line(self):
        self.print("#line " + str(self.out_linenum + 2) + " \"" +
                   self.out_filename + "\"\n")

    def process(self, in_filename, out_filename):
        self.in_filename = in_filename
        self.out_filename = out_filename
        ln = 0
        didln = False

        # Read the file and parse into list of functions that generate output
        with open(self.in_filename) as fhi:
            for line in fhi:
                ln += 1
                if not didln:
                    self.print("#line " + str(ln) + " \"" + self.in_filename +
                               "\"\n")
                    didln = True
                match = re.match(r'^\s+(TREE.*)$', line)
                if match:
                    func = match.group(1)
                    self.in_linenum = ln
                    self.print("//" + line)
                    self.output_func(lambda self: self._output_line())
                    self.tree_line(func)
                    didln = False
                elif not re.match(r'^\s*/[/\*]\s*TREE', line) and re.search(
                        r'\s+TREE', line):
                    self.error("Unknown astgen line: " + line)
                else:
                    self.print(line)

        # Put out the resultant file, if the list has a reference to a
        # function, then call that func to generate output
        with open_file(self.out_filename) as fho:
            togen = self.out_lines
            for line in togen:
                if type(line) is str:
                    self.out_lines = [line]
                else:
                    self.out_lines = []
                    line(self)  # lambda call
                for out in self.out_lines:
                    for _ in re.findall(r'\n', out):
                        self.out_linenum += 1
                    fho.write(out)

    def tree_line(self, func):
        func = re.sub(r'\s*//.*$', '', func)
        func = re.sub(r'\s*;\s*$', '', func)

        # doflag "S" indicates an op specifying short-circuiting for a type.
        match = re.search(
            #       1   2                 3                  4
            r'TREEOP(1?)([ACSV]?)\s*\(\s*\"([^\"]*)\"\s*,\s*\"([^\"]*)\"\s*\)',
            func)
        match_skip = re.search(r'TREE_SKIP_VISIT\s*\(\s*\"([^\"]*)\"\s*\)',
                               func)

        if match:
            order = match.group(1)
            doflag = match.group(2)
            fromn = match.group(3)
            to = match.group(4)
            # self.print("// $fromn $to\n")
            if not self.did_out_tree:
                self.did_out_tree = True
                self.output_func(lambda self: self.tree_match_base())
            match = re.search(r'Ast([a-zA-Z0-9]+)\s*\{(.*)\}\s*$', fromn)
            if not match:
                self.error("Can't parse from function: " + func)
            typen = match.group(1)
            subnodes = match.group(2)
            if not subclasses_of(typen):
                self.error("Unknown AstNode typen: " + typen + ": in " + func)

            mif = ""
            if doflag == '':
                mif = "m_doNConst"
            elif doflag == 'A':
                mif = ""
            elif doflag == 'C':
                mif = "m_doCpp"
            elif doflag == 'S':
                mif = "m_doNConst"  # Not just for m_doGenerate
            elif doflag == 'V':
                mif = "m_doV"
            else:
                self.error("Unknown flag: " + doflag)
            subnodes = re.sub(r',,', '__ESCAPEDCOMMA__', subnodes)
            for subnode in re.split(r'\s*,\s*', subnodes):
                subnode = re.sub(r'__ESCAPEDCOMMA__', ',', subnode)
                if re.match(r'^\$([a-zA-Z0-9]+)$', subnode):
                    continue  # "$lhs" is just a comment that this op has a lhs
                subnodeif = subnode
                subnodeif = re.sub(
                    r'\$([a-zA-Z0-9]+)\.cast([A-Z][A-Za-z0-9]+)$',
                    r'VN_IS(nodep->\1(),\2)', subnodeif)
                subnodeif = re.sub(r'\$([a-zA-Z0-9]+)\.([a-zA-Z0-9]+)$',
                                   r'nodep->\1()->\2()', subnodeif)
                subnodeif = self.add_nodep(subnodeif)
                if mif != "" and subnodeif != "":
                    mif += " && "
                mif += subnodeif

            exec_func = self.treeop_exec_func(to)
            exec_func = re.sub(
                r'([-()a-zA-Z0-9_>]+)->cast([A-Z][A-Za-z0-9]+)\(\)',
                r'VN_CAST(\1,\2)', exec_func)

            if typen not in self.treeop:
                self.treeop[typen] = []
            n = len(self.treeop[typen])
            typefunc = {
                'order': order,
                'comment': func,
                'match_func': "match_" + typen + "_" + str(n),
                'match_if': mif,
                'exec_func': exec_func,
                'uinfo': re.sub(r'[ \t\"\{\}]+', ' ', func),
                'uinfo_level': (0 if re.match(r'^!', to) else 7),
                'short_circuit': (doflag == 'S'),
            }
            self.treeop[typen].append(typefunc)

        elif match_skip:
            typen = match_skip.group(1)
            self.tree_skip_visit[typen] = 1
            if typen not in Classes:
                self.error("Unknown node type: " + typen)

        else:
            self.error("Unknown astgen op: " + func)

    @staticmethod
    def add_nodep(strg):
        strg = re.sub(r'\$([a-zA-Z0-9]+)', r'nodep->\1()', strg)
        return strg

    def _exec_syms_recurse(self, aref):
        for sym in aref:
            if type(sym) is list:
                self._exec_syms_recurse(sym)
            elif re.search(r'^\$.*', sym):
                if sym not in self._exec_syms:
                    self._exec_nsyms += 1
                    self._exec_syms[sym] = "arg" + str(self._exec_nsyms) + "p"

    def _exec_new_recurse(self, aref):
        out = "new " + aref[0] + "(nodep->fileline()"
        first = True
        for sym in aref:
            if first:
                first = False
                continue
            out += ", "
            if type(sym) is list:
                out += self._exec_new_recurse(sym)
            elif re.match(r'^\$.*', sym):
                out += self._exec_syms[sym]
            else:
                out += sym
        return out + ")"

    def treeop_exec_func(self, func):
        out = ""
        func = re.sub(r'^!', '', func)

        if re.match(r'^\s*[a-zA-Z0-9]+\s*\(', func):  # Function call
            outl = re.sub(r'\$([a-zA-Z0-9]+)', r'nodep->\1()', func)
            out += outl + ";"
        elif re.match(r'^\s*Ast([a-zA-Z0-9]+)\s*\{\s*(.*)\s*\}$', func):
            aref = None
            # Recursive array with structure to form
            astack = []
            forming = ""
            argtext = func + "\000"  # EOF character
            for tok in argtext:
                if tok == "\000":
                    None  # pylint: disable=pointless-statement
                elif re.match(r'\s+', tok):
                    None  # pylint: disable=pointless-statement
                elif tok == "{":
                    newref = [forming]
                    if not aref:
                        aref = []
                    aref.append(newref)
                    astack.append(aref)
                    aref = newref
                    forming = ""
                elif tok == "}":
                    if forming:
                        aref.append(forming)
                    if len(astack) == 0:
                        self.error("Too many } in execution function: " + func)
                    aref = astack.pop()
                    forming = ""
                elif tok == ",":
                    if forming:
                        aref.append(forming)
                    forming = ""
                else:
                    forming += tok
            if not (aref and len(aref) == 1):
                self.error("Badly formed execution function: " + func)
            aref = aref[0]

            # Assign numbers to each $ symbol
            self._exec_syms = {}
            self._exec_nsyms = 0
            self._exec_syms_recurse(aref)

            for sym in sorted(self._exec_syms.keys(),
                              key=lambda val: self._exec_syms[val]):
                argnp = self._exec_syms[sym]
                arg = self.add_nodep(sym)
                out += "AstNode* " + argnp + " = " + arg + "->unlinkFrBack();\n"

            out += "AstNode* newp = " + self._exec_new_recurse(aref) + ";\n"
            out += "nodep->replaceWith(newp);"
            out += "VL_DO_DANGLING(nodep->deleteTree(), nodep);"
        elif func == "NEVER":
            out += "nodep->v3fatalSrc(\"Executing transform that was NEVERed\");"
        elif func == "DONE":
            None  # pylint: disable=pointless-statement
        else:
            self.error("Unknown execution function format: " + func + "\n")
        return out

    def tree_match_base(self):
        self.tree_match()
        self.tree_base()

    def tree_match(self):
        self.print(
            "    // TREEOP functions, each return true if they matched & transformed\n"
        )
        for base in sorted(self.treeop.keys()):
            for typefunc in self.treeop[base]:
                self.print("    // Generated by astgen\n")
                self.print("    bool " + typefunc['match_func'] + "(Ast" +
                           base + "* nodep) {\n")
                self.print("\t// " + typefunc['comment'] + "\n")
                self.print("\tif (" + typefunc['match_if'] + ") {\n")
                self.print("\t    UINFO(" + str(typefunc['uinfo_level']) +
                           ",cvtToHex(nodep)" + "<<\" " + typefunc['uinfo'] +
                           "\\n\");\n")
                self.print("\t    " + typefunc['exec_func'] + "\n")
                self.print("\t    return true;\n")
                self.print("\t}\n")
                self.print("\treturn false;\n")
                self.print("    }\n", )

    def tree_base(self):
        self.print("    // TREEOP visitors, call each base type's match\n")
        self.print(
            "    // Bottom class up, as more simple transforms are generally better\n"
        )
        for typen in sorted(Classes.keys()):
            out_for_type_sc = []
            out_for_type = []
            bases = subclasses_of(typen)
            bases.append(typen)
            for base in bases:
                if base not in self.treeop:
                    continue
                for typefunc in self.treeop[base]:
                    lines = [
                        "        if (" + typefunc['match_func'] +
                        "(nodep)) return;\n"
                    ]
                    if typefunc['short_circuit']:  # short-circuit match fn
                        out_for_type_sc.extend(lines)
                    else:  # Standard match fn
                        if typefunc[
                                'order']:  # TREEOP1's go in front of others
                            out_for_type = lines + out_for_type
                        else:
                            out_for_type.extend(lines)

            # We need to deal with two cases. For short circuited functions we
            # evaluate the LHS, then apply the short-circuit matches, then
            # evaluate the RHS and possibly THS (ternary operators may
            # short-circuit) and apply all the other matches.

            # For types without short-circuits, we just use iterateChildren, which
            # saves one comparison.
            if len(out_for_type_sc) > 0:  # Short-circuited types
                self.print(
                    "    // Generated by astgen with short-circuiting\n" +
                    "    virtual void visit(Ast" + typen +
                    "* nodep) override {\n" +
                    "      iterateAndNextNull(nodep->lhsp());\n" +
                    "".join(out_for_type_sc))
                if out_for_type[0]:
                    self.print(
                        "      iterateAndNextNull(nodep->rhsp());\n" +
                        "      AstNodeTriop *tnp = VN_CAST(nodep, NodeTriop);\n"
                        +
                        "      if (tnp && tnp->thsp()) iterateAndNextNull(tnp->thsp());\n"
                        + "".join(out_for_type) + "    }\n")
            elif len(out_for_type) > 0:  # Other types with something to print
                skip = typen in self.tree_skip_visit
                gen = "Gen" if skip else ""
                override = "" if skip else " override"
                self.print(
                    "    // Generated by astgen\n" + "    virtual void visit" +
                    gen + "(Ast" + typen + "* nodep)" + override + " {\n" +
                    ("" if skip else "        iterateChildren(nodep);\n") +
                    ''.join(out_for_type) + "    }\n")


######################################################################
######################################################################


def read_types(filename):
    with open(filename) as fh:
        for line in fh:
            line = re.sub(r'//.*$', '', line)
            if re.match(r'^\s*$', line):
                continue
            match = re.search(r'^\s*(class|struct)\s*(\S+)', line)
            if match:
                classn = match.group(2)
                inh = ""
                match = re.search(r':\s*public\s+(\S+)', line)
                if match:
                    inh = match.group(1)
                # print("class "+classn+" : "+inh)
                if classn == "AstNode":
                    inh = ""
                if re.search(r'Ast', inh) or classn == "AstNode":
                    classn = re.sub(r'^Ast', '', classn)
                    inh = re.sub(r'^Ast', '', inh)
                    Classes[classn] = inh
                    if inh != '':
                        if inh not in Children:
                            Children[inh] = {}
                        Children[inh][classn] = 1


def read_stages(filename):
    with open(filename) as fh:
        n = 100
        for line in fh:
            line = re.sub(r'//.*$', '', line)
            if re.match(r'^\s*$', line):
                continue
            match = re.match(r'^\s*([A-Za-z0-9]+)::', line)
            if match:
                stage = match.group(1) + ".cpp"
                if stage not in Stages:
                    Stages[stage] = n
                    n += 1


def read_refs(filename):
    basename = re.sub(r'.*/', '', filename)
    with open(filename) as fh:
        for line in fh:
            line = re.sub(r'//.*$', '', line)
            for match in re.finditer(r'\bnew\s*(Ast[A-Za-z0-9_]+)', line):
                ref = match.group(1)
                if ref not in ClassRefs:
                    ClassRefs[ref] = {'newed': {}, 'used': {}}
                ClassRefs[ref]['newed'][basename] = 1
            for match in re.finditer(r'\b(Ast[A-Za-z0-9_]+)', line):
                ref = match.group(1)
                if ref not in ClassRefs:
                    ClassRefs[ref] = {'newed': {}, 'used': {}}
                ClassRefs[ref]['used'][basename] = 1


def open_file(filename):
    fh = open(filename, "w")
    if re.search(r'\.txt$', filename):
        fh.write("// Generated by astgen\n")
    else:
        fh.write(
            '// Generated by astgen // -*- mode: C++; c-file-style: "cc-mode" -*-'
            + "\n")
    return fh


def subclasses_of(typen):
    cllist = []
    subclass = Classes[typen]
    while True:
        if subclass not in Classes:
            break
        cllist.append(subclass)
        subclass = Classes[subclass]

    cllist.reverse()
    return cllist


def children_of(typen):
    cllist = []
    todo = []
    todo.append(typen)
    while len(todo) != 0:
        subclass = todo.pop(0)
        if subclass in Children:
            for child in sorted(Children[subclass].keys()):
                todo.append(child)
                cllist.append(child)

    return cllist


# ---------------------------------------------------------------------


def write_report(filename):
    with open_file(filename) as fh:

        fh.write(
            "Processing stages (approximate, based on order in Verilator.cpp):\n"
        )
        for classn in sorted(Stages.keys(), key=lambda val: Stages[val]):
            fh.write("  " + classn + "\n")

        fh.write("\nClasses:\n")
        for typen in sorted(Classes.keys()):
            fh.write("  class Ast%-17s\n" % typen)
            fh.write("    parent: ")
            for subclass in subclasses_of(typen):
                if subclass != 'Node':
                    fh.write("Ast%-12s " % subclass)
            fh.write("\n")
            fh.write("    childs:  ")
            for subclass in children_of(typen):
                if subclass != 'Node':
                    fh.write("Ast%-12s " % subclass)
            fh.write("\n")
            if ("Ast" + typen) in ClassRefs:  # pylint: disable=superfluous-parens
                refs = ClassRefs["Ast" + typen]
                fh.write("    newed:  ")
                for stage in sorted(refs['newed'].keys(),
                                    key=lambda val: Stages[val]
                                    if (val in Stages) else -1):
                    fh.write(stage + "  ")
                fh.write("\n")
                fh.write("    used:   ")
                for stage in sorted(refs['used'].keys(),
                                    key=lambda val: Stages[val]
                                    if (val in Stages) else -1):
                    fh.write(stage + "  ")
                fh.write("\n")
            fh.write("\n")


def write_classes(filename):
    with open_file(filename) as fh:
        fh.write("class AstNode;\n")
        for typen in sorted(Classes.keys()):
            fh.write("class Ast%-17s // " % (typen + ";"))
            for subclass in subclasses_of(typen):
                fh.write("Ast%-12s " % subclass)
            fh.write("\n")


def write_visitor(filename):
    with open_file(filename) as fh:
        for typen in sorted(Classes.keys()):
            if typen == "Node":
                fh.write("    virtual void visit(Ast" + typen + "*) = 0;\n")
            else:
                base = Classes[typen]
                fh.write("    virtual void visit(Ast" + typen +
                         "* nodep) { visit((Ast" + base + "*)(nodep)); }\n")


def write_impl(filename):
    with open_file(filename) as fh:
        fh.write("\n")
        fh.write("    // These for use by VN_IS only\n")
        for typen in sorted(Classes.keys()):
            fh.write("template<> inline bool AstNode::privateIs<Ast" + typen +
                     ">(const AstNode* nodep) { ")
            if typen == "Node":
                fh.write("return nodep != NULL; ")
            else:
                fh.write("return nodep && ")
                if re.search(r'^Node', typen):
                    fh.write(
                        "(static_cast<int>(nodep->type()) >= static_cast<int>(AstType::first"
                        + typen + ")) && ")
                    fh.write(
                        "(static_cast<int>(nodep->type()) <= static_cast<int>(AstType::last"
                        + typen + ")); ")
                else:
                    fh.write(
                        "(static_cast<int>(nodep->type()) == static_cast<int>(AstType::at"
                        + typen + ")); ")
            fh.write("}\n")

        fh.write("    // These for use by VN_CAST macro only\n")
        for typen in sorted(Classes.keys()):
            fh.write("template<> inline Ast" + typen +
                     "* AstNode::privateCast<Ast" + typen +
                     ">(AstNode* nodep) { ")
            if typen == "Node":
                fh.write("return nodep; ")
            else:
                fh.write("return AstNode::privateIs<Ast" + typen +
                         ">(nodep) ? ")
                fh.write("reinterpret_cast<Ast" + typen + "*>(nodep) : NULL; ")
            fh.write("}\n")

        fh.write("    // These for use by VN_CAST_CONST macro only\n")
        for typen in sorted(Classes.keys()):
            fh.write("template<> inline const Ast" + typen +
                     "* AstNode::privateConstCast<Ast" + typen +
                     ">(const AstNode* nodep) { ")
            if typen == "Node":
                fh.write("return nodep; ")
            else:
                fh.write("return AstNode::privateIs<Ast" + typen +
                         ">(nodep) ? ")
                fh.write("reinterpret_cast<const Ast" + typen +
                         "*>(nodep) : NULL; ")
            fh.write("}\n")


def write_type_enum(fh, typen, idx, processed, kind, indent):
    # Skip this if it has already been processed
    if typen in processed:
        return idx
    # Mark processed
    processed[typen] = 1

    # The last used index
    last = None

    if not re.match(r'^Node', typen):
        last = idx
        if kind == "concrete-enum":
            fh.write(" " * (indent * 4) + "at" + typen + " = " + str(idx) +
                     ",\n")
        elif kind == "concrete-ascii":
            fh.write(" " * (indent * 4) + "\"" + typen.upper() + "\",\n")
        idx += 1
    elif kind == "abstract-enum":
        fh.write(" " * (indent * 4) + "first" + typen + " = " + str(idx) +
                 ",\n")

    if typen in Children:
        for child in sorted(Children[typen].keys()):
            (idx, last) = write_type_enum(fh, child, idx, processed, kind,
                                          indent)

    if re.match(r'^Node', typen) and kind == "abstract-enum":
        fh.write(" " * (indent * 4) + "last" + typen + " = " + str(last) +
                 ",\n")

    return [idx, last]


def write_types(filename):
    with open_file(filename) as fh:
        fh.write("    enum en : uint16_t {\n")
        (final, ignored) = write_type_enum(  # pylint: disable=W0612
            fh, "Node", 0, {}, "concrete-enum", 2)
        fh.write("        _ENUM_END = " + str(final) + "\n")
        fh.write("    };\n")

        fh.write("    enum bounds : uint16_t {\n")
        write_type_enum(fh, "Node", 0, {}, "abstract-enum", 2)
        fh.write("        _BOUNDS_END\n")
        fh.write("    };\n")

        fh.write("    const char* ascii() const {\n")
        fh.write("        static const char* const names[_ENUM_END + 1] = {\n")
        write_type_enum(fh, "Node", 0, {}, "concrete-ascii", 3)
        fh.write("            \"_ENUM_END\"\n")
        fh.write("        };\n")
        fh.write("        return names[m_e];\n")
        fh.write("    }\n")


def write_header(filename):
    with open_file(filename) as fh:
        typen = "None"
        base = "None"

        in_filename = "V3AstNodes.h"
        ifile = Args.I + "/" + in_filename
        with open(ifile) as ifh:

            fh.write("#line 1 \"../" + in_filename + "\"\n")

            for line in ifh:
                # Drop expanded macro definitions - but keep empty line so compiler
                # message locations are accurate
                line = re.sub(r'^\s*#(define|undef)\s+ASTGEN_.*$', '', line)

                # Track current node type and base class
                match = re.search(
                    r'\s*class\s*Ast(\S+)\s*(final|VL_NOT_FINAL)?\s*:\s*(public)?\s*(AstNode\S*)',
                    line)
                if match:
                    typen = match.group(1)
                    base = match.group(4)

                # Substitute macros
                line = re.sub(r'\bASTGEN_SUPER\s*\(',
                              base + "(AstType::at" + typen + ", ", line)

                # Emit the line
                fh.write(line)


######################################################################
# main

parser = argparse.ArgumentParser(
    allow_abbrev=False,
    formatter_class=argparse.RawDescriptionHelpFormatter,
    description="""Generate V3Ast headers to reduce C++ code duplication.""",
    epilog=
    """Copyright 2002-2021 by Wilson Snyder. This program is free software; you
can redistribute it and/or modify it under the terms of either the GNU
Lesser General Public License Version 3 or the Perl Artistic License
Version 2.0.

SPDX-License-Identifier: LGPL-3.0-only OR Artistic-2.0""")

parser.add_argument('-I', action='store', help='source code include directory')
parser.add_argument('--classes',
                    action='store_true',
                    help='makes class declaration files')
parser.add_argument('--debug', action='store_true', help='enable debug')

parser.add_argument('infiles', nargs='*', help='list of input .cpp filenames')

Args = parser.parse_args()

read_types(Args.I + "/V3Ast.h")
read_types(Args.I + "/V3AstNodes.h")
for typen in sorted(Classes.keys()):
    # Check all leaves are not AstNode* and non-leaves are AstNode*
    children = children_of(typen)
    if re.match(r'^Node', typen):
        if len(children) == 0:
            sys.exit(
                "%Error: Final AstNode subclasses must not be named AstNode*: Ast"
                + typen)
    else:
        if len(children) != 0:
            sys.exit(
                "%Error: Non-final AstNode subclasses must be named AstNode*: Ast"
                + typen)

read_stages(Args.I + "/Verilator.cpp")

source_files = glob.glob(Args.I + "/*.y")
source_files.extend(glob.glob(Args.I + "/*.h"))
source_files.extend(glob.glob(Args.I + "/*.cpp"))
for filename in source_files:
    read_refs(filename)

if Args.classes:
    write_report("V3Ast__gen_report.txt")
    write_classes("V3Ast__gen_classes.h")
    write_visitor("V3Ast__gen_visitor.h")
    write_impl("V3Ast__gen_impl.h")
    write_types("V3Ast__gen_types.h")
    write_header("V3AstNodes__gen.h")

for cpt in Args.infiles:
    if not re.search(r'.cpp$', cpt):
        sys.exit("%Error: Expected argument to be .cpp file: " + cpt)
    cpt = re.sub(r'.cpp$', '', cpt)
    Cpt().process(in_filename=Args.I + "/" + cpt + ".cpp",
                  out_filename=cpt + "__gen.cpp")

######################################################################
# Local Variables:
# compile-command: "cd obj_dbg && ../astgen -I.. V3Const.cpp"
# End:
