Source code for ESMBenchmarkViz.core_portrait_plot
# - Generate an interactive Portrait Plot using Bokeh.# - Author: Jiwoo Lee (2021.08)# - Last update: 2024.11importmathimportsysfromcopyimportdeepcopyfromtypingimportList,Optional,Tuple,Unionimportmatplotlib.pyplotaspltimportnumpyasnpimportpandasaspdfrombokeh.colorsimportRGBfrombokeh.modelsimport(BasicTicker,ColorBar,ColumnDataSource,LinearAxis,LinearColorMapper,OpenURL,Patches,TapTool,)frombokeh.plottingimportfigure,show# -------------# Main function# -------------
[docs]defportrait_plot(data:Union[np.ndarray,List[np.ndarray]],xaxis_labels:List[str],yaxis_labels:List[str],width:Union[int,str]=600,height:Union[int,str]=600,annotate:bool=False,annotate_data:Optional[np.ndarray]=None,vrange:Optional[Tuple[float,float]]=None,xaxis_fontsize:Optional[int]=None,yaxis_fontsize:Optional[int]=None,xaxis_fontstyle:Optional[str]=None,yaxis_fontstyle:Optional[str]=None,xaxis_location:str="above",xaxis_rotation:int=45,title:Optional[str]=None,cmap:str="RdBu_r",cmap_bounds:Optional[List[float]]=None,cbar_place:str="right",cbar_tick_fontsize:Optional[int]=None,invert_yaxis:bool=True,clickable:bool=False,line_color:str="grey",legend_name:Optional[str]="Group",legend_labels:Optional[List[str]]=None,img_url:Optional[List[str]]=None,tooltips:Optional[Union[str,List[Tuple[str,str]]]]=None,url_open:Optional[List[str]]=None,missing_color:str="grey",aspect_scale:float=1,show_plot:bool=True,bokeh_toolbar:bool=True,bokeh_logo:bool=True,debug:bool=False,):""" Generates an interactive portrait plot using Bokeh. Parameters ---------- data : numpy.ndarray or list of numpy.ndarray A 2D array, a list of 2D arrays, or a 3D array (stacked 2D arrays) containing the data to plot. xaxis_labels : list of str Labels for the x-axis. The number of labels must match the x-axis dimensions, or use an empty list to disable labels. yaxis_labels : list of str Labels for the y-axis. The number of labels must match the y-axis dimensions, or use an empty list to disable labels. width : int, optional Width of the plot in pixels. Default is 600. If 'auto', the width is calculated based on the data size. height : int, optional Height of the plot in pixels. Default is 600. If 'auto', the height is calculated based on the data size. annotate : bool, optional If True, adds annotations to the plot (only for heatmap-style plots). Default is False. annotate_data : numpy.ndarray, optional A 2D array to use for annotations. If None, `data` is used. Default is None. vrange : tuple of float, optional Range of values for the color scale. Default is None. xaxis_fontsize : int, optional Font size for the x-axis tick labels. Default is None. yaxis_fontsize : int, optional Font size for the y-axis tick labels. Default is None. xaxis_fontstyle : str, optional Font style for the x-axis labels. Options are ['normal', 'italic', 'bold', 'bold italic']. Default is None. yaxis_fontstyle : str, optional Font style for the y-axis labels. Options are ['normal', 'italic', 'bold', 'bold italic']. Default is None. xaxis_location : str, optional Location of the x-axis. Options are ['above', 'below', 'both']. Default is 'above'. xaxis_rotation : int, optional Rotation angle of the x-axis tick labels in degrees. Default is 45. title : str, optional Title of the figure. Default is None. cmap : str, optional Name of the matplotlib colormap to use. Default is 'RdBu_r'. cmap_bounds : list of float, optional If specified, applies discrete color bins. Default is None. cbar_place : str, optional Location of the colorbar. Options are ['left', 'right', 'above', 'below', 'center']. Default is 'right'. cbar_tick_fontsize : int, optional Font size for the colorbar tick labels. Default is None. invert_yaxis : bool, optional If True, places y=0 at the top of the plot. Default is True. clickable : bool, optional If True, enables clickable functionality. Default is False. line_color: str, optional Color of the lines in the plot. Default is 'grey'. legend_name: str, optional Name of the legend (used for triangular plots). Default is 'Group'. legend_labels : list of str, optional Labels for the legend (used for triangular plots). Default is None. img_url : list of str, optional Links to images displayed in tooltips. Default is None. tooltips : str or list of tuple, optional Tooltips for the plot. Default is None. url_open : list of str, optional Links to open when a tooltip is clicked. Default is None. missing_color : str, optional Color for missing values in the plot. Default is 'grey'. aspect_scale : float, optional Scale factor for the plot aspect ratio. Default is 1. show_plot : bool, optional If True, the plot will be displayed in the workflow (default is True). bokeh_toolbar : bool, optional If True, displays the Bokeh toolbar in the plot. Default is True. bokeh_logo : bool, optional If True, displays the Bokeh logo in the plot. Default is True. debug : bool, optional If True, prints debug messages. Default is False. Returns ------- plot : Bokeh component A Bokeh plot object representing the interactive portrait plot. Example ------- >>> from ESMBenchmarkViz import portrait_plot Notes ----- - The function supports both 2D and stacked 3D data for generating portrait plots. - Interactive features include tooltips and clickable URLs, enabled through Bokeh. - Missing values are displayed using the specified `missing_color`. """# ----------------# Prepare plotting# ----------------data,num_divide=prepare_data(data,xaxis_labels,yaxis_labels,debug)ifannotate:annotate_data,num_divide_annotate=prepare_data(annotate_data,xaxis_labels,yaxis_labels,debug)ifnum_divide_annotate!=num_divide:sys.exit("Error: annotate_data does not have same size as data")ifurl_openisNone:url_open=img_url# Figure typeifnum_divide>1andlen(data.shape)==3:ifnum_divide!=len(data):sys.exit("Error: data.shape[0] is not equal to num_divide")ifannotate:ifnum_divide!=len(annotate_data):sys.exit("Error: annotate_data.shape[0] is not equal to num_divide")xpts_list,ypts_list=get_x_y_points(num_divide)positions=get_positions(num_divide)# Prepare data for plotting# ~~~~~~~~~~~~~~~~~~~~~~~~~xs,ys=list(),list()field,field2=list(),list()positions_list,position_description_list=list(),list()xname_list,yname_list=list(),list()foriinrange(num_divide):xpts=xpts_list[i]ypts=ypts_list[i]ifnum_divide>1andlen(data.shape)==3:a=data[i].copy()ifannotate:annotate_a=annotate_data[i].copy()elifnum_divide==1andlen(data.shape)==2:a=data.copy()ifannotate:annotate_a=annotate_data.copy()else:sys.exit("Error: data.shape is not right")y=list(range(0,a.shape[0]))x=list(range(0,a.shape[1]))ifinvert_yaxis:a=np.flipud(deepcopy(a))ifannotate:annotate_a=np.flipud(deepcopy(annotate_a))# xs, ys: x- and y-coordinates for all the patches,# given as a “list of lists”.foriyiny:yname=yaxis_labels[iy]forixinx:xname=xaxis_labels[ix]xs.append([tmp_x+ixfortmp_xinxpts])ys.append([tmp_y+iyfortmp_yinypts])field.append(a[iy,ix])ifannotate:field2.append(annotate_a[iy,ix])iflegend_labelsisnotNone:position_description_list.append(legend_labels[i])ifpositionsisnotNone:positions_list.append(positions[i])xname_list.append(xname)yname_list.append(yname)# Gathered data for plottingcol_dict=dict(xs=xs,ys=ys,field=field,xname=xname_list,yname=yname_list,)# if img_url is not None, update col_dict with img_urlifimg_urlisnotNone:col_dict.update(dict(img=img_url))# if url_open is not None, update col_dict with url_openifurl_openisnotNone:col_dict.update(dict(url=url_open))# if field2 is not empty, update col_dict with field2iflen(field2)>0:col_dict.update(dict(field2=field2))col_dict_df=pd.DataFrame.from_dict(col_dict)col_dict_df.loc[col_dict_df.field2.isna(),("img")]="https://pcmdi.llnl.gov/pmp-preliminary-results/interactive_plot/mean_climate/no-data-whitebg.png"col_dict.update(dict(img=col_dict_df["img"].tolist()))# if position_list is not None, update col_dict with position_listiflen(positions_list)==len(xname_list):col_dict.update(dict(position=positions_list))# if position_description_list is not empty, update col_dict with position_description_listiflen(position_description_list)>0:col_dict.update(dict(position_description=position_description_list))ifdebug:print("col_dict: ",col_dict)print("col_dict.keys(): ",col_dict.keys())print("col_dict['xs']: ",col_dict["xs"])print("col_dict['ys']: ",col_dict["ys"])print("col_dict['position']: ",col_dict["position"])source=ColumnDataSource(col_dict)# ----------------# Ready to plot!!# ----------------# Figure sizeifwidth=="auto":plot_width=data.shape[-1]*30+150else:plot_width=widthifheight=="auto":plot_height=data.shape[-2]*30else:plot_height=height# yaxis starts from topifinvert_yaxis:yaxis_labels=deepcopy(yaxis_labels)[::-1]ifclickable:tools="hover, tap, save"# hover needed for tooltip, tap needed for url openelse:tools="hover, save"ifimg_urlisnotNone:iftooltipsisNone:# Customized tooltiptooltips=""" <div> <div> <img src="@img" alt="@img" width="300" height="200" style="float: left; margin: 0px 5px 5px 0px;" border="1" onerror style="margin: -75px 0 0 -100px" ></img><br> <span style="font-size: 14px"> <font color=darkgreen>Model:</font> <b>@yname</b><br> <font color=darkgreen>Variable:</font> <b>@xname</b><br>"""iflen(position_description_list)>0:tooltips+=f"<font color=darkgreen>{legend_name.capitalize()}:</font> <b>@position_description</b><br>"tooltips+=""" <font color=darkgreen>Value (Nor.):</font> @field<br> <font color=darkgreen>Value (Act.):</font> @field2</span> </div> </div>"""else:iftooltipsisNone:tooltips=[("Model","@yname"),("Variable","@xname")]iflen(position_description_list)>0:tooltips.append((legend_name.capitalize(),"@position_description"))tooltips+=[("Value (Nor.)","@field"),]ifannotate:tooltips.append(("Value (Act.)","@field2"))ifdebugandnum_divide>1:tooltips.append(("Position","@position"))ifxaxis_locationin["above","below"]:x_axis_location=xaxis_locationelifxaxis_location=="both":x_axis_location="above"else:sys.exit("Error: xaxis_location should be either above, below, or both")plot=figure(title=title,x_range=xaxis_labels,y_range=yaxis_labels,width=plot_width,height=plot_height,min_border=50,tools=tools,tooltips=tooltips,x_axis_location=x_axis_location,aspect_scale=aspect_scale,)# Color Map controlifcmap_boundsisNone:ncolors=255else:ncolors=len(cmap_bounds)-1colormap=plt.get_cmap(cmap,ncolors)m_colormap_rgb=(255*colormap(range(0,ncolors))).astype("int")colors=[RGB(*tuple(rgb)).to_hex()forrgbinm_colormap_rgb]ifvrangeisNone:vmin=np.nanmin(np.array(field))vmax=np.nanmax(np.array(field))else:vmin=np.min(vrange)vmax=np.max(vrange)ifcmap_boundsisnotNone:vmin=min([vmin,min(cmap_bounds)])vmax=max([vmax,max(cmap_bounds)])mapper=LinearColorMapper(palette=colors,low=vmin,high=vmax,nan_color=missing_color)glyph=Patches(xs="xs",ys="ys",fill_color={"field":"field","transform":mapper},line_color=line_color,line_width=0.5,)# Generate actual plotplot.add_glyph(source,glyph,selection_glyph=glyph,nonselection_glyph=glyph)# keep same transparency regardless of selection# x-axis tick labelsplot.xaxis.major_label_orientation=math.radians(xaxis_rotation)# degree to radianifxaxis_fontsizeisnotNone:plot.xaxis.major_label_text_font_size=str(xaxis_fontsize)+"pt"ifxaxis_fontstyleisnotNone:plot.xaxis.major_label_text_font_style=xaxis_fontstyle# y-axis tick labelsifyaxis_fontsizeisnotNone:plot.yaxis.major_label_text_font_size=str(yaxis_fontsize)+"pt"ifyaxis_fontstyleisnotNone:plot.yaxis.major_label_text_font_style=yaxis_fontstyle# x-axis at the bottom as wellifxaxis_location=="both"andx_axis_location=="above":plot.add_layout(LinearAxis(),"below")plot.xaxis[1].ticker=[x+0.5forxinlist(range(0,len(xaxis_labels)))]xaxis_dict=dict()forxinlist(range(0,len(xaxis_labels))):xaxis_dict[x+0.5]=xaxis_labels[x]plot.xaxis[1].major_label_overrides=xaxis_dictplot.xaxis[1].major_label_orientation=math.radians(xaxis_rotation)# degree to radianifxaxis_fontsizeisnotNone:plot.xaxis[1].major_label_text_font_size=str(xaxis_fontsize)+"pt"ifxaxis_fontstyleisnotNone:plot.xaxis[1].major_label_text_font_style=xaxis_fontstyle# Color Barifcbar_tick_fontsizeisNone:cbar_tick_fontsize="12px"else:cbar_tick_fontsize=str(int(cbar_tick_fontsize))+"px"color_bar=ColorBar(color_mapper=mapper,major_label_text_font_size=cbar_tick_fontsize,ticker=BasicTicker(desired_num_ticks=10),label_standoff=6,border_line_color=None,location=(0,0),)plot.add_layout(color_bar,cbar_place)# Link to open when clickedifclickable:taptool=plot.select(type=TapTool)taptool.callback=OpenURL(url="@url")# Trun off bokeh logifbokeh_logoisFalse:plot.toolbar.logo=None# Turn off bokeh toolbarifbokeh_toolbarisFalse:plot.toolbar_location=None# TitleiftitleisnotNone:plot.title.align="center"return_object=plot# Show the plot if requestedifshow_plot:show(return_object)returnreturn_object
# -----------------# Support functions# -----------------defprepare_data(data,xaxis_labels,yaxis_labels,debug=False):""" Prepare data for plotting. Parameters ---------- data : list or np.ndarray Data to plot. xaxis_labels : list of str Labels for the x-axis. yaxis_labels : list of str Labels for the y-axis. debug : bool, optional If True, print debug information. The default is False. Returns ------- np.ndarray, int The data and the number of divisions. """# In case data was given as list of arrays, convert it to numpy (stacked) arrayifisinstance(data,list):ifdebug:print("data type is list")print("len(data):",len(data))iflen(data)==1:# list has only 1 array as elementifisinstance(data[0],np.ndarray)and(len(data[0].shape)==2):data=data[0]num_divide=1else:sys.exit("Error: Element of given list is not in np.ndarray type")else:# list has more than 1 arrays as elementsdata=np.stack(data)num_divide=len(data)# Now, data is expected to be be a numpy array (whether given or converted from list)ifdebug:print("data.shape:",data.shape)ifdata.shape[-1]!=len(xaxis_labels)andlen(xaxis_labels)>0:sys.exit("Error: Number of elements in xaxis_label mismatchs to the data")ifdata.shape[-2]!=len(yaxis_labels)andlen(yaxis_labels)>0:sys.exit("Error: Number of elements in yaxis_label mismatchs to the data")ifisinstance(data,np.ndarray):# data = np.squeeze(data)iflen(data.shape)==2:num_divide=1eliflen(data.shape)==3:num_divide=data.shape[0]else:print("data.shape:",data.shape)sys.exit("Error: data.shape is not right")else:sys.exit("Error: Converted or given data is not in np.ndarray type")ifdebug:print("num_divide:",num_divide)returndata,num_dividedeffind_intersection_points_centered(num_sectors:int)->list:""" Finds intersection points of evenly spaced radius lines from the center of a circle with a square on the x-y plane. The first partition is centered at angle 90° (π/2 radians), regardless of whether the number of sectors is odd or even. Parameters ---------- num_sectors : int Number of sectors to divide the circle into. Returns ------- list of tuples Intersection points of the radius lines with the square boundary. """ifnum_sectors<2:raiseValueError("Number of sectors must be greater than 1.")ifnum_sectors==2:return[(0,0),(1,1)]# Define square edgesedges=[((0,0),(1,0)),# Bottom edge((1,0),(1,1)),# Right edge((1,1),(0,1)),# Top edge((0,1),(0,0)),# Left edge]# Circle center and radiuscenter=(0.5,0.5)radius=0.5# Calculate evenly spaced angles in the clockwise directionangles=np.linspace(0,-2*np.pi,num_sectors,endpoint=False)# Negative to reverse directionifnum_sectors%2==0:# Shift angles to center the first sector at 90° (π/2)angles=angles-(angles[1]/2)+np.pi/2else:# Shift angles for odd sectors to ensure first centered at 90° (π/2)angles=angles+np.pi/num_sectors+np.pi/2intersection_points=[]forangleinangles:# Define the line direction from the centerdx,dy=np.cos(angle),np.sin(angle)ray_end=(center[0]+dx*radius,center[1]+dy*radius)foredgeinedges:(x1,y1),(x2,y2)=edge# Solve line-line intersectiondet=(x1-x2)*(center[1]-ray_end[1])-(y1-y2)*(center[0]-ray_end[0])ifdet==0:continuet=((x1-center[0])*(center[1]-ray_end[1])-(y1-center[1])*(center[0]-ray_end[0]))/detu=-((x1-x2)*(y1-center[1])-(y1-y2)*(x1-center[0]))/detif0<=t<=1andu>=0:ix=x1+t*(x2-x1)iy=y1+t*(y2-y1)intersection_points.append((ix,iy))breakreturnintersection_pointsdefcreate_polygons(num_sectors:int):""" Creates polygons for each sector formed by the radius lines and the square's boundary. Parameters ---------- num_sectors : int Number of sectors to divide the circle into. Returns ------- list of list of tuples Each polygon is represented as a list of (x, y) tuples. """center=(0.5,0.5)points=find_intersection_points_centered(num_sectors)polygons=[]square_apexes=[(0,0),(1,0),(1,1),(0,1)]foriinrange(num_sectors):# Define the current and next intersection pointsp1=points[i]p2=points[(i+1)%len(points)]# Initialize the polygon with the center and intersection pointsifnum_sectors!=2:polygon=[center,p1]else:polygon=[p1]# Special case for 2 sectors: Add the square apex again to form a triangleifnum_sectors==2:ifi==0:polygon.append(square_apexes[3])elifi==1:polygon.append(square_apexes[1])# Special case for 3 sectors: Add the square apex again to form a pentagonelifnum_sectors==3andi==0:polygon.append(square_apexes[3])polygon.append(square_apexes[2])# Special case for 4 sectors: No need to add the square apexes as they are already includedelifnum_sectors==4:passelse:# Add intermediate square apexes if necessaryforapexinsquare_apexes:ifmin(p1[0],p2[0])<=apex[0]<=max(p1[0],p2[0])andmin(p1[1],p2[1])<=apex[1]<=max(p1[1],p2[1]):polygon.append(apex)# Add the second intersection point and close the polygonpolygon.append(p2)# Append the polygon to the listpolygons.append(polygon)returnpolygonsdefextract_polygon_coordinates(polygons):""" Extracts x and y coordinates from a list of polygons, rounded to 2 decimal places. Parameters ---------- polygons : list of list of tuples Each polygon is represented as a list of (x, y) tuples. Returns ------- tuple of lists xpts_list : list of lists of floats x-coordinates of the polygons, rounded to 2 decimal places. ypts_list : list of lists of floats y-coordinates of the polygons, rounded to 2 decimal places. """xpts_list=[]ypts_list=[]forpolygoninpolygons:xpts=[round(float(point[0]),2)forpointinpolygon]ypts=[round(float(point[1]),2)forpointinpolygon]xpts_list.append(xpts)ypts_list.append(ypts)returnxpts_list,ypts_listdefget_x_y_points(num_sectors:int):""" Returns the x and y coordinates of the polygons formed by radius partitions. Parameters ---------- num_sectors : int The number of sectors to divide the circle into. Returns ------- tuple of lists xpts_list : list of lists of floats x-coordinates of the polygons, rounded to 2 decimal places. ypts_list : list of lists of floats y-coordinates of the polygons, rounded to 2 decimal places. """ifnum_sectors==1:xpts=[0,0,1,1]ypts=[1,0,0,1]return([xpts],[ypts])# Create polygonspolygons=create_polygons(num_sectors)xpts_list,ypts_list=extract_polygon_coordinates(polygons)return(xpts_list,ypts_list)defget_positions(num_sectors:int):""" Returns the positions of the polygons formed by radius partitions. Parameters ---------- num_sectors : int The number of sectors to divide the circle into. Returns ------- list Positions of the polygons. """positions=Noneifnum_sectors==4:positions=["top","right","bottom","left"]elifnum_sectors==3:positions=["top","lower-left","lower-right"]elifnum_sectors==2:positions=["upper","lower"]elifnum_sectors==1:positions=["box"]returnpositions