"""Plotting functions."""
import logging
import os
import matplotlib
import matplotlib.pyplot as plt
try:
    import networkx as nx
except ImportError:  # pragma: no cover
    print('Please install networkx via pip install "pygenstability[networkx]" for full plotting.')
import numpy as np
from matplotlib import gridspec
from matplotlib import patches
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from tqdm import tqdm
try:
    import plotly.graph_objects as go
    from plotly.offline import plot as _plot
except ImportError:  # pragma: no cover
    pass
from pygenstability.optimal_scales import identify_optimal_scales
L = logging.getLogger(__name__)
[docs]def plot_scan(
    all_results,
    figsize=(6, 5),
    scale_axis=True,
    figure_name="scan_results.pdf",
    use_plotly=False,
    live=True,
    plotly_filename="scan_results.html",
):
    """Plot results of pygenstability with matplotlib or plotly.
    Args:
        all_results (dict): results of pygenstability scan
        figsize (tuple): matplotlib figure size
        scale_axis (bool): display scale of scale index on scale axis
        figure_name (str): name of matplotlib figure
        use_plotly (bool): use matplotlib or plotly backend
        live (bool): for plotly backend, open browser with pot
        plotly_filename (str): filename of .html figure from plotly
    """
    if len(all_results["scales"]) == 1:  # pragma: no cover
        L.info("Cannot plot the results if only one scale point, we display the result instead:")
        L.info(all_results)
        return None
    if use_plotly:
        return plot_scan_plotly(all_results, live=live, filename=plotly_filename)
    return plot_scan_plt(
        all_results, figsize=figsize, scale_axis=scale_axis, figure_name=figure_name
    ) 
[docs]def plot_scan_plotly(  # pylint: disable=too-many-branches,too-many-statements,too-many-locals
    all_results,
    live=False,
    filename="clusters.html",
):
    """Plot results of pygenstability with plotly."""
    scales = _get_scales(all_results, scale_axis=True)
    hovertemplate = str("<b>scale</b>: %{x:.2f}, <br>%{text}<extra></extra>")
    if "NVI" in all_results:
        nvi_data = all_results["NVI"]
        nvi_opacity = 1.0
        nvi_title = "Variation of information"
        nvi_ticks = True
    else:  # pragma: no cover
        nvi_data = np.zeros(len(scales))
        nvi_opacity = 0.0
        nvi_title = None
        nvi_ticks = False
    text = [
        f"""Number of communities: {n}, <br> Stability: {np.round(s, 3)},
        <br> Normalised Variation Information: {np.round(vi, 3)}, <br> Index: {i}"""
        for n, s, vi, i in zip(
            all_results["number_of_communities"],
            all_results["stability"],
            nvi_data,
            np.arange(0, len(scales)),
        )
    ]
    ncom = go.Scatter(
        x=scales,
        y=all_results["number_of_communities"],
        mode="lines+markers",
        hovertemplate=hovertemplate,
        name="Number of communities",
        xaxis="x2",
        yaxis="y4",
        text=text,
        marker_color="red",
    )
    if "ttprime" in all_results:
        z = all_results["ttprime"]
        showscale = True
        tprime_title = "log10(scale)"
    else:  # pragma: no cover
        z = np.nan + np.zeros([len(scales), len(scales)])
        showscale = False
        tprime_title = None
    ttprime = go.Heatmap(
        z=z,
        x=scales,
        y=scales,
        colorscale="YlOrBr_r",
        yaxis="y2",
        xaxis="x2",
        hoverinfo="skip",
        colorbar={"title": "VI", "len": 0.2, "yanchor": "middle", "y": 0.5},
        showscale=showscale,
    )
    if "stability" in all_results:
        stab = go.Scatter(
            x=scales,
            y=all_results["stability"],
            mode="lines+markers",
            hovertemplate=hovertemplate,
            text=text,
            name="Stability",
            marker_color="blue",
        )
    vi = go.Scatter(
        x=scales,
        y=nvi_data,
        mode="lines+markers",
        hovertemplate=hovertemplate,
        text=text,
        name="NVI",
        yaxis="y3",
        xaxis="x",
        marker_color="green",
        opacity=nvi_opacity,
    )
    layout = go.Layout(
        yaxis={
            "title": "Stability",
            "titlefont": {"color": "blue"},
            "tickfont": {"color": "blue"},
            "domain": [0.0, 0.28],
        },
        yaxis2={
            "title": tprime_title,
            "titlefont": {"color": "black"},
            "tickfont": {"color": "black"},
            "domain": [0.32, 1],
            "side": "right",
            "range": [scales[0], scales[-1]],
        },
        yaxis3={
            "title": nvi_title,
            "titlefont": {"color": "green"},
            "tickfont": {"color": "green"},
            "showticklabels": nvi_ticks,
            "overlaying": "y",
            "side": "right",
        },
        yaxis4={
            "title": "Number of communities",
            "titlefont": {"color": "red"},
            "tickfont": {"color": "red"},
            "overlaying": "y2",
        },
        xaxis={"range": [scales[0], scales[-1]]},
        xaxis2={"range": [scales[0], scales[-1]]},
    )
    fig = go.Figure(data=[stab, ncom, vi, ttprime], layout=layout)
    fig.update_layout(xaxis_title="log10(scale)")
    if filename is not None:
        _plot(fig, filename=filename, auto_open=live)
    return fig, layout 
