# -*- coding: utf-8 -*-
# Copyright (C) Duncan Macleod (2013)
#
# This file is part of GWpy.
#
# GWpy is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# GWpy is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with GWpy.  If not, see <http://www.gnu.org/licenses/>.

"""Extension of the simple Plot class for displaying segment objects
"""

import operator

from six import string_types
from six.moves import reduce

from matplotlib.artist import allow_rasterization
from matplotlib.ticker import (Formatter, MultipleLocator, NullLocator)
from matplotlib.projections import register_projection
from matplotlib.collections import PatchCollection
from matplotlib.patches import Rectangle
try:
    from mpl_toolkits.axes_grid1 import make_axes_locatable
except ImportError:
    from mpl_toolkits.axes_grid import make_axes_locatable

import ligo.segments

from ..segments import (Segment, SegmentList, DataQualityFlag, DataQualityDict)
from .timeseries import (TimeSeriesPlot, TimeSeriesAxes)
from .decorators import auto_refresh
from .text import to_string

__author__ = 'Duncan Macleod <duncan.macleod@ligo.org>'
__all__ = ['SegmentAxes', 'SegmentPlot']


class SegmentAxes(TimeSeriesAxes):
    """Custom `Axes` for a `~gwpy.plotter.SegmentPlot`.

    This `SegmentAxes` provides custom methods for displaying any of

    - `~gwpy.segments.DataQualityFlag`
    - `~gwpy.segments.Segment` or :class:`ligo.segments.segment`
    - `~gwpy.segments.SegmentList` or :class:`ligo.segments.segmentlist`
    - `~gwpy.segments.SegmentListDict` or
      :class:`ligo.segments.segmentlistdict`

    Parameters
    ----------
    insetlabels : `bool`, default: `False`
        display segment labels inside the axes. Prevents very long segment
        names from getting squeezed off the end of a standard figure

    See also
    --------
    gwpy.plotter.TimeSeriesAxes
        for documentation of other args and kwargs
    """
    name = 'segments'

    def __init__(self, *args, **kwargs):
        # set labelling format
        kwargs.setdefault('insetlabels', False)

        # make axes
        super(SegmentAxes, self).__init__(*args, **kwargs)

        # set y-axis labels
        self.yaxis.set_major_locator(MultipleLocator())
        formatter = SegmentFormatter()
        self.yaxis.set_major_formatter(formatter)

    def plot(self, *args, **kwargs):
        """Plot data onto these axes

        Parameters
        ----------
        args
            a single instance of

                - `~gwpy.segments.DataQualityFlag`
                - `~gwpy.segments.Segment`
                - `~gwpy.segments.SegmentList`
                - `~gwpy.segments.SegmentListDict`

            or equivalent types upstream from :mod:`ligo.segments`

        kwargs
            keyword arguments applicable to `~matplotib.axes.Axes.plot`

        Returns
        -------
        Line2D
            the `~matplotlib.lines.Line2D` for this line layer

        See Also
        --------
        :meth:`matplotlib.axes.Axes.plot`
            for a full description of acceptable ``*args` and ``**kwargs``
        """
        out = []
        args = list(args)
        while args:
            if isinstance(args[0], DataQualityDict):
                out.append(self.plot_dqdict(args.pop(0), **kwargs))
                continue
            elif isinstance(args[0], DataQualityFlag):
                out.append(self.plot_dqflag(args[0], **kwargs))
                args.pop(0)
                continue
            elif isinstance(args[0], ligo.segments.segmentlistdict):
                out.extend(self.plot_segmentlistdict(args[0], **kwargs))
                args.pop(0)
                continue
            elif isinstance(args[0], ligo.segments.segmentlist):
                out.append(self.plot_segmentlist(args[0], **kwargs))
                args.pop(0)
                continue
            elif isinstance(args[0], ligo.segments.segment):
                raise ValueError("Input must be DataQualityFlag, "
                                 "SegmentListDict, or SegmentList")
            break
        if args:
            out.extend(super(SegmentAxes, self).plot(*args, **kwargs))
        self.autoscale(axis='y')
        return out

    @auto_refresh
    def plot_dqdict(self, flags, label='key', known='x', **kwargs):
        """Plot a `~gwpy.segments.DataQualityDict` onto these axes

        Parameters
        ----------
        flags : `~gwpy.segments.DataQualityDict`
            data-quality dict to display

        label : `str`, optional
            labelling system to use, or fixed label for all `DataQualityFlags`.
            Special values include

            - ``'key'``: use the key of the `DataQualityDict`,
            - ``'name'``: use the :attr:`~DataQualityFlag.name` of the
              `DataQualityFlag`

            If anything else, that fixed label will be used for all lines.

        known : `str`, `dict`, `None`, default: '/'
            display `known` segments with the given hatching, or give a
            dict of keyword arguments to pass to
            :meth:`~SegmentAxes.plot_segmentlist`, or `None` to hide.

        **kwargs
            any other keyword arguments acceptable for
            `~matplotlib.patches.Rectangle`

        Returns
        -------
        collection : `~matplotlib.patches.PatchCollection`
            list of `~matplotlib.patches.Rectangle` patches
        """
        out = []
        for lab, flag in flags.items():
            if label.lower() == 'name':
                lab = flag.name
            elif label.lower() != 'key':
                lab = label
            out.append(self.plot(flag, label=to_string(lab), known=known,
                                 **kwargs))
        return out

    @auto_refresh
    def plot_dqflag(self, flag, y=None, known='red', facecolor=(0.2, 0.8, 0.2),
                    **kwargs):
        """Plot a `~gwpy.segments.DataQualityFlag`
        onto these axes

        Parameters
        ----------
        flag : `~gwpy.segments.DataQualityFlag`
            data-quality flag to display

        y : `float`, optional
            y-axis value for new segments

        height : `float`, optional, default: 0.8
            height for each segment block

        known : `str`, `dict`, `None`, default: '/'
            display `known` segments with the given hatching, or give a
            dict of keyword arguments to pass to
            :meth:`~SegmentAxes.plot_segmentlist`, or `None` to hide.

        **kwargs
            any other keyword arguments acceptable for
            `~matplotlib.patches.Rectangle`

        Returns
        -------
        collection : `~matplotlib.patches.PatchCollection`
            list of `~matplotlib.patches.Rectangle` patches
        """
        # get y axis position
        if y is None:
            y = self.get_next_y()
        # get flag name
        name = kwargs.pop('label', flag.label or flag.name)

        # get epoch
        try:
            if not self.epoch:
                self.set_epoch(flag.known[0][0])
            else:
                self.set_epoch(min(self.epoch, flag.known[0][0]))
        except IndexError:
            pass
        # make known collection
        if known == 'x' and 'valid' in kwargs:
            known = kwargs.pop('valid')
        if known is not None:
            if isinstance(known, dict):
                vkwargs = known
            else:
                vkwargs = kwargs.copy()
                vkwargs.pop('label', None)
                if known in ['-', '+', 'x', '\\', '*', 'o', 'O', '.']:
                    vkwargs['fill'] = False
                    vkwargs['hatch'] = known
                else:
                    vkwargs['fill'] = True
                    vkwargs['facecolor'] = known
                    vkwargs['edgecolor'] = 'black'
            vkwargs['collection'] = 'ignore'
            vkwargs['zorder'] = -1000
            self.plot_segmentlist(flag.known, y=y, label=name, **vkwargs)
        # make active collection
        collection = self.plot_segmentlist(flag.active, y=y, label=name,
                                           facecolor=facecolor, **kwargs)
        if (known is not None and len(self.collections) == 2 or
                len(self.collections) == 1):
            if flag.known:
                self.set_xlim(*map(float, flag.extent))
            self.autoscale(axis='y')
        return collection

    @auto_refresh
    def plot_segmentlist(self, segmentlist, y=None, collection=True,
                         label=None, rasterized=None, **kwargs):
        """Plot a `~gwpy.segments.SegmentList` onto these axes

        Parameters
        ----------
        segmentlist : `~gwpy.segments.SegmentList`
            list of segments to display

        y : `float`, optional
            y-axis value for new segments

        collection : `bool`, default: `True`
            add all patches as a
            `~matplotlib.collections.PatchCollection`, doesn't seem
            to work for hatched rectangles

        label : `str`, optional
            custom descriptive name to print as y-axis tick label

        **kwargs
            any other keyword arguments acceptable for
            `~matplotlib.patches.Rectangle`

        Returns
        -------
        collection : `~matplotlib.patches.PatchCollection`
            list of `~matplotlib.patches.Rectangle` patches
        """
        if y is None:
            y = self.get_next_y()
        patches = []
        for seg in segmentlist:
            patches.append(self.build_segment(seg, y, **kwargs))
        try:
            if not self.epoch:
                self.set_epoch(segmentlist[0][0])
            else:
                self.set_epoch(min(self.epoch, segmentlist[0][0]))
        except IndexError:
            pass
        if collection:
            coll = PatchCollection(patches, match_original=patches)
            coll.set_rasterized(rasterized)
            coll._ignore = collection == 'ignore'
            coll._ypos = y
            out = self.add_collection(coll)
            # reset label with tex-formatting now
            #   matplotlib default label is applied by add_collection
            #   so we can only replace the leading underscore after
            #   this point
            if label is None:
                label = coll.get_label()
            coll.set_label(to_string(label))
        else:
            out = []
            for patch in patches:
                patch.set_label(label)
                patch.set_rasterized(rasterized)
                label = ''
                out.append(self.add_patch(patch))
        self.autoscale(axis='y')
        return out

    @auto_refresh
    def plot_segmentlistdict(self, segmentlistdict, y=None, dy=1, **kwargs):
        """Plot a `~gwpy.segments.SegmentListDict` onto
        these axes

        Parameters
        ----------
        segmentlistdict : `~gwpy.segments.SegmentListDict`
            (name, `~gwpy.segments.SegmentList`) dict

        y : `float`, optional
            starting y-axis value for new segmentlists

        **kwargs
            any other keyword arguments acceptable for
            `~matplotlib.patches.Rectangle`

        Returns
        -------
        collections : `list`
            list of `~matplotlib.patches.PatchCollection` sets for
            each segmentlist
        """
        if y is None:
            y = self.get_next_y()
        collections = []
        for name, segmentlist in segmentlistdict.items():
            collections.append(self.plot_segmentlist(segmentlist, y=y,
                                                     label=name, **kwargs))
            y += dy
        return collections

    @staticmethod
    def build_segment(segment, y, height=.8, valign='center', **kwargs):
        """Build a `~matplotlib.patches.Rectangle` to display
        a single `~gwpy.segments.Segment`

        Parameters
        ----------
        segment : `~gwpy.segments.Segment`
            ``[start, stop)`` GPS segment

        y : `float`
            y-axis position for segment

        height : `float`, optional, default: 1
            height (in y-axis units) for segment

        valign : `str`
            alignment of segment on y-axis value:
            `top`, `center`, or `bottom`

        **kwargs
            any other keyword arguments acceptable for
            `~matplotlib.patches.Rectangle`

        Returns
        -------
        box : `~matplotlib.patches.Rectangle`
            rectangle patch for segment display
        """
        if valign.lower() == 'bottom':
            y0 = y
        elif valign.lower() in ['center', 'centre']:
            y0 = y - height/2.
        elif valign.lower() == 'top':
            y0 = y - height
        else:
            raise ValueError("valign must be one of 'top', 'center', or "
                             "'bottom'")
        width = segment[1] - segment[0]
        return Rectangle((segment[0], y0), width=width, height=height,
                         **kwargs)

    def set_xlim(self, *args, **kwargs):
        out = super(SegmentAxes, self).set_xlim(*args, **kwargs)
        _xlim = self.get_xlim()
        try:
            texts = self.texts
        except AttributeError:
            pass
        else:
            for txt in texts:
                # pylint: disable=protected-access
                if hasattr(txt, '_is_segment_label') and txt._is_segment_label:
                    txt.set_x(_xlim[0] + (_xlim[1] - _xlim[0]) * 0.01)
        return out
    set_xlim.__doc__ = TimeSeriesAxes.set_xlim.__doc__

    def get_next_y(self):
        """Find the next y-axis value at which a segment list can be placed

        This method simply counts the number of independent segmentlists or
        flags that have been plotted onto these axes.
        """
        return len(self.get_collections(ignore=False))

    def get_collections(self, ignore=None):
        """Return the collections matching the given `_ignore` value

        Parameters
        ----------
        ignore : `bool`, or `None`
            value of `_ignore` to match

        Returns
        -------
        collections : `list`
            if `ignore=None`, simply returns all collections, otherwise
            returns those collections matching the `ignore` parameter
        """
        if ignore is None:
            return self.collections
        return [c for c in self.collections if
                getattr(c, '_ignore', None) == ignore]

    def set_insetlabels(self, inset=None):
        """Set the labels to be inset or not

        Parameters
        ----------
        inset : `bool`, `None`
            if `None`, toggle the inset state, otherwise set the labels to
            be inset (`True) or not (`False`)
        """
        # pylint: disable=attribute-defined-outside-init
        self._insetlabels = not self._insetlabels if inset is None else inset

    def get_insetlabels(self):
        """Returns the inset labels state
        """
        return self._insetlabels

    insetlabels = property(fget=get_insetlabels, fset=set_insetlabels,
                           doc=get_insetlabels.__doc__)

    @allow_rasterization
    def draw(self, *args, **kwargs):  # pylint: disable=missing-docstring
        # inset the labels if requested
        for tick in self.get_yaxis().get_ticklabels():
            if self.get_insetlabels():
                # record parameters we are changing
                # pylint: disable=protected-access
                tick._orig_bbox = tick.get_bbox_patch()
                tick._orig_ha = tick.get_ha()
                tick._orig_pos = tick.get_position()
                # modify tick
                tick.set_horizontalalignment('left')
                tick.set_position((0.01, tick.get_position()[1]))
                tick.set_bbox({'alpha': 0.5, 'facecolor': 'white',
                               'edgecolor': 'none'})
            elif self.get_insetlabels() is False:
                # if label has been moved, reset things
                # pylint: disable=protected-access
                try:
                    tick.set_bbox(tick._orig_bbox)
                except AttributeError:
                    pass
                else:
                    tick.set_horizontalalignment(tick._orig_ha)
                    tick.set_position(tick._orig_pos)
                    del tick._orig_bbox
                    del tick._orig_ha
                    del tick._orig_pos
        return super(SegmentAxes, self).draw(*args, **kwargs)
    draw.__doc__ = TimeSeriesAxes.draw.__doc__


