import inspect
import sys
import time
import types
import warnings
from collections import OrderedDict
from io import StringIO
from typing import TYPE_CHECKING, Dict, Optional, Set

import numpy as np

import aesara
from aesara.configdefaults import config
from aesara.graph.basic import Variable, io_toposort
from aesara.graph.utils import InconsistencyError


if TYPE_CHECKING:
    from aesara.graph.basic import Apply


class AlreadyThere(Exception):
    """
    Raised by a Feature's on_attach callback method if the FunctionGraph
    attempting to attach the feature already has a functionally identical
    feature.

    """


class ReplacementDidNotRemoveError(Exception):
    """
    This exception should be thrown by replace_all_validate_remove
    when an optimization wanted to remove a Variable or a Node from
    the graph, but the replacement it gave didn't do that.

    """


class BadOptimization(Exception):
    """
    Exception: some variable and its substitute take different runtime values.

    Note: If there is only 1 parameter and it is a string, we will use
    it as the error message. This is needed when we catch, extend and
    reraise an error.

    """

    new_r = None
    """
    A `Variable` instance that took a different value from `old_r`,
    but which replaced `old_r`.

    """

    old_r = None
    """
    A `Variable` instance that was replaced by `new_r`.

    """

    old_r_val = None
    """
    The value computed for `old_r`.

    """

    new_r_val = None
    """
    The value computed for `new_r`.

    """

    reason = None
    """
    An object that indicates why old_r was turned into new_r.

    Convention is that this is the name of the optimization that
    requested the replacement.

    """

    old_graph = ""
    """
    A multiline string representation of the graph leading to
    old_r, at the time of the replacement.

    """

    new_graph = ""
    """
    A multiline string representation of the graph leading to
    new_r, at the time of the replacement.

    """

    def __init__(
        self,
        old_r,
        new_r=None,
        old_r_val=None,
        new_r_val=None,
        reason=None,
        old_graph=None,
        new_graph=None,
    ):
        super().__init__()

        self.old_r = old_r
        self.new_r = new_r
        self.old_r_val = old_r_val
        self.new_r_val = new_r_val
        self.reason = reason

        done = dict()
        used_ids = dict()

        if isinstance(old_r, Variable):
            self.old_graph = aesara.printing._debugprint(
                old_r,
                prefix="  ",
                depth=6,
                file=StringIO(),
                done=done,
                print_type=True,
                used_ids=used_ids,
            ).getvalue()
        else:
            self.old_graph = None

        if isinstance(new_r, Variable):
            self.new_graph = aesara.printing._debugprint(
                new_r,
                prefix="  ",
                depth=6,
                file=StringIO(),
                done=done,
                print_type=True,
                used_ids=used_ids,
            ).getvalue()
        else:
            self.new_graph = None

        # To allow extending the error message of an existing error.
        self.full_err = None
        if isinstance(old_r, str):
            assert (
                new_r is None
                and old_r_val is None
                and new_r_val is None
                and reason is None
                and old_graph is None
                and new_graph is None
            )
            self.full_err = old_r

    def __str__(self):
        return self.str_diagnostic()

    def str_diagnostic(self):
        """
        Return a pretty multiline string representing the cause of the exception.
        """
        # We have a pre-made message
        if getattr(self, "full_err", None) is not None:
            return self.full_err
        sio = StringIO()
        val_str_len_limit = 800
        print("BadOptimization Error", super().__str__(), file=sio)
        print("  Variable: id", id(self.new_r), self.new_r, file=sio)
        print("  Op", self.new_r.owner, file=sio)
        print("  Value Type:", type(self.new_r_val), file=sio)
        try:
            ssio = StringIO()
            print("  Old Value shape, dtype, strides:", end=" ", file=ssio)
            print(self.old_r_val.shape, end=" ", file=ssio)
            print(self.old_r_val.dtype, end=" ", file=ssio)
            print(self.old_r_val.strides, file=ssio)
            # only if all succeeds to we add anything to sio
            print(ssio.getvalue(), file=sio)
        except Exception:
            pass

        str_old_r_val = str(self.old_r_val)
        if len(str_old_r_val) > val_str_len_limit:
            print(
                "  Old Value: ",
                str(self.old_r_val)[:val_str_len_limit],
                "...",
                file=sio,
            )
        else:
            print("  Old Value: ", str(self.old_r_val), file=sio)

        try:
            ssio = StringIO()
            print("  New Value shape, dtype, strides:", end=" ", file=ssio)
            print(self.new_r_val.shape, end=" ", file=ssio)
            print(self.new_r_val.dtype, end=" ", file=ssio)
            print(self.new_r_val.strides, file=ssio)
            # only if all succeeds to we add anything to sio
            print(ssio.getvalue(), file=sio)
        except Exception:
            pass
        str_new_r_val = str(self.new_r_val)
        if len(str_new_r_val) > val_str_len_limit:
            print(
                "  New Value: ",
                str(self.new_r_val)[:val_str_len_limit],
                "...",
                file=sio,
            )
        else:
            print("  New Value: ", str(self.new_r_val), file=sio)

        try:
            ov = np.asarray(self.old_r_val)
            nv = np.asarray(self.new_r_val)
            ssio = StringIO()
            abs_diff = np.absolute(nv - ov)
            print("  Max Abs Diff: ", np.max(abs_diff), file=ssio)
            print("  Mean Abs Diff: ", np.mean(abs_diff), file=ssio)
            print("  Median Abs Diff: ", np.median(abs_diff), file=ssio)
            print("  Std Abs Diff: ", np.std(abs_diff), file=ssio)
            arg_max_val = np.argmax(abs_diff)
            values_at_max = (nv.flatten()[arg_max_val], ov.flatten()[arg_max_val])
            print("  Value at Max Diff: ", values_at_max, file=ssio)

            # N.B. the maximum(..., 1e-8) protects against div by 0 when
            #      nv == ov == 0
            reldiff = abs_diff / np.maximum(np.absolute(nv) + np.absolute(ov), 1e-8)
            print("  Max Rel Diff: ", np.max(reldiff), file=ssio)
            print("  Mean Rel Diff: ", np.mean(reldiff), file=ssio)
            print("  Median Rel Diff: ", np.median(reldiff), file=ssio)
            print("  Std Rel Diff: ", np.std(reldiff), file=ssio)
            arg_max_val = np.argmax(reldiff)
            values_at_max = (nv.flatten()[arg_max_val], ov.flatten()[arg_max_val])
            print("  Value at Max Diff: ", values_at_max, file=ssio)
            # only if all succeeds to we add anything to sio
            print(ssio.getvalue(), file=sio)
        except Exception:
            pass

        print("  Reason: ", str(self.reason), file=sio)
        print("  Old Graph:", file=sio)
        print(self.old_graph, file=sio)
        print("  New Graph:", file=sio)
        print(self.new_graph, file=sio)
        print("", file=sio)
        print("Hint: relax the tolerance by setting tensor__cmp_sloppy=1", file=sio)
        print("  or even tensor__cmp_sloppy=2 for less-strict comparison", file=sio)
        return sio.getvalue()


