"""Plottng functions for visualizing distributions."""
from __future__ import division
import inspect
import colorsys
import numpy as np
from scipy import stats
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
import warnings

try:
    import statsmodels.api as sm
    _has_statsmodels = True
except ImportError:
    _has_statsmodels = False

from .external.six.moves import range

from .utils import set_hls_values, desaturate, percentiles, iqr, _kde_support
from .palettes import color_palette, husl_palette, blend_palette
from .axisgrid import JointGrid


def _box_reshape(vals, groupby, names, order):
    """Reshape the box/violinplot input options and find plot labels."""

    # Set up default label outputs
    xlabel, ylabel = None, None

    # If order is provided, make sure it was used correctly
    if order is not None:
        # Assure that order is the same length as names, if provided
        if names is not None:
            if len(order) != len(names):
                raise ValueError("`order` must have same length as `names`")
        # Assure that order is only used with the right inputs
        is_pd = isinstance(vals, pd.Series) or isinstance(vals, pd.DataFrame)
        if not is_pd:
            raise ValueError("`vals` must be a Pandas object to use `order`.")

    # Handle case where data is a wide DataFrame
    if isinstance(vals, pd.DataFrame):
        if order is not None:
            vals = vals[order]
        if names is None:
            names = vals.columns.tolist()
        if vals.columns.name is not None:
            xlabel = vals.columns.name
        vals = vals.values.T

    # Handle case where data is a long Series and there is a grouping object
    elif isinstance(vals, pd.Series) and groupby is not None:
        groups = pd.groupby(vals, groupby).groups
        order = sorted(groups) if order is None else order
        if hasattr(groupby, "name"):
            if groupby.name is not None:
                xlabel = groupby.name
        if vals.name is not None:
            ylabel = vals.name
        vals = [vals.reindex(groups[name]) for name in order]
        if names is None:
            names = order

    else:

        # Handle case where the input data is an array or there was no groupby
        if hasattr(vals, 'shape'):
            if len(vals.shape) == 1:
                if np.isscalar(vals[0]):
                    vals = [vals]
                else:
                    vals = list(vals)
            elif len(vals.shape) == 2:
                nr, nc = vals.shape
                if nr == 1:
                    vals = [vals]
                elif nc == 1:
                    vals = [vals.ravel()]
                else:
                    vals = [vals[:, i] for i in range(nc)]
            else:
                error = "Input `vals` can have no more than 2 dimensions"
                raise ValueError(error)

        # This should catch things like flat lists
        elif np.isscalar(vals[0]):
            vals = [vals]

        # By default, just use the plot positions as names
        if names is None:
            names = list(range(1, len(vals) + 1))
        elif hasattr(names, "name"):
            if names.name is not None:
                xlabel = names.name

    # Now convert vals to a common representation
    # The plotting functions will work with a list of arrays
    # The list allows each array to possibly be of a different length
    vals = [np.asarray(a, np.float) for a in vals]

    return vals, xlabel, ylabel, names


def _box_colors(vals, color, sat):
    """Find colors to use for boxplots or violinplots."""
    if color is None:
        # Default uses either the current palette or husl
        current_palette = mpl.rcParams["axes.color_cycle"]
        if len(vals) <= len(current_palette):
            colors = color_palette(n_colors=len(vals))
        else:
            colors = husl_palette(len(vals), l=.7)
    else:
        try:
            color = mpl.colors.colorConverter.to_rgb(color)
            colors = [color for _ in vals]
        except ValueError:
                colors = color_palette(color, len(vals))

    # Desaturate a bit because these are patches
    colors = [mpl.colors.colorConverter.to_rgb(c) for c in colors]
    colors = [desaturate(c, sat) for c in colors]

    # Determine the gray color for the lines
    light_vals = [colorsys.rgb_to_hls(*c)[1] for c in colors]
    l = min(light_vals) * .6
    gray = (l, l, l)

    return colors, gray


