import numpy as np
import matplotlib.pyplot as plt
__all__ = ["plot_diagrams", "bottleneck_matching", "wasserstein_matching"]
[docs]def plot_diagrams(
diagrams,
plot_only=None,
title=None,
xy_range=None,
labels=None,
colormap="default",
size=20,
ax_color=np.array([0.0, 0.0, 0.0]),
diagonal=True,
lifetime=False,
legend=True,
show=False,
ax=None
):
"""A helper function to plot persistence diagrams.
Parameters
----------
diagrams: ndarray (n_pairs, 2) or list of diagrams
A diagram or list of diagrams. If diagram is a list of diagrams,
then plot all on the same plot using different colors.
plot_only: list of numeric
If specified, an array of only the diagrams that should be plotted.
title: string, default is None
If title is defined, add it as title of the plot.
xy_range: list of numeric [xmin, xmax, ymin, ymax]
User provided range of axes. This is useful for comparing
multiple persistence diagrams.
labels: string or list of strings
Legend labels for each diagram.
If none are specified, we use H_0, H_1, H_2,... by default.
colormap: string, default is 'default'
Any of matplotlib color palettes.
Some options are 'default', 'seaborn', 'sequential'.
See all available styles with
.. code:: python
import matplotlib as mpl
print(mpl.styles.available)
size: numeric, default is 20
Pixel size of each point plotted.
ax_color: any valid matplotlib color type.
See [https://matplotlib.org/api/colors_api.html](https://matplotlib.org/api/colors_api.html) for complete API.
diagonal: bool, default is True
Plot the diagonal x=y line.
lifetime: bool, default is False. If True, diagonal is turned to False.
Plot life time of each point instead of birth and death.
Essentially, visualize (x, y-x).
legend: bool, default is True
If true, show the legend.
show: bool, default is False
Call plt.show() after plotting. If you are using self.plot() as part
of a subplot, set show=False and call plt.show() only once at the end.
"""
ax = ax or plt.gca()
plt.style.use(colormap)
xlabel, ylabel = "Birth", "Death"
if not isinstance(diagrams, list):
# Must have diagrams as a list for processing downstream
diagrams = [diagrams]
if labels is None:
# Provide default labels for diagrams if using self.dgm_
labels = ["$H_{{{}}}$".format(i) for i , _ in enumerate(diagrams)]
if plot_only:
diagrams = [diagrams[i] for i in plot_only]
labels = [labels[i] for i in plot_only]
if not isinstance(labels, list):
labels = [labels] * len(diagrams)
# Construct copy with proper type of each diagram
# so we can freely edit them.
diagrams = [dgm.astype(np.float32, copy=True) for dgm in diagrams]
# find min and max of all visible diagrams
concat_dgms = np.concatenate(diagrams).flatten()
has_inf = np.any(np.isinf(concat_dgms))
finite_dgms = concat_dgms[np.isfinite(concat_dgms)]
# clever bounding boxes of the diagram
if not xy_range:
# define bounds of diagram
ax_min, ax_max = np.min(finite_dgms), np.max(finite_dgms)
x_r = ax_max - ax_min
# Give plot a nice buffer on all sides.
# ax_range=0 when only one point,
buffer = 1 if xy_range == 0 else x_r / 5
x_down = ax_min - buffer / 2
x_up = ax_max + buffer
y_down, y_up = x_down, x_up
else:
x_down, x_up, y_down, y_up = xy_range
yr = y_up - y_down
if lifetime:
# Don't plot landscape and diagonal at the same time.
diagonal = False
# reset y axis so it doesn't go much below zero
y_down = -yr * 0.05
y_up = y_down + yr
# set custom ylabel
ylabel = "Lifetime"
# set diagrams to be (x, y-x)
for dgm in diagrams:
dgm[:, 1] -= dgm[:, 0]
# plot horizon line
ax.plot([x_down, x_up], [0, 0], c=ax_color)
# Plot diagonal
if diagonal:
ax.plot([x_down, x_up], [x_down, x_up], "--", c=ax_color)
# Plot inf line
if has_inf:
# put inf line slightly below top
b_inf = y_down + yr * 0.95
ax.plot([x_down, x_up], [b_inf, b_inf], "--", c="k", label=r"$\infty$")
# convert each inf in each diagram with b_inf
for dgm in diagrams:
dgm[np.isinf(dgm)] = b_inf
# Plot each diagram
for dgm, label in zip(diagrams, labels):
# plot persistence pairs
ax.scatter(dgm[:, 0], dgm[:, 1], size, label=label, edgecolor="none")
ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)
ax.set_xlim([x_down, x_up])
ax.set_ylim([y_down, y_up])
ax.set_aspect('equal', 'box')
if title is not None:
ax.set_title(title)
if legend is True:
ax.legend(loc="lower right")
if show is True:
plt.show()
def plot_a_bar(p, q, c='b', linestyle='-'):
plt.plot([p[0], q[0]], [p[1], q[1]], c=c, linestyle=linestyle, linewidth=1)
[docs]def bottleneck_matching(dgm1, dgm2, matching, labels=["dgm1", "dgm2"], ax=None):
""" Visualize bottleneck matching between two diagrams
Parameters
===========
dgm1: Mx(>=2)
array of birth/death pairs for PD 1
dgm2: Nx(>=2)
array of birth/death paris for PD 2
matching: ndarray(Mx+Nx, 3)
A list of correspondences in an optimal matching, as well as their distance, where:
* First column is index of point in first persistence diagram, or -1 if diagonal
* Second column is index of point in second persistence diagram, or -1 if diagonal
* Third column is the distance of each matching
labels: list of strings
names of diagrams for legend. Default = ["dgm1", "dgm2"],
ax: matplotlib Axis object
For plotting on a particular axis.
Examples
==========
dist, matching = persim.bottleneck(A_h1, B_h1, matching=True)
persim.bottleneck_matching(A_h1, B_h1, matching)
"""
ax = ax or plt.gca()
plot_diagrams([dgm1, dgm2], labels=labels, ax=ax)
cp = np.cos(np.pi / 4)
sp = np.sin(np.pi / 4)
R = np.array([[cp, -sp], [sp, cp]])
if dgm1.size == 0:
dgm1 = np.array([[0, 0]])
if dgm2.size == 0:
dgm2 = np.array([[0, 0]])
dgm1Rot = dgm1.dot(R)
dgm2Rot = dgm2.dot(R)
max_idx = np.argmax(matching[:, 2])
for idx, [i, j, d] in enumerate(matching):
i = int(i)
j = int(j)
linestyle = '--'
linewidth = 1
c = 'C2'
if idx == max_idx:
linestyle = '-'
linewidth = 2
c = 'C3'
if i != -1 or j != -1: # At least one point is a non-diagonal point
if i == -1:
diagElem = np.array([dgm2Rot[j, 0], 0])
diagElem = diagElem.dot(R.T)
plt.plot([dgm2[j, 0], diagElem[0]], [dgm2[j, 1], diagElem[1]], c, linewidth=linewidth, linestyle=linestyle)
elif j == -1:
diagElem = np.array([dgm1Rot[i, 0], 0])
diagElem = diagElem.dot(R.T)
ax.plot([dgm1[i, 0], diagElem[0]], [dgm1[i, 1], diagElem[1]], c, linewidth=linewidth, linestyle=linestyle)
else:
ax.plot([dgm1[i, 0], dgm2[j, 0]], [dgm1[i, 1], dgm2[j, 1]], c, linewidth=linewidth, linestyle=linestyle)
[docs]def wasserstein_matching(dgm1, dgm2, matching, labels=["dgm1", "dgm2"], ax=None):
""" Visualize bottleneck matching between two diagrams
Parameters
===========
dgm1: array
A diagram
dgm2: array
A diagram
matching: ndarray(Mx+Nx, 3)
A list of correspondences in an optimal matching, as well as their distance, where:
* First column is index of point in first persistence diagram, or -1 if diagonal
* Second column is index of point in second persistence diagram, or -1 if diagonal
* Third column is the distance of each matching
labels: list of strings
names of diagrams for legend. Default = ["dgm1", "dgm2"],
ax: matplotlib Axis object
For plotting on a particular axis.
Examples
==========
bn_matching, (matchidx, D) = persim.wasserstien(A_h1, B_h1, matching=True)
persim.wasserstein_matching(A_h1, B_h1, matchidx, D)
"""
ax = ax or plt.gca()
cp = np.cos(np.pi / 4)
sp = np.sin(np.pi / 4)
R = np.array([[cp, -sp], [sp, cp]])
if dgm1.size == 0:
dgm1 = np.array([[0, 0]])
if dgm2.size == 0:
dgm2 = np.array([[0, 0]])
dgm1Rot = dgm1.dot(R)
dgm2Rot = dgm2.dot(R)
for [i, j, d] in matching:
i = int(i)
j = int(j)
if i != -1 or j != -1: # At least one point is a non-diagonal point
if i == -1:
diagElem = np.array([dgm2Rot[j, 0], 0])
diagElem = diagElem.dot(R.T)
plt.plot([dgm2[j, 0], diagElem[0]], [dgm2[j, 1], diagElem[1]], "g")
elif j == -1:
diagElem = np.array([dgm1Rot[i, 0], 0])
diagElem = diagElem.dot(R.T)
ax.plot([dgm1[i, 0], diagElem[0]], [dgm1[i, 1], diagElem[1]], "g")
else:
ax.plot([dgm1[i, 0], dgm2[j, 0]], [dgm1[i, 1], dgm2[j, 1]], "g")
plot_diagrams([dgm1, dgm2], labels=labels, ax=ax)