class Feature:
    """
    Base class for FunctionGraph extensions.

    A Feature is an object with several callbacks that are triggered
    by various operations on FunctionGraphs. It can be used to enforce
    graph properties at all stages of graph optimization.

    See Also
    --------
    aesara.graph.features : for common extensions.

    """

    def on_attach(self, fgraph):
        """
        Called by `FunctionGraph.attach_feature`, the method that attaches the
        feature to the `FunctionGraph`. Since this is called after the
        `FunctionGraph` is initially populated, this is where you should run
        checks on the initial contents of the `FunctionGraph`.

        The on_attach method may raise the `AlreadyThere` exception to cancel
        the attach operation if it detects that another Feature instance
        implementing the same functionality is already attached to the
        `FunctionGraph`.

        The feature has great freedom in what it can do with the `fgraph`: it
        may, for example, add methods to it dynamically.

        """

    def on_detach(self, fgraph):
        """
        Called by `FunctionGraph.remove_feature`.  Should remove any
        dynamically-added functionality that it installed into the fgraph.

        """

    def on_import(self, fgraph, node: "Apply", reason: Optional[str]):
        """
        Called whenever a node is imported into `fgraph`, which is just before
        the node is actually connected to the graph.

        Note: this is not called when the graph is created. If you want to
        detect the first nodes to be implemented to the graph, you should do
        this by implementing `on_attach`.

        """

    def on_change_input(
        self,
        fgraph,
        node: "Apply",
        i: int,
        var: "Variable",
        new_var: "Variable",
        reason: Optional[str] = None,
    ):
        """
        Called whenever ``node.inputs[i]`` is changed from `var` to `new_var`.
        At the moment the callback is done, the change has already taken place.

        If you raise an exception in this function, the state of the graph
        might be broken for all intents and purposes.

        """

    def on_prune(self, fgraph, node: "Apply", reason: Optional[str]) -> None:
        """
        Called whenever a node is pruned (removed) from the `fgraph`, after it
        is disconnected from the graph.

        """

    def orderings(self, fgraph, ordered: bool = True) -> Dict["Apply", Set["Apply"]]:
        """
        Called by `FunctionGraph.toposort`. It should return a dictionary of
        ``{node: predecessors}`` where ``predecessors`` is a list of
        nodes that should be computed before the key node.

        If you raise an exception in this function, the state of the graph
        might be broken for all intents and purposes.

        """
        return OrderedDict()

    def clone(self):
        """Create a clone that can be attached to a new `FunctionGraph`.

        This default implementation returns `self`, which carries the
        assumption that the `Feature` is essentially stateless.  If a subclass
        has state of its own that is in any way relative to a given
        `FunctionGraph`, this method should be overridden with an
        implementation that actually creates a fresh copy.
        """
        return self