def boxplot(vals, groupby=None, names=None, join_rm=False, order=None,
            color=None, alpha=None, fliersize=3, linewidth=1.5, widths=.8,
            saturation=.7, label=None, ax=None, **kwargs):
    """Wrapper for matplotlib boxplot with better aesthetics and functionality.

    Parameters
    ----------
    vals : DataFrame, Series, 2D array, list of vectors, or vector.
        Data for plot. DataFrames and 2D arrays are assumed to be "wide" with
        each column mapping to a box. Lists of data are assumed to have one
        element per box.  Can also provide one long Series in conjunction with
        a grouping element as the `groupy` parameter to reshape the data into
        several boxes. Otherwise 1D data will produce a single box.
    groupby : grouping object
        If `vals` is a Series, this is used to group into boxes by calling
        pd.groupby(vals, groupby).
    names : list of strings, optional
        Names to plot on x axis; otherwise plots numbers. This will override
        names inferred from Pandas inputs.
    order : list of strings, optional
        If vals is a Pandas object with name information, you can control the
        order of the boxes by providing the box names in your preferred order.
    join_rm : boolean, optional
        If True, positions in the input arrays are treated as repeated
        measures and are joined with a line plot.
    color : mpl color, sequence of colors, or seaborn palette name
        Inner box color.
    alpha : float
        Transparancy of the inner box color.
    fliersize : float, optional
        Markersize for the fliers.
    linewidth : float, optional
        Width for the box outlines and whiskers.
    saturation : float, 0-1
        Saturation relative to the fully-saturated color. Large patches tend
        to look better at lower saturations, so this dims the palette colors
        a bit by default.
    ax : matplotlib axis, optional
        Existing axis to plot into, otherwise grab current axis.
    kwargs : additional keyword arguments to boxplot

    Returns
    -------
    ax : matplotlib axis
        Axis where boxplot is plotted.

    """
    if ax is None:
        ax = plt.gca()

    # Reshape and find labels for the plot
    vals, xlabel, ylabel, names = _box_reshape(vals, groupby, names, order)

    # Find plot colors
    colors, gray = _box_colors(vals, color, saturation)

    # Make a flierprops dict and set symbol to override buggy default behavior
    # on matplotlib 1.4.0
    kwargs["sym"] = "d"

    # Later versions of matplotlib only (but those are the one with the bug)
    if "flierprops" in inspect.getargspec(ax.boxplot).args:
        kwargs["flierprops"] = {"markerfacecolor": gray,
                                "markeredgecolor": gray,
                                "markersize": fliersize}

    # Draw the boxplot using matplotlib
    boxes = ax.boxplot(vals, patch_artist=True, widths=widths, **kwargs)

    # Set the new aesthetics
    for i, box in enumerate(boxes["boxes"]):
        box.set_color(colors[i])
        if alpha is not None:
            box.set_alpha(alpha)
        box.set_edgecolor(gray)
        box.set_linewidth(linewidth)
    for i, whisk in enumerate(boxes["whiskers"]):
        whisk.set_color(gray)
        whisk.set_linewidth(linewidth)
        whisk.set_linestyle("-")
    for i, cap in enumerate(boxes["caps"]):
        cap.set_color(gray)
        cap.set_linewidth(linewidth)
    for i, med in enumerate(boxes["medians"]):
        med.set_color(gray)
        med.set_linewidth(linewidth)

    # As of matplotlib 1.4.0 there is a bug where these values are being
    # ignored, so this is redundant with what's above but I am keeping it
    for i, fly in enumerate(boxes["fliers"]):
        fly.set_color(gray)
        fly.set_marker("d")
        fly.set_markeredgecolor(gray)
        fly.set_markersize(fliersize)

    # This is a hack to get labels to work
    # It's unclear whether this is actually broken in matplotlib or just not
    # implemented, either way it's annoying.
    if label is not None:
        pos = kwargs.get("positions", [1])[0]
        med = np.median(vals[0])
        color = colors[0]
        ax.add_patch(plt.Rectangle([pos, med], 0, 0, color=color, label=label))

    # Is this a vertical plot?
    vertical = kwargs.get("vert", True)

    # Draw the joined repeated measures
    if join_rm:
        x, y = np.arange(1, len(vals) + 1), vals
        if not vertical:
            x, y = y, x
        ax.plot(x, y, color=gray, alpha=2. / 3)

    # Label the axes and ticks
    if vertical:
        ax.set_xticklabels(list(names))
    else:
        ax.set_yticklabels(list(names))
        xlabel, ylabel = ylabel, xlabel
    if xlabel is not None:
        ax.set_xlabel(xlabel)
    if ylabel is not None:
        ax.set_ylabel(ylabel)

    # Turn off the grid parallel to the boxes
    if vertical:
        ax.xaxis.grid(False)
    else:
        ax.yaxis.grid(False)

    return ax


