"""
The eqplot module contains functions for general plotting of
the results of equilibrium calculations.
"""
from pycalphad.core.utils import unpack_condition
from pycalphad.plot.utils import phase_legend
import pycalphad.variables as v
from matplotlib import collections as mc
import matplotlib.pyplot as plt
import numpy as np
from collections import OrderedDict
# TODO: support other state variables here or make isinstance elif == v.T or v.P
_plot_labels = {v.T: 'Temperature (K)', v.P: 'Pressure (Pa)'}
def _axis_label(ax_var):
    if isinstance(ax_var, v.MoleFraction):
        return 'X({})'.format(ax_var.species.name)
    elif isinstance(ax_var, v.StateVariable):
        return _plot_labels[ax_var]
    else:
        return ax_var
def _map_coord_to_variable(coord):
    """
    Map a coordinate to a StateVariable object.
    Parameters
    ----------
    coord : str
        Name of coordinate in equilibrium object.
    Returns
    -------
    pycalphad StateVariable
    """
    vals = {'T': v.T, 'P': v.P}
    if coord.startswith('X_'):
        return v.X(coord[2:])
    elif coord in vals:
        return vals[coord]
    else:
        return coord
[docs]
def eqplot(eq, ax=None, x=None, y=None, z=None, tielines=True, tieline_color=(0, 1, 0, 1), tie_triangle_color=(1, 0, 0, 1), legend_generator=phase_legend, **kwargs):
    """
    Plot the result of an equilibrium calculation.
    The type of plot is controlled by the degrees of freedom in the equilibrium calculation.
    Parameters
    ----------
    eq : xarray.Dataset
        Result of equilibrium calculation.
    ax : matplotlib.Axes
        Default axes used if not specified.
    x : StateVariable, optional
    y : StateVariable, optional
    z : StateVariable, optional
    tielines : bool
        If True, will plot tielines
    tieline_color: color
        A valid matplotlib color, such as a named color string, hex RGB
        string, or a tuple of RGBA components to set the color of the two
        phase region tielines. The default is an RGBA tuple for green:
        (0, 1, 0, 1).
    tie_triangle_color: color
        A valid matplotlib color, such as a named color string, hex RGB
        string, or a tuple of RGBA components to set the color of the two
        phase region tielines. The default is an RGBA tuple for red:
        (1, 0, 0, 1).
    legend_generator : Callable
        A function that will be called with the list of phases and will
        return legend labels and colors for each phase. By default
        pycalphad.plot.utils.phase_legend is used
    kwargs : kwargs
        Passed to `matplotlib.pyplot.scatter`.
    Returns
    -------
    matplotlib AxesSubplot
    """
    conds = OrderedDict([(_map_coord_to_variable(key), unpack_condition(np.asarray(value)))
                         for key, value in sorted(eq.coords.items(), key=str)
                         if (key in ('T', 'P', 'N')) or (key.startswith('X_'))])
    indep_comps = sorted([key for key, value in conds.items() if isinstance(key, v.MoleFraction) and len(value) > 1], key=str)
    indep_pots = [key for key, value in conds.items() if isinstance(key, v.IndependentPotential) and len(value) > 1]
    # determine what the type of plot will be
    if len(indep_comps) == 1 and len(indep_pots) == 1:
        projection = None
    elif len(indep_comps) == 2 and len(indep_pots) == 0:
        projection = 'triangular'
    else:
        raise ValueError('The eqplot projection is not defined and cannot be autodetected. There are {} independent compositions and {} indepedent potentials.'.format(len(indep_comps), len(indep_pots)))
    if z is not None:
        raise NotImplementedError('3D plotting is not yet implemented')
    if ax is None:
        fig, (ax) = plt.subplots(subplot_kw={'projection': projection})
    # Handle cases for different plot types
    if projection is None:
        x = indep_comps[0] if x is None else x
        y = indep_pots[0] if y is None else y
        # plot settings
        ax.set_xlim([np.min(conds[x]) - 1e-2, np.max(conds[x]) + 1e-2])
        ax.set_ylim([np.min(conds[y]), np.max(conds[y])])
    elif projection == 'triangular':
        x = indep_comps[0] if x is None else x
        y = indep_comps[1] if y is None else y
        # Here we adjust the x coordinate of the ylabel.
        # We make it reasonably comparable to the position of the xlabel from the xaxis
        # As the figure size gets very large, the label approaches ~0.55 on the yaxis
        # 0.55*cos(60 deg)=0.275, so that is the xcoord we are approaching.
        ax.yaxis.label.set_va('baseline')
        fig_x_size = ax.figure.get_size_inches()[0]
        y_label_offset = 1 / fig_x_size
        ax.yaxis.set_label_coords(x=(0.275 - y_label_offset), y=0.5)
    # get the active phases and support loading netcdf files from disk
    phases = map(str, sorted(set(np.array(eq.Phase.values.ravel(), dtype='U')) - {''}, key=str))
    comps = map(str, sorted(np.array(eq.coords['component'].values, dtype='U'), key=str))
    eq['component'] = np.array(eq['component'], dtype='U')
    eq['Phase'].values = np.array(eq['Phase'].values, dtype='U')
    # Select all two- and three-phase regions
    three_phase_idx = np.nonzero(np.sum(eq.Phase.values != '', axis=-1, dtype=np.int_) == 3)
    two_phase_idx = np.nonzero(np.sum(eq.Phase.values != '', axis=-1, dtype=np.int_) == 2)
    legend_handles, colorlist = legend_generator(phases)
    # For both two and three phase, cast the tuple of indices to an array and flatten
    # If we found two phase regions:
    if two_phase_idx[0].size > 0:
        found_two_phase = eq.Phase.values[two_phase_idx][..., :2]
        # get tieline endpoint compositions
        two_phase_x = eq.X.sel(component=x.species.name).values[two_phase_idx][..., :2]
        # handle special case for potential
        if isinstance(y, v.MoleFraction):
            two_phase_y = eq.X.sel(component=y.species.name).values[two_phase_idx][..., :2]
        else:
            # it's a StateVariable. This must be True
            two_phase_y = np.take(eq[str(y)].values, two_phase_idx[list(str(i) for i in conds.keys()).index(str(y))])
            # because the above gave us a shape of (n,) instead of (n,2) we are going to create it ourselves
            two_phase_y = np.array([two_phase_y, two_phase_y]).swapaxes(0, 1)
        # plot two phase points
        two_phase_plotcolors = np.array(list(map(lambda x: [colorlist[x[0]], colorlist[x[1]]], found_two_phase)), dtype='U')
        ax.scatter(two_phase_x[..., 0], two_phase_y[..., 0], s=3, c=two_phase_plotcolors[:, 0], edgecolors='None', zorder=2, **kwargs)
        ax.scatter(two_phase_x[..., 1], two_phase_y[..., 1], s=3, c=two_phase_plotcolors[:, 1], edgecolors='None', zorder=2, **kwargs)
        if tielines:
            # construct and plot tielines
            two_phase_tielines = np.array([np.concatenate((two_phase_x[..., 0][..., np.newaxis], two_phase_y[..., 0][..., np.newaxis]), axis=-1),
                                           np.concatenate((two_phase_x[..., 1][..., np.newaxis], two_phase_y[..., 1][..., np.newaxis]), axis=-1)])
            two_phase_tielines = np.rollaxis(two_phase_tielines, 1)
            lc = mc.LineCollection(two_phase_tielines, zorder=1, colors=tieline_color, linewidths=[0.5, 0.5])
            ax.add_collection(lc)
    # If we found three phase regions:
    if (three_phase_idx[0].size > 0) and (len(indep_comps) == 2):
        found_three_phase = eq.Phase.values[three_phase_idx][..., :3]
        # get tieline endpoints
        three_phase_x = eq.X.sel(component=x.species.name).values[three_phase_idx][..., :3]
        three_phase_y = eq.X.sel(component=y.species.name).values[three_phase_idx][..., :3]
        # three phase tielines, these are tie triangles and we always plot them
        three_phase_tielines = np.array([np.concatenate((three_phase_x[..., 0][..., np.newaxis], three_phase_y[..., 0][..., np.newaxis]), axis=-1),
                                         np.concatenate((three_phase_x[..., 1][..., np.newaxis], three_phase_y[..., 1][..., np.newaxis]), axis=-1),
                                         np.concatenate((three_phase_x[..., 2][..., np.newaxis], three_phase_y[..., 2][..., np.newaxis]), axis=-1)])
        three_phase_tielines = np.rollaxis(three_phase_tielines, 1)
        three_lc = mc.LineCollection(three_phase_tielines, zorder=1, colors=tie_triangle_color, linewidths=[0.5, 0.5])
        # plot three phase points and tielines
        three_phase_plotcolors = np.array(list(map(lambda x: [colorlist[x[0]], colorlist[x[1]], colorlist[x[2]]], found_three_phase)), dtype='U')
        ax.scatter(three_phase_x[..., 0], three_phase_y[..., 0], s=3, c=three_phase_plotcolors[:, 0], edgecolors='None', zorder=2, **kwargs)
        ax.scatter(three_phase_x[..., 1], three_phase_y[..., 1], s=3, c=three_phase_plotcolors[:, 1], edgecolors='None', zorder=2, **kwargs)
        ax.scatter(three_phase_x[..., 2], three_phase_y[..., 2], s=3, c=three_phase_plotcolors[:, 2], edgecolors='None', zorder=2, **kwargs)
        ax.add_collection(three_lc)
    # position the phase legend and configure plot
    box = ax.get_position()
    ax.set_position([box.x0, box.y0, box.width * 0.8, box.height])
    ax.legend(handles=legend_handles, loc='center left', bbox_to_anchor=(1, 0.5))
    ax.tick_params(axis='both', which='major', labelsize=14)
    ax.grid(True)
    plot_title = '-'.join([component.title() for component in sorted(comps) if component != 'VA'])
    ax.set_title(plot_title, fontsize=20)
    ax.set_xlabel(_axis_label(x), labelpad=15, fontsize=20)
    ax.set_ylabel(_axis_label(y), fontsize=20)
    return ax