import math
from copy import deepcopy
from typing import List, Union
import numpy as np
from bokeh.layouts import column, row
from bokeh.models import ColumnDataSource, HoverTool, Label, LabelSet
from bokeh.plotting import figure, show
from bokeh.transform import factor_cmap
from .support_functions import (
convert_to_numpy_array,
create_click_callback,
create_dropdown_callback,
create_image_display,
create_name_select,
create_navigation_buttons,
create_navigation_callbacks,
debug_print,
load_colormap,
)
# -------------
# Main function
# -------------
[docs]
def taylor_diagram(
std_devs: Union[List[float], np.ndarray],
correlations: Union[List[float], np.ndarray],
names: List[str],
refstd: float,
title: str = "Interactive Taylor Diagram",
normalize: bool = False,
step: float = 0.2,
show_reference: bool = True,
reference_name: str = "Reference",
reference_image: str = None,
colormap: Union[str, List[str]] = "Spectral",
width: int = 600,
show_plot: bool = True,
images: List[str] = None,
debug: bool = False,
) -> figure:
"""
Creates an interactive Taylor diagram using Bokeh.
.. image:: /_static/example_taylor_diagram.gif
:alt: Example interactive Taylor diagram
:align: center
:width: 600px
The Taylor diagram visually represents the relationship between the
standard deviation and correlation of different models against a
reference model. This function allows for the comparison of multiple
models based on their standard deviations and correlations to a
specified reference standard deviation.
Parameters
----------
std_devs : list or np.ndarray
A list of standard deviations of the models being compared.
correlations : list or np.ndarray
A list of correlation coefficients of the models with respect to the reference model.
names : list of str
A list of names for the models being compared, used for labeling in the plot.
refstd : float
The standard deviation of the reference model, used for calculating RMSE and for normalization if applicable.
title : str, optional
The title of the plot (default is "Interactive Taylor Diagram").
normalize : bool, optional
If True, the standard deviations are normalized by the reference standard deviation (default is False).
step : float, optional
The step size for the arcs and grid lines in the Taylor diagram (default is 0.2).
show_reference : bool, optional
If True, the reference point is shown in the Taylor diagram (default is True).
reference_name : str, optional
The name of the reference dataset (default is "Reference").
reference_image : str, optional
The path to an image file to be used as the reference image (default is None).
colormap : str or list, optional
A name of the `Matplotlib` or list of colors to use for the model points. Available names of `Matplotlib` colormap can be found `here <https://matplotlib.org/stable/users/explain/colors/colormaps.html>`_. Default is Spectral.
width : int, optional
The width of the plot in pixels (default is 600). Note that height will be set to equal as width for a Taylor Diagram. If images is provided, 2 times of the given width is going to be total width to show diagnostic image display panel on the right of the Taylor Diagram.
show_plot : bool, optional
If True, the plot will be displayed in the workflow (default is True).
images: str, optional
A list of image paths to be displayed on the plot. The images will be placed at the data points of the models.
debug: bool, optional
If True, prints additional debugging information (default is False).
Returns
-------
bokeh.plotting.Figure
Bokeh figure object containing the interactive Taylor diagram.
Example
-------
>>> from ESMBenchmarkViz import taylor_diagram
>>> std_devs = [0.8, 1.0, 1.2] # Standard deviations of models
>>> correlations = [0.9, 0.85, 0.7] # Correlation coefficients
>>> names = ["Model A", "Model B", "Model C"] # Names of models
>>> refstd = 1.0 # Standard deviation of reference model
>>> taylor_diagram(std_devs, correlations, names, refstd)
Example use case can be found `here <../examples/example_taylor_diagram.html>`_.
Notes
-----
The Taylor diagram is a polar plot where the radial distance represents the standard deviation
and the azimuthal angle represents the correlation coefficient. The reference standard deviation
is used as a reference point for the radial distance. The correlation coefficient is represented
by the angle between the model point and the reference point. The RMSE (Root Mean Square Error)
is calculated as the distance between the model point and the reference point.
2024-10-04: Jiwoo Lee, initial version
"""
# Sanity check for input data
if (
len(std_devs) != len(correlations)
or len(std_devs) != len(names)
or len(correlations) != len(names)
):
raise ValueError(
"The lengths of 'std_devs', 'correlations', and 'names' must be equal."
)
if images is not None:
if len(std_devs) != len(images):
raise ValueError(
"The lengths of 'std_devs', 'correlations', 'names', and 'images' must be equal."
)
# Taylor diagram width and height are equal.
# If images are provided, the total width will be 2 times of the given width.
height = width
# Convert input lists to numpy arrays for consistency
std_devs = convert_to_numpy_array(std_devs)
correlations = convert_to_numpy_array(correlations)
names = deepcopy(names)
# Standard deviation axis extent
if normalize:
std_devs = std_devs / refstd
refstd = 1.0
std_name = "Normalized St. Dv."
else:
std_name = "Standard Deviation"
# Add the reference to the list of models
if show_reference:
names.append(reference_name)
std_devs = np.append(std_devs, refstd)
correlations = np.append(correlations, 1.0)
if images:
images.append(reference_image)
if isinstance(colormap, list):
colormap.append("black")
# Calculate RMSE values
rmse = [
np.sqrt(refstd**2 + rs**2 - 2 * refstd * rs * ts)
for rs, ts in zip(std_devs, correlations)
]
# Calculate polar coordinates
r = std_devs
theta = np.arccos(correlations)
# Create figure
max_stddev = max(std_devs) # Get the largest standard deviation
max_range = max_stddev * 1.1 + step # 10% larger than the largest value
p = figure(
width=width,
height=height,
x_range=(step * -1, max_range),
y_range=(step * -1, max_range),
aspect_ratio=1,
title=title,
tools="tap, pan, wheel_zoom, box_zoom, reset",
)
p.grid.visible = False
p.axis.visible = False
# Apply the adjustments in your main code
# Standard deviation and RMSE arcs
add_reference_arcs(p, max_stddev, step=step, refstd=refstd)
# Adjust reference lines to end at the outermost arc
add_reference_lines(p, max_stddev + step)
# Get the selected colormap
selected_colors = load_colormap(colormap, len(names))
if debug:
print("selected_colors:", selected_colors)
# Color mapping based on model names
colors = factor_cmap("names", palette=selected_colors, factors=names)
if debug:
print(debug, "colors:", colors)
# Wrap up input as a dictionary
data = {
"x": r * np.cos(theta),
"y": r * np.sin(theta),
"names": names,
"std_devs": std_devs,
"correlations": correlations,
"rmse": rmse,
}
if images:
data["images"] = images
if debug:
print("data:", data)
print("len(data[x]):", len(data["x"]))
print("len(data[y]):", len(data["y"]))
print("len(data[names]):", len(data["names"]))
print("len(data[std_devs]):", len(data["std_devs"]))
print("len(data[correlations]):", len(data["correlations"]))
print("len(data[rmse]):", len(data["rmse"]))
if images:
print("len(data[images]:", len(data["images"]))
# Create a ColumnDataSource
source = ColumnDataSource(data=data)
# Plot data points with color mapping
points = p.scatter(
"x", "y", size=10, source=source, color=colors, legend_field="names"
)
debug_print(debug, "points added")
# Add labels for data points
labels = LabelSet(
x="x",
y="y",
text="names",
x_offset=5,
y_offset=5,
source=source,
text_font_size="8pt",
)
p.add_layout(labels)
if debug:
print("label for data points added")
# Add hover tool
if images:
# Add hover tool with image tooltip
hover = HoverTool(
renderers=[points],
tooltips="""
<div>
<img src="@images" alt="" style="width:100px;height:auto;"/>
<div><strong>Model:</strong> @names</div>
<div><strong>STD:</strong> @std_devs{0.000}</div>
<div><strong>COR:</strong> @correlations{0.000}</div>
<div><strong>RMSE:</strong> @rmse{0.000}</div>
</div>
""",
)
else:
hover = HoverTool(
renderers=[points],
tooltips=[
("Model", "@names"),
(std_name, "@std_devs{0.000}"),
("Correlation", "@correlations{0.000}"),
("RMSE", "@rmse{0.000}"),
],
)
p.add_tools(hover)
debug_print(debug, "hover tool added")
# Add axes labels with improved alignment
p.add_layout(
Label(
x=max_range / 2,
y=-0.18,
text="Standard Deviation",
text_font_style="italic",
text_align="center",
)
)
p.add_layout(
Label(
x=max_range / 1.5,
y=max_range / 1.4,
text="Correlation",
text_font_style="italic",
angle=np.deg2rad(-45),
text_align="center",
)
)
debug_print(debug, "axes labels added")
# Customize legend
p.legend.location = "top_right"
# Return the plot object if images are not provided, otherwise return the layout
if not images:
return_object = p
else:
# Div to display image and x, y values on click
# maximum height is for the actual image display inside the image_display Div
image_display, max_height = create_image_display(width, height)
# Dropdown menu for names with default "Select Data"
name_select = create_name_select(data)
debug_print(debug, "name_select made")
# Create buttons for Previous and Next Image Navigation
previous_button, next_button = create_navigation_buttons()
# JavaScript callback for dropdown selection changes
dropdown_callback = create_dropdown_callback(
source, image_display, name_select, max_height
)
name_select.js_on_change("value", dropdown_callback)
debug_print(debug, "dropdown_callback made")
# JavaScript callback for click events
click_callback = create_click_callback(
source, image_display, name_select, max_height
)
source.selected.js_on_change("indices", click_callback)
debug_print(debug, "click_callback added to source")
# JavaScript callbacks for Previous and Next Image buttons
previous_callback, next_callback = create_navigation_callbacks(
source, image_display, name_select, max_height
)
previous_button.js_on_event("button_click", previous_callback)
next_button.js_on_event("button_click", next_callback)
debug_print(debug, "navigation callbacks added")
# Arrange the Previous and Next buttons side by side
navigation_buttons = row(previous_button, next_button)
# Arrange layout
controls = column(name_select, image_display, navigation_buttons)
layout = row(p, controls)
debug_print(debug, "layout added")
return_object = layout
# Show the plot if requested
if show_plot:
show(return_object)
return return_object
# -----------------
# Support functions
# -----------------
def find_circle_intersection(x1, y1, r1, x2, y2, r2):
"""
Find the intersection points of two circles.
Parameters
----------
x1, y1 : float
Center coordinates of the first circle.
r1 : float
Radius of the first circle.
x2, y2 : float
Center coordinates of the second circle.
r2 : float
Radius of the second circle.
Returns
-------
list of tuple
A list of intersection points, where each point is represented as a tuple (x, y).
Notes
-----
If the circles do not intersect, or they are identical (infinite intersections),
the function returns an empty list.
Example
-------
>>> x1, y1, r1 = 0, 0, 5 # First circle: center (0, 0), radius 5
>>> x2, y2, r2 = 4, 0, 3 # Second circle: center (4, 0), radius 3
>>> find_circle_intersection(x1, y1, r1, x2, y2, r2)
[(4.0, -3.0), (4.0, 3.0)]
"""
# Distance between the centers
d = math.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2)
# Check for no solution (the circles are too far apart or one is inside the other)
if d > r1 + r2 or d < abs(r1 - r2):
return [] # No intersection points
# Check for identical circles (infinite intersections)
if d == 0 and r1 == r2:
return [] # Infinite intersections, return empty for now
# Finding the intersection points
a = (r1**2 - r2**2 + d**2) / (2 * d)
h = math.sqrt(abs(r1**2 - a**2))
# Finding point P2 which is the point where the line through the circle
# intersection points crosses the line between the circle centers.
x3 = x1 + a * (x2 - x1) / d
y3 = y1 + a * (y2 - y1) / d
# Now find the two intersection points
intersection1 = (x3 + h * (y2 - y1) / d, y3 - h * (x2 - x1) / d)
intersection2 = (x3 - h * (y2 - y1) / d, y3 + h * (x2 - x1) / d)
return [intersection1, intersection2]
def find_circle_y_axis_intersection(x1, y1, r):
"""
Find the intersection points of a circle and the y-axis.
Parameters
----------
x1, y1 : float
Center coordinates of the circle.
r : float
Radius of the circle.
Returns
-------
list of tuple
A list of intersection points on the y-axis, where each point is represented as a tuple (0, y).
If there are no intersections, an empty list is returned.
Example
-------
>>> x1, y1, r = 3, 0, 5 # Circle center (3, 0), radius 5
>>> find_circle_y_axis_intersection(x1, y1, r)
[(0, 4.0), (0, -4.0)]
"""
# Calculate the discriminant to check if there's a valid solution
if r**2 - x1**2 < 0:
return [] # No intersection points, circle does not reach the y-axis
# Calculate the y values of the intersection points
y_intersection_1 = y1 + math.sqrt(r**2 - x1**2)
y_intersection_2 = y1 - math.sqrt(r**2 - x1**2)
return [(0, y_intersection_1), (0, y_intersection_2)]
def angle_with_x_axis(x1, y1, x2, y2):
"""
Calculate the angle between the line connecting two points and the x-axis.
Parameters
----------
x1, y1 : float
Coordinates of the first point.
x2, y2 : float
Coordinates of the second point.
Returns
-------
float
The angle in degrees between the line connecting the two points and the x-axis.
The angle is in the range [0, 360).
Example
-------
>>> x1, y1 = 1, 1
>>> x2, y2 = 4, 5
>>> angle_with_x_axis(x1, y1, x2, y2)
53.13
"""
# Avoid division by zero in case x1 == x2 (vertical line)
if x2 == x1:
if y2 > y1:
return 90.0 # Vertical line pointing upwards
else:
return 270.0 # Vertical line pointing downwards
# Calculate the slope and then the angle in radians
angle_radians = math.atan2(y2 - y1, x2 - x1)
# Convert the angle from radians to degrees
angle_degrees = math.degrees(angle_radians)
# Ensure the angle is in the range [0, 360)
if angle_degrees < 0:
angle_degrees += 360
return angle_degrees
def find_line_circle_intersection(x1, y1, x2, y2, x3, y3, r):
"""
Find the intersection points of a line passing through two points and a circle.
Parameters
----------
x1, y1 : float
Coordinates of the first point on the line.
x2, y2 : float
Coordinates of the second point on the line.
x3, y3 : float
Center coordinates of the circle.
r : float
Radius of the circle.
Returns
-------
list of tuple
A list of intersection points, where each point is represented as a tuple (x, y).
If no intersection exists, the list will be empty.
Example
-------
>>> x1, y1 = 1, 2 # First point on the line
>>> x2, y2 = 4, 6 # Second point on the line
>>> x3, y3 = 3, 3 # Center of the circle
>>> r = 5 # Radius of the circle
>>> intersections = find_line_circle_intersection(x1, y1, x2, y2, x3, y3, r)
>>> print("Intersection points:", intersections)
Intersection points: [(5.14, 7.52), (-0.74, -0.32)]
"""
# Check if the line is vertical to avoid division by zero
if x1 == x2:
# Special case: Vertical line x = x1
x = x1
discriminant = r**2 - (x - x3) ** 2
if discriminant < 0:
return [] # No intersection
y1_int = y3 + math.sqrt(discriminant)
y2_int = y3 - math.sqrt(discriminant)
return [(x, y1_int), (x, y2_int)]
# Calculate the slope (m) and intercept (b) of the line
m = (y2 - y1) / (x2 - x1)
b = y1 - m * x1
# Substitute the line equation y = mx + b into the circle equation (x - x3)^2 + (y - y3)^2 = r^2
# Expand the equation and solve for x
A = 1 + m**2
B = 2 * (m * (b - y3) - x3)
C = x3**2 + (b - y3) ** 2 - r**2
# Calculate the discriminant
discriminant = B**2 - 4 * A * C
if discriminant < 0:
return [] # No intersection points
# Calculate the x values of the intersection points
x_int1 = (-B + math.sqrt(discriminant)) / (2 * A)
x_int2 = (-B - math.sqrt(discriminant)) / (2 * A)
# Calculate the corresponding y values using the line equation y = mx + b
y_int1 = m * x_int1 + b
y_int2 = m * x_int2 + b
return [(x_int1, y_int1), (x_int2, y_int2)]
# Function to add reference arcs and labels (Standard deviation and RMSE) on Taylor Diagram
def add_reference_arcs(plot, max_stddev, step, refstd=1, thick_refstd=True):
"""Adds reference arcs for standard deviation and RMSE to the plot."""
loop_range = np.arange(step, max_stddev + 2 * step, step)
outermost_radius = loop_range[-1]
for n, i in enumerate(loop_range):
if n < len(loop_range) - 1:
# inner arcs
line_color = "gray"
line_width = 1
else:
# outermost arc
line_color = "black"
line_width = 3
# ======================
# Standard deviation arc
# ======================
plot.arc(
0,
0,
radius=i,
start_angle=0,
end_angle=np.pi / 2,
color=line_color,
line_dash="solid",
alpha=0.3,
line_width=line_width,
)
label = Label(
x=i + 0.05,
y=0,
text=f"{i:.1f}",
text_font_size="10pt",
text_align="right",
text_alpha=0.7,
x_offset=0,
y_offset=-12,
)
plot.add_layout(label)
# ========
# RMSE arc
# ========
# To make RMSE arc starts from the outermost standard deviation arc,
# find intersection with outermost standard deviation arc and RMSE arc
intersections_start = find_circle_intersection(
refstd, 0, i, 0, 0, outermost_radius
)
# Basic start and end angle: 0 deg to 180 deg
start_angle = 0
end_angle = np.pi
# Update start angle if there is intersection of the outermost standard deviation arc and the RMSE arc
if len(intersections_start) > 0:
for intersection in intersections_start:
x_i = intersection[0]
y_i = intersection[1]
if x_i > 0 and y_i > 0:
start_angle = np.deg2rad(angle_with_x_axis(refstd, 0, x_i, y_i))
# To make RMSE arc ends at y-axis,
# find intersection with y-axis and update end angle
intersections_end = find_circle_y_axis_intersection(refstd, 0, i)
if len(intersections_end) > 0:
for intersection in intersections_end:
y_i = intersection[1]
if y_i > 0:
end_angle = np.deg2rad(angle_with_x_axis(refstd, 0, 0, y_i))
# Plot actual RMSE arc
plot.arc(
refstd,
0,
radius=i,
start_angle=start_angle,
end_angle=end_angle,
color="gray",
line_dash="dashed",
alpha=0.3,
)
# Add labels on RMSE arcs
add_rmse_labels(plot, loop_range, refstd)
if thick_refstd:
# Make the reference standard deviation arc more noticeable using thicker line
plot.arc(
0,
0,
radius=refstd,
start_angle=0,
end_angle=np.pi / 2,
color="black",
line_dash="solid",
alpha=0.3,
line_width=2,
)
# Function to add RMSE labels on Taylor Diagram
def add_rmse_labels(plot, rmse_values, refstd):
"""Adds labels along RMSE arcs following a virtual line that passes the center of RMSE arc and upper-left corner of the plot"""
for rmse_value in rmse_values:
# Find intersections of RMSE arc and the virtual line, then find crossing angle
intersections = find_line_circle_intersection(
refstd, 0, 0, rmse_values[-1], refstd, 0, rmse_value
)
for intersection in intersections:
if len(intersection) > 0 and intersection[0] > 0 and intersection[1] > 0:
x = intersection[0]
y = intersection[1]
angle_deg = angle_with_x_axis(refstd, 0, x, y) - 90
else:
# just in case something fails use the below values
x = (
1 - rmse_value / 2
) # Adjust x-coordinate for label placement along the RMSE arc
y = (
rmse_value / 1.18
) # Adjust y-coordinate for label placement along the RMSE arc
angle_deg = 40
label = Label(
x=x,
y=y,
text=f"{rmse_value:.2f}",
text_font_size="8pt",
text_align="center",
text_alpha=0.7,
angle=np.deg2rad(angle_deg),
)
plot.add_layout(label)
# Function to add reference lines for correlation on Taylor Diagram
def add_reference_lines(plot, max_radius):
"""Adds reference correlation lines to the plot that end at the max radius (outermost standard deviation arc)."""
rlocs = np.flip(np.concatenate((np.arange(10) / 10.0, [0.95, 0.99, 1])))
tlocs = np.arccos(rlocs)
for angle, correlation in zip(tlocs, rlocs):
if angle == 0 or angle == np.pi / 2:
line_color = "black"
line_width = 3
line_dash = "solid"
else:
line_color = "gray"
line_width = 1
line_dash = "dotted"
x = max_radius * correlation
y = max_radius * np.sin(angle)
plot.line(
[0, x],
[0, y],
color=line_color,
line_dash=line_dash,
alpha=0.3,
line_width=line_width,
)
label = Label(
x=x,
y=y,
text=f"{correlation}",
text_font_size="8pt",
text_align="left",
text_alpha=0.7,
x_offset=5,
y_offset=5,
)
plot.add_layout(label)