def violinplot(vals, groupby=None, inner="box", color=None, positions=None,
               names=None, order=None, bw="scott", widths=.8, alpha=None,
               saturation=.7, join_rm=False, gridsize=100, cut=3,
               inner_kws=None, ax=None, vert=True, **kwargs):

    """Create a violin plot (a combination of boxplot and kernel density plot).

    Parameters
    ----------
    vals : DataFrame, Series, 2D array, or list of vectors.
        Data for plot. DataFrames and 2D arrays are assumed to be "wide" with
        each column mapping to a box. Lists of data are assumed to have one
        element per box.  Can also provide one long Series in conjunction with
        a grouping element as the `groupy` parameter to reshape the data into
        several violins. Otherwise 1D data will produce a single violins.
    groupby : grouping object
        If `vals` is a Series, this is used to group into boxes by calling
        pd.groupby(vals, groupby).
    inner : {'box' | 'stick' | 'points'}
        Plot quartiles or individual sample values inside violin.
    color : mpl color, sequence of colors, or seaborn palette name
        Inner violin colors
    positions : number or sequence of numbers
        Position of first violin or positions of each violin.
    names : list of strings, optional
        Names to plot on x axis; otherwise plots numbers. This will override
        names inferred from Pandas inputs.
    order : list of strings, optional
        If vals is a Pandas object with name information, you can control the
        order of the plot by providing the violin names in your preferred
        order.
    bw : {'scott' | 'silverman' | scalar}
        Name of reference method to determine kernel size, or size as a
        scalar.
    widths : float
        Width of each violin at maximum density.
    alpha : float, optional
        Transparancy of violin fill.
    saturation : float, 0-1
        Saturation relative to the fully-saturated color. Large patches tend
        to look better at lower saturations, so this dims the palette colors
        a bit by default.
    join_rm : boolean, optional
        If True, positions in the input arrays are treated as repeated
        measures and are joined with a line plot.
    gridsize : int
        Number of discrete gridpoints to evaluate the density on.
    cut : scalar
        Draw the estimate to cut * bw from the extreme data points.
    inner_kws : dict, optional
        Keyword arugments for inner plot.
    ax : matplotlib axis, optional
        Axis to plot on, otherwise grab current axis.
    vert : boolean, optional
        If true (default), draw vertical plots; otherwise, draw horizontal
        ones.
    kwargs : additional parameters to fill_betweenx

    Returns
    -------
    ax : matplotlib axis
        Axis with violin plot.

    """

    if ax is None:
        ax = plt.gca()

    # Reshape and find labels for the plot
    vals, xlabel, ylabel, names = _box_reshape(vals, groupby, names, order)

    # Sort out the plot colors
    colors, gray = _box_colors(vals, color, saturation)

    # Initialize the kwarg dict for the inner plot
    if inner_kws is None:
        inner_kws = {}
    inner_kws.setdefault("alpha", .6 if inner == "points" else 1)
    inner_kws["alpha"] *= 1 if alpha is None else alpha
    inner_kws.setdefault("color", gray)
    inner_kws.setdefault("marker", "." if inner == "points" else "")
    lw = inner_kws.pop("lw", 1.5 if inner == "box" else .8)
    inner_kws.setdefault("linewidth", lw)

    # Find where the violins are going
    if positions is None:
        positions = np.arange(1, len(vals) + 1)
    elif not hasattr(positions, "__iter__"):
        positions = np.arange(positions, len(vals) + positions)

    # Set the default linewidth if not provided in kwargs
    try:
        lw = kwargs[({"lw", "linewidth"} & set(kwargs)).pop()]
    except KeyError:
        lw = 1.5

    # Iterate over the variables
    for i, a in enumerate(vals):

        x = positions[i]

        # If we only have a single value, plot a horizontal line
        if len(a) == 1:
            y = a[0]
            if vert:
                ax.plot([x - widths / 2, x + widths / 2], [y, y], **inner_kws)
            else:
                ax.plot([y, y], [x - widths / 2, x + widths / 2], **inner_kws)
            continue

        # Fit the KDE
        try:
            kde = stats.gaussian_kde(a, bw)
        except TypeError:
            kde = stats.gaussian_kde(a)
            if bw != "scott":  # scipy default
                msg = ("Ignoring bandwidth choice, "
                       "please upgrade scipy to use a different bandwidth.")
                warnings.warn(msg, UserWarning)

        # Determine the support region
        if isinstance(bw, str):
            bw_name = "scotts" if bw == "scott" else bw
            _bw = getattr(kde, "%s_factor" % bw_name)() * a.std(ddof=1)
        else:
            _bw = bw
        y = _kde_support(a, _bw, gridsize, cut, (-np.inf, np.inf))
        dens = kde.evaluate(y)
        scl = 1 / (dens.max() / (widths / 2))
        dens *= scl

        # Draw the violin. If vert (default), we will use ``ax.plot`` in the
        # standard way; otherwise, we invert x,y.
        # For this, define a simple wrapper ``ax_plot``
        color = colors[i]
        if vert:
            ax.fill_betweenx(y, x - dens, x + dens, alpha=alpha, color=color)

            def ax_plot(x, y, *args, **kwargs):
                ax.plot(x, y, *args, **kwargs)

        else:
            ax.fill_between(y, x - dens, x + dens, alpha=alpha, color=color)

            def ax_plot(x, y, *args, **kwargs):
                ax.plot(y, x, *args, **kwargs)

        if inner == "box":
            for quant in percentiles(a, [25, 75]):
                q_x = kde.evaluate(quant) * scl
                q_x = [x - q_x, x + q_x]
                ax_plot(q_x, [quant, quant], linestyle=":",  **inner_kws)
            med = np.median(a)
            m_x = kde.evaluate(med) * scl
            m_x = [x - m_x, x + m_x]
            ax_plot(m_x, [med, med], linestyle="--", **inner_kws)
        elif inner == "stick":
            x_vals = kde.evaluate(a) * scl
            x_vals = [x - x_vals, x + x_vals]
            ax_plot(x_vals, [a, a], linestyle="-", **inner_kws)
        elif inner == "points":
            x_vals = [x for _ in a]
            ax_plot(x_vals, a, mew=0, linestyle="", **inner_kws)
        for side in [-1, 1]:
            ax_plot((side * dens) + x, y, c=gray, lw=lw)

    # Draw the repeated measure bridges
    if join_rm:
        ax.plot(range(1, len(vals) + 1), vals,
                color=inner_kws["color"], alpha=2. / 3)

    # Add in semantic labels
    if names is not None:
        if len(vals) != len(names):
            raise ValueError("Length of names list must match nuber of bins")
        names = list(names)

    if vert:
        # Add in semantic labels
        ax.set_xticks(positions)
        ax.set_xlim(positions[0] - .5, positions[-1] + .5)
        ax.set_xticklabels(names)

        if xlabel is not None:
            ax.set_xlabel(xlabel)
        if ylabel is not None:
            ax.set_ylabel(ylabel)
    else:
        # Add in semantic labels
        ax.set_yticks(positions)
        ax.set_yticklabels(names)
        ax.set_ylim(positions[0] - .5, positions[-1] + .5)

        if ylabel is not None:
            ax.set_ylabel(xlabel)
        if xlabel is not None:
            ax.set_xlabel(ylabel)

    ax.xaxis.grid(False)
    return ax