[docs]def plot_single_partition(
    graph, all_results, scale_id, edge_color="0.5", edge_width=0.5, node_size=100
):
    """Plot the community structures for a given scale.
    Args:
        graph (networkx.Graph): graph to plot
        all_results (dict): results of pygenstability scan
        scale_id (int): index of scale to plot
        folder (str): folder to save figures
        edge_color (str): color of edges
        edge_width (float): width of edges
        node_size (float): size of nodes
        ext (str): extension of figures files
    """
    if any("pos" not in graph.nodes[u] for u in graph):
        pos = nx.spring_layout(graph)
        for u in graph:
            graph.nodes[u]["pos"] = pos[u]
    pos = {u: graph.nodes[u]["pos"] for u in graph}
    node_color = all_results["community_id"][scale_id]
    nx.draw_networkx_nodes(
        graph,
        pos=pos,
        node_color=node_color,
        node_size=node_size,
        cmap=plt.get_cmap("tab20"),
    )
    nx.draw_networkx_edges(graph, pos=pos, width=edge_width, edge_color=edge_color)
    plt.axis("off")
    plt.title(
        str(r"$log_{10}(scale) =$ ")
        + str(np.round(np.log10(all_results["scales"][scale_id]), 2))
        + ", with "
        + str(all_results["number_of_communities"][scale_id])
        + " communities"
    ) 
[docs]def plot_optimal_partitions(
    graph,
    all_results,
    edge_color="0.5",
    edge_width=0.5,
    folder="optimal_partitions",
    ext=".pdf",
    show=False,
):
    """Plot the community structures at each optimal scale.
    Args:
        graph (networkx.Graph): graph to plot
        all_results (dict): results of pygenstability scan
        edge_color (str): color of edges
        edge_width (float): width of edgs
        folder (str): folder to save figures
        ext (str): extension of figures files
        show (bool): show each plot with plt.show() or not
    """
    if not os.path.isdir(folder):
        os.mkdir(folder)
    if "selected_partitions" not in all_results:  # pragma: no cover
        identify_optimal_scales(all_results)
    selected_scales = all_results["selected_partitions"]
    n_selected_scales = len(selected_scales)
    if n_selected_scales == 0:  # pragma: no cover
        return
    for optimal_scale_id in selected_scales:
        plot_single_partition(
            graph,
            all_results,
            optimal_scale_id,
            edge_color=edge_color,
            edge_width=edge_width,
        )
        plt.savefig(f"{folder}/scale_{optimal_scale_id}{ext}", bbox_inches="tight")
        if show:  # pragma: no cover
            plt.show() 
[docs]def plot_communities(
    graph,
    all_results,
    folder="communities",
    edge_color="0.5",
    edge_width=0.5,
    ext=".pdf",
):
    """Plot the community structures at each scale in a folder.
    Args:
        graph (networkx.Graph): graph to plot
        all_results (dict): results of pygenstability scan
        folder (str): folder to save figures
        edge_color (str): color of edges
        edge_width (float): width of edgs
        ext (str): extension of figures files
    """
    if not os.path.isdir(folder):
        os.mkdir(folder)
    mpl_backend = matplotlib.get_backend()
    matplotlib.use("Agg")
    for scale_id in tqdm(range(len(all_results["scales"]))):
        plt.figure()
        plot_single_partition(
            graph, all_results, scale_id, edge_color=edge_color, edge_width=edge_width
        )
        plt.savefig(os.path.join(folder, "scale_" + str(scale_id) + ext), bbox_inches="tight")
        plt.close()
    matplotlib.use(mpl_backend) 
