Source code for persim.sliced_wasserstein

import numpy as np
from scipy.spatial.distance import cityblock

__all__ = ["sliced_wasserstein"]

[docs]def sliced_wasserstein(PD1, PD2, M=50): """ Implementation of Sliced Wasserstein distance as described in Sliced Wasserstein Kernel for Persistence Diagrams by Mathieu Carriere, Marco Cuturi, Steve Oudot (https://arxiv.org/abs/1706.03358) Parameters ----------- PD1: np.array size (m,2) Persistence diagram PD2: np.array size (n,2) Persistence diagram M: int, default is 50 Iterations to run approximation. Returns -------- sw: float Sliced Wasserstein distance between PD1 and PD2 """ diag_theta = np.array( [np.cos(0.25 * np.pi), np.sin(0.25 * np.pi)], dtype=np.float32 ) l_theta1 = [np.dot(diag_theta, x) for x in PD1] l_theta2 = [np.dot(diag_theta, x) for x in PD2] if (len(l_theta1) != PD1.shape[0]) or (len(l_theta2) != PD2.shape[0]): raise ValueError("The projected points and origin do not match") PD_delta1 = [[np.sqrt(x ** 2 / 2.0)] * 2 for x in l_theta1] PD_delta2 = [[np.sqrt(x ** 2 / 2.0)] * 2 for x in l_theta2] # i have the input now to compute the sw sw = 0 theta = 0.5 step = 1.0 / M for i in range(M): l_theta = np.array( [np.cos(theta * np.pi), np.sin(theta * np.pi)], dtype=np.float32 ) V1 = [np.dot(l_theta, x) for x in PD1] + [np.dot(l_theta, x) for x in PD_delta2] V2 = [np.dot(l_theta, x) for x in PD2] + [np.dot(l_theta, x) for x in PD_delta1] sw += step * cityblock(sorted(V1), sorted(V2)) theta += step return sw