def _freedman_diaconis_bins(a):
    """Calculate number of hist bins using Freedman-Diaconis rule."""
    # From http://stats.stackexchange.com/questions/798/
    a = np.asarray(a)
    h = 2 * iqr(a) / (len(a) ** (1 / 3))
    return np.ceil((a.max() - a.min()) / h)


def distplot(a, bins=None, hist=True, kde=True, rug=False, fit=None,
             hist_kws=None, kde_kws=None, rug_kws=None, fit_kws=None,
             color=None, vertical=False, norm_hist=False, axlabel=None,
             label=None, ax=None):
    """Flexibly plot a distribution of observations.

    Parameters
    ----------

    a : (squeezable to) 1d array
        Observed data.
    bins : argument for matplotlib hist(), or None, optional
        Specification of hist bins, or None to use Freedman-Diaconis rule.
    hist : bool, optional
        Whether to plot a (normed) histogram.
    kde : bool, optional
        Whether to plot a gaussian kernel density estimate.
    rug : bool, optional
        Whether to draw a rugplot on the support axis.
    fit : random variable object, optional
        An object with `fit` method, returning a tuple that can be passed to a
        `pdf` method a positional arguments following an grid of values to
        evaluate the pdf on.
    {hist, kde, rug, fit}_kws : dictionaries, optional
        Keyword arguments for underlying plotting functions.
    color : matplotlib color, optional
        Color to plot everything but the fitted curve in.
    vertical : bool, optional
        If True, oberved values are on y-axis.
    norm_hist : bool, otional
        If True, the histogram height shows a density rather than a count.
        This is implied if a KDE or fitted density is plotted.
    axlabel : string, False, or None, optional
        Name for the support axis label. If None, will try to get it
        from a.namel if False, do not set a label.
    label : string, optional
        Legend label for the relevent component of the plot
    ax : matplotlib axis, optional
        if provided, plot on this axis

    Returns
    -------
    ax : matplotlib axis

    """
    if ax is None:
        ax = plt.gca()

    # Intelligently label the support axis
    label_ax = bool(axlabel)
    if axlabel is None and hasattr(a, "name"):
        axlabel = a.name
        if axlabel is not None:
            label_ax = True

    # Make a a 1-d array
    a = np.asarray(a).squeeze()

    # Decide if the hist is normed
    norm_hist = norm_hist or kde or (fit is not None)

    # Handle dictionary defaults
    if hist_kws is None:
        hist_kws = dict()
    if kde_kws is None:
        kde_kws = dict()
    if rug_kws is None:
        rug_kws = dict()
    if fit_kws is None:
        fit_kws = dict()

    # Get the color from the current color cycle
    if color is None:
        if vertical:
            line, = ax.plot(0, a.mean())
        else:
            line, = ax.plot(a.mean(), 0)
        color = line.get_color()
        line.remove()

    # Plug the label into the right kwarg dictionary
    if label is not None:
        if hist:
            hist_kws["label"] = label
        elif kde:
            kde_kws["label"] = label
        elif rug:
            rug_kws["label"] = label
        elif fit:
            fit_kws["label"] = label

    if hist:
        if bins is None:
            bins = _freedman_diaconis_bins(a)
        hist_kws.setdefault("alpha", 0.4)
        hist_kws.setdefault("normed", norm_hist)
        orientation = "horizontal" if vertical else "vertical"
        hist_color = hist_kws.pop("color", color)
        ax.hist(a, bins, orientation=orientation,
                color=hist_color, **hist_kws)
        if hist_color != color:
            hist_kws["color"] = hist_color

    if kde:
        kde_color = kde_kws.pop("color", color)
        kdeplot(a, vertical=vertical, ax=ax, color=kde_color, **kde_kws)
        if kde_color != color:
            kde_kws["color"] = kde_color

    if rug:
        rug_color = rug_kws.pop("color", color)
        axis = "y" if vertical else "x"
        rugplot(a, axis=axis, ax=ax, color=rug_color, **rug_kws)
        if rug_color != color:
            rug_kws["color"] = rug_color

    if fit is not None:
        fit_color = fit_kws.pop("color", "#282828")
        gridsize = fit_kws.pop("gridsize", 200)
        cut = fit_kws.pop("cut", 3)
        clip = fit_kws.pop("clip", (-np.inf, np.inf))
        bw = stats.gaussian_kde(a).scotts_factor() * a.std(ddof=1)
        x = _kde_support(a, bw, gridsize, cut, clip)
        params = fit.fit(a)
        pdf = lambda x: fit.pdf(x, *params)
        y = pdf(x)
        if vertical:
            x, y = y, x
        ax.plot(x, y, color=fit_color, **fit_kws)
        if fit_color != "#282828":
            fit_kws["color"] = fit_color

    if label_ax:
        if vertical:
            ax.set_ylabel(axlabel)
        else:
            ax.set_xlabel(axlabel)

    return ax


def _univariate_kdeplot(data, shade, vertical, kernel, bw, gridsize, cut,
                        clip, legend, ax, cumulative=False, **kwargs):
    """Plot a univariate kernel density estimate on one of the axes."""

    # Sort out the clipping
    if clip is None:
        clip = (-np.inf, np.inf)

    # Calculate the KDE
    if _has_statsmodels:
        # Prefer using statsmodels for kernel flexibility
        x, y = _statsmodels_univariate_kde(data, kernel, bw,
                                           gridsize, cut, clip,
                                           cumulative=cumulative)
    else:
        # Fall back to scipy if missing statsmodels
        if kernel != "gau":
            kernel = "gau"
            msg = "Kernel other than `gau` requires statsmodels."
            warnings.warn(msg, UserWarning)
        if cumulative:
            raise ImportError("Cumulative distributions are currently"
                              "only implemented in statsmodels."
                              "Please install statsmodels.")
        x, y = _scipy_univariate_kde(data, bw, gridsize, cut, clip)

    # Make sure the density is nonnegative
    y = np.amax(np.c_[np.zeros_like(y), y], axis=1)

    # Flip the data if the plot should be on the y axis
    if vertical:
        x, y = y, x

    # Check if a label was specified in the call
    label = kwargs.pop("label", None)

    # Otherwise check if the data object has a name
    if label is None and hasattr(data, "name"):
        label = data.name

    # Decide if we're going to add a legend
    legend = label is not None and legend
    label = "_nolegend_" if label is None else label

    # Use the active color cycle to find the plot color
    line, = ax.plot(x, y, **kwargs)
    color = line.get_color()
    line.remove()
    kwargs.pop("color", None)

    # Draw the KDE plot and, optionally, shade
    ax.plot(x, y, color=color, label=label, **kwargs)
    alpha = kwargs.get("alpha", 0.25)
    if shade:
        if vertical:
            ax.fill_betweenx(y, 1e-12, x, color=color, alpha=alpha)
        else:
            ax.fill_between(x, 1e-12, y, color=color, alpha=alpha)

    # Draw the legend here
    if legend:
        ax.legend(loc="best")

    return ax


