/*
 * Decompiled with CFR 0.152.
 */
package ghidra.app.plugin.assembler.sleigh.sem;

import com.google.common.collect.Sets;
import ghidra.app.plugin.assembler.sleigh.expr.MaskedLong;
import ghidra.app.plugin.assembler.sleigh.expr.NeedsBackfillException;
import ghidra.app.plugin.assembler.sleigh.expr.RecursiveDescentSolver;
import ghidra.app.plugin.assembler.sleigh.grammars.AssemblyGrammar;
import ghidra.app.plugin.assembler.sleigh.grammars.AssemblyProduction;
import ghidra.app.plugin.assembler.sleigh.sem.AssemblyConstructorSemantic;
import ghidra.app.plugin.assembler.sleigh.sem.AssemblyContextGraph;
import ghidra.app.plugin.assembler.sleigh.sem.AssemblyPatternBlock;
import ghidra.app.plugin.assembler.sleigh.sem.AssemblyResolution;
import ghidra.app.plugin.assembler.sleigh.sem.AssemblyResolutionResults;
import ghidra.app.plugin.assembler.sleigh.sem.AssemblyResolvedBackfill;
import ghidra.app.plugin.assembler.sleigh.sem.AssemblyResolvedConstructor;
import ghidra.app.plugin.assembler.sleigh.sem.AssemblyResolvedError;
import ghidra.app.plugin.assembler.sleigh.symbol.AssemblyNonTerminal;
import ghidra.app.plugin.assembler.sleigh.symbol.AssemblyNumericTerminal;
import ghidra.app.plugin.assembler.sleigh.symbol.AssemblySymbol;
import ghidra.app.plugin.assembler.sleigh.tree.AssemblyParseBranch;
import ghidra.app.plugin.assembler.sleigh.tree.AssemblyParseNumericToken;
import ghidra.app.plugin.assembler.sleigh.tree.AssemblyParseTreeNode;
import ghidra.app.plugin.assembler.sleigh.util.DbgTimer;
import ghidra.app.plugin.processors.sleigh.Constructor;
import ghidra.app.plugin.processors.sleigh.SleighLanguage;
import ghidra.app.plugin.processors.sleigh.expression.PatternExpression;
import ghidra.app.plugin.processors.sleigh.symbol.OperandSymbol;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Deque;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import org.apache.commons.collections4.IteratorUtils;

public class AssemblyTreeResolver {
    protected static final RecursiveDescentSolver solver = RecursiveDescentSolver.getSolver();
    protected static final DbgTimer dbg = DbgTimer.INACTIVE;
    protected final SleighLanguage lang;
    protected final long instStart;
    protected final Map<String, Long> vals = new HashMap<String, Long>();
    protected final AssemblyParseBranch tree;
    protected final AssemblyGrammar grammar;
    protected final AssemblyPatternBlock context;
    protected final AssemblyContextGraph ctxGraph;
    public static final String INST_START = "inst_start";
    public static final String INST_NEXT = "inst_next";

    public AssemblyTreeResolver(SleighLanguage lang, long instStart, AssemblyParseBranch tree, AssemblyPatternBlock context, AssemblyContextGraph ctxGraph) {
        this.lang = lang;
        this.instStart = instStart;
        this.vals.put(INST_START, lang.getDefaultSpace().getAddressableWordOffset(instStart));
        this.tree = tree;
        this.grammar = tree.getGrammar();
        this.context = context.fillMask();
        this.ctxGraph = ctxGraph;
    }