class Bookkeeper(Feature):
    def on_attach(self, fgraph):
        for node in io_toposort(fgraph.inputs, fgraph.outputs):
            self.on_import(fgraph, node, "Bookkeeper.on_attach")

    def on_detach(self, fgraph):
        for node in io_toposort(fgraph.inputs, fgraph.outputs):
            self.on_prune(fgraph, node, "Bookkeeper.detach")


class LambdaExtract:
    def __init__(self, fgraph, node, i, r, reason=None):
        self.fgraph = fgraph
        self.node = node
        self.i = i
        self.r = r
        self.reason = reason

    def __call__(self):
        return self.fgraph.change_node_input(
            self.node, self.i, self.r, reason=("Revert", self.reason), check=False
        )


class History(Feature):
    """Keep a history of changes to a `FunctionGraph`.

    A `FunctionGraph` can be reverted up to the last checkpoint using this
    `Feature`.  It can revert to only one point in the past.  This limit was
    added to lower memory usage.

    """

    def on_attach(self, fgraph):
        if hasattr(fgraph, "checkpoint") or hasattr(fgraph, "revert"):
            raise AlreadyThere(
                "History feature is already present or in"
                " conflict with another plugin."
            )
        fgraph._history_is_reverting = False
        fgraph._history_nb = 0
        fgraph._history_history = []
        fgraph.checkpoint = types.MethodType(self.checkpoint, fgraph)
        fgraph.revert = types.MethodType(self.revert, fgraph)

    def clone(self):
        return type(self)()

    def on_detach(self, fgraph):
        del fgraph.checkpoint
        del fgraph.revert
        del fgraph._history_history
        del fgraph._history_is_reverting

    def on_change_input(self, fgraph, node, i, r, new_r, reason=None):
        if fgraph._history_is_reverting:
            return
        fgraph._history_history.append(LambdaExtract(fgraph, node, i, r, str(reason)))

    @staticmethod
    def checkpoint(fgraph):
        fgraph._history_history = []
        fgraph._history_nb += 1
        return fgraph._history_nb

    @staticmethod
    def revert(fgraph, checkpoint):
        """
        Reverts the graph to whatever it was at the provided
        checkpoint (undoes all replacements). A checkpoint at any
        given time can be obtained using :meth:`self.checkpoint`.

        """
        h = fgraph._history_history
        fgraph._history_is_reverting = True
        assert fgraph._history_nb == checkpoint
        while h:
            f = h.pop()
            f()
        fgraph._history_is_reverting = False