register_projection(SegmentAxes)


class SegmentPlot(TimeSeriesPlot):
    """`Plot` for displaying a `~gwpy.segments.DataQualityFlag`

    Parameters
    ----------
    *flags : `DataQualityFlag`
        any number of `~gwpy.segments.DataQualityFlag` to
        display on the plot

    insetlabels : `bool`, default: `False`
        display segment labels inside the axes. Prevents very long segment
        names from getting squeezed off the end of a standard figure

    **kwargs
        other keyword arguments as applicable for the `~gwpy.plotter.Plot`
    """
    _DefaultAxesClass = SegmentAxes

    def __init__(self, *flags, **kwargs):
        """Initialise a new SegmentPlot
        """
        # separate kwargs into figure args and plotting args
        figargs = {}
        for key in ['figsize', 'auto_refresh']:
            if key in kwargs:
                figargs[key] = kwargs.pop(key)
        sep = kwargs.pop('sep', False)
        epoch = kwargs.pop('epoch', None)
        inset = kwargs.pop('insetlabels', False)

        # generate figure
        super(SegmentPlot, self).__init__(**figargs)

        # plot data
        if len(flags) == 1 and isinstance(flags[0], DataQualityDict):
            flags = flags[0].keys()
        for flag in flags:
            self.add_dataqualityflag(flag,
                                     projection=self._DefaultAxesClass.name,
                                     newax=sep, **kwargs)

        # set epoch
        if flags:
            span = reduce(operator.or_, [f.known for f in flags]).extent()
            if not epoch:
                epoch = span[0]
            for ax in self.axes:
                ax.set_epoch(epoch)
                ax.set_xlim(*map(float, span))
                ax.set_insetlabels(inset)
            for ax in self.axes[:-1]:
                ax.set_xlabel("")

        if sep:
            for ax in self.axes:
                ax.set_ylim(-0.5, 0.5)
                ax.grid(b=False, which='both', axis='y')
        elif flags:
            ax.set_ylim(-0.5, len(flags)-0.5)
            ax.grid(b=False, which='both', axis='y')

    def add_dataqualityflag(self, flag, **kwargs):
        super(SegmentPlot, self).add_dataqualityflag(flag, **kwargs)
        if self.epoch is None:
            try:
                self.set_epoch(flag.known[0][0])
            except IndexError:
                pass
    add_dataqualityflag.__doc__ = TimeSeriesPlot.add_dataqualityflag.__doc__

    def add_bitmask(self, mask, ax=None, width=0.2, pad=0.1,
                    visible=True, axes_class=SegmentAxes, topdown=False,
                    **plotargs):
        """Display a state-word bitmask on a new set of Axes.
        """
        # find default axes
        if ax is None:
            ax = self.axes[-1]

        # get new segment axes
        divider = make_axes_locatable(ax)
        maskax = divider.new_horizontal(size=width, pad=pad,
                                        axes_class=axes_class)
        maskax.set_xscale('gps')
        maskax.xaxis.set_major_locator(NullLocator())
        maskax.xaxis.set_minor_locator(NullLocator())
        maskax.yaxis.set_minor_locator(NullLocator())
        if visible:
            self.add_axes(maskax)
        else:
            return

        # format mask as a binary string and work out how many bits to set
        if isinstance(mask, int):
            mask = bin(mask)
        elif isinstance(mask, string_types) and 'x' in mask:
            mask = bin(int(mask, 16))
        maskint = int(mask, 2)
        if topdown:
            bits = list(range(len(mask.split('b', 1)[1])))[::-1]
        else:
            bits = list(range(len(mask.split('b', 1)[1])))

        # loop over bits
        plotargs.setdefault('facecolor', 'green')
        plotargs.setdefault('edgecolor', 'black')
        seg = Segment(0, 1)
        for bit in bits:
            if maskint >> bit & 1:
                seglist = SegmentList([seg])
            else:
                seglist = SegmentList()
            maskax.plot(seglist, **plotargs)
        maskax.set_title('Bitmask')
        maskax.set_xlim(0, 1)
        maskax.set_xticks([])
        maskax.yaxis.set_ticklabels([])
        maskax.set_xlabel('')
        maskax.set_ylim(*ax.get_ylim())

        return maskax


class SegmentFormatter(Formatter):
    """Custom tick formatter for y-axis flag names
    """
    def __call__(self, t, pos=None):
        # if segments have been plotted at this y-axis value, continue
        for coll in self.axis.axes.get_collections(ignore=False):
            if t == coll._ypos:  # pylint: disable=protected-access
                return coll.get_label()
        for patch in self.axis.axes.patches:
            if not patch.get_label():
                continue
            if t in Segment(*patch.get_bbox().intervaly):
                return patch.get_label()
        return ''