    public AssemblyResolutionResults resolve() {
        AssemblyResolutionResults results = this.resolveBranch(this.tree);
        AssemblyResolutionResults ret = new AssemblyResolutionResults();
        Iterator iterator = results.iterator();
        while (iterator.hasNext()) {
            AssemblyResolution ar = (AssemblyResolution)iterator.next();
            assert (!(ar instanceof AssemblyResolvedBackfill));
            if (ar.isError()) {
                ret.add(ar);
                continue;
            }
            AssemblyResolvedConstructor rc = (AssemblyResolvedConstructor)ar;
            this.vals.put(INST_NEXT, this.lang.getDefaultSpace().getAddressableWordOffset(this.instStart + (long)rc.getInstructionLength()));
            if (rc.hasBackfills()) {
                dbg.println("Backfilling: " + rc);
            }
            ar = rc.backfill(solver, this.vals);
            dbg.println("Backfilled final: " + ar);
            if (ar.isError()) {
                ret.add(ar);
                continue;
            }
            rc = (AssemblyResolvedConstructor)ar;
            if (rc.hasBackfills()) {
                ret.add(AssemblyResolution.error("Solution is incomplete", "failed backfill", List.of(rc)));
                continue;
            }
            AssemblyResolvedConstructor ctx = AssemblyResolution.contextOnly(this.context, "Selecting context", null);
            AssemblyResolvedConstructor check = rc.combine(ctx);
            if (null == check) {
                ret.add(AssemblyResolution.error("Incompatible context", "resolving", List.of(rc)));
                continue;
            }
            rc = check;
            AssemblyResolution fcheck = rc.checkNotForbidden();
            if (fcheck.isError()) {
                ret.add(fcheck);
                continue;
            }
            rc = (AssemblyResolvedConstructor)fcheck;
            ret.add(rc);
        }
        return ret;
    }

    protected AssemblyResolutionResults resolveBranch(AssemblyParseBranch branch) {
        AssemblyProduction prod = branch.getProduction();
        Object lhs = prod.getLHS();
        AssemblyProduction rec = this.grammar.getPureRecursion((AssemblyNonTerminal)lhs);
        if (rec != null && branch.getParent() == null) {
            return this.resolveBranchRecursive(branch, rec);
        }
        return this.resolveBranchNonRecursive(branch);
    }

    protected AssemblyResolutionResults applyRecursionPath(Deque<AssemblyConstructorSemantic> path, AssemblyParseBranch branch, AssemblyProduction rec, AssemblyResolvedConstructor child) {
        AssemblyResolutionResults result = new AssemblyResolutionResults();
        AssemblyResolutionResults collected = new AssemblyResolutionResults();
        LinkedHashSet<AssemblyResolvedConstructor> intoNext = new LinkedHashSet<AssemblyResolvedConstructor>();
        intoNext.add(child);
        while (!path.isEmpty()) {
            AssemblyConstructorSemantic sem = path.pollLast();
            List<AssemblyParseTreeNode> substs = List.of(branch);
            for (AssemblyResolvedConstructor assemblyResolvedConstructor : intoNext) {
                List<AssemblyResolvedConstructor> sel = List.of(assemblyResolvedConstructor);
                collected.absorb(this.resolveSelectedChildren(rec, substs, sel, List.of(sem)));
            }
            intoNext.clear();
            Iterator iterator = collected.iterator();
            while (iterator.hasNext()) {
                AssemblyResolution assemblyResolution = (AssemblyResolution)iterator.next();
                if (assemblyResolution.isError()) {
                    result.add(assemblyResolution);
                    continue;
                }
                intoNext.add((AssemblyResolvedConstructor)assemblyResolution);
            }
        }
        result.addAll(intoNext);
        return result;
    }

    protected AssemblyResolutionResults resolveBranchRecursive(AssemblyParseBranch branch, AssemblyProduction rec) {
        try (DbgTimer.DbgCtx dc = dbg.start("Resolving (recursive) branch: " + branch.getProduction());){
            AssemblyResolutionResults result = new AssemblyResolutionResults();
            Object object = this.resolveBranchNonRecursive(branch).iterator();
            while (object.hasNext()) {
                AssemblyResolution ar = (AssemblyResolution)object.next();
                if (ar.isError()) {
                    result.add(ar);
                    continue;
                }
                AssemblyResolvedConstructor rc = (AssemblyResolvedConstructor)ar;
                AssemblyPatternBlock dst = rc.getContext();
                AssemblyPatternBlock src = this.context;
                String table = branch.getProduction().getName();
                dbg.println("Finding paths from " + this.context + " to " + ar.lineToString());
                Collection<Deque<AssemblyConstructorSemantic>> paths = this.ctxGraph.computeOptimalApplications(src, table, dst, table);
                dbg.println("Found " + paths.size());
                for (Deque<AssemblyConstructorSemantic> path : paths) {
                    dbg.println("  " + path);
                    result.absorb(this.applyRecursionPath(path, branch, rec, rc));
                }
            }
            object = result;
            return object;
        }
    }