class Validator(Feature):
    def on_attach(self, fgraph):
        for attr in ("validate", "validate_time"):
            if hasattr(fgraph, attr):
                raise AlreadyThere(
                    "Validator feature is already present or in"
                    " conflict with another plugin."
                )
        fgraph.validate = types.MethodType(self.validate_, fgraph)
        fgraph.consistent = types.MethodType(self.consistent_, fgraph)

    def on_detach(self, fgraph):
        """
        Should remove any dynamically added functionality
        that it installed into the function_graph
        """
        del fgraph.validate
        del fgraph.consistent

    @staticmethod
    def validate_(fgraph):
        """
        If the caller is replace_all_validate, just raise the
        exception. replace_all_validate will print out the
        verbose output. Or it has to be done here before raise.
        """
        t0 = time.perf_counter()
        try:
            ret = fgraph.execute_callbacks("validate")
        except Exception as e:
            cf = inspect.currentframe()
            uf = cf.f_back
            uf_info = inspect.getframeinfo(uf)

            # If the caller is replace_all_validate, just raise the
            # exception. replace_all_validate will print out the
            # verbose output.
            # Or it has to be done here before raise.
            if uf_info.function == "replace_all_validate":
                raise
            else:
                verbose = uf.f_locals.get("verbose", False)
                if verbose:
                    r = uf.f_locals.get("r", "")
                    reason = uf_info.function
                    print(f"validate failed on node {r}.\n Reason: {reason}, {e}")
                raise
        t1 = time.perf_counter()
        if fgraph.profile:
            fgraph.profile.validate_time += t1 - t0
        return ret

    @staticmethod
    def consistent_(fgraph):
        try:
            fgraph.validate()
            return True
        except Exception:
            return False


class ReplaceValidate(History, Validator):
    def on_attach(self, fgraph):
        for attr in (
            "replace_validate",
            "replace_all_validate",
            "replace_all_validate_remove",
        ):
            if hasattr(fgraph, attr):
                raise AlreadyThere(
                    "ReplaceValidate feature is already present"
                    " or in conflict with another plugin."
                )
        fgraph._replace_nodes_removed = set()
        fgraph._replace_validate_failed = False

        History.on_attach(self, fgraph)
        Validator.on_attach(self, fgraph)

        fgraph.replace_validate = types.MethodType(self.replace_validate, fgraph)
        fgraph.replace_all_validate = types.MethodType(
            self.replace_all_validate, fgraph
        )
        fgraph.replace_all_validate_remove = types.MethodType(
            self.replace_all_validate_remove, fgraph
        )

    def clone(self):
        return type(self)()

    def on_detach(self, fgraph):
        History.on_detach(self, fgraph)
        Validator.on_detach(self, fgraph)
        del fgraph._replace_nodes_removed
        del fgraph._replace_validate_failed
        del fgraph.replace_validate
        del fgraph.replace_all_validate
        del fgraph.replace_all_validate_remove

    @staticmethod
    def replace_validate(fgraph, r, new_r, reason=None, **kwargs):
        ReplaceValidate.replace_all_validate(
            fgraph, [(r, new_r)], reason=reason, **kwargs
        )

    @staticmethod
    def replace_all_validate(fgraph, replacements, reason=None, verbose=None, **kwargs):
        chk = fgraph.checkpoint()

        if verbose is None:
            verbose = config.optimizer_verbose

        for r, new_r in replacements:
            try:
                fgraph.replace(r, new_r, reason=reason, verbose=False, **kwargs)
            except Exception as e:
                msg = str(e)
                s1 = "The type of the replacement must be the same"
                s2 = "does not belong to this FunctionGraph"
                s3 = "maximum recursion depth exceeded"
                if s3 in msg:
                    # There is nothing safe we can do to recover from this.
                    # So don't revert as this raise a different error
                    # that isn't helpful.
                    e.args += (
                        " As a temporary work around, you can raise Python"
                        " stack limit with:"
                        " import sys; sys.setrecursionlimit(10000)",
                    )
                    raise
                elif s1 not in msg and s2 not in msg:
                    out = sys.stderr
                    print(
                        "<<!! BUG IN FGRAPH.REPLACE OR A LISTENER !!>>",
                        type(e),
                        e,
                        reason,
                        file=out,
                    )
                # this might fail if the error is in a listener:
                # (fgraph.replace kinda needs better internal error handling)
                fgraph.revert(chk)
                raise
        try:
            fgraph.validate()
        except Exception as e:
            fgraph.revert(chk)
            if verbose:
                print(
                    f"rewriting: validate failed on node {r}.\n Reason: {reason}, {e}"
                )
            raise

        if verbose:
            print(
                f"rewriting: rewrite {reason} replaces {r} of {r.owner} with {new_r} of {new_r.owner}"
            )

        # The return is needed by replace_all_validate_remove
        return chk

    @staticmethod
    def replace_all_validate_remove(
        fgraph, replacements, remove, reason=None, warn=True, **kwargs
    ):
        """
        As replace_all_validate, revert the replacement if the ops
        in the list remove are still in the graph. Also print a warning.

        """
        chk = fgraph.replace_all_validate(replacements, reason=reason, **kwargs)
        fgraph._replace_nodes_removed.update(remove)
        for rm in remove:
            if rm in fgraph.apply_nodes or rm in fgraph.variables:
                fgraph.revert(chk)
                if warn:
                    warnings.warn(
                        "An optimization wanted to replace a Variable"
                        " in the graph, but the replacement for it doesn't"
                        " remove it. We disabled the optimization."
                        f"{reason}: {replacements}",
                    )
                raise ReplacementDidNotRemoveError()

    def __getstate__(self):
        d = self.__dict__.copy()
        if "history" in d:
            del d["history"]
        return d

    def on_import(self, fgraph, node, reason):
        if node in fgraph._replace_nodes_removed:
            fgraph._replace_validate_failed = True

    def validate(self, fgraph):
        if fgraph._replace_validate_failed:
            fgraph._replace_validate_failed = False
            raise InconsistencyError("Trying to reintroduce a removed node")


