Source code for persim.wasserstein

"""

    Implementation of the Wasserstein distance using
    the Hungarian algorithm

    Author: Chris Tralie

"""
import numpy as np
from sklearn import metrics
from scipy import optimize
import warnings

__all__ = ["wasserstein"]


[docs] def wasserstein(dgm1, dgm2, matching=False): """ Perform the Wasserstein distance matching between persistence diagrams. Assumes first two columns of dgm1 and dgm2 are the coordinates of the persistence points, but allows for other coordinate columns (which are ignored in diagonal matching). See the `distances` notebook for an example of how to use this. Parameters ------------ dgm1: Mx(>=2) array of birth/death pairs for PD 1 dgm2: Nx(>=2) array of birth/death paris for PD 2 matching: bool, default False if True, return matching information and cross-similarity matrix Returns --------- d: float Wasserstein distance between dgm1 and dgm2 (matching, D): Only returns if `matching=True` (tuples of matched indices, (N+M)x(N+M) cross-similarity matrix) """ S = np.array(dgm1) M = min(S.shape[0], S.size) if S.size > 0: S = S[np.isfinite(S[:, 1]), :] if S.shape[0] < M: warnings.warn( "dgm1 has points with non-finite death times;"+ "ignoring those points" ) M = S.shape[0] T = np.array(dgm2) N = min(T.shape[0], T.size) if T.size > 0: T = T[np.isfinite(T[:, 1]), :] if T.shape[0] < N: warnings.warn( "dgm2 has points with non-finite death times;"+ "ignoring those points" ) N = T.shape[0] if M == 0: S = np.array([[0, 0]]) M = 1 if N == 0: T = np.array([[0, 0]]) N = 1 # Compute CSM between S and dgm2, including points on diagonal DUL = metrics.pairwise.pairwise_distances(S, T) # Put diagonal elements into the matrix # Rotate the diagrams to make it easy to find the straight line # distance to the diagonal cp = np.cos(np.pi/4) sp = np.sin(np.pi/4) R = np.array([[cp, -sp], [sp, cp]]) S = S[:, 0:2].dot(R) T = T[:, 0:2].dot(R) D = np.zeros((M+N, M+N)) np.fill_diagonal(D, 0) D[0:M, 0:N] = DUL UR = np.inf*np.ones((M, M)) np.fill_diagonal(UR, S[:, 1]) D[0:M, N:N+M] = UR UL = np.inf*np.ones((N, N)) np.fill_diagonal(UL, T[:, 1]) D[M:N+M, 0:N] = UL # Step 2: Run the hungarian algorithm matchi, matchj = optimize.linear_sum_assignment(D) matchdist = np.sum(D[matchi, matchj]) if matching: matchidx = [(i, j) for i, j in zip(matchi, matchj)] ret = np.zeros((len(matchidx), 3)) ret[:, 0:2] = np.array(matchidx) ret[:, 2] = D[matchi, matchj] # Indicate diagonally matched points ret[ret[:, 0] >= M, 0] = -1 ret[ret[:, 1] >= N, 1] = -1 # Exclude diagonal to diagonal ret = ret[ret[:, 0] + ret[:, 1] != -2, :] return matchdist, ret return matchdist