def _statsmodels_univariate_kde(data, kernel, bw, gridsize, cut, clip,
                                cumulative=False):
    """Compute a univariate kernel density estimate using statsmodels."""
    fft = kernel == "gau"
    kde = sm.nonparametric.KDEUnivariate(data)
    kde.fit(kernel, bw, fft, gridsize=gridsize, cut=cut, clip=clip)
    if cumulative:
        grid, y = kde.support, kde.cdf
    else:
        grid, y = kde.support, kde.density
    return grid, y


def _scipy_univariate_kde(data, bw, gridsize, cut, clip):
    """Compute a univariate kernel density estimate using scipy."""
    try:
        kde = stats.gaussian_kde(data, bw_method=bw)
    except TypeError:
        kde = stats.gaussian_kde(data)
        if bw != "scott":  # scipy default
            msg = ("Ignoring bandwidth choice, "
                   "please upgrade scipy to use a different bandwidth.")
            warnings.warn(msg, UserWarning)
    if isinstance(bw, str):
        bw = "scotts" if bw == "scott" else bw
        bw = getattr(kde, "%s_factor" % bw)()
    grid = _kde_support(data, bw, gridsize, cut, clip)
    y = kde(grid)
    return grid, y


def _bivariate_kdeplot(x, y, filled, kernel, bw, gridsize, cut, clip, axlabel,
                       ax, **kwargs):
    """Plot a joint KDE estimate as a bivariate contour plot."""

    # Determine the clipping
    if clip is None:
        clip = [(-np.inf, np.inf), (-np.inf, np.inf)]
    elif np.ndim(clip) == 1:
        clip = [clip, clip]

    # Calculate the KDE
    if _has_statsmodels:
        xx, yy, z = _statsmodels_bivariate_kde(x, y, bw, gridsize, cut, clip)
    else:
        xx, yy, z = _scipy_bivariate_kde(x, y, bw, gridsize, cut, clip)

    # Plot the contours
    n_levels = kwargs.pop("n_levels", 10)
    cmap = kwargs.get("cmap", "BuGn" if filled else "BuGn_d")
    if isinstance(cmap, str):
        if cmap.endswith("_d"):
            pal = ["#333333"]
            pal.extend(color_palette(cmap.replace("_d", "_r"), 2))
            cmap = blend_palette(pal, as_cmap=True)
    kwargs["cmap"] = cmap
    contour_func = ax.contourf if filled else ax.contour
    contour_func(xx, yy, z, n_levels, **kwargs)
    kwargs["n_levels"] = n_levels

    # Label the axes
    if hasattr(x, "name") and axlabel:
        ax.set_xlabel(x.name)
    if hasattr(y, "name") and axlabel:
        ax.set_ylabel(y.name)

    return ax


def _statsmodels_bivariate_kde(x, y, bw, gridsize, cut, clip):
    """Compute a bivariate kde using statsmodels."""
    if isinstance(bw, str):
        bw_func = getattr(sm.nonparametric.bandwidths, "bw_" + bw)
        x_bw = bw_func(x)
        y_bw = bw_func(y)
        bw = [x_bw, y_bw]
    elif np.isscalar(bw):
        bw = [bw, bw]

    if isinstance(x, pd.Series):
        x = x.values
    if isinstance(y, pd.Series):
        y = y.values

    kde = sm.nonparametric.KDEMultivariate([x, y], "cc", bw)
    x_support = _kde_support(x, kde.bw[0], gridsize, cut, clip[0])
    y_support = _kde_support(y, kde.bw[1], gridsize, cut, clip[1])
    xx, yy = np.meshgrid(x_support, y_support)
    z = kde.pdf([xx.ravel(), yy.ravel()]).reshape(xx.shape)
    return xx, yy, z