class NodeFinder(Bookkeeper):
    def on_attach(self, fgraph):
        if hasattr(fgraph, "get_nodes"):
            raise AlreadyThere("NodeFinder is already present")

        fgraph._finder_ops_to_nodes = {}

        def query(self, op):
            return self._finder_ops_to_nodes.get(op, [])

        fgraph.get_nodes = types.MethodType(query, fgraph)
        super().on_attach(fgraph)

    def clone(self):
        return type(self)()

    def on_detach(self, fgraph):
        del fgraph.get_nodes
        del fgraph._finder_ops_to_nodes

    def on_import(self, fgraph, node, reason):
        try:
            fgraph._finder_ops_to_nodes.setdefault(node.op, []).append(node)
        except TypeError:
            # In case the `Op` is unhashable
            return

    def on_prune(self, fgraph, node, reason):
        try:
            nodes = fgraph._finder_ops_to_nodes[node.op]
        except TypeError:
            # In case the `Op` is unhashable
            return

        nodes.remove(node)

        if not nodes:
            del fgraph._finder_ops_to_nodes[node.op]


class PrintListener(Feature):
    def __init__(self, active=True):
        self.active = active

    def on_attach(self, fgraph):
        if self.active:
            print("-- attaching to: ", fgraph)

    def on_detach(self, fgraph):
        """
        Should remove any dynamically added functionality
        that it installed into the function_graph
        """
        if self.active:
            print("-- detaching from: ", fgraph)

    def on_import(self, fgraph, node, reason):
        if self.active:
            print(f"-- importing: {node}, reason: {reason}")

    def on_prune(self, fgraph, node, reason):
        if self.active:
            print(f"-- pruning: {node}, reason: {reason}")

    def on_change_input(self, fgraph, node, i, r, new_r, reason=None):
        if self.active:
            print(f"-- changing ({node}.inputs[{i}]) from {r} to {new_r}")


class PreserveVariableAttributes(Feature):
    """
    This preserve some variables attributes and tag during optimization.
    """

    def on_change_input(self, fgraph, node, i, r, new_r, reason=None):
        if r.name is not None and new_r.name is None:
            new_r.name = r.name
        if (
            getattr(r.tag, "nan_guard_mode_check", False)
            and getattr(new_r.tag, "nan_guard_mode_check", False) is False
        ):
            new_r.tag.nan_guard_mode_check = r.tag.nan_guard_mode_check


class NoOutputFromInplace(Feature):
    """Prevent `FunctionGraph` outputs within a range from being altered in-place."""

    def __init__(self, protected_out_ids):
        self.protected_out_ids = tuple(protected_out_ids)

    def on_attach(self, fgraph):
        if hasattr(fgraph, "_no_output_from_inplace"):
            raise AlreadyThere(f"InnerGraphWatcher is already attached to {fgraph}.")

        fgraph._no_output_from_inplace = self

    def clone(self):
        return type(self)(self.protected_out_ids)

    def validate(self, fgraph):
        if not hasattr(fgraph, "destroyers"):
            return True

        for out in tuple(fgraph.outputs[i] for i in self.protected_out_ids):
            node = out.owner

            if node is None:
                continue

            # Validate that the node that produces the output does not produce
            # it by modifying something else in-place.
            op = node.op
            out_idx = node.outputs.index(out)
            if out_idx in op.destroy_map:
                raise InconsistencyError(
                    "A function graph Feature has requested that outputs of the graph "
                    "be prevented from being the result of in-place "
                    f"operations. This has prevented the output {out} from "
                    "being computed by modifying another variable in-place."
                )

        return True
