"""Tools to plot basemaps"""

import warnings
import numpy as np
from . import providers
from xyzservices import TileProvider
from .tile import bounds2img, _sm2ll, warp_tiles, _warper
from rasterio.enums import Resampling
from rasterio.warp import transform_bounds
from matplotlib import patheffects
from matplotlib.pyplot import draw

INTERPOLATION = "bilinear"
ZOOM = "auto"
ATTRIBUTION_SIZE = 8


def add_basemap(
    ax,
    zoom=ZOOM,
    source=None,
    interpolation=INTERPOLATION,
    attribution=None,
    attribution_size=ATTRIBUTION_SIZE,
    reset_extent=True,
    crs=None,
    resampling=Resampling.bilinear,
    **extra_imshow_args
):
    """
    Add a (web/local) basemap to `ax`.

    Parameters
    ----------
    ax : AxesSubplot
        Matplotlib axes object on which to add the basemap. The extent of the
        axes is assumed to be in Spherical Mercator (EPSG:3857), unless the `crs`
        keyword is specified.
    zoom : int or 'auto'
        [Optional. Default='auto'] Level of detail for the basemap. If 'auto',
        it is calculated automatically. Ignored if `source` is a local file.
    source : xyzservices.TileProvider object or str
        [Optional. Default: Stamen Terrain web tiles]
        The tile source: web tile provider or path to local file. The web tile
        provider can be in the form of a :class:`xyzservices.TileProvider` object or a
        URL. The placeholders for the XYZ in the URL need to be `{x}`, `{y}`,
        `{z}`, respectively. For local file paths, the file is read with
        `rasterio` and all bands are loaded into the basemap.
        IMPORTANT: tiles are assumed to be in the Spherical Mercator
        projection (EPSG:3857), unless the `crs` keyword is specified.
    interpolation : str
        [Optional. Default='bilinear'] Interpolation algorithm to be passed
        to `imshow`. See `matplotlib.pyplot.imshow` for further details.
    attribution : str
        [Optional. Defaults to attribution specified by the source]
        Text to be added at the bottom of the axis. This
        defaults to the attribution of the provider specified
        in `source` if available. Specify False to not
        automatically add an attribution, or a string to pass
        a custom attribution.
    attribution_size : int
        [Optional. Defaults to `ATTRIBUTION_SIZE`].
        Font size to render attribution text with.
    reset_extent : bool
        [Optional. Default=True] If True, the extent of the
        basemap added is reset to the original extent (xlim,
        ylim) of `ax`
    crs : None or str or CRS
        [Optional. Default=None] coordinate reference system (CRS),
        expressed in any format permitted by rasterio, to use for the
        resulting basemap. If None (default), no warping is performed
        and the original Spherical Mercator (EPSG:3857) is used.
    resampling : <enum 'Resampling'>
        [Optional. Default=Resampling.bilinear] Resampling
        method for executing warping, expressed as a
        `rasterio.enums.Resampling` method
    **extra_imshow_args :
        Other parameters to be passed to `imshow`.

    Examples
    --------

    >>> import geopandas
    >>> import contextily as ctx
    >>> db = geopandas.read_file(ps.examples.get_path('virginia.shp'))

    Ensure the data is in Spherical Mercator:

    >>> db = db.to_crs(epsg=3857)

    Add a web basemap:

    >>> ax = db.plot(alpha=0.5, color='k', figsize=(6, 6))
    >>> ctx.add_basemap(ax, source=url)
    >>> plt.show()

    Or download a basemap to a local file and then plot it:

    >>> source = 'virginia.tiff'
    >>> _ = ctx.bounds2raster(*db.total_bounds, zoom=6, source=source)
    >>> ax = db.plot(alpha=0.5, color='k', figsize=(6, 6))
    >>> ctx.add_basemap(ax, source=source)
    >>> plt.show()

    """
    xmin, xmax, ymin, ymax = ax.axis()
    # If web source
    if (
        source is None
        or isinstance(source, (dict, TileProvider))
        or (isinstance(source, str) and source[:4] == "http")
    ):
        # Extent
        left, right, bottom, top = xmin, xmax, ymin, ymax
        # Convert extent from `crs` into WM for tile query
        if crs is not None:
            left, right, bottom, top = _reproj_bb(
                left, right, bottom, top, crs, {"init": "epsg:3857"}
            )
        # Download image
        image, extent = bounds2img(
            left, bottom, right, top, zoom=zoom, source=source, ll=False
        )
        # Warping
        if crs is not None:
            image, extent = warp_tiles(image, extent, t_crs=crs, resampling=resampling)
        # Check if overlay
        if _is_overlay(source) and 'zorder' not in extra_imshow_args:
            # If zorder was not set then make it 9 otherwise leave it
            extra_imshow_args['zorder'] = 9
    # If local source
    else:
        import rasterio as rio

        # Read file
        with rio.open(source) as raster:
            if reset_extent:
                from rasterio.mask import mask as riomask

                # Read window
                if crs:
                    left, bottom, right, top = rio.warp.transform_bounds(
                        crs, raster.crs, xmin, ymin, xmax, ymax
                    )
                else:
                    left, bottom, right, top = xmin, ymin, xmax, ymax
                window = [
                    {
                        "type": "Polygon",
                        "coordinates": (
                            (
                                (left, bottom),
                                (right, bottom),
                                (right, top),
                                (left, top),
                                (left, bottom),
                            ),
                        ),
                    }
                ]
                image, img_transform = riomask(raster, window, crop=True)
                extent = left, right, bottom, top
            else:
                # Read full
                image = np.array([band for band in raster.read()])
                img_transform = raster.transform
                bb = raster.bounds
                extent = bb.left, bb.right, bb.bottom, bb.top
            # Warp
            if (crs is not None) and (raster.crs != crs):
                image, bounds, _ = _warper(
                    image, img_transform, raster.crs, crs, resampling
                )
                extent = bounds.left, bounds.right, bounds.bottom, bounds.top
            image = image.transpose(1, 2, 0)

    # Plotting
    if image.shape[2] == 1:
        image = image[:, :, 0]
    img = ax.imshow(
        image, extent=extent, interpolation=interpolation, **extra_imshow_args
    )

    if reset_extent:
        ax.axis((xmin, xmax, ymin, ymax))
    else:
        max_bounds = (
            min(xmin, extent[0]),
            max(xmax, extent[1]),
            min(ymin, extent[2]),
            max(ymax, extent[3]),
        )
        ax.axis(max_bounds)

    # Add attribution text
    if source is None:
        source = providers.Stamen.Terrain
    if isinstance(source, (dict, TileProvider)) and attribution is None:
        attribution = source.get("attribution")
    if attribution:
        add_attribution(ax, attribution, font_size=attribution_size)

    return