    protected AssemblyResolutionResults resolveSelectedChildren(AssemblyProduction prod, List<AssemblyParseTreeNode> substs, List<AssemblyResolvedConstructor> sel, Collection<AssemblyConstructorSemantic> semantics) {
        try (DbgTimer.DbgCtx dc = dbg.start("Selecting: " + IteratorUtils.toString(sel.iterator(), rc -> rc.lineToString()));){
            AssemblyResolutionResults results = new AssemblyResolutionResults();
            AssemblyPatternBlock combCtx = AssemblyPatternBlock.nop();
            for (AssemblyResolvedConstructor child : sel) {
                AssemblyPatternBlock check = combCtx.combine(child.getContext());
                if (null == check) {
                    results.add(AssemblyResolution.error("Incompatible context requirements among selected children", "Resolving " + prod, sel));
                    AssemblyResolutionResults assemblyResolutionResults = results;
                    return assemblyResolutionResults;
                }
                combCtx = check;
            }
            dbg.println("Combined context: " + combCtx);
            AssemblyResolvedConstructor res = AssemblyResolution.nop("Resolving " + prod, sel);
            block18: for (AssemblyConstructorSemantic sem : semantics) {
                try {
                    DbgTimer.DbgCtx dc2 = dbg.start("Trying: " + sem);
                    try {
                        Constructor cons = sem.getConstructor();
                        AssemblyResolvedConstructor subres = res.copyAppendDescription("Applying constructor: " + sem);
                        HashMap<Integer, Object> opvals = new HashMap<Integer, Object>();
                        Iterator<Integer> opidxit = sem.getOperandIndexIterator();
                        Iterator<AssemblyResolvedConstructor> selit = sel.iterator();
                        for (int i = 0; i < prod.size(); ++i) {
                            AssemblyParseTreeNode child = substs.get(i);
                            AssemblySymbol sym = (AssemblySymbol)prod.get(i);
                            if (!sym.takesOperandIndex()) continue;
                            int opidx = opidxit.next();
                            if (child.isNumeric()) {
                                AssemblyParseNumericToken num = (AssemblyParseNumericToken)child;
                                opvals.put(opidx, num.getNumericValue());
                                continue;
                            }
                            if (!child.isConstructor()) continue;
                            opvals.put(opidx, selit.next());
                        }
                        opidxit = sem.getOperandIndexIterator();
                        Iterator<AssemblyResolvedConstructor> subit = sel.iterator();
                        for (int i = 0; i < prod.size(); ++i) {
                            AssemblyParseTreeNode child = substs.get(i);
                            AssemblySymbol sym = (AssemblySymbol)prod.get(i);
                            if (!sym.takesOperandIndex()) continue;
                            dbg.println("Current: " + subres.lineToString());
                            int opidx = opidxit.next();
                            OperandSymbol subsym = cons.getOperand(opidx);
                            int shift = AssemblyTreeResolver.computeOffset(subsym, cons, opvals);
                            String symname = subsym.getName();
                            dbg.println("Processing symbol: " + symname);
                            if (child.isNumeric()) {
                                int bitsize = 0;
                                if (sym instanceof AssemblyNumericTerminal) {
                                    AssemblyNumericTerminal numeric = (AssemblyNumericTerminal)sym;
                                    bitsize = numeric.getBitSize();
                                }
                                Long opval = (Long)opvals.get(opidx);
                                PatternExpression symexp = subsym.getDefiningExpression();
                                if (symexp == null) {
                                    symexp = subsym.getDefiningSymbol().getPatternExpression();
                                }
                                String desc = "Solution to " + sym + " := " + Long.toHexString(opval) + " = " + symexp + " (immediate op:" + opidx + ",shift:" + shift + ")";
                                dbg.println("Writing: " + desc);
                                AssemblyResolution sol = AssemblyTreeResolver.solveOrBackfill(symexp, opval, bitsize, this.vals, opvals, null, desc);
                                dbg.println("Solution: " + sol);
                                if (null == sol) {
                                    throw new AssertionError((Object)"Who returned a null solution!? Throw an exception or return an error result, please!");
                                }
                                if (sol.isError()) {
                                    AssemblyResolvedError err = (AssemblyResolvedError)sol;
                                    results.add(AssemblyResolution.error(err.getError(), subres));
                                    continue block18;
                                }
                                if (sol instanceof AssemblyResolvedConstructor) {
                                    AssemblyResolvedConstructor solcon = (AssemblyResolvedConstructor)sol;
                                    AssemblyResolvedConstructor check = subres.combine(solcon.shift(shift));
                                    if (null == check) {
                                        results.add(AssemblyResolution.error("Conflict: Immediate operand (token " + i + ") " + sol, subres));
                                        continue block18;
                                    }
                                    subres = check;
                                    continue;
                                }
                                AssemblyResolvedBackfill solbf = (AssemblyResolvedBackfill)sol;
                                subres = subres.combine(solbf.shift(shift));
                                continue;
                            }
                            if (child.isConstructor()) {
                                AssemblyResolvedConstructor childrc = subit.next();
                                dbg.println("Writing subtable(opidx:" + opidx + "): " + symname + ": " + childrc.lineToString() + " (shift:" + shift + ")");
                                AssemblyResolvedConstructor check = subres.combine(childrc.shift(shift));
                                if (null == check) {
                                    results.add(AssemblyResolution.error("Conflict: Subtable operand (token " + i + ")", subres));
                                    continue block18;
                                }
                                subres = check;
                                continue;
                            }
                            dbg.println("Probably encountered a varnode production: " + child);
                        }
                        AssemblyResolution backctx = sem.solveContextChanges(subres, this.vals, opvals);
                        if (!(backctx instanceof AssemblyResolvedConstructor)) {
                            results.add(backctx);
                            continue;
                        }
                        subres = (AssemblyResolvedConstructor)backctx;
                        subres = subres.solveContextChangesForForbids(sem, this.vals, opvals);
                        dbg.println("Writing patterns:");
                        for (AssemblyResolvedConstructor pat : sem.getPatterns()) {
                            AssemblyResolvedConstructor temp = subres;
                            dbg.println("  Pattern: " + pat.lineToString());
                            dbg.println("    Current: " + temp.lineToString());
                            AssemblyResolvedConstructor check = temp.combine(pat);
                            if (null == check) {
                                results.add(AssemblyResolution.error("The patterns conflict " + subres, temp));
                                continue;
                            }
                            temp = check;
                            dbg.println("    Final: " + temp.lineToString());
                            AssemblyResolution fcheck = temp.checkNotForbidden();
                            if (fcheck.isError()) {
                                results.add(fcheck);
                                continue;
                            }
                            temp = (AssemblyResolvedConstructor)fcheck;
                            results.add(temp);
                        }
                    }
                    finally {
                        if (dc2 == null) continue;
                        dc2.close();
                    }
                }
                catch (Exception e) {
                    dbg.println("While processing: " + sem);
                    throw e;
                }
            }
            results = this.tryResolveBackfills(results);
            Object object = results;
            return object;
        }
    }

