Source code for ESMBenchmarkViz.core_portrait_plot

# - Generate an interactive Portrait Plot using Bokeh.
# - Author: Jiwoo Lee (2021.08)
# - Last update: 2024.11

import math
import sys
from copy import deepcopy
from typing import List, Optional, Tuple, Union

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from bokeh.colors import RGB
from bokeh.models import (
from bokeh.plotting import figure, show

# -------------
# Main function
# -------------

[docs] def portrait_plot( data: Union[np.ndarray, List[np.ndarray]], xaxis_labels: List[str], yaxis_labels: List[str], width: Union[int, str] = 600, height: Union[int, str] = 600, annotate: bool = False, annotate_data: Optional[np.ndarray] = None, vrange: Optional[Tuple[float, float]] = None, xaxis_fontsize: Optional[int] = None, yaxis_fontsize: Optional[int] = None, xaxis_fontstyle: Optional[str] = None, yaxis_fontstyle: Optional[str] = None, xaxis_location: str = "above", xaxis_rotation: int = 45, title: Optional[str] = None, cmap: str = "RdBu_r", cmap_bounds: Optional[List[float]] = None, cbar_place: str = "right", cbar_tick_fontsize: Optional[int] = None, invert_yaxis: bool = True, clickable: bool = False, line_color: str = "grey", legend_name: Optional[str] = "Group", legend_labels: Optional[List[str]] = None, img_url: Optional[List[str]] = None, tooltips: Optional[Union[str, List[Tuple[str, str]]]] = None, url_open: Optional[List[str]] = None, missing_color: str = "grey", aspect_scale: float = 1, show_plot: bool = True, bokeh_toolbar: bool = True, bokeh_logo: bool = True, debug: bool = False, ): """ Generates an interactive portrait plot using Bokeh. Parameters ---------- data : numpy.ndarray or list of numpy.ndarray A 2D array, a list of 2D arrays, or a 3D array (stacked 2D arrays) containing the data to plot. xaxis_labels : list of str Labels for the x-axis. The number of labels must match the x-axis dimensions, or use an empty list to disable labels. yaxis_labels : list of str Labels for the y-axis. The number of labels must match the y-axis dimensions, or use an empty list to disable labels. width : int, optional Width of the plot in pixels. Default is 600. If 'auto', the width is calculated based on the data size. height : int, optional Height of the plot in pixels. Default is 600. If 'auto', the height is calculated based on the data size. annotate : bool, optional If True, adds annotations to the plot (only for heatmap-style plots). Default is False. annotate_data : numpy.ndarray, optional A 2D array to use for annotations. If None, `data` is used. Default is None. vrange : tuple of float, optional Range of values for the color scale. Default is None. xaxis_fontsize : int, optional Font size for the x-axis tick labels. Default is None. yaxis_fontsize : int, optional Font size for the y-axis tick labels. Default is None. xaxis_fontstyle : str, optional Font style for the x-axis labels. Options are ['normal', 'italic', 'bold', 'bold italic']. Default is None. yaxis_fontstyle : str, optional Font style for the y-axis labels. Options are ['normal', 'italic', 'bold', 'bold italic']. Default is None. xaxis_location : str, optional Location of the x-axis. Options are ['above', 'below', 'both']. Default is 'above'. xaxis_rotation : int, optional Rotation angle of the x-axis tick labels in degrees. Default is 45. title : str, optional Title of the figure. Default is None. cmap : str, optional Name of the matplotlib colormap to use. Default is 'RdBu_r'. cmap_bounds : list of float, optional If specified, applies discrete color bins. Default is None. cbar_place : str, optional Location of the colorbar. Options are ['left', 'right', 'above', 'below', 'center']. Default is 'right'. cbar_tick_fontsize : int, optional Font size for the colorbar tick labels. Default is None. invert_yaxis : bool, optional If True, places y=0 at the top of the plot. Default is True. clickable : bool, optional If True, enables clickable functionality. Default is False. line_color: str, optional Color of the lines in the plot. Default is 'grey'. legend_name: str, optional Name of the legend (used for triangular plots). Default is 'Group'. legend_labels : list of str, optional Labels for the legend (used for triangular plots). Default is None. img_url : list of str, optional Links to images displayed in tooltips. Default is None. tooltips : str or list of tuple, optional Tooltips for the plot. Default is None. url_open : list of str, optional Links to open when a tooltip is clicked. Default is None. missing_color : str, optional Color for missing values in the plot. Default is 'grey'. aspect_scale : float, optional Scale factor for the plot aspect ratio. Default is 1. show_plot : bool, optional If True, the plot will be displayed in the workflow (default is True). bokeh_toolbar : bool, optional If True, displays the Bokeh toolbar in the plot. Default is True. bokeh_logo : bool, optional If True, displays the Bokeh logo in the plot. Default is True. debug : bool, optional If True, prints debug messages. Default is False. Returns ------- plot : Bokeh component A Bokeh plot object representing the interactive portrait plot. Example ------- >>> from ESMBenchmarkViz import portrait_plot Notes ----- - The function supports both 2D and stacked 3D data for generating portrait plots. - Interactive features include tooltips and clickable URLs, enabled through Bokeh. - Missing values are displayed using the specified `missing_color`. """ # ---------------- # Prepare plotting # ---------------- data, num_divide = prepare_data(data, xaxis_labels, yaxis_labels, debug) if annotate: annotate_data, num_divide_annotate = prepare_data( annotate_data, xaxis_labels, yaxis_labels, debug ) if num_divide_annotate != num_divide: sys.exit("Error: annotate_data does not have same size as data") if url_open is None: url_open = img_url # Figure type if num_divide > 1 and len(data.shape) == 3: if num_divide != len(data): sys.exit("Error: data.shape[0] is not equal to num_divide") if annotate: if num_divide != len(annotate_data): sys.exit("Error: annotate_data.shape[0] is not equal to num_divide") xpts_list, ypts_list = get_x_y_points(num_divide) positions = get_positions(num_divide) # Prepare data for plotting # ~~~~~~~~~~~~~~~~~~~~~~~~~ xs, ys = list(), list() field, field2 = list(), list() positions_list, position_description_list = list(), list() xname_list, yname_list = list(), list() for i in range(num_divide): xpts = xpts_list[i] ypts = ypts_list[i] if num_divide > 1 and len(data.shape) == 3: a = data[i].copy() if annotate: annotate_a = annotate_data[i].copy() elif num_divide == 1 and len(data.shape) == 2: a = data.copy() if annotate: annotate_a = annotate_data.copy() else: sys.exit("Error: data.shape is not right") y = list(range(0, a.shape[0])) x = list(range(0, a.shape[1])) if invert_yaxis: a = np.flipud(deepcopy(a)) if annotate: annotate_a = np.flipud(deepcopy(annotate_a)) # xs, ys: x- and y-coordinates for all the patches, # given as a “list of lists”. for iy in y: yname = yaxis_labels[iy] for ix in x: xname = xaxis_labels[ix] xs.append([tmp_x + ix for tmp_x in xpts]) ys.append([tmp_y + iy for tmp_y in ypts]) field.append(a[iy, ix]) if annotate: field2.append(annotate_a[iy, ix]) if legend_labels is not None: position_description_list.append(legend_labels[i]) if positions is not None: positions_list.append(positions[i]) xname_list.append(xname) yname_list.append(yname) # Gathered data for plotting col_dict = dict( xs=xs, ys=ys, field=field, xname=xname_list, yname=yname_list, ) # if img_url is not None, update col_dict with img_url if img_url is not None: col_dict.update(dict(img=img_url)) # if url_open is not None, update col_dict with url_open if url_open is not None: col_dict.update(dict(url=url_open)) # if field2 is not empty, update col_dict with field2 if len(field2) > 0: col_dict.update(dict(field2=field2)) col_dict_df = pd.DataFrame.from_dict(col_dict) col_dict_df.loc[ col_dict_df.field2.isna(), ("img") ] = "" col_dict.update(dict(img=col_dict_df["img"].tolist())) # if position_list is not None, update col_dict with position_list if len(positions_list) == len(xname_list): col_dict.update(dict(position=positions_list)) # if position_description_list is not empty, update col_dict with position_description_list if len(position_description_list) > 0: col_dict.update(dict(position_description=position_description_list)) if debug: print("col_dict: ", col_dict) print("col_dict.keys(): ", col_dict.keys()) print("col_dict['xs']: ", col_dict["xs"]) print("col_dict['ys']: ", col_dict["ys"]) print("col_dict['position']: ", col_dict["position"]) source = ColumnDataSource(col_dict) # ---------------- # Ready to plot!! # ---------------- # Figure size if width == "auto": plot_width = data.shape[-1] * 30 + 150 else: plot_width = width if height == "auto": plot_height = data.shape[-2] * 30 else: plot_height = height # yaxis starts from top if invert_yaxis: yaxis_labels = deepcopy(yaxis_labels)[::-1] if clickable: tools = "hover, tap, save" # hover needed for tooltip, tap needed for url open else: tools = "hover, save" if img_url is not None: if tooltips is None: # Customized tooltip tooltips = """ <div> <div> <img src="@img" alt="@img" width="300" height="200" style="float: left; margin: 0px 5px 5px 0px;" border="1" onerror style="margin: -75px 0 0 -100px" ></img><br> <span style="font-size: 14px"> <font color=darkgreen>Model:</font> <b>@yname</b><br> <font color=darkgreen>Variable:</font> <b>@xname</b><br>""" if len(position_description_list) > 0: tooltips += f"<font color=darkgreen>{legend_name.capitalize()}:</font> <b>@position_description</b><br>" tooltips += """ <font color=darkgreen>Value (Nor.):</font> @field<br> <font color=darkgreen>Value (Act.):</font> @field2</span> </div> </div>""" else: if tooltips is None: tooltips = [("Model", "@yname"), ("Variable", "@xname")] if len(position_description_list) > 0: tooltips.append((legend_name.capitalize(), "@position_description")) tooltips += [ ("Value (Nor.)", "@field"), ] if annotate: tooltips.append(("Value (Act.)", "@field2")) if debug and num_divide > 1: tooltips.append(("Position", "@position")) if xaxis_location in ["above", "below"]: x_axis_location = xaxis_location elif xaxis_location == "both": x_axis_location = "above" else: sys.exit("Error: xaxis_location should be either above, below, or both") plot = figure( title=title, x_range=xaxis_labels, y_range=yaxis_labels, width=plot_width, height=plot_height, min_border=50, tools=tools, tooltips=tooltips, x_axis_location=x_axis_location, aspect_scale=aspect_scale, ) # Color Map control if cmap_bounds is None: ncolors = 255 else: ncolors = len(cmap_bounds) - 1 colormap = plt.get_cmap(cmap, ncolors) m_colormap_rgb = (255 * colormap(range(0, ncolors))).astype("int") colors = [RGB(*tuple(rgb)).to_hex() for rgb in m_colormap_rgb] if vrange is None: vmin = np.nanmin(np.array(field)) vmax = np.nanmax(np.array(field)) else: vmin = np.min(vrange) vmax = np.max(vrange) if cmap_bounds is not None: vmin = min([vmin, min(cmap_bounds)]) vmax = max([vmax, max(cmap_bounds)]) mapper = LinearColorMapper( palette=colors, low=vmin, high=vmax, nan_color=missing_color ) glyph = Patches( xs="xs", ys="ys", fill_color={"field": "field", "transform": mapper}, line_color=line_color, line_width=0.5, ) # Generate actual plot plot.add_glyph( source, glyph, selection_glyph=glyph, nonselection_glyph=glyph ) # keep same transparency regardless of selection # x-axis tick labels plot.xaxis.major_label_orientation = math.radians( xaxis_rotation ) # degree to radian if xaxis_fontsize is not None: plot.xaxis.major_label_text_font_size = str(xaxis_fontsize) + "pt" if xaxis_fontstyle is not None: plot.xaxis.major_label_text_font_style = xaxis_fontstyle # y-axis tick labels if yaxis_fontsize is not None: plot.yaxis.major_label_text_font_size = str(yaxis_fontsize) + "pt" if yaxis_fontstyle is not None: plot.yaxis.major_label_text_font_style = yaxis_fontstyle # x-axis at the bottom as well if xaxis_location == "both" and x_axis_location == "above": plot.add_layout(LinearAxis(), "below") plot.xaxis[1].ticker = [x + 0.5 for x in list(range(0, len(xaxis_labels)))] xaxis_dict = dict() for x in list(range(0, len(xaxis_labels))): xaxis_dict[x + 0.5] = xaxis_labels[x] plot.xaxis[1].major_label_overrides = xaxis_dict plot.xaxis[1].major_label_orientation = math.radians( xaxis_rotation ) # degree to radian if xaxis_fontsize is not None: plot.xaxis[1].major_label_text_font_size = str(xaxis_fontsize) + "pt" if xaxis_fontstyle is not None: plot.xaxis[1].major_label_text_font_style = xaxis_fontstyle # Color Bar if cbar_tick_fontsize is None: cbar_tick_fontsize = "12px" else: cbar_tick_fontsize = str(int(cbar_tick_fontsize)) + "px" color_bar = ColorBar( color_mapper=mapper, major_label_text_font_size=cbar_tick_fontsize, ticker=BasicTicker(desired_num_ticks=10), label_standoff=6, border_line_color=None, location=(0, 0), ) plot.add_layout(color_bar, cbar_place) # Link to open when clicked if clickable: taptool = taptool.callback = OpenURL(url="@url") # Trun off bokeh log if bokeh_logo is False: plot.toolbar.logo = None # Turn off bokeh toolbar if bokeh_toolbar is False: plot.toolbar_location = None # Title if title is not None: plot.title.align = "center" return_object = plot # Show the plot if requested if show_plot: show(return_object) return return_object
# ----------------- # Support functions # ----------------- def prepare_data(data, xaxis_labels, yaxis_labels, debug=False): """ Prepare data for plotting. Parameters ---------- data : list or np.ndarray Data to plot. xaxis_labels : list of str Labels for the x-axis. yaxis_labels : list of str Labels for the y-axis. debug : bool, optional If True, print debug information. The default is False. Returns ------- np.ndarray, int The data and the number of divisions. """ # In case data was given as list of arrays, convert it to numpy (stacked) array if isinstance(data, list): if debug: print("data type is list") print("len(data):", len(data)) if len(data) == 1: # list has only 1 array as element if isinstance(data[0], np.ndarray) and (len(data[0].shape) == 2): data = data[0] num_divide = 1 else: sys.exit("Error: Element of given list is not in np.ndarray type") else: # list has more than 1 arrays as elements data = np.stack(data) num_divide = len(data) # Now, data is expected to be be a numpy array (whether given or converted from list) if debug: print("data.shape:", data.shape) if data.shape[-1] != len(xaxis_labels) and len(xaxis_labels) > 0: sys.exit("Error: Number of elements in xaxis_label mismatchs to the data") if data.shape[-2] != len(yaxis_labels) and len(yaxis_labels) > 0: sys.exit("Error: Number of elements in yaxis_label mismatchs to the data") if isinstance(data, np.ndarray): # data = np.squeeze(data) if len(data.shape) == 2: num_divide = 1 elif len(data.shape) == 3: num_divide = data.shape[0] else: print("data.shape:", data.shape) sys.exit("Error: data.shape is not right") else: sys.exit("Error: Converted or given data is not in np.ndarray type") if debug: print("num_divide:", num_divide) return data, num_divide def find_intersection_points_centered(num_sectors: int) -> list: """ Finds intersection points of evenly spaced radius lines from the center of a circle with a square on the x-y plane. The first partition is centered at angle 90° (π/2 radians), regardless of whether the number of sectors is odd or even. Parameters ---------- num_sectors : int Number of sectors to divide the circle into. Returns ------- list of tuples Intersection points of the radius lines with the square boundary. """ if num_sectors < 2: raise ValueError("Number of sectors must be greater than 1.") if num_sectors == 2: return [(0, 0), (1, 1)] # Define square edges edges = [ ((0, 0), (1, 0)), # Bottom edge ((1, 0), (1, 1)), # Right edge ((1, 1), (0, 1)), # Top edge ((0, 1), (0, 0)), # Left edge ] # Circle center and radius center = (0.5, 0.5) radius = 0.5 # Calculate evenly spaced angles in the clockwise direction angles = np.linspace( 0, -2 * np.pi, num_sectors, endpoint=False ) # Negative to reverse direction if num_sectors % 2 == 0: # Shift angles to center the first sector at 90° (π/2) angles = angles - (angles[1] / 2) + np.pi / 2 else: # Shift angles for odd sectors to ensure first centered at 90° (π/2) angles = angles + np.pi / num_sectors + np.pi / 2 intersection_points = [] for angle in angles: # Define the line direction from the center dx, dy = np.cos(angle), np.sin(angle) ray_end = (center[0] + dx * radius, center[1] + dy * radius) for edge in edges: (x1, y1), (x2, y2) = edge # Solve line-line intersection det = (x1 - x2) * (center[1] - ray_end[1]) - (y1 - y2) * ( center[0] - ray_end[0] ) if det == 0: continue t = ( (x1 - center[0]) * (center[1] - ray_end[1]) - (y1 - center[1]) * (center[0] - ray_end[0]) ) / det u = -((x1 - x2) * (y1 - center[1]) - (y1 - y2) * (x1 - center[0])) / det if 0 <= t <= 1 and u >= 0: ix = x1 + t * (x2 - x1) iy = y1 + t * (y2 - y1) intersection_points.append((ix, iy)) break return intersection_points def create_polygons(num_sectors: int): """ Creates polygons for each sector formed by the radius lines and the square's boundary. Parameters ---------- num_sectors : int Number of sectors to divide the circle into. Returns ------- list of list of tuples Each polygon is represented as a list of (x, y) tuples. """ center = (0.5, 0.5) points = find_intersection_points_centered(num_sectors) polygons = [] square_apexes = [(0, 0), (1, 0), (1, 1), (0, 1)] for i in range(num_sectors): # Define the current and next intersection points p1 = points[i] p2 = points[(i + 1) % len(points)] # Initialize the polygon with the center and intersection points if num_sectors != 2: polygon = [center, p1] else: polygon = [p1] # Special case for 2 sectors: Add the square apex again to form a triangle if num_sectors == 2: if i == 0: polygon.append(square_apexes[3]) elif i == 1: polygon.append(square_apexes[1]) # Special case for 3 sectors: Add the square apex again to form a pentagon elif num_sectors == 3 and i == 0: polygon.append(square_apexes[3]) polygon.append(square_apexes[2]) # Special case for 4 sectors: No need to add the square apexes as they are already included elif num_sectors == 4: pass else: # Add intermediate square apexes if necessary for apex in square_apexes: if min(p1[0], p2[0]) <= apex[0] <= max(p1[0], p2[0]) and min( p1[1], p2[1] ) <= apex[1] <= max(p1[1], p2[1]): polygon.append(apex) # Add the second intersection point and close the polygon polygon.append(p2) # Append the polygon to the list polygons.append(polygon) return polygons def extract_polygon_coordinates(polygons): """ Extracts x and y coordinates from a list of polygons, rounded to 2 decimal places. Parameters ---------- polygons : list of list of tuples Each polygon is represented as a list of (x, y) tuples. Returns ------- tuple of lists xpts_list : list of lists of floats x-coordinates of the polygons, rounded to 2 decimal places. ypts_list : list of lists of floats y-coordinates of the polygons, rounded to 2 decimal places. """ xpts_list = [] ypts_list = [] for polygon in polygons: xpts = [round(float(point[0]), 2) for point in polygon] ypts = [round(float(point[1]), 2) for point in polygon] xpts_list.append(xpts) ypts_list.append(ypts) return xpts_list, ypts_list def get_x_y_points(num_sectors: int): """ Returns the x and y coordinates of the polygons formed by radius partitions. Parameters ---------- num_sectors : int The number of sectors to divide the circle into. Returns ------- tuple of lists xpts_list : list of lists of floats x-coordinates of the polygons, rounded to 2 decimal places. ypts_list : list of lists of floats y-coordinates of the polygons, rounded to 2 decimal places. """ if num_sectors == 1: xpts = [0, 0, 1, 1] ypts = [1, 0, 0, 1] return ([xpts], [ypts]) # Create polygons polygons = create_polygons(num_sectors) xpts_list, ypts_list = extract_polygon_coordinates(polygons) return (xpts_list, ypts_list) def get_positions(num_sectors: int): """ Returns the positions of the polygons formed by radius partitions. Parameters ---------- num_sectors : int The number of sectors to divide the circle into. Returns ------- list Positions of the polygons. """ positions = None if num_sectors == 4: positions = ["top", "right", "bottom", "left"] elif num_sectors == 3: positions = ["top", "lower-left", "lower-right"] elif num_sectors == 2: positions = ["upper", "lower"] elif num_sectors == 1: positions = ["box"] return positions