#!/usr/bin/env python
# -*- coding: utf-8 -*-

# Copyright (C) Joseph Areeda (2015)
#
# This file is part of GWpy.
#
# GWpy is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# GWpy is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with GWpy.  If not, see <http://www.gnu.org/licenses/>.
#

""" Spectrogram plots
"""

from .cliproduct import (FFTMixin, TimeDomainProduct, ImageProduct, unique)

__author__ = 'Joseph Areeda <joseph.areeda@ligo.org>'


class Spectrogram(FFTMixin, TimeDomainProduct, ImageProduct):
    """Plot the spectrogram of a time series
    """
    action = 'spectrogram'

    def __init__(self, *args, **kwargs):
        super(Spectrogram, self).__init__(*args, **kwargs)

        #: attribute to hold calculated Spectrogram data array
        self.result = None

    @classmethod
    def arg_yaxis(cls, parser):
        return cls._arg_faxis('y', parser)

    def _finalize_arguments(self, args):
        if args.color_scale is None:
            args.color_scale = 'log'
        super(Spectrogram, self)._finalize_arguments(args)

    @property
    def units(self):
        return unique([self.result.unit])

    def get_ylabel(self):
        """Default text for y-axis label
        """
        return 'Frequency (Hz)'

    def get_title(self):
        return 'fftlength={0}, overlap={1}'.format(self.args.secpfft,
                                                   self.args.overlap)

    def get_suptitle(self):
        return 'Spectrogram: {0}'.format(self.chan_list[0])

    def get_color_label(self):
        """Text for colorbar label
        """
        if self.args.norm:
            return 'Normalized to {}'.format(self.args.norm)
        if len(self.units) == 1 and self.usetex:
            return r'ASD $\left({0}\right)$'.format(
                self.units[0].to_string('latex').strip('$'))
        elif len(self.units) == 1:
            return 'ASD ({0})'.format(self.units[0].to_string('generic'))
        return super(Spectrogram, self).get_color_label()

    def get_stride(self):
        """Calculate the stride for the spectrogram

        This method returns the stride as a `float`, or `None` to indicate
        selected usage of `TimeSeries.spectrogram2`.
        """
        fftlength = float(self.args.secpfft)
        overlap = fftlength * self.args.overlap
        stride = fftlength - overlap
        nfft = self.duration / stride  # number of FFTs
        ffps = int(nfft / (self.width * 0.8))  # FFTs per second
        if ffps > 3:
            return max(2 * fftlength, ffps * stride + fftlength - 1)
        return None  # do not use strided spectrogram

    def get_spectrogram(self):
        """Calculate the spectrogram to be plotted

        This exists as a separate method to allow subclasses to override
        this and not the entire `get_plot` method, e.g. `Coherencegram`.

        This method should not apply the normalisation from `args.norm`.
        """
        args = self.args

        fftlength = float(args.secpfft)
        overlap = fftlength * args.overlap
        self.log(2, "Calculating spectrogram secpfft: %s, overlap: %s" %
                 (fftlength, overlap))

        stride = self.get_stride()

        if stride:
            specgram = self.timeseries[0].spectrogram(
                stride, fftlength=fftlength, overlap=overlap,
                window=args.window)
            nfft = stride * (stride // (fftlength - overlap))
            self.log(3, 'Spectrogram calc, stride: %s, fftlength: %s, '
                        'overlap: %sf, #fft: %d' % (stride, fftlength,
                                                    overlap, nfft))
        else:
            specgram = self.timeseries[0].spectrogram2(
                fftlength=fftlength, overlap=overlap, window=args.window)
            nfft = specgram.shape[0]
            self.log(3, 'HR-Spectrogram calc, fftlength: %s, overlap: %s, '
                        '#fft: %d' % (fftlength, overlap, nfft))

        return specgram ** (1/2.)   # ASD

    def make_plot(self):
        """Generate the plot from time series and arguments
        """
        args = self.args

        # create 'raw' spectrogram
        specgram = self.get_spectrogram()

        # apply normalisation
        if args.norm:
            specgram = specgram.ratio(args.norm)

        self.result = specgram

        # -- update plot defaults

        if not args.ymin:
            args.ymin = 1/args.secpfft if args.yscale == 'log' else 0

        norm = 'log' if args.color_scale == 'log' else None
        # vmin/vmax set in scale_axes_from_data()
        return specgram.plot(figsize=self.figsize, dpi=self.dpi,
                             norm=norm, cmap=args.cmap)

    def scale_axes_from_data(self):
        args = self.args

        # get tight axes limits from time and frequency Axes
        if args.xmin is None:
            args.xmin = self.result.xspan[0]
        if args.xmax is None:
            args.xmax = self.result.xspan[1]
        if args.ymin is None:
            args.ymin = self.result.yspan[0]
        if args.ymax is None:
            args.ymax = self.result.yspan[1]

        specgram = self.result.crop(
            args.xmin, args.xmax).crop_frequencies(
                args.ymin, args.ymax)

        # auto scale colours
        from numpy import percentile
        if args.norm:
            imin = specgram.value.min()
            imax = specgram.value.max()
        else:
            imin = percentile(specgram, .01)
            imax = percentile(specgram, 100.)
        imin = args.imin if args.imin is not None else imin
        imax = args.imax if args.imax is not None else imax
        try:
            image = self.ax.images[0]
        except IndexError:
            image = self.ax.collections[0]
        image.set_clim(imin, imax)