[docs]def plot_communities_matrix(graph, all_results, folder="communities_matrix", ext=".pdf"):
    """Plot communities at all scales in matrix form.
    Args:
        graph (array): as a numpy matrix
        all_results (dict): clustring results
        folder (str): folder to save figures
        ext (str): figure file format
    """
    if not os.path.isdir(folder):
        os.mkdir(folder)
    for scale_id in tqdm(range(len(all_results["scales"]))):
        plt.figure()
        com_ids = all_results["community_id"][scale_id]
        ids = []
        lines = [0]
        for i in range(len(set(com_ids))):
            _ids = list(np.argwhere(com_ids == i).flatten())
            lines.append(len(_ids))
            ids += _ids
        plt.imshow(graph[ids][:, ids], origin="lower")
        lines = np.cumsum(lines)
        for i in range(len(lines) - 1):
            plt.plot((lines[i], lines[i + 1]), (lines[i], lines[i]), c="k")
            plt.plot((lines[i], lines[i]), (lines[i], lines[i + 1]), c="k")
            plt.plot((lines[i + 1], lines[i + 1]), (lines[i + 1], lines[i]), c="k")
            plt.plot((lines[i + 1], lines[i]), (lines[i + 1], lines[i + 1]), c="k")
        plt.savefig(os.path.join(folder, "scale_" + str(scale_id) + ext), bbox_inches="tight") 
def _get_scales(all_results, scale_axis=True):
    """Get the scale vector."""
    if not scale_axis:  # pragma: no cover
        return np.arange(len(all_results["scales"]))
    if all_results["run_params"]["log_scale"]:
        return np.log10(all_results["scales"])
    return all_results["scales"]  # pragma: no cover
def _plot_number_comm(all_results, ax, scales):
    """Plot number of communities."""
    ax.plot(scales, all_results["number_of_communities"], "-", c="C3", label="size", lw=2.0)
    ax.set_ylim(0, 1.1 * max(all_results["number_of_communities"]))
    ax.set_ylabel("# clusters", color="C3")
    ax.tick_params("y", colors="C3")
def _plot_ttprime(all_results, ax, scales):
    """Plot ttprime."""
    contourf_ = ax.contourf(scales, scales, all_results["ttprime"], cmap="YlOrBr_r", extend="min")
    ax.set_ylabel(r"$log_{10}(t^\prime)$")
    ax.yaxis.tick_left()
    ax.yaxis.set_label_position("left")
    ax.axis([scales[0], scales[-1], scales[0], scales[-1]])
    ax.set_xlabel(r"$log_{10}(t)$")
    axins = inset_axes(
        ax,
        width="3%",
        height="40%",
        loc="lower left",
        bbox_to_anchor=(0.05, 0.45, 1, 1),
        bbox_transform=ax.transAxes,
        borderpad=0,
    )
    axins.tick_params(labelsize=7)
    plt.colorbar(contourf_, cax=axins, label="NVI(t,t')")
def _plot_NVI(all_results, ax, scales):
    """Plot variation information."""
    ax.plot(scales, all_results["NVI"], "-", lw=2.0, c="C2", label="VI")
    ax.yaxis.tick_right()
    ax.tick_params("y", colors="C2")
    ax.set_ylabel(r"NVI", color="C2")
    ax.axhline(1, ls="--", lw=1.0, c="C2")
    ax.axis([scales[0], scales[-1], 0.0, np.max(all_results["NVI"]) * 1.1])
    ax.set_xlabel(r"$log_{10}(t)$")
def _plot_stability(all_results, ax, scales):
    """Plot stability."""
    ax.plot(scales, all_results["stability"], "-", label=r"Stability", c="C0")
    ax.tick_params("y", colors="C0")
    ax.set_ylabel("Stability", color="C0")
    ax.set_ylim(0, 1.1 * max(all_results["stability"]))
    ax.yaxis.set_label_position("left")
