from textwrap import indent

import numpy as np

from aesara.link.numba.dispatch import basic as numba_basic
from aesara.link.numba.dispatch.basic import _numba_funcify, create_tuple_string
from aesara.link.utils import compile_function_src, unique_name_generator
from aesara.tensor.basic import (
    Alloc,
    AllocDiag,
    AllocEmpty,
    ARange,
    ExtractDiag,
    Eye,
    Join,
    MakeVector,
    ScalarFromTensor,
    Split,
    TensorFromScalar,
)
from aesara.tensor.shape import Unbroadcast


@_numba_funcify.register(AllocEmpty)
def numba_funcify_AllocEmpty(op, node, **kwargs):
    global_env = {
        "np": np,
        "to_scalar": numba_basic.to_scalar,
        "dtype": np.dtype(op.dtype),
    }

    unique_names = unique_name_generator(
        ["np", "to_scalar", "dtype", "allocempty", "scalar_shape"], suffix_sep="_"
    )
    shape_var_names = [unique_names(v, force_unique=True) for v in node.inputs]
    shape_var_item_names = [f"{name}_item" for name in shape_var_names]
    shapes_to_items_src = indent(
        "\n".join(
            [
                f"{item_name} = to_scalar({shape_name})"
                for item_name, shape_name in zip(shape_var_item_names, shape_var_names)
            ]
        ),
        " " * 4,
    )

    alloc_def_src = f"""
def allocempty({", ".join(shape_var_names)}):
{shapes_to_items_src}
    scalar_shape = {create_tuple_string(shape_var_item_names)}
    return np.empty(scalar_shape, dtype)
    """

    alloc_fn = compile_function_src(
        alloc_def_src, "allocempty", {**globals(), **global_env}
    )

    return numba_basic.numba_njit(alloc_fn)


@_numba_funcify.register(Alloc)
def numba_funcify_Alloc(op, node, **kwargs):
    global_env = {"np": np, "to_scalar": numba_basic.to_scalar}

    unique_names = unique_name_generator(
        ["np", "to_scalar", "alloc", "val_np", "val", "scalar_shape", "res"],
        suffix_sep="_",
    )
    shape_var_names = [unique_names(v, force_unique=True) for v in node.inputs[1:]]
    shape_var_item_names = [f"{name}_item" for name in shape_var_names]
    shapes_to_items_src = indent(
        "\n".join(
            [
                f"{item_name} = to_scalar({shape_name})"
                for item_name, shape_name in zip(shape_var_item_names, shape_var_names)
            ]
        ),
        " " * 4,
    )

    alloc_def_src = f"""
def alloc(val, {", ".join(shape_var_names)}):
    val_np = np.asarray(val)
{shapes_to_items_src}
    scalar_shape = {create_tuple_string(shape_var_item_names)}
    res = np.empty(scalar_shape, dtype=val_np.dtype)
    res[...] = val_np
    return res
    """

    alloc_fn = compile_function_src(alloc_def_src, "alloc", {**globals(), **global_env})

    return numba_basic.numba_njit(alloc_fn)


@_numba_funcify.register(AllocDiag)
def numba_funcify_AllocDiag(op, **kwargs):
    offset = op.offset

    @numba_basic.numba_njit(inline="always")
    def allocdiag(v):
        return np.diag(v, k=offset)

    return allocdiag


@_numba_funcify.register(ARange)
def numba_funcify_ARange(op, **kwargs):
    dtype = np.dtype(op.dtype)

    @numba_basic.numba_njit(inline="always")
    def arange(start, stop, step):
        return np.arange(
            numba_basic.to_scalar(start),
            numba_basic.to_scalar(stop),
            numba_basic.to_scalar(step),
            dtype=dtype,
        )

    return arange


@_numba_funcify.register(Join)
def numba_funcify_Join(op, **kwargs):
    view = op.view

    if view != -1:
        # TODO: Where (and why) is this `Join.view` even being used?  From a
        # quick search, the answer appears to be "nowhere", so we should
        # probably just remove it.
        raise NotImplementedError("The `view` parameter to `Join` is not supported")

    @numba_basic.numba_njit
    def join(axis, *tensors):
        return np.concatenate(tensors, numba_basic.to_scalar(axis))

    return join


@_numba_funcify.register(Split)
def numba_funcify_Split(op, **kwargs):
    @numba_basic.numba_njit
    def split(tensor, axis, indices):
        # Work around for https://github.com/numba/numba/issues/8257
        axis = axis % tensor.ndim
        axis = numba_basic.to_scalar(axis)
        return np.split(tensor, np.cumsum(indices)[:-1], axis=axis)

    return split


@_numba_funcify.register(ExtractDiag)
def numba_funcify_ExtractDiag(op, **kwargs):
    offset = op.offset
    # axis1 = op.axis1
    # axis2 = op.axis2

    @numba_basic.numba_njit(inline="always")
    def extract_diag(x):
        return np.diag(x, k=offset)

    return extract_diag


@_numba_funcify.register(Eye)
def numba_funcify_Eye(op, **kwargs):
    dtype = np.dtype(op.dtype)

    @numba_basic.numba_njit(inline="always")
    def eye(N, M, k):
        return np.eye(
            numba_basic.to_scalar(N),
            numba_basic.to_scalar(M),
            numba_basic.to_scalar(k),
            dtype=dtype,
        )

    return eye


@_numba_funcify.register(MakeVector)
def numba_funcify_MakeVector(op, node, **kwargs):
    dtype = np.dtype(op.dtype)

    global_env = {"np": np, "to_scalar": numba_basic.to_scalar, "dtype": dtype}

    unique_names = unique_name_generator(
        ["np", "to_scalar"],
        suffix_sep="_",
    )
    input_names = [unique_names(v, force_unique=True) for v in node.inputs]

    def create_list_string(x):
        args = ", ".join([f"to_scalar({i})" for i in x] + ([""] if len(x) == 1 else []))
        return f"[{args}]"

    makevector_def_src = f"""
def makevector({", ".join(input_names)}):
    return np.array({create_list_string(input_names)}, dtype=dtype)
    """

    makevector_fn = compile_function_src(
        makevector_def_src, "makevector", {**globals(), **global_env}
    )

    return numba_basic.numba_njit(makevector_fn)


@_numba_funcify.register(Unbroadcast)
def numba_funcify_Unbroadcast(op, **kwargs):
    @numba_basic.numba_njit
    def unbroadcast(x):
        return x

    return unbroadcast


@_numba_funcify.register(TensorFromScalar)
def numba_funcify_TensorFromScalar(op, **kwargs):
    @numba_basic.numba_njit(inline="always")
    def tensor_from_scalar(x):
        return np.array(x)

    return tensor_from_scalar


@_numba_funcify.register(ScalarFromTensor)
def numba_funcify_ScalarFromTensor(op, **kwargs):
    @numba_basic.numba_njit(inline="always")
    def scalar_from_tensor(x):
        return numba_basic.to_scalar(x)

    return scalar_from_tensor