def _reproj_bb(left, right, bottom, top, s_crs, t_crs):
    n_l, n_b, n_r, n_t = transform_bounds(s_crs, t_crs, left, bottom, right, top)
    return n_l, n_r, n_b, n_t

def _is_overlay(source):
    """
    Check if the identified source is an overlay (partially transparent) layer.

    Parameters
    ----------
    source : dict
        The tile source: web tile provider.  Must be preprocessed as
        into a dictionary, not just a string.

    Returns
    -------
    bool

    Notes
    -----
    This function is based on a very similar javascript version found in leaflet:
    https://github.com/leaflet-extras/leaflet-providers/blob/9eb968f8442ea492626c9c8f0dac8ede484e6905/preview/preview.js#L56-L70
    """
    if not isinstance(source, dict):
        return False
    if source.get('opacity', 1.0) < 1.0:
        return True
    overlayPatterns = [
        '^(OpenWeatherMap|OpenSeaMap)',
        'OpenMapSurfer.(Hybrid|AdminBounds|ContourLines|Hillshade|ElementsAtRisk)',
        'Stamen.Toner(Hybrid|Lines|Labels)',
        'CartoDB.(Positron|DarkMatter|Voyager)OnlyLabels',
        'Hydda.RoadsAndLabels',
        '^JusticeMap',
        'OpenPtMap',
        'OpenRailwayMap',
        'OpenFireMap',
        'SafeCast'
    ]
    import re
    return bool(re.match('(' + '|'.join(overlayPatterns) + ')', source.get('name', '')))


def add_attribution(ax, text, font_size=ATTRIBUTION_SIZE, **kwargs):
    """
    Utility to add attribution text.

    Parameters
    ----------
    ax : AxesSubplot
        Matplotlib axes object on which to add the attribution text.
    text : str
        Text to be added at the bottom of the axis.
    font_size : int
        [Optional. Defaults to 8] Font size in which to render
        the attribution text.
    **kwargs : Additional keywords to pass to the matplotlib `text` method.

    Returns
    -------
    matplotlib.text.Text
                          Matplotlib Text object added to the plot.
    """
    # Add draw() as it resizes the axis and allows the wrapping to work as
    # expected. See https://github.com/darribas/contextily/issues/95 for some
    # details on the issue
    draw()

    text_artist = ax.text(
        0.005,
        0.005,
        text,
        transform=ax.transAxes,
        size=font_size,
        path_effects=[patheffects.withStroke(linewidth=2, foreground="w")],
        wrap=True,
        **kwargs,
    )
    # hack to have the text wrapped in the ax extent, for some explanation see
    # https://stackoverflow.com/questions/48079364/wrapping-text-not-working-in-matplotlib
    wrap_width = ax.get_window_extent().width * 0.99
    text_artist._get_wrap_line_width = lambda: wrap_width
    return text_artist