def _plot_optimal_scales(all_results, ax, scales, ax1, ax2):
    """Plot stability."""
    ax.plot(
        scales,
        all_results["block_nvi"],
        "-",
        lw=2.0,
        c="C4",
        label="Block NVI",
    )
    ax.plot(
        scales[all_results["selected_partitions"]],
        all_results["block_nvi"][all_results["selected_partitions"]],
        "o",
        lw=2.0,
        c="C4",
        label="optimal scales",
    )
    ax.tick_params("y", colors="C4")
    ax.set_ylabel("Block NVI", color="C4")
    ax.yaxis.set_label_position("left")
    ax.set_xlabel(r"$log_{10}(t)$")
    for scale in scales[all_results["selected_partitions"]]:
        ax.axvline(scale, ls="--", color="C4")
        ax1.axvline(scale, ls="--", color="C4")
        ax2.axvline(scale, ls="--", color="C4")
[docs]def plot_scan_plt(all_results, figsize=(6, 5), scale_axis=True, figure_name="scan_results.svg"):
    """Plot results of pygenstability with matplotlib."""
    scales = _get_scales(all_results, scale_axis=scale_axis)
    plt.figure(figsize=figsize)
    gs = gridspec.GridSpec(3, 1, height_ratios=[0.5, 1.0, 0.5])
    gs.update(hspace=0)
    axes = []
    if "ttprime" in all_results:
        ax0 = plt.subplot(gs[1, 0])
        axes.append(ax0)
        _plot_ttprime(all_results, ax=ax0, scales=scales)
        ax1 = ax0.twinx()
    else:  # pragma: no cover
        ax1 = plt.subplot(gs[1, 0])
    axes.append(ax1)
    ax1.set_xticks([])
    _plot_NVI(all_results, ax=ax1, scales=scales)
    if "ttprime" in all_results:
        ax1.yaxis.tick_right()
        ax1.yaxis.set_label_position("right")
    ax2 = plt.subplot(gs[0, 0])
    if "stability" in all_results:
        _plot_stability(all_results, ax=ax2, scales=scales)
        ax2.set_xticks([])
        axes.append(ax2)
    if "NVI" in all_results:
        ax3 = ax2.twinx()
        _plot_number_comm(all_results, ax=ax3, scales=scales)
        axes.append(ax3)
    if "block_nvi" in all_results:
        ax4 = plt.subplot(gs[2, 0])
        _plot_optimal_scales(all_results, ax=ax4, scales=scales, ax1=ax1, ax2=ax2)
        axes.append(ax4)
    for ax in axes:
        ax.set_xlim(scales[0], scales[-1])
    if figure_name is not None:
        plt.savefig(figure_name)
    return axes 
[docs]def plot_clustered_adjacency(
    adjacency,
    all_results,
    scale,
    labels=None,
    figsize=(12, 10),
    cmap="Blues",
    figure_name="clustered_adjacency.pdf",
):
    """Plot the clustered adjacency matrix of the graph at a given scale.
    Args:
        adjacency (ndarray): adjacency matrix to plot
        all_results (dict): results of PyGenStability
        scale (int): scale index for clustering
        labels (list): node labels, or None
        figsize (tubple): figure size
        cmap (str): colormap for matrix elements
        figure_name (str): filename of the figure with extension
    """
    comms, counts = np.unique(all_results["community_id"][scale], return_counts=True)
    node_ids = []
    for comm in comms:
        node_ids += list(np.where(all_results["community_id"][scale] == comm)[0])
    adjacency = adjacency[np.ix_(node_ids, node_ids)]
    adjacency[adjacency == 0] = np.nan
    plt.figure(figsize=figsize)
    plt.imshow(adjacency, aspect="auto", cmap=cmap)
    ax = plt.gca()
    pos = 0
    for comm, count in zip(comms, counts):
        rect = patches.Rectangle(
            (pos - 0.5, pos - 0.5),
            count,
            count,
            linewidth=5,
            facecolor="none",
            edgecolor="g",
        )
        ax.add_patch(rect)
        pos += count
    ax.set_xticks(np.arange(len(adjacency)))
    ax.set_yticks(np.arange(len(adjacency)))
    if labels is not None:  # pragma: no cover
        labels_plot = [labels[i] for i in node_ids]
        ax.set_xticklabels(labels_plot)
        ax.set_yticklabels(labels_plot)
    plt.colorbar()
    plt.xticks(rotation=90)
    plt.axis([-0.5, len(adjacency) - 0.5, -0.5, len(adjacency) - 0.5])
    plt.suptitle(
        "log10(scale) = "
        + str(np.round(np.log10(all_results["scales"][scale]), 2))
        + ",  number_of_communities="
        + str(all_results["number_of_communities"][scale])
    )
    plt.savefig(figure_name, bbox_inches="tight")