Source code for pcmdi_metrics.graphics.parallel_coordinate_plot.parallel_coordinate_plot_lib

import sys

import matplotlib.pylab as pylab
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from matplotlib.cbook import flatten

from pcmdi_metrics.graphics import add_logo


[docs] def parallel_coordinate_plot( data, metric_names, model_names, models_to_highlight=list(), models_to_highlight_by_line=True, models_to_highlight_colors=None, models_to_highlight_labels=None, models_to_highlight_markers=["s", "o", "^", "*"], models_to_highlight_markers_size=10, fig=None, ax=None, figsize=(15, 5), show_boxplot=False, show_violin=False, violin_colors=("lightgrey", "pink"), violin_label=None, title=None, identify_all_models=True, xtick_labelsize=None, ytick_labelsize=None, colormap="viridis", num_color=20, legend_off=False, legend_ncol=6, legend_bbox_to_anchor=(0.5, -0.14), legend_loc="upper center", legend_fontsize=10, logo_rect=None, logo_off=False, model_names2=None, group1_name="group1", group2_name="group2", comparing_models=None, fill_between_lines=False, fill_between_lines_colors=("red", "green"), arrow_between_lines=False, arrow_between_lines_colors=("red", "green"), arrow_alpha=1, arrow_width=0.05, arrow_linewidth=0, arrow_head_width=0.15, arrow_head_length=0.15, vertical_center=None, vertical_center_line=False, vertical_center_line_label=None, ymax=None, ymin=None, debug=False, ): """ Create a parallel coordinate plot for visualizing multi-dimensional data. .. image:: /_static/images/parallel_coordiate_plot_docstring_example.png :alt: Example parallel coordinate plot :align: center :width: 600px Parameters ---------- data : ndarray 2-d numpy array for metrics. metric_names : list Names of metrics for individual vertical axes (axis=1). model_names : list Name of models for markers/lines (axis=0). models_to_highlight : list, optional List of models to highlight as lines or markers. models_to_highlight_by_line : bool, optional If True, highlight as lines. If False, highlight as markers. Default is True. models_to_highlight_colors : list, optional List of colors for models to highlight as lines. models_to_highlight_labels : list, optional List of string labels for models to highlight as lines. models_to_highlight_markers : list, optional Matplotlib markers for models to highlight if as marker. Default is ["s", "o", "^", "*"]. models_to_highlight_markers_size : float, optional Size of matplotlib markers for models to highlight if as marker. Default is 10. fig : matplotlib.figure.Figure, optional Figure instance to which the parallel coordinate plot is plotted. ax : matplotlib.axes.Axes, optional Axes instance to which the parallel coordinate plot is plotted. figsize : tuple, optional Figure size (width, height) in inches. Default is (15, 5). show_boxplot : bool, optional If True, show box and whiskers plot. Default is False. show_violin : bool, optional If True, show violin plot. Default is False. violin_colors : tuple or list, optional Two strings for colors of violin. Default is ("lightgrey", "pink"). violin_label : str, optional Label for the violin plot when not split. Default is None. title : str, optional Plot title. identify_all_models : bool, optional If True, show and identify all models using markers. Default is True. xtick_labelsize : int or str, optional Fontsize for x-axis tick labels. ytick_labelsize : int or str, optional Fontsize for y-axis tick labels. colormap : str, optional Matplotlib colormap. Default is 'viridis'. num_color : int, optional Number of colors to use. Default is 20. legend_off : bool, optional If True, turn off legend. Default is False. legend_ncol : int, optional Number of columns for legend text. Default is 6. legend_bbox_to_anchor : tuple, optional Set legend box location. Default is (0.5, -0.14). legend_loc : str, optional Set legend box location. Default is "upper center". legend_fontsize : float, optional Legend font size. Default is 10. logo_rect : sequence of float, optional The dimensions [left, bottom, width, height] of the new Axes for logo. logo_off : bool, optional If True, turn off PMP logo. Default is False. model_names2 : list of str, optional Should be a subset of `model_names`. If given, violin plot will be split into 2 groups. group1_name : str, optional Name for the first group in split violin plot. Default is 'group1'. group2_name : str, optional Name for the second group in split violin plot. Default is 'group2'. comparing_models : tuple or list, optional Two strings for models to compare with colors filled between the two lines. fill_between_lines : bool, optional If True, fill color between lines for models in comparing_models. Default is False. fill_between_lines_colors : tuple or list, optional Two strings of colors for filled between the two lines. Default is ('red', 'green'). arrow_between_lines : bool, optional If True, place arrows between two lines for models in comparing_models. Default is False. arrow_between_lines_colors : tuple or list, optional Two strings of colors for arrow between the two lines. Default is ('red', 'green'). arrow_alpha : float, optional Transparency of arrow (fraction between 0 to 1). Default is 1. arrow_width : float, optional Width of arrow. Default is 0.05. arrow_linewidth : float, optional Width of arrow edge line. Default is 0. arrow_head_width : float, optional Width of arrow head. Default is 0.15. arrow_head_length : float, optional Length of arrow head. Default is 0.15. vertical_center : str or float or int, optional Adjust range of vertical axis to set center of vertical axis as median, mean, or given number. vertical_center_line : bool, optional If True, show median as line. Default is False. vertical_center_line_label : str, optional Label in legend for the horizontal vertical center line. If not given, it will be automatically assigned. ymax : int or float or str, optional Specify value of vertical axis top. If 'percentile', 95th percentile or extended for top. ymin : int or float or str, optional Specify value of vertical axis bottom. If 'percentile', 5th percentile or extended for bottom. debug : bool, optional If True, print debug information. Default is False. Returns ------- fig : matplotlib.figure.Figure The figure component of the plot. ax : matplotlib.axes.Axes The axes component of the plot. Notes ----- This function creates a parallel coordinate plot for visualizing multi-dimensional data. It supports various customization options including highlighting specific models, adding violin plots, and comparing models with filled areas or arrows. The function uses matplotlib for plotting and can integrate with existing figure and axes objects. Author: Jiwoo Lee @ LLNL (2021. 7) Update history: - 2021-07 Plotting code created. Inspired by https://stackoverflow.com/questions/8230638/parallel-coordinates-plot-in-matplotlib - 2022-09 violin plots added - 2023-03 median centered option added - 2023-04 vertical center option diversified (median, mean, or given number) - 2024-03 parameter added for violin plot label - 2024-04 parameters added for arrow and option added for ymax/ymin setting - 2024-11 docstring cleaned up Examples -------- >>> from pcmdi_metrics.graphics import parallel_coordinate_plot >>> import numpy as np >>> data = np.random.rand(10, 10) >>> metric_names = ['Metric' + str(i) for i in range(10)] >>> model_names = ['Model' + str(i) for i in range(10)] >>> fig, ax = parallel_coordinate_plot(data, metric_names, model_names, models_to_highlight=model_names[0]) .. image:: /_static/images/parallel_coordiate_plot_docstring_example.png :alt: Example parallel coordinate plot :align: center :width: 600px Further examples can be found `here <https://github.com/PCMDI/pcmdi_metrics/tree/main/pcmdi_metrics/graphics/parallel_coordinate_plot#readme>`__. """ params = { "legend.fontsize": "large", "axes.labelsize": "x-large", "axes.titlesize": "x-large", "xtick.labelsize": "x-large", "ytick.labelsize": "x-large", } pylab.rcParams.update(params) # Quick initial QC _quick_qc(data, model_names, metric_names, model_names2=model_names2) # Transform data for plotting zs, zs_middle, N, ymins, ymaxs, df_stacked, df2_stacked = _data_transform( data, metric_names, model_names, model_names2=model_names2, group1_name=group1_name, group2_name=group2_name, vertical_center=vertical_center, ymax=ymax, ymin=ymin, ) if debug: print("ymins:", ymins) print("ymaxs:", ymaxs) # Prepare plot if N > 20: if xtick_labelsize is None: xtick_labelsize = "large" if ytick_labelsize is None: ytick_labelsize = "large" else: if xtick_labelsize is None: xtick_labelsize = "x-large" if ytick_labelsize is None: ytick_labelsize = "x-large" params = { "legend.fontsize": "large", "axes.labelsize": "x-large", "axes.titlesize": "x-large", "xtick.labelsize": xtick_labelsize, "ytick.labelsize": ytick_labelsize, } pylab.rcParams.update(params) if fig is None and ax is None: fig, ax = plt.subplots(figsize=figsize) axes = [ax] + [ax.twinx() for i in range(N - 1)] for i, ax_y in enumerate(axes): ax_y.set_ylim(ymins[i], ymaxs[i]) ax_y.spines["top"].set_visible(False) ax_y.spines["bottom"].set_visible(False) if ax_y == ax: ax_y.spines["left"].set_position(("data", i)) if ax_y != ax: ax_y.spines["left"].set_visible(False) ax_y.yaxis.set_ticks_position("right") ax_y.spines["right"].set_position(("data", i)) # Population distribuion on each vertical axis if show_boxplot or show_violin: y = [zs[:, i] for i in range(N)] y_filtered = [ y_i[~np.isnan(y_i)] for y_i in y ] # Remove NaN value for box/violin plot # Box plot if show_boxplot: box = ax.boxplot( y_filtered, positions=range(N), patch_artist=True, widths=0.15 ) for item in ["boxes", "whiskers", "fliers", "medians", "caps"]: plt.setp(box[item], color="darkgrey") plt.setp(box["boxes"], facecolor="None") plt.setp(box["fliers"], markeredgecolor="darkgrey") # Violin plot if show_violin: if model_names2 is None: # matplotlib for regular violin plot violin = ax.violinplot( y_filtered, positions=range(N), showmeans=False, showmedians=False, showextrema=False, ) for pc in violin["bodies"]: if isinstance(violin_colors, tuple) or isinstance( violin_colors, list ): violin_color = violin_colors[0] else: violin_color = violin_colors pc.set_facecolor(violin_color) pc.set_edgecolor("None") pc.set_alpha(0.8) else: # seaborn for split violin plot violin = sns.violinplot( data=df2_stacked, x="Metric", y="value", ax=ax, hue="group", split=True, linewidth=0.1, scale="count", scale_hue=False, palette={ group1_name: violin_colors[0], group2_name: violin_colors[1], }, ) # Line or marker num_color = min(len(model_names), num_color) colors = [plt.get_cmap(colormap)(c) for c in np.linspace(0, 1, num_color)] marker_types = ["o", "s", "*", "^", "X", "D", "p"] markers = list(flatten([[marker] * len(colors) for marker in marker_types])) colors *= len(marker_types) mh_index = 0 for j, model in enumerate(model_names): # to just draw straight lines between the axes: if model in models_to_highlight: if models_to_highlight_colors is not None: color = models_to_highlight_colors[mh_index] else: color = colors[j] if models_to_highlight_labels is not None: label = models_to_highlight_labels[mh_index] else: label = model if models_to_highlight_by_line: ax.plot(range(N), zs[j, :], "-", c=color, label=label, lw=3) else: ax.plot( range(N), zs[j, :], models_to_highlight_markers[mh_index], c=color, label=label, markersize=models_to_highlight_markers_size, ) mh_index += 1 else: if identify_all_models: ax.plot( range(N), zs[j, :], markers[j], c=colors[j], label=model, clip_on=False, ) if vertical_center_line: if vertical_center_line_label is None: vertical_center_line_label = str(vertical_center) elif vertical_center_line_label == "off": vertical_center_line_label = None ax.plot(range(N), zs_middle, "-", c="k", label=vertical_center_line_label, lw=1) # Compare two models if comparing_models is not None: if isinstance(comparing_models, tuple) or ( isinstance(comparing_models, list) and len(comparing_models) == 2 ): x = range(N) m1 = model_names.index(comparing_models[0]) m2 = model_names.index(comparing_models[1]) y1 = zs[m1, :] y2 = zs[m2, :] # Fill between lines if fill_between_lines: ax.fill_between( x, y1, y2, where=(y2 > y1), facecolor=fill_between_lines_colors[0], interpolate=False, alpha=0.5, ) ax.fill_between( x, y1, y2, where=(y2 < y1), facecolor=fill_between_lines_colors[1], interpolate=False, alpha=0.5, ) # Add vertical arrows if arrow_between_lines: for xi, yi1, yi2 in zip(x, y1, y2): if yi2 > yi1: arrow_color = arrow_between_lines_colors[0] elif yi2 < yi1: arrow_color = arrow_between_lines_colors[1] else: arrow_color = None arrow_length = yi2 - yi1 ax.arrow( xi, yi1, 0, arrow_length, color=arrow_color, length_includes_head=True, alpha=arrow_alpha, width=arrow_width, linewidth=arrow_linewidth, head_width=arrow_head_width, head_length=arrow_head_length, zorder=999, ) ax.set_xlim(-0.5, N - 0.5) ax.set_xticks(range(N)) ax.set_xticklabels(metric_names, fontsize=xtick_labelsize) ax.tick_params(axis="x", which="major", pad=7) ax.spines["right"].set_visible(False) ax.set_title(title, fontsize=18) if not legend_off: if violin_label is not None: # Get all lines for legend lines = [violin["bodies"][0]] + ax.lines # Get labels for legend labels = [violin_label] + [line.get_label() for line in ax.lines] # Remove unnessasary lines that its name starts with '_' to avoid the burden of warning message lines = [aa for aa, bb in zip(lines, labels) if not bb.startswith("_")] labels = [bb for bb in labels if not bb.startswith("_")] # Add legend ax.legend( lines, labels, loc=legend_loc, ncol=legend_ncol, bbox_to_anchor=legend_bbox_to_anchor, fontsize=legend_fontsize, ) else: # Add legend ax.legend( loc=legend_loc, ncol=legend_ncol, bbox_to_anchor=legend_bbox_to_anchor, fontsize=legend_fontsize, ) if not logo_off: fig, ax = add_logo(fig, ax, logo_rect) return fig, ax
def _quick_qc(data, model_names, metric_names, model_names2=None, debug=False): # Quick initial QC if data.shape[0] != len(model_names): sys.exit( "Error: data.shape[0], " + str(data.shape[0]) + ", mismatch to len(model_names), " + str(len(model_names)) ) if data.shape[1] != len(metric_names): sys.exit( "Error: data.shape[1], " + str(data.shape[1]) + ", mismatch to len(metric_names), " + str(len(metric_names)) ) if model_names2 is not None: # Check: model_names2 should be a subset of model_names for model in model_names2: if model not in model_names: sys.exit( "Error: model_names2 should be a subset of model_names, but " + model + " is not in model_names" ) if debug: print("Passed a quick QC") def _data_transform( data, metric_names, model_names, model_names2=None, group1_name="group1", group2_name="group2", vertical_center=None, ymax=None, ymin=None, ): # Data to plot ys = data # stacked y-axis values N = ys.shape[1] # number of vertical axis (i.e., =len(metric_names)) if ymax is None: ymaxs = np.nanmax(ys, axis=0) # maximum (ignore nan value) else: try: if isinstance(ymax, str) and ymax == "percentile": ymaxs = np.nanpercentile(ys, 95, axis=0) else: ymaxs = np.repeat(ymax, N) except ValueError: print(f"Invalid input for ymax: {ymax}") if ymin is None: ymins = np.nanmin(ys, axis=0) # minimum (ignore nan value) else: try: if isinstance(ymin, str) and ymin == "percentile": ymins = np.nanpercentile(ys, 5, axis=0) else: ymins = np.repeat(ymin, N) except ValueError: print(f"Invalid input for ymin: {ymin}") ymeds = np.nanmedian(ys, axis=0) # median ymean = np.nanmean(ys, axis=0) # mean if vertical_center is not None: if vertical_center == "median": ymids = ymeds elif vertical_center == "mean": ymids = ymean elif isinstance(vertical_center, float) or isinstance(vertical_center, int): ymids = np.repeat(vertical_center, N) else: raise ValueError(f"vertical center {vertical_center} unknown.") for i in range(0, N): distance_from_middle = max( abs(ymaxs[i] - ymids[i]), abs(ymids[i] - ymins[i]) ) ymaxs[i] = ymids[i] + distance_from_middle ymins[i] = ymids[i] - distance_from_middle dys = ymaxs - ymins if ymin is None: ymins -= dys * 0.05 # add 5% padding below and above if ymax is None: ymaxs += dys * 0.05 dys = ymaxs - ymins # Transform all data to be compatible with the main axis zs = np.zeros_like(ys) zs[:, 0] = ys[:, 0] zs[:, 1:] = (ys[:, 1:] - ymins[1:]) / dys[1:] * dys[0] + ymins[0] if vertical_center is not None: zs_middle = (ymids[:] - ymins[:]) / dys[:] * dys[0] + ymins[0] else: zs_middle = (ymaxs[:] - ymins[:]) / 2 / dys[:] * dys[0] + ymins[0] if model_names2 is not None: print("Models in the second group:", model_names2) # Pandas dataframe for seaborn plotting df_stacked = _to_pd_dataframe( data, metric_names, model_names, model_names2=model_names2, group1_name=group1_name, group2_name=group2_name, ) df2_stacked = _to_pd_dataframe( zs, metric_names, model_names, model_names2=model_names2, group1_name=group1_name, group2_name=group2_name, ) return zs, zs_middle, N, ymins, ymaxs, df_stacked, df2_stacked def _to_pd_dataframe( data: np.ndarray, metric_names: list[str], model_names: list[str], model_names2: list[str] = None, group1_name: str = "group1", group2_name: str = "group2", debug=False, ) -> pd.DataFrame: """ Converts data into a stacked pandas DataFrame for seaborn plotting. Parameters ---------- data : np.ndarray 2D array of data values, where rows correspond to `model_names` and columns to `metric_names`. metric_names : list of str List of metric names for DataFrame columns. model_names : list of str List of model names for DataFrame index. model_names2 : list of str, optional Secondary list of model names for alternate grouping. group1_name : str, default="group1" Name assigned to the group when `model_names2` is not matched. group2_name : str, default="group2" Name assigned to the group when `model_names2` is matched. debug : bool, default=False If True, print debug information. Returns ------- pd.DataFrame Stacked DataFrame with columns: 'Model', 'Metric', 'value', and 'group'. """ if debug: print("data.shape:", data.shape) # Check input validity if data.shape[1] != len(metric_names): raise ValueError( "Number of columns in `data` must match length of `metric_names`." ) if data.shape[0] != len(model_names): raise ValueError("Number of rows in `data` must match length of `model_names`.") # Create DataFrame df = pd.DataFrame(data, columns=metric_names, index=model_names) # Stack without dropna (using new stack implementation) df_stacked = df.stack(future_stack=True).reset_index() df_stacked = df_stacked.rename( columns={"level_0": "Model", "level_1": "Metric", 0: "value"} ) df_stacked["group"] = group1_name # Update group column based on model_names2 if model_names2 is not None: for model2 in model_names2: df_stacked["group"] = np.where( (df_stacked.Model == model2), group2_name, df_stacked.group ) return df_stacked