def _scipy_bivariate_kde(x, y, bw, gridsize, cut, clip):
    """Compute a bivariate kde using scipy."""
    data = np.c_[x, y]
    kde = stats.gaussian_kde(data.T)
    data_std = data.std(axis=0, ddof=1)
    if isinstance(bw, str):
        bw = "scotts" if bw == "scott" else bw
        bw_x = getattr(kde, "%s_factor" % bw)() * data_std[0]
        bw_y = getattr(kde, "%s_factor" % bw)() * data_std[1]
    x_support = _kde_support(data[:, 0], bw_x, gridsize, cut, clip[0])
    y_support = _kde_support(data[:, 1], bw_y, gridsize, cut, clip[1])
    xx, yy = np.meshgrid(x_support, y_support)
    z = kde([xx.ravel(), yy.ravel()]).reshape(xx.shape)
    return xx, yy, z


def kdeplot(data, data2=None, shade=False, vertical=False, kernel="gau",
            bw="scott", gridsize=100, cut=3, clip=None, legend=True, ax=None,
            cumulative=False, **kwargs):
    """Fit and plot a univariate or bivarate kernel density estimate.

    Parameters
    ----------
    data : 1d or 2d array-like
        Input data. If two-dimensional, assumed to be shaped (n_unit x n_var),
        and a bivariate contour plot will be drawn.
    data2: 1d array-like
        Second input data. If provided `data` must be one-dimensional, and
        a bivariate plot is produced.
    shade : bool, optional
        If true, shade in the area under the KDE curve (or draw with filled
        contours when data is bivariate).
    vertical : bool
        If True, density is on x-axis.
    kernel : {'gau' | 'cos' | 'biw' | 'epa' | 'tri' | 'triw' }, optional
        Code for shape of kernel to fit with. Bivariate KDE can only use
        gaussian kernel.
    bw : {'scott' | 'silverman' | scalar | pair of scalars }, optional
        Name of reference method to determine kernel size, scalar factor,
        or scalar for each dimension of the bivariate plot.
    gridsize : int, optional
        Number of discrete points in the evaluation grid.
    cut : scalar, optional
        Draw the estimate to cut * bw from the extreme data points.
    clip : pair of scalars, or pair of pair of scalars, optional
        Lower and upper bounds for datapoints used to fit KDE. Can provide
        a pair of (low, high) bounds for bivariate plots.
    legend : bool, optoinal
        If True, add a legend or label the axes when possible.
    ax : matplotlib axis, optional
        Axis to plot on, otherwise uses current axis.
    cumulative : bool
        If draw, draw the cumulative distribution estimated by the kde.
    kwargs : other keyword arguments for plot()

    Returns
    -------
    ax : matplotlib axis
        Axis with plot.

    """
    if ax is None:
        ax = plt.gca()

    data = data.astype(np.float64)
    if data2 is not None:
        data2 = data2.astype(np.float64)

    bivariate = False
    if isinstance(data, np.ndarray) and np.ndim(data) > 1:
        bivariate = True
        x, y = data.T
    elif isinstance(data, pd.DataFrame) and np.ndim(data) > 1:
        bivariate = True
        x = data.iloc[:, 0].values
        y = data.iloc[:, 1].values
    elif data2 is not None:
        bivariate = True
        x = data
        y = data2

    if bivariate and cumulative:
        raise TypeError("Cumulative distribution plots are not"
                        "supported for bivariate distributions.")
    if bivariate:
        ax = _bivariate_kdeplot(x, y, shade, kernel, bw, gridsize,
                                cut, clip, legend, ax, **kwargs)
    else:
        ax = _univariate_kdeplot(data, shade, vertical, kernel, bw,
                                 gridsize, cut, clip, legend, ax,
                                 cumulative=cumulative, **kwargs)

    return ax


def rugplot(a, height=None, axis="x", ax=None, **kwargs):
    """Plot datapoints in an array as sticks on an axis.

    Parameters
    ----------
    a : vector
        1D array of datapoints.
    height : scalar, optional
        Height of ticks, if None draw at 5% of axis range.
    axis : {'x' | 'y'}, optional
        Axis to draw rugplot on.
    ax : matplotlib axis
        Axis to draw plot into; otherwise grabs current axis.
    kwargs : other keyword arguments for plt.plot()

    Returns
    -------
    ax : matplotlib axis
        Axis with rugplot.

    """
    if ax is None:
        ax = plt.gca()
    a = np.asarray(a)
    vertical = kwargs.pop("vertical", None)
    if vertical is not None:
        axis = "y" if vertical else "x"
    other_axis = dict(x="y", y="x")[axis]
    min, max = getattr(ax, "get_%slim" % other_axis)()
    if height is None:
        range = max - min
        height = range * .05
    if axis == "x":
        ax.plot([a, a], [min, min + height], **kwargs)
    else:
        ax.plot([min, min + height], [a, a], **kwargs)
    return ax