    protected AssemblyResolutionResults tryResolveBackfills(AssemblyResolutionResults results) {
        AssemblyResolutionResults res = new AssemblyResolutionResults();
        Iterator iterator = results.iterator();
        block0: while (iterator.hasNext()) {
            AssemblyResolvedConstructor rc;
            AssemblyResolution ar = (AssemblyResolution)iterator.next();
            if (ar.isError()) {
                res.add(ar);
                continue;
            }
            do {
                if (!(rc = (AssemblyResolvedConstructor)ar).hasBackfills()) {
                    res.add(rc);
                    continue block0;
                }
                ar = rc.backfill(solver, this.vals);
                if (!ar.isError() && !ar.isBackfill()) continue;
                res.add(ar);
                continue block0;
            } while (!ar.equals(rc));
            res.add(ar);
        }
        return res;
    }

    protected AssemblyResolutionResults resolveBranchNonRecursive(AssemblyParseBranch branch) {
        try (DbgTimer.DbgCtx dc = dbg.start("Resolving (non-recursive) branch: " + branch.getProduction());){
            AssemblyResolutionResults results = new AssemblyResolutionResults();
            AssemblyProduction prod = branch.getProduction();
            List<AssemblyParseTreeNode> substs = branch.getSubstitutions();
            assert (prod.size() == substs.size());
            ArrayList childRes = new ArrayList();
            ArrayList<AssemblyResolvedError> childErr = new ArrayList<AssemblyResolvedError>();
            for (int i = 0; i < prod.size(); ++i) {
                AssemblyParseTreeNode child;
                AssemblySymbol sym = (AssemblySymbol)prod.get(i);
                if (!sym.takesOperandIndex() || !(child = substs.get(i)).isConstructor()) continue;
                AssemblyResolutionResults rr = this.resolveBranch((AssemblyParseBranch)child);
                HashSet<AssemblyResolvedConstructor> childResElem = new HashSet<AssemblyResolvedConstructor>();
                Iterator iterator = rr.iterator();
                while (iterator.hasNext()) {
                    AssemblyResolution ar = (AssemblyResolution)iterator.next();
                    if (ar.isError()) {
                        childErr.add((AssemblyResolvedError)ar);
                        continue;
                    }
                    childResElem.add((AssemblyResolvedConstructor)ar);
                }
                childRes.add(childResElem);
            }
            Collection<AssemblyConstructorSemantic> semantics = this.grammar.getSemantics(prod);
            for (List sel : Sets.cartesianProduct(childRes)) {
                results.absorb(this.resolveSelectedChildren(prod, substs, Collections.unmodifiableList(sel), semantics));
            }
            if (!childErr.isEmpty()) {
                results.add(AssemblyResolution.error("Child errors", "Resolving " + prod, Collections.unmodifiableList(childErr)));
            }
            Object object = results;
            return object;
        }
    }

    public static int computeOffset(OperandSymbol opsym, Constructor cons, Map<Integer, Object> res) {
        int offset = opsym.getRelativeOffset();
        int baseidx = opsym.getOffsetBase();
        if (baseidx != -1) {
            OperandSymbol baseop = cons.getOperand(baseidx);
            Object r = res.get(baseidx);
            if (r instanceof AssemblyResolvedConstructor) {
                AssemblyResolvedConstructor rc = (AssemblyResolvedConstructor)r;
                offset += rc.getInstructionLength();
            } else {
                offset += baseop.getMinimumLength();
            }
            offset += AssemblyTreeResolver.computeOffset(baseop, cons, res);
        }
        return offset;
    }

    protected static AssemblyResolution solveOrBackfill(PatternExpression exp, MaskedLong goal, Map<String, Long> vals, Map<Integer, Object> res, AssemblyResolvedConstructor cur, String description) {
        try {
            return solver.solve(exp, goal, vals, res, cur, description);
        }
        catch (NeedsBackfillException bf) {
            int fieldLength = solver.getInstructionLength(exp, res);
            return AssemblyResolution.backfill(exp, goal, res, fieldLength, description);
        }
    }

    protected static AssemblyResolution solveOrBackfill(PatternExpression exp, long goal, Map<String, Long> vals, Map<Integer, Object> res, AssemblyResolvedConstructor cur, String description) {
        return AssemblyTreeResolver.solveOrBackfill(exp, MaskedLong.fromLong(goal), vals, res, cur, description);
    }

    protected static AssemblyResolution solveOrBackfill(PatternExpression exp, long goal, int bits, Map<String, Long> vals, Map<Integer, Object> res, AssemblyResolvedConstructor cur, String description) {
        long msk = bits == 0 || bits >= 64 ? -1L : -1L << bits ^ 0xFFFFFFFFFFFFFFFFL;
        return AssemblyTreeResolver.solveOrBackfill(exp, MaskedLong.fromMaskAndValue(msk, goal), vals, res, cur, description);
    }
}