def jointplot(x, y, data=None, kind="scatter", stat_func=stats.pearsonr,
              color=None, size=6, ratio=5, space=.2,
              dropna=True, xlim=None, ylim=None,
              joint_kws=None, marginal_kws=None, annot_kws=None):
    """Draw a plot of two variables with bivariate and univariate graphs.

    Parameters
    ----------
    x, y : strings or vectors
        Data or names of variables in `data`.
    data : DataFrame, optional
        DataFrame when `x` and `y` are variable names.
    kind : { "scatter" | "reg" | "resid" | "kde" | "hex" }, optional
        Kind of plot to draw.
    stat_func : callable or None
        Function used to calculate a statistic about the relationship and
        annotate the plot. Should map `x` and `y` either to a single value
        or to a (value, p) tuple. Set to ``None`` if you don't want to
        annotate the plot.
    color : matplotlib color, optional
        Color used for the plot elements.
    size : numeric, optional
        Size of the figure (it will be square).
    ratio : numeric, optional
        Ratio of joint axes size to marginal axes height.
    space : numeric, optional
        Space between the joint and marginal axes
    dropna : bool, optional
        If True, remove observations that are missing from `x` and `y`.
    {x, y}lim : two-tuples, optional
        Axis limits to set before plotting.
    {joint, marginal, annot}_kws : dicts
        Additional keyword arguments for the plot components.

    Returns
    -------
    grid : JointGrid
        JointGrid object with the plot on it.

    See Also
    --------
    JointGrid : The Grid class used for drawing this plot. Use it directly if
                you need more flexibility.

    """
    # Set up empty default kwarg dicts
    if joint_kws is None:
        joint_kws = {}
    if marginal_kws is None:
        marginal_kws = {}
    if annot_kws is None:
        annot_kws = {}

    # Make a colormap based off the plot color
    if color is None:
        color = color_palette()[0]
    color_rgb = mpl.colors.colorConverter.to_rgb(color)
    colors = [set_hls_values(color_rgb, l=l) for l in np.linspace(1, 0, 12)]
    cmap = blend_palette(colors, as_cmap=True)

    # Initialize the JointGrid object
    grid = JointGrid(x, y, data, dropna=dropna,
                     size=size, ratio=ratio, space=space,
                     xlim=xlim, ylim=ylim)

    # Plot the data using the grid
    if kind == "scatter":

        joint_kws.setdefault("color", color)
        grid.plot_joint(plt.scatter, **joint_kws)

        marginal_kws.setdefault("kde", False)
        marginal_kws.setdefault("color", color)
        grid.plot_marginals(distplot, **marginal_kws)

    elif kind.startswith("hex"):

        x_bins = _freedman_diaconis_bins(grid.x)
        y_bins = _freedman_diaconis_bins(grid.y)
        gridsize = int(np.mean([x_bins, y_bins]))

        joint_kws.setdefault("gridsize", gridsize)
        joint_kws.setdefault("cmap", cmap)
        grid.plot_joint(plt.hexbin, **joint_kws)

        marginal_kws.setdefault("kde", False)
        marginal_kws.setdefault("color", color)
        grid.plot_marginals(distplot, **marginal_kws)

    elif kind.startswith("kde"):

        joint_kws.setdefault("shade", True)
        joint_kws.setdefault("cmap", cmap)
        grid.plot_joint(kdeplot, **joint_kws)

        marginal_kws.setdefault("shade", True)
        marginal_kws.setdefault("color", color)
        grid.plot_marginals(kdeplot, **marginal_kws)

    elif kind.startswith("reg"):

        from .linearmodels import regplot

        marginal_kws.setdefault("color", color)
        grid.plot_marginals(distplot, **marginal_kws)

        joint_kws.setdefault("color", color)
        grid.plot_joint(regplot, **joint_kws)

    elif kind.startswith("resid"):

        from .linearmodels import residplot

        joint_kws.setdefault("color", color)
        grid.plot_joint(residplot, **joint_kws)

        x, y = grid.ax_joint.collections[0].get_offsets().T
        marginal_kws.setdefault("color", color)
        marginal_kws.setdefault("kde", False)
        distplot(x, ax=grid.ax_marg_x, **marginal_kws)
        distplot(y, vertical=True, fit=stats.norm, ax=grid.ax_marg_y,
                 **marginal_kws)
        stat_func = None
    else:
        msg = "kind must be either 'scatter', 'reg', 'resid', 'kde', or 'hex'"
        raise ValueError(msg)

    if stat_func is not None:
        grid.annotate(stat_func, **annot_kws)

    return grid
