From c6f9d6527222297979e928a2fc0e834b327f58b5 Mon Sep 17 00:00:00 2001 From: eddardd Date: Fri, 1 Jul 2022 17:34:39 +0200 Subject: [PATCH 01/14] Adding function for computing Sinkhorn Free Support barycenters --- ot/bregman.py | 121 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 121 insertions(+) diff --git a/ot/bregman.py b/ot/bregman.py index 34dcadb81..c242215ed 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -1540,6 +1540,127 @@ def barycenter_sinkhorn(A, M, reg, weights=None, numItermax=1000, return geometricBar(weights, UKv) +def free_support_sinkhorn_barycenter(measures_locations, measures_weights, X_init, reg, b=None, weights=None, + numItermax=100, numInnerItermax=1000, stopThr=1e-7, verbose=False, log=None, + **kwargs): + r""" + Solves the free support (locations of the barycenters are optimized, not the weights) regularized Wasserstein barycenter problem (i.e. the weighted Frechet mean for the 2-Sinkhorn divergence), formally: + + .. math:: + \min_\mathbf{X} \quad \sum_{i=1}^N w_i W_{reg}^2(\mathbf{b}, \mathbf{X}, \mathbf{a}_i, \mathbf{X}_i) + + where : + + - :math:`w \in \mathbb{(0, 1)}^{N}`'s are the barycenter weights and sum to one + - `measure_weights` denotes the :math:`\mathbf{a}_i \in \mathbb{R}^{k_i}`: empirical measures weights (on simplex) + - `measures_locations` denotes the :math:`\mathbf{X}_i \in \mathbb{R}^{k_i, d}`: empirical measures atoms locations + - :math:`\mathbf{b} \in \mathbb{R}^{k}` is the desired weights vector of the barycenter + + This problem is considered in :ref:`[20] ` (Algorithm 2). + There are two differences with the following codes: + + - we do not optimize over the weights + - we do not do line search for the locations updates, we use i.e. :math:`\theta = 1` in + :ref:`[20] ` (Algorithm 2). This can be seen as a discrete + implementation of the fixed-point algorithm of + :ref:`[43] ` proposed in the continuous setting. + - at each iteration, instead of solving an exact OT problem, we use the Sinkhorn algorithm for calculating the + transport plan in :ref:`[20] ` (Algorithm 2). + + Parameters + ---------- + measures_locations : list of N (k_i,d) array-like + The discrete support of a measure supported on :math:`k_i` locations of a `d`-dimensional space + (:math:`k_i` can be different for each element of the list) + measures_weights : list of N (k_i,) array-like + Numpy arrays where each numpy array has :math:`k_i` non-negatives values summing to one + representing the weights of each discrete input measure + + X_init : (k,d) array-like + Initialization of the support locations (on `k` atoms) of the barycenter + reg : float + Regularization term >0 + b : (k,) array-like + Initialization of the weights of the barycenter (non-negatives, sum to 1) + weights : (N,) array-like + Initialization of the coefficients of the barycenter (non-negatives, sum to 1) + + numItermax : int, optional + Max number of iterations + numInnerItermax : int, optional + Max number of iterations when calculating the transport plans with Sinkhorn + stopThr : float, optional + Stop threshold on error (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + Returns + ------- + X : (k,d) array-like + Support locations (on k atoms) of the barycenter + + See Also + -------- + ot.bregman.sinkhorn : Entropic regularized OT solver + ot.lp.free_support_barycenter : Barycenter solver based on Linear Programming + + .. _references-free-support-barycenter: + References + ---------- + .. [20] Cuturi, Marco, and Arnaud Doucet. "Fast computation of Wasserstein barycenters." International Conference on Machine Learning. 2014. + + .. [43] Álvarez-Esteban, Pedro C., et al. "A fixed-point approach to barycenters in Wasserstein space." Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762. + + """ + nx = get_backend(*measures_locations, *measures_weights, X_init) + + iter_count = 0 + + N = len(measures_locations) + k = X_init.shape[0] + d = X_init.shape[1] + if b is None: + b = nx.ones((k,), type_as=X_init) / k + if weights is None: + weights = nx.ones((N,), type_as=X_init) / N + + X = X_init + + log_dict = {} + displacement_square_norms = [] + + displacement_square_norm = stopThr + 1. + + while (displacement_square_norm > stopThr and iter_count < numItermax): + + T_sum = nx.zeros((k, d), type_as=X_init) + + for (measure_locations_i, measure_weights_i, weight_i) in zip(measures_locations, measures_weights, weights): + M_i = dist(X, measure_locations_i) + T_i = sinkhorn(b, measure_weights_i, M_i, reg=reg, numItermax=numInnerItermax, **kwargs) + T_sum = T_sum + weight_i * 1. / b[:, None] * nx.dot(T_i, measure_locations_i) + + displacement_square_norm = nx.sum((T_sum - X) ** 2) + if log: + displacement_square_norms.append(displacement_square_norm) + + X = T_sum + + if verbose: + print('iteration %d, displacement_square_norm=%f\n', iter_count, displacement_square_norm) + + iter_count += 1 + + if log: + log_dict['displacement_square_norms'] = displacement_square_norms + return X, log_dict + else: + return X + + + def _barycenter_sinkhorn_log(A, M, reg, weights=None, numItermax=1000, stopThr=1e-4, verbose=False, log=False, warn=True): r"""Compute the entropic wasserstein barycenter in log-domain From 85fe8aefde2fbe2dfbd825642e5b78d071f58a0a Mon Sep 17 00:00:00 2001 From: eddardd Date: Sun, 3 Jul 2022 23:29:33 +0200 Subject: [PATCH 02/14] Adding exampel on Free Support Sinkhorn Barycenter --- .../plot_free_support_sinkhorn_barycenter.py | 155 ++++++++++++++++++ 1 file changed, 155 insertions(+) create mode 100644 examples/barycenters/plot_free_support_sinkhorn_barycenter.py diff --git a/examples/barycenters/plot_free_support_sinkhorn_barycenter.py b/examples/barycenters/plot_free_support_sinkhorn_barycenter.py new file mode 100644 index 000000000..4266e6bc6 --- /dev/null +++ b/examples/barycenters/plot_free_support_sinkhorn_barycenter.py @@ -0,0 +1,155 @@ +# -*- coding: utf-8 -*- +""" +======================================================== +2D free support Sinkhorn barycenters of distributions +======================================================== + +Illustration of Sinkhorn barycenter calculation between empirical distributions understood as point clouds + +""" + +# Authors: Eduardo Fernandes Montesuma +# +# License: MIT License + +import ot +import numpy as np +import matplotlib.pyplot as plt +from sklearn.datasets import make_moons + +# %% +# General Parameters +# ------------------ +reg = 1e-2 # Entropic Regularization +numItermax = 20 # Maximum number of iterations for the Barycenter algorithm +numInnerItermax = 50 # Maximum number of sinkhorn iterations +n_samples = 200 + +# %% +# Generate Data +# ------------- + +get_rotation = lambda angle: np.array([ + [np.cos((angle / 180) * np.pi), -np.sin((angle / 180) * np.pi)], + [np.sin((angle / 180) * np.pi), np.cos((angle / 180) * np.pi)] + ]) + +R2, R3, R4 = get_rotation(15), get_rotation(30), get_rotation(45) + +X1, _ = make_moons(n_samples=300, noise=1e-1) +a1 = ot.utils.unif(X1.shape[0], type_as=X1) +X2 = np.dot(X1, R2) +a2 = ot.utils.unif(X1.shape[0], type_as=X2) +X3 = np.dot(X1, R3) +a3 = ot.utils.unif(X1.shape[0], type_as=X3) +X4 = np.dot(X1, R4) +a4 = ot.utils.unif(X1.shape[0], type_as=X4) + +# %% +# Inspect generated distributions +# ------------------------------- + +fig, axes = plt.subplots(1, 4, figsize=(16, 4)) + +axes[0].scatter(x=X1[:, 0], y=X1[:, 1], c='steelblue', edgecolor='k') +axes[1].scatter(x=X2[:, 0], y=X2[:, 1], c='steelblue', edgecolor='k') +axes[2].scatter(x=X3[:, 0], y=X3[:, 1], c='steelblue', edgecolor='k') +axes[3].scatter(x=X4[:, 0], y=X4[:, 1], c='steelblue', edgecolor='k') + +axes[0].set_xlim([-3, 3]) +axes[0].set_ylim([-3, 3]) +axes[0].set_title('Rotation: 0') + +axes[1].set_xlim([-3, 3]) +axes[1].set_ylim([-3, 3]) +axes[0].set_title('Rotation: 15') + +axes[2].set_xlim([-3, 3]) +axes[2].set_ylim([-3, 3]) +axes[0].set_title('Rotation: 30') + +axes[3].set_xlim([-3, 3]) +axes[3].set_ylim([-3, 3]) +axes[0].set_title('Rotation: 45') + +plt.tight_layout() +plt.show() + +# %% +# Interpolating Empirical Distributions +# ------------------------------------- + +fig = plt.figure(figsize=(10, 10)) + +weights = np.array([ + [3/3, 0/3], + [2/3, 1/3], + [1/3, 2/3], + [0/3, 3/3], +]).astype(np.float32) + +for k in range(4): + XB_init = np.random.randn(n_samples, 2) + XB = ot.bregman.free_support_sinkhorn_barycenter( + measures_locations=[X1, X2], + measures_weights=[a1, a2], + weights=weights[k], + X_init=XB_init, + reg=reg, + numItermax=numItermax, + numInnerItermax=numInnerItermax + ) + ax = plt.subplot2grid((4, 4), (0, k)) + ax.scatter(XB[:, 0], XB[:, 1], color='steelblue', edgecolor='k') + ax.set_xlim([-3, 3]) + ax.set_ylim([-3, 3]) + +for k in range(1, 4, 1): + XB_init = np.random.randn(n_samples, 2) + XB = ot.bregman.free_support_sinkhorn_barycenter( + measures_locations=[X1, X3], + measures_weights=[a1, a2], + weights=weights[k], + X_init=XB_init, + reg=reg, + numItermax=numItermax, + numInnerItermax=numInnerItermax + ) + ax = plt.subplot2grid((4, 4), (k, 0)) + ax.scatter(XB[:, 0], XB[:, 1], color='steelblue', edgecolor='k') + ax.set_xlim([-3, 3]) + ax.set_ylim([-3, 3]) + +for k in range(1, 4, 1): + XB_init = np.random.randn(n_samples, 2) + XB = ot.bregman.free_support_sinkhorn_barycenter( + measures_locations=[X3, X4], + measures_weights=[a1, a2], + weights=weights[k], + X_init=XB_init, + reg=reg, + numItermax=numItermax, + numInnerItermax=numInnerItermax + ) + ax = plt.subplot2grid((4, 4), (3, k)) + ax.scatter(XB[:, 0], XB[:, 1], color='steelblue', edgecolor='k') + ax.set_xlim([-3, 3]) + ax.set_ylim([-3, 3]) + +for k in range(1, 3, 1): + XB_init = np.random.randn(n_samples, 2) + XB = ot.bregman.free_support_sinkhorn_barycenter( + measures_locations=[X2, X4], + measures_weights=[a1, a2], + weights=weights[k], + X_init=XB_init, + reg=reg, + numItermax=numItermax, + numInnerItermax=numInnerItermax + ) + ax = plt.subplot2grid((4, 4), (k, 3)) + ax.scatter(XB[:, 0], XB[:, 1], color='steelblue', edgecolor='k') + ax.set_xlim([-3, 3]) + ax.set_ylim([-3, 3]) + +plt.show() \ No newline at end of file From ca47bab2bbfebf1ba030fb118036007266b96a29 Mon Sep 17 00:00:00 2001 From: eddardd Date: Sun, 3 Jul 2022 23:33:38 +0200 Subject: [PATCH 03/14] Fixing typo on free support sinkhorn barycenter example --- .../barycenters/plot_free_support_sinkhorn_barycenter.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/barycenters/plot_free_support_sinkhorn_barycenter.py b/examples/barycenters/plot_free_support_sinkhorn_barycenter.py index 4266e6bc6..4ea1bd796 100644 --- a/examples/barycenters/plot_free_support_sinkhorn_barycenter.py +++ b/examples/barycenters/plot_free_support_sinkhorn_barycenter.py @@ -62,15 +62,15 @@ axes[1].set_xlim([-3, 3]) axes[1].set_ylim([-3, 3]) -axes[0].set_title('Rotation: 15') +axes[1].set_title('Rotation: 15') axes[2].set_xlim([-3, 3]) axes[2].set_ylim([-3, 3]) -axes[0].set_title('Rotation: 30') +axes[2].set_title('Rotation: 30') axes[3].set_xlim([-3, 3]) axes[3].set_ylim([-3, 3]) -axes[0].set_title('Rotation: 45') +axes[3].set_title('Rotation: 45') plt.tight_layout() plt.show() From 2ac61ab740c96722a1a6a565f4f2b4dcca427e1a Mon Sep 17 00:00:00 2001 From: Eduardo Montesuma Date: Mon, 4 Jul 2022 14:37:52 +0200 Subject: [PATCH 04/14] Adding info on new Free Support Barycenter solver --- RELEASES.md | 1 + 1 file changed, 1 insertion(+) diff --git a/RELEASES.md b/RELEASES.md index 78a7d9ec0..7efda04e7 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -5,6 +5,7 @@ #### New features - Added Generalized Wasserstein Barycenter solver + example (PR #372), fixed graphical details on the example (PR #376) +- Added Free Support Sinkhorn Barycenter + example #### Closed issues From e04f860f1e87966830fb2e74ba2410edaa45e9ff Mon Sep 17 00:00:00 2001 From: Eduardo Montesuma Date: Mon, 4 Jul 2022 16:45:01 +0200 Subject: [PATCH 05/14] Removing extra line so that code follows pep8 --- ot/bregman.py | 1 - 1 file changed, 1 deletion(-) diff --git a/ot/bregman.py b/ot/bregman.py index c242215ed..b1321a4de 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -1660,7 +1660,6 @@ def free_support_sinkhorn_barycenter(measures_locations, measures_weights, X_ini return X - def _barycenter_sinkhorn_log(A, M, reg, weights=None, numItermax=1000, stopThr=1e-4, verbose=False, log=False, warn=True): r"""Compute the entropic wasserstein barycenter in log-domain From b36935ce109fcaf539f14af45722af35bcba72af Mon Sep 17 00:00:00 2001 From: eddardd Date: Mon, 4 Jul 2022 19:47:29 +0200 Subject: [PATCH 06/14] Fixing issues with pep8 in example --- .../plot_free_support_sinkhorn_barycenter.py | 26 ++++++++++--------- 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/examples/barycenters/plot_free_support_sinkhorn_barycenter.py b/examples/barycenters/plot_free_support_sinkhorn_barycenter.py index 4ea1bd796..fdef4add8 100644 --- a/examples/barycenters/plot_free_support_sinkhorn_barycenter.py +++ b/examples/barycenters/plot_free_support_sinkhorn_barycenter.py @@ -20,19 +20,21 @@ # %% # General Parameters # ------------------ -reg = 1e-2 # Entropic Regularization -numItermax = 20 # Maximum number of iterations for the Barycenter algorithm -numInnerItermax = 50 # Maximum number of sinkhorn iterations +reg = 1e-2 # Entropic Regularization +numItermax = 20 # Maximum number of iterations for the Barycenter algorithm +numInnerItermax = 50 # Maximum number of sinkhorn iterations n_samples = 200 # %% # Generate Data # ------------- -get_rotation = lambda angle: np.array([ - [np.cos((angle / 180) * np.pi), -np.sin((angle / 180) * np.pi)], - [np.sin((angle / 180) * np.pi), np.cos((angle / 180) * np.pi)] - ]) + +def get_rotation(angle): return np.array([ + [np.cos((angle / 180) * np.pi), -np.sin((angle / 180) * np.pi)], + [np.sin((angle / 180) * np.pi), np.cos((angle / 180) * np.pi)] +]) + R2, R3, R4 = get_rotation(15), get_rotation(30), get_rotation(45) @@ -82,10 +84,10 @@ fig = plt.figure(figsize=(10, 10)) weights = np.array([ - [3/3, 0/3], - [2/3, 1/3], - [1/3, 2/3], - [0/3, 3/3], + [3 / 3, 0 / 3], + [2 / 3, 1 / 3], + [1 / 3, 2 / 3], + [0 / 3, 3 / 3], ]).astype(np.float32) for k in range(4): @@ -152,4 +154,4 @@ ax.set_xlim([-3, 3]) ax.set_ylim([-3, 3]) -plt.show() \ No newline at end of file +plt.show() From 943abee8fb4a380e079b3be713ddd524a2735879 Mon Sep 17 00:00:00 2001 From: Eduardo Montesuma Date: Tue, 5 Jul 2022 09:52:25 +0200 Subject: [PATCH 07/14] Correcting issues with pep8 standards --- .../plot_free_support_sinkhorn_barycenter.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/examples/barycenters/plot_free_support_sinkhorn_barycenter.py b/examples/barycenters/plot_free_support_sinkhorn_barycenter.py index fdef4add8..996a53f1e 100644 --- a/examples/barycenters/plot_free_support_sinkhorn_barycenter.py +++ b/examples/barycenters/plot_free_support_sinkhorn_barycenter.py @@ -12,10 +12,10 @@ # # License: MIT License -import ot import numpy as np import matplotlib.pyplot as plt from sklearn.datasets import make_moons +import ot # %% # General Parameters @@ -30,10 +30,13 @@ # ------------- -def get_rotation(angle): return np.array([ - [np.cos((angle / 180) * np.pi), -np.sin((angle / 180) * np.pi)], - [np.sin((angle / 180) * np.pi), np.cos((angle / 180) * np.pi)] -]) +def get_rotation(angle): + """Returns a rotation matrix for angle given in degrees""" + R = np.array([ + [np.cos((angle / 180) * np.pi), -np.sin((angle / 180) * np.pi)], + [np.sin((angle / 180) * np.pi), np.cos((angle / 180) * np.pi)] + ]) + return R R2, R3, R4 = get_rotation(15), get_rotation(30), get_rotation(45) From 16f6c18e6331c3b6ed6d9011dc8fe1c9e3663232 Mon Sep 17 00:00:00 2001 From: eddardd Date: Wed, 20 Jul 2022 18:51:38 +0200 Subject: [PATCH 08/14] Adding tests for free support sinkhorn barycenter --- test/test_bregman.py | 26 ++ test/tmp.py | 982 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 1008 insertions(+) create mode 100644 test/tmp.py diff --git a/test/test_bregman.py b/test/test_bregman.py index 112bfca48..e128ea225 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -3,6 +3,7 @@ # Author: Remi Flamary # Kilian Fatras # Quang Huy Tran +# Eduardo Fernandes Montesuma # # License: MIT License @@ -490,6 +491,31 @@ def test_barycenter(nx, method, verbose, warn): ot.bregman.barycenter(A_nx, M_nx, reg, log=True) +def test_free_support_sinkhorn_barycenter(): + measures_locations = [ + np.array([-1.]).reshape((1, 1)), # First dirac support + np.array([1.]).reshape((1, 1)) # Second dirac support + ] + + measures_weights = [ + np.array([1.]), # First dirac sample weights + np.array([1.]) # Second dirac sample weights + ] + + # Barycenter initialization + X_init = np.array([-12.]).reshape((1, 1)) + + # Obvious barycenter locations. Take a look on test_ot.py, test_free_support_barycenter + bar_locations = np.array([0.]).reshape((1, 1)) + + # Calculate free support barycenter w/ Sinkhorn algorithm. We set the entropic regularization + # term to 1, but this should be, in general, fine-tuned to the problem. + X = ot.bregman.free_support_sinkhorn_barycenter(measures_locations, measures_weights, X_init, reg=1) + + # Verifies if calculated barycenter matches ground-truth + np.testing.assert_allclose(X, bar_locations, rtol=1e-5, atol=1e-7) + + @pytest.mark.parametrize("method, verbose, warn", product(["sinkhorn", "sinkhorn_stabilized", "sinkhorn_log"], [True, False], [True, False])) diff --git a/test/tmp.py b/test/tmp.py new file mode 100644 index 000000000..e128ea225 --- /dev/null +++ b/test/tmp.py @@ -0,0 +1,982 @@ +"""Tests for module bregman on OT with bregman projections """ + +# Author: Remi Flamary +# Kilian Fatras +# Quang Huy Tran +# Eduardo Fernandes Montesuma +# +# License: MIT License + +from itertools import product + +import numpy as np +import pytest + +import ot +from ot.backend import torch, tf + + +@pytest.mark.parametrize("verbose, warn", product([True, False], [True, False])) +def test_sinkhorn(verbose, warn): + # test sinkhorn + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + u = ot.utils.unif(n) + + M = ot.dist(x, x) + + G = ot.sinkhorn(u, u, M, 1, stopThr=1e-10, verbose=verbose, warn=warn) + + # check constraints + np.testing.assert_allclose( + u, G.sum(1), atol=1e-05) # cf convergence sinkhorn + np.testing.assert_allclose( + u, G.sum(0), atol=1e-05) # cf convergence sinkhorn + + with pytest.warns(UserWarning): + ot.sinkhorn(u, u, M, 1, stopThr=0, numItermax=1) + + +@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized", + "sinkhorn_epsilon_scaling", + "greenkhorn", + "sinkhorn_log"]) +def test_convergence_warning(method): + # test sinkhorn + n = 100 + a1 = ot.datasets.make_1D_gauss(n, m=30, s=10) + a2 = ot.datasets.make_1D_gauss(n, m=40, s=10) + A = np.asarray([a1, a2]).T + M = ot.utils.dist0(n) + + with pytest.warns(UserWarning): + ot.sinkhorn(a1, a2, M, 1., method=method, stopThr=0, numItermax=1) + + if method in ["sinkhorn", "sinkhorn_stabilized", "sinkhorn_log"]: + with pytest.warns(UserWarning): + ot.barycenter(A, M, 1, method=method, stopThr=0, numItermax=1) + with pytest.warns(UserWarning): + ot.sinkhorn2(a1, a2, M, 1, method=method, stopThr=0, numItermax=1) + + +def test_not_implemented_method(): + # test sinkhorn + w = 10 + n = w ** 2 + rng = np.random.RandomState(42) + A_img = rng.rand(2, w, w) + A_flat = A_img.reshape(n, 2) + a1, a2 = A_flat.T + M_flat = ot.utils.dist0(n) + not_implemented = "new_method" + reg = 0.01 + with pytest.raises(ValueError): + ot.sinkhorn(a1, a2, M_flat, reg, method=not_implemented) + with pytest.raises(ValueError): + ot.sinkhorn2(a1, a2, M_flat, reg, method=not_implemented) + with pytest.raises(ValueError): + ot.barycenter(A_flat, M_flat, reg, method=not_implemented) + with pytest.raises(ValueError): + ot.bregman.barycenter_debiased(A_flat, M_flat, reg, + method=not_implemented) + with pytest.raises(ValueError): + ot.bregman.convolutional_barycenter2d(A_img, reg, + method=not_implemented) + with pytest.raises(ValueError): + ot.bregman.convolutional_barycenter2d_debiased(A_img, reg, + method=not_implemented) + + +@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"]) +def test_nan_warning(method): + # test sinkhorn + n = 100 + a1 = ot.datasets.make_1D_gauss(n, m=30, s=10) + a2 = ot.datasets.make_1D_gauss(n, m=40, s=10) + + M = ot.utils.dist0(n) + reg = 0 + with pytest.warns(UserWarning): + # warn set to False to avoid catching a convergence warning instead + ot.sinkhorn(a1, a2, M, reg, method=method, warn=False) + + +def test_sinkhorn_stabilization(): + # test sinkhorn + n = 100 + a1 = ot.datasets.make_1D_gauss(n, m=30, s=10) + a2 = ot.datasets.make_1D_gauss(n, m=40, s=10) + M = ot.utils.dist0(n) + reg = 1e-5 + loss1 = ot.sinkhorn2(a1, a2, M, reg, method="sinkhorn_log") + loss2 = ot.sinkhorn2(a1, a2, M, reg, tau=1, method="sinkhorn_stabilized") + np.testing.assert_allclose( + loss1, loss2, atol=1e-06) # cf convergence sinkhorn + + +@pytest.mark.parametrize("method, verbose, warn", + product(["sinkhorn", "sinkhorn_stabilized", + "sinkhorn_log"], + [True, False], [True, False])) +def test_sinkhorn_multi_b(method, verbose, warn): + # test sinkhorn + n = 10 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + u = ot.utils.unif(n) + + b = rng.rand(n, 3) + b = b / np.sum(b, 0, keepdims=True) + + M = ot.dist(x, x) + + loss0, log = ot.sinkhorn(u, b, M, .1, method=method, stopThr=1e-10, + log=True) + + loss = [ot.sinkhorn2(u, b[:, k], M, .1, method=method, stopThr=1e-10, + verbose=verbose, warn=warn) for k in range(3)] + # check constraints + np.testing.assert_allclose( + loss0, loss, atol=1e-4) # cf convergence sinkhorn + + +def test_sinkhorn_backends(nx): + n_samples = 100 + n_features = 2 + rng = np.random.RandomState(0) + + x = rng.randn(n_samples, n_features) + y = rng.randn(n_samples, n_features) + a = ot.utils.unif(n_samples) + + M = ot.dist(x, y) + + G = ot.sinkhorn(a, a, M, 1) + + ab, M_nx = nx.from_numpy(a, M) + + Gb = ot.sinkhorn(ab, ab, M_nx, 1) + + np.allclose(G, nx.to_numpy(Gb)) + + +def test_sinkhorn2_backends(nx): + n_samples = 100 + n_features = 2 + rng = np.random.RandomState(0) + + x = rng.randn(n_samples, n_features) + y = rng.randn(n_samples, n_features) + a = ot.utils.unif(n_samples) + + M = ot.dist(x, y) + + G = ot.sinkhorn(a, a, M, 1) + + ab, M_nx = nx.from_numpy(a, M) + + Gb = ot.sinkhorn2(ab, ab, M_nx, 1) + + np.allclose(G, nx.to_numpy(Gb)) + + +def test_sinkhorn2_gradients(): + n_samples = 100 + n_features = 2 + rng = np.random.RandomState(0) + + x = rng.randn(n_samples, n_features) + y = rng.randn(n_samples, n_features) + a = ot.utils.unif(n_samples) + + M = ot.dist(x, y) + + if torch: + + a1 = torch.tensor(a, requires_grad=True) + b1 = torch.tensor(a, requires_grad=True) + M1 = torch.tensor(M, requires_grad=True) + + val = ot.sinkhorn2(a1, b1, M1, 1) + + val.backward() + + assert a1.shape == a1.grad.shape + assert b1.shape == b1.grad.shape + assert M1.shape == M1.grad.shape + + +def test_sinkhorn_empty(): + # test sinkhorn + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + u = ot.utils.unif(n) + + M = ot.dist(x, x) + + G, log = ot.sinkhorn([], [], M, 1, stopThr=1e-10, method="sinkhorn_log", + verbose=True, log=True) + # check constraints + np.testing.assert_allclose(u, G.sum(1), atol=1e-05) + np.testing.assert_allclose(u, G.sum(0), atol=1e-05) + + G, log = ot.sinkhorn([], [], M, 1, stopThr=1e-10, verbose=True, log=True) + # check constraints + np.testing.assert_allclose(u, G.sum(1), atol=1e-05) + np.testing.assert_allclose(u, G.sum(0), atol=1e-05) + + G, log = ot.sinkhorn([], [], M, 1, stopThr=1e-10, + method='sinkhorn_stabilized', verbose=True, log=True) + # check constraints + np.testing.assert_allclose(u, G.sum(1), atol=1e-05) + np.testing.assert_allclose(u, G.sum(0), atol=1e-05) + + G, log = ot.sinkhorn( + [], [], M, 1, stopThr=1e-10, method='sinkhorn_epsilon_scaling', + verbose=True, log=True) + # check constraints + np.testing.assert_allclose(u, G.sum(1), atol=1e-05) + np.testing.assert_allclose(u, G.sum(0), atol=1e-05) + + # test empty weights greenkhorn + ot.sinkhorn([], [], M, 1, method='greenkhorn', stopThr=1e-10, log=True) + + +@pytest.skip_backend('tf') +@pytest.skip_backend("jax") +def test_sinkhorn_variants(nx): + # test sinkhorn + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + u = ot.utils.unif(n) + + M = ot.dist(x, x) + + ub, M_nx = nx.from_numpy(u, M) + + G = ot.sinkhorn(u, u, M, 1, method='sinkhorn', stopThr=1e-10) + Gl = nx.to_numpy(ot.sinkhorn(ub, ub, M_nx, 1, method='sinkhorn_log', stopThr=1e-10)) + G0 = nx.to_numpy(ot.sinkhorn(ub, ub, M_nx, 1, method='sinkhorn', stopThr=1e-10)) + Gs = nx.to_numpy(ot.sinkhorn(ub, ub, M_nx, 1, method='sinkhorn_stabilized', stopThr=1e-10)) + Ges = nx.to_numpy(ot.sinkhorn( + ub, ub, M_nx, 1, method='sinkhorn_epsilon_scaling', stopThr=1e-10)) + G_green = nx.to_numpy(ot.sinkhorn(ub, ub, M_nx, 1, method='greenkhorn', stopThr=1e-10)) + + # check values + np.testing.assert_allclose(G, G0, atol=1e-05) + np.testing.assert_allclose(G, Gl, atol=1e-05) + np.testing.assert_allclose(G0, Gs, atol=1e-05) + np.testing.assert_allclose(G0, Ges, atol=1e-05) + np.testing.assert_allclose(G0, G_green, atol=1e-5) + + +@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized", + "sinkhorn_epsilon_scaling", + "greenkhorn", + "sinkhorn_log"]) +@pytest.skip_arg(("nx", "method"), ("tf", "sinkhorn_epsilon_scaling"), reason="tf does not support sinkhorn_epsilon_scaling", getter=str) +@pytest.skip_arg(("nx", "method"), ("tf", "greenkhorn"), reason="tf does not support greenkhorn", getter=str) +@pytest.skip_arg(("nx", "method"), ("jax", "sinkhorn_epsilon_scaling"), reason="jax does not support sinkhorn_epsilon_scaling", getter=str) +@pytest.skip_arg(("nx", "method"), ("jax", "greenkhorn"), reason="jax does not support greenkhorn", getter=str) +def test_sinkhorn_variants_dtype_device(nx, method): + n = 100 + + x = np.random.randn(n, 2) + u = ot.utils.unif(n) + + M = ot.dist(x, x) + + for tp in nx.__type_list__: + print(nx.dtype_device(tp)) + + ub, Mb = nx.from_numpy(u, M, type_as=tp) + + Gb = ot.sinkhorn(ub, ub, Mb, 1, method=method, stopThr=1e-10) + + nx.assert_same_dtype_device(Mb, Gb) + + +@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized", "sinkhorn_log"]) +def test_sinkhorn2_variants_dtype_device(nx, method): + n = 100 + + x = np.random.randn(n, 2) + u = ot.utils.unif(n) + + M = ot.dist(x, x) + + for tp in nx.__type_list__: + print(nx.dtype_device(tp)) + + ub, Mb = nx.from_numpy(u, M, type_as=tp) + + lossb = ot.sinkhorn2(ub, ub, Mb, 1, method=method, stopThr=1e-10) + + nx.assert_same_dtype_device(Mb, lossb) + + +@pytest.mark.skipif(not tf, reason="tf not installed") +@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized", "sinkhorn_log"]) +def test_sinkhorn2_variants_device_tf(method): + nx = ot.backend.TensorflowBackend() + n = 100 + x = np.random.randn(n, 2) + u = ot.utils.unif(n) + M = ot.dist(x, x) + + # Check that everything stays on the CPU + with tf.device("/CPU:0"): + ub, Mb = nx.from_numpy(u, M) + Gb = ot.sinkhorn(ub, ub, Mb, 1, method=method, stopThr=1e-10) + lossb = ot.sinkhorn2(ub, ub, Mb, 1, method=method, stopThr=1e-10) + nx.assert_same_dtype_device(Mb, Gb) + nx.assert_same_dtype_device(Mb, lossb) + + if len(tf.config.list_physical_devices('GPU')) > 0: + # Check that everything happens on the GPU + ub, Mb = nx.from_numpy(u, M) + Gb = ot.sinkhorn(ub, ub, Mb, 1, method=method, stopThr=1e-10) + lossb = ot.sinkhorn2(ub, ub, Mb, 1, method=method, stopThr=1e-10) + nx.assert_same_dtype_device(Mb, Gb) + nx.assert_same_dtype_device(Mb, lossb) + assert nx.dtype_device(Gb)[1].startswith("GPU") + + +@pytest.skip_backend('tf') +@pytest.skip_backend("jax") +def test_sinkhorn_variants_multi_b(nx): + # test sinkhorn + n = 50 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + u = ot.utils.unif(n) + + b = rng.rand(n, 3) + b = b / np.sum(b, 0, keepdims=True) + + M = ot.dist(x, x) + + ub, bb, M_nx = nx.from_numpy(u, b, M) + + G = ot.sinkhorn(u, b, M, 1, method='sinkhorn', stopThr=1e-10) + Gl = nx.to_numpy(ot.sinkhorn(ub, bb, M_nx, 1, method='sinkhorn_log', stopThr=1e-10)) + G0 = nx.to_numpy(ot.sinkhorn(ub, bb, M_nx, 1, method='sinkhorn', stopThr=1e-10)) + Gs = nx.to_numpy(ot.sinkhorn(ub, bb, M_nx, 1, method='sinkhorn_stabilized', stopThr=1e-10)) + + # check values + np.testing.assert_allclose(G, G0, atol=1e-05) + np.testing.assert_allclose(G, Gl, atol=1e-05) + np.testing.assert_allclose(G0, Gs, atol=1e-05) + + +@pytest.skip_backend('tf') +@pytest.skip_backend("jax") +def test_sinkhorn2_variants_multi_b(nx): + # test sinkhorn + n = 50 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + u = ot.utils.unif(n) + + b = rng.rand(n, 3) + b = b / np.sum(b, 0, keepdims=True) + + M = ot.dist(x, x) + + ub, bb, M_nx = nx.from_numpy(u, b, M) + + G = ot.sinkhorn2(u, b, M, 1, method='sinkhorn', stopThr=1e-10) + Gl = nx.to_numpy(ot.sinkhorn2(ub, bb, M_nx, 1, method='sinkhorn_log', stopThr=1e-10)) + G0 = nx.to_numpy(ot.sinkhorn2(ub, bb, M_nx, 1, method='sinkhorn', stopThr=1e-10)) + Gs = nx.to_numpy(ot.sinkhorn2(ub, bb, M_nx, 1, method='sinkhorn_stabilized', stopThr=1e-10)) + + # check values + np.testing.assert_allclose(G, G0, atol=1e-05) + np.testing.assert_allclose(G, Gl, atol=1e-05) + np.testing.assert_allclose(G0, Gs, atol=1e-05) + + +def test_sinkhorn_variants_log(): + # test sinkhorn + n = 50 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + u = ot.utils.unif(n) + + M = ot.dist(x, x) + + G0, log0 = ot.sinkhorn(u, u, M, 1, method='sinkhorn', stopThr=1e-10, log=True) + Gl, logl = ot.sinkhorn(u, u, M, 1, method='sinkhorn_log', stopThr=1e-10, log=True) + Gs, logs = ot.sinkhorn(u, u, M, 1, method='sinkhorn_stabilized', stopThr=1e-10, log=True) + Ges, loges = ot.sinkhorn( + u, u, M, 1, method='sinkhorn_epsilon_scaling', stopThr=1e-10, log=True,) + G_green, loggreen = ot.sinkhorn(u, u, M, 1, method='greenkhorn', stopThr=1e-10, log=True) + + # check values + np.testing.assert_allclose(G0, Gs, atol=1e-05) + np.testing.assert_allclose(G0, Gl, atol=1e-05) + np.testing.assert_allclose(G0, Ges, atol=1e-05) + np.testing.assert_allclose(G0, G_green, atol=1e-5) + + +@pytest.mark.parametrize("verbose, warn", product([True, False], [True, False])) +def test_sinkhorn_variants_log_multib(verbose, warn): + # test sinkhorn + n = 50 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + u = ot.utils.unif(n) + b = rng.rand(n, 3) + b = b / np.sum(b, 0, keepdims=True) + + M = ot.dist(x, x) + + G0, log0 = ot.sinkhorn(u, b, M, 1, method='sinkhorn', stopThr=1e-10, log=True) + Gl, logl = ot.sinkhorn(u, b, M, 1, method='sinkhorn_log', stopThr=1e-10, log=True, + verbose=verbose, warn=warn) + Gs, logs = ot.sinkhorn(u, b, M, 1, method='sinkhorn_stabilized', stopThr=1e-10, log=True, + verbose=verbose, warn=warn) + + # check values + np.testing.assert_allclose(G0, Gs, atol=1e-05) + np.testing.assert_allclose(G0, Gl, atol=1e-05) + + +@pytest.mark.parametrize("method, verbose, warn", + product(["sinkhorn", "sinkhorn_stabilized", "sinkhorn_log"], + [True, False], [True, False])) +def test_barycenter(nx, method, verbose, warn): + n_bins = 100 # nb bins + + # Gaussian distributions + a1 = ot.datasets.make_1D_gauss(n_bins, m=30, s=10) # m= mean, s= std + a2 = ot.datasets.make_1D_gauss(n_bins, m=40, s=10) + + # creating matrix A containing all distributions + A = np.vstack((a1, a2)).T + + # loss matrix + normalization + M = ot.utils.dist0(n_bins) + M /= M.max() + + alpha = 0.5 # 0<=alpha<=1 + weights = np.array([1 - alpha, alpha]) + + A_nx, M_nx, weights_nx = nx.from_numpy(A, M, weights) + reg = 1e-2 + + if nx.__name__ in ("jax", "tf") and method == "sinkhorn_log": + with pytest.raises(NotImplementedError): + ot.bregman.barycenter(A_nx, M_nx, reg, weights, method=method) + else: + # wasserstein + bary_wass_np = ot.bregman.barycenter(A, M, reg, weights, method=method, verbose=verbose, warn=warn) + bary_wass, _ = ot.bregman.barycenter(A_nx, M_nx, reg, weights_nx, method=method, log=True) + bary_wass = nx.to_numpy(bary_wass) + + np.testing.assert_allclose(1, np.sum(bary_wass)) + np.testing.assert_allclose(bary_wass, bary_wass_np) + + ot.bregman.barycenter(A_nx, M_nx, reg, log=True) + + +def test_free_support_sinkhorn_barycenter(): + measures_locations = [ + np.array([-1.]).reshape((1, 1)), # First dirac support + np.array([1.]).reshape((1, 1)) # Second dirac support + ] + + measures_weights = [ + np.array([1.]), # First dirac sample weights + np.array([1.]) # Second dirac sample weights + ] + + # Barycenter initialization + X_init = np.array([-12.]).reshape((1, 1)) + + # Obvious barycenter locations. Take a look on test_ot.py, test_free_support_barycenter + bar_locations = np.array([0.]).reshape((1, 1)) + + # Calculate free support barycenter w/ Sinkhorn algorithm. We set the entropic regularization + # term to 1, but this should be, in general, fine-tuned to the problem. + X = ot.bregman.free_support_sinkhorn_barycenter(measures_locations, measures_weights, X_init, reg=1) + + # Verifies if calculated barycenter matches ground-truth + np.testing.assert_allclose(X, bar_locations, rtol=1e-5, atol=1e-7) + + +@pytest.mark.parametrize("method, verbose, warn", + product(["sinkhorn", "sinkhorn_stabilized", "sinkhorn_log"], + [True, False], [True, False])) +def test_barycenter_assymetric_cost(nx, method, verbose, warn): + n_bins = 20 # nb bins + + # Gaussian distributions + A = ot.datasets.make_1D_gauss(n_bins, m=30, s=10) # m= mean, s= std + + # creating matrix A containing all distributions + A = A[:, None] + + # assymetric loss matrix + normalization + rng = np.random.RandomState(42) + M = rng.randn(n_bins, n_bins) ** 2 + M /= M.max() + + A_nx, M_nx = nx.from_numpy(A, M) + reg = 1e-2 + + if nx.__name__ in ("jax", "tf") and method == "sinkhorn_log": + with pytest.raises(NotImplementedError): + ot.bregman.barycenter(A_nx, M_nx, reg, method=method) + else: + # wasserstein + bary_wass_np = ot.bregman.barycenter(A, M, reg, method=method, verbose=verbose, warn=warn) + bary_wass, _ = ot.bregman.barycenter(A_nx, M_nx, reg, method=method, log=True) + bary_wass = nx.to_numpy(bary_wass) + + np.testing.assert_allclose(1, np.sum(bary_wass)) + np.testing.assert_allclose(bary_wass, bary_wass_np) + + ot.bregman.barycenter(A_nx, M_nx, reg, log=True) + + +@pytest.mark.parametrize("method, verbose, warn", + product(["sinkhorn", "sinkhorn_log"], + [True, False], [True, False])) +def test_barycenter_debiased(nx, method, verbose, warn): + n_bins = 100 # nb bins + + # Gaussian distributions + a1 = ot.datasets.make_1D_gauss(n_bins, m=30, s=10) # m= mean, s= std + a2 = ot.datasets.make_1D_gauss(n_bins, m=40, s=10) + + # creating matrix A containing all distributions + A = np.vstack((a1, a2)).T + + # loss matrix + normalization + M = ot.utils.dist0(n_bins) + M /= M.max() + + alpha = 0.5 # 0<=alpha<=1 + weights = np.array([1 - alpha, alpha]) + + A_nx, M_nx, weights_nx = nx.from_numpy(A, M, weights) + + # wasserstein + reg = 1e-2 + if nx.__name__ in ("jax", "tf") and method == "sinkhorn_log": + with pytest.raises(NotImplementedError): + ot.bregman.barycenter_debiased(A_nx, M_nx, reg, weights, method=method) + else: + bary_wass_np = ot.bregman.barycenter_debiased(A, M, reg, weights, method=method, + verbose=verbose, warn=warn) + bary_wass, _ = ot.bregman.barycenter_debiased(A_nx, M_nx, reg, weights_nx, method=method, log=True) + bary_wass = nx.to_numpy(bary_wass) + + np.testing.assert_allclose(1, np.sum(bary_wass), atol=1e-3) + np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-5) + + ot.bregman.barycenter_debiased(A_nx, M_nx, reg, log=True, verbose=False) + + +@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_log"]) +def test_convergence_warning_barycenters(method): + w = 10 + n_bins = w ** 2 # nb bins + + # Gaussian distributions + a1 = ot.datasets.make_1D_gauss(n_bins, m=30, s=10) # m= mean, s= std + a2 = ot.datasets.make_1D_gauss(n_bins, m=40, s=10) + + # creating matrix A containing all distributions + A = np.vstack((a1, a2)).T + A_img = A.reshape(2, w, w) + A_img /= A_img.sum((1, 2))[:, None, None] + + # loss matrix + normalization + M = ot.utils.dist0(n_bins) + M /= M.max() + + alpha = 0.5 # 0<=alpha<=1 + weights = np.array([1 - alpha, alpha]) + reg = 0.1 + with pytest.warns(UserWarning): + ot.bregman.barycenter_debiased(A, M, reg, weights, method=method, numItermax=1) + with pytest.warns(UserWarning): + ot.bregman.barycenter(A, M, reg, weights, method=method, numItermax=1) + with pytest.warns(UserWarning): + ot.bregman.convolutional_barycenter2d(A_img, reg, weights, + method=method, numItermax=1) + with pytest.warns(UserWarning): + ot.bregman.convolutional_barycenter2d_debiased(A_img, reg, weights, + method=method, numItermax=1) + + +def test_barycenter_stabilization(nx): + n_bins = 100 # nb bins + + # Gaussian distributions + a1 = ot.datasets.make_1D_gauss(n_bins, m=30, s=10) # m= mean, s= std + a2 = ot.datasets.make_1D_gauss(n_bins, m=40, s=10) + + # creating matrix A containing all distributions + A = np.vstack((a1, a2)).T + + # loss matrix + normalization + M = ot.utils.dist0(n_bins) + M /= M.max() + + alpha = 0.5 # 0<=alpha<=1 + weights = np.array([1 - alpha, alpha]) + + A_nx, M_nx, weights_b = nx.from_numpy(A, M, weights) + + # wasserstein + reg = 1e-2 + bar_np = ot.bregman.barycenter(A, M, reg, weights, method="sinkhorn", stopThr=1e-8, verbose=True) + bar_stable = nx.to_numpy(ot.bregman.barycenter( + A_nx, M_nx, reg, weights_b, method="sinkhorn_stabilized", + stopThr=1e-8, verbose=True + )) + bar = nx.to_numpy(ot.bregman.barycenter( + A_nx, M_nx, reg, weights_b, method="sinkhorn", + stopThr=1e-8, verbose=True + )) + np.testing.assert_allclose(bar, bar_stable) + np.testing.assert_allclose(bar, bar_np) + + +@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_log"]) +def test_wasserstein_bary_2d(nx, method): + size = 20 # size of a square image + a1 = np.random.rand(size, size) + a1 += a1.min() + a1 = a1 / np.sum(a1) + a2 = np.random.rand(size, size) + a2 += a2.min() + a2 = a2 / np.sum(a2) + # creating matrix A containing all distributions + A = np.zeros((2, size, size)) + A[0, :, :] = a1 + A[1, :, :] = a2 + + A_nx = nx.from_numpy(A) + + # wasserstein + reg = 1e-2 + if nx.__name__ in ("jax", "tf") and method == "sinkhorn_log": + with pytest.raises(NotImplementedError): + ot.bregman.convolutional_barycenter2d(A_nx, reg, method=method) + else: + bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d(A, reg, method=method, verbose=True, log=True) + bary_wass = nx.to_numpy(ot.bregman.convolutional_barycenter2d(A_nx, reg, method=method)) + + np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3) + np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3) + + # help in checking if log and verbose do not bug the function + ot.bregman.convolutional_barycenter2d(A, reg, log=True, verbose=True) + + +@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_log"]) +def test_wasserstein_bary_2d_debiased(nx, method): + size = 20 # size of a square image + a1 = np.random.rand(size, size) + a1 += a1.min() + a1 = a1 / np.sum(a1) + a2 = np.random.rand(size, size) + a2 += a2.min() + a2 = a2 / np.sum(a2) + # creating matrix A containing all distributions + A = np.zeros((2, size, size)) + A[0, :, :] = a1 + A[1, :, :] = a2 + + A_nx = nx.from_numpy(A) + + # wasserstein + reg = 1e-2 + if nx.__name__ in ("jax", "tf") and method == "sinkhorn_log": + with pytest.raises(NotImplementedError): + ot.bregman.convolutional_barycenter2d_debiased(A_nx, reg, method=method) + else: + bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d_debiased(A, reg, method=method, verbose=True, log=True) + bary_wass = nx.to_numpy(ot.bregman.convolutional_barycenter2d_debiased(A_nx, reg, method=method)) + + np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3) + np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3) + + # help in checking if log and verbose do not bug the function + ot.bregman.convolutional_barycenter2d(A, reg, log=True, verbose=True) + + +def test_unmix(nx): + n_bins = 50 # nb bins + + # Gaussian distributions + a1 = ot.datasets.make_1D_gauss(n_bins, m=20, s=10) # m= mean, s= std + a2 = ot.datasets.make_1D_gauss(n_bins, m=40, s=10) + + a = ot.datasets.make_1D_gauss(n_bins, m=30, s=10) + + # creating matrix A containing all distributions + D = np.vstack((a1, a2)).T + + # loss matrix + normalization + M = ot.utils.dist0(n_bins) + M /= M.max() + + M0 = ot.utils.dist0(2) + M0 /= M0.max() + h0 = ot.unif(2) + + ab, Db, M_nx, M0b, h0b = nx.from_numpy(a, D, M, M0, h0) + + # wasserstein + reg = 1e-3 + um_np = ot.bregman.unmix(a, D, M, M0, h0, reg, 1, alpha=0.01) + um = nx.to_numpy(ot.bregman.unmix(ab, Db, M_nx, M0b, h0b, reg, 1, alpha=0.01)) + + np.testing.assert_allclose(1, np.sum(um), rtol=1e-03, atol=1e-03) + np.testing.assert_allclose([0.5, 0.5], um, rtol=1e-03, atol=1e-03) + np.testing.assert_allclose(um, um_np) + + ot.bregman.unmix(ab, Db, M_nx, M0b, h0b, reg, + 1, alpha=0.01, log=True, verbose=True) + + +def test_empirical_sinkhorn(nx): + # test sinkhorn + n = 10 + a = ot.unif(n) + b = ot.unif(n) + + X_s = np.reshape(1.0 * np.arange(n), (n, 1)) + X_t = np.reshape(1.0 * np.arange(0, n), (n, 1)) + M = ot.dist(X_s, X_t) + M_m = ot.dist(X_s, X_t, metric='euclidean') + + ab, bb, X_sb, X_tb, M_nx, M_mb = nx.from_numpy(a, b, X_s, X_t, M, M_m) + + G_sqe = nx.to_numpy(ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1)) + sinkhorn_sqe = nx.to_numpy(ot.sinkhorn(ab, bb, M_nx, 1)) + + G_log, log_es = ot.bregman.empirical_sinkhorn(X_sb, X_tb, 0.1, log=True) + G_log = nx.to_numpy(G_log) + sinkhorn_log, log_s = ot.sinkhorn(ab, bb, M_nx, 0.1, log=True) + sinkhorn_log = nx.to_numpy(sinkhorn_log) + + G_m = nx.to_numpy(ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1, metric='euclidean')) + sinkhorn_m = nx.to_numpy(ot.sinkhorn(ab, bb, M_mb, 1)) + + loss_emp_sinkhorn = nx.to_numpy(ot.bregman.empirical_sinkhorn2(X_sb, X_tb, 1)) + loss_sinkhorn = nx.to_numpy(ot.sinkhorn2(ab, bb, M_nx, 1)) + + # check constraints + np.testing.assert_allclose( + sinkhorn_sqe.sum(1), G_sqe.sum(1), atol=1e-05) # metric sqeuclidian + np.testing.assert_allclose( + sinkhorn_sqe.sum(0), G_sqe.sum(0), atol=1e-05) # metric sqeuclidian + np.testing.assert_allclose( + sinkhorn_log.sum(1), G_log.sum(1), atol=1e-05) # log + np.testing.assert_allclose( + sinkhorn_log.sum(0), G_log.sum(0), atol=1e-05) # log + np.testing.assert_allclose( + sinkhorn_m.sum(1), G_m.sum(1), atol=1e-05) # metric euclidian + np.testing.assert_allclose( + sinkhorn_m.sum(0), G_m.sum(0), atol=1e-05) # metric euclidian + np.testing.assert_allclose(loss_emp_sinkhorn, loss_sinkhorn, atol=1e-05) + + +def test_lazy_empirical_sinkhorn(nx): + # test sinkhorn + n = 10 + a = ot.unif(n) + b = ot.unif(n) + numIterMax = 1000 + + X_s = np.reshape(np.arange(n, dtype=np.float64), (n, 1)) + X_t = np.reshape(np.arange(0, n, dtype=np.float64), (n, 1)) + M = ot.dist(X_s, X_t) + M_m = ot.dist(X_s, X_t, metric='euclidean') + + ab, bb, X_sb, X_tb, M_nx, M_mb = nx.from_numpy(a, b, X_s, X_t, M, M_m) + + f, g = ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1, numIterMax=numIterMax, isLazy=True, batchSize=(1, 3), verbose=True) + f, g = nx.to_numpy(f), nx.to_numpy(g) + G_sqe = np.exp(f[:, None] + g[None, :] - M / 1) + sinkhorn_sqe = nx.to_numpy(ot.sinkhorn(ab, bb, M_nx, 1)) + + f, g, log_es = ot.bregman.empirical_sinkhorn(X_sb, X_tb, 0.1, numIterMax=numIterMax, isLazy=True, batchSize=1, log=True) + f, g = nx.to_numpy(f), nx.to_numpy(g) + G_log = np.exp(f[:, None] + g[None, :] - M / 0.1) + sinkhorn_log, log_s = ot.sinkhorn(ab, bb, M_nx, 0.1, log=True) + sinkhorn_log = nx.to_numpy(sinkhorn_log) + + f, g = ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1, metric='euclidean', numIterMax=numIterMax, isLazy=True, batchSize=1) + f, g = nx.to_numpy(f), nx.to_numpy(g) + G_m = np.exp(f[:, None] + g[None, :] - M_m / 1) + sinkhorn_m = nx.to_numpy(ot.sinkhorn(ab, bb, M_mb, 1)) + + loss_emp_sinkhorn, log = ot.bregman.empirical_sinkhorn2(X_sb, X_tb, 1, numIterMax=numIterMax, isLazy=True, batchSize=1, log=True) + loss_emp_sinkhorn = nx.to_numpy(loss_emp_sinkhorn) + loss_sinkhorn = nx.to_numpy(ot.sinkhorn2(ab, bb, M_nx, 1)) + + # check constraints + np.testing.assert_allclose( + sinkhorn_sqe.sum(1), G_sqe.sum(1), atol=1e-05) # metric sqeuclidian + np.testing.assert_allclose( + sinkhorn_sqe.sum(0), G_sqe.sum(0), atol=1e-05) # metric sqeuclidian + np.testing.assert_allclose( + sinkhorn_log.sum(1), G_log.sum(1), atol=1e-05) # log + np.testing.assert_allclose( + sinkhorn_log.sum(0), G_log.sum(0), atol=1e-05) # log + np.testing.assert_allclose( + sinkhorn_m.sum(1), G_m.sum(1), atol=1e-05) # metric euclidian + np.testing.assert_allclose( + sinkhorn_m.sum(0), G_m.sum(0), atol=1e-05) # metric euclidian + np.testing.assert_allclose(loss_emp_sinkhorn, loss_sinkhorn, atol=1e-05) + + +def test_empirical_sinkhorn_divergence(nx): + # Test sinkhorn divergence + n = 10 + a = np.linspace(1, n, n) + a /= a.sum() + b = ot.unif(n) + X_s = np.reshape(np.arange(n, dtype=np.float64), (n, 1)) + X_t = np.reshape(np.arange(0, n * 2, 2, dtype=np.float64), (n, 1)) + M = ot.dist(X_s, X_t) + M_s = ot.dist(X_s, X_s) + M_t = ot.dist(X_t, X_t) + + ab, bb, X_sb, X_tb, M_nx, M_sb, M_tb = nx.from_numpy(a, b, X_s, X_t, M, M_s, M_t) + + emp_sinkhorn_div = nx.to_numpy(ot.bregman.empirical_sinkhorn_divergence(X_sb, X_tb, 1, a=ab, b=bb)) + sinkhorn_div = nx.to_numpy( + ot.sinkhorn2(ab, bb, M_nx, 1) + - 1 / 2 * ot.sinkhorn2(ab, ab, M_sb, 1) + - 1 / 2 * ot.sinkhorn2(bb, bb, M_tb, 1) + ) + emp_sinkhorn_div_np = ot.bregman.empirical_sinkhorn_divergence(X_s, X_t, 1, a=a, b=b) + + # check constraints + np.testing.assert_allclose(emp_sinkhorn_div, emp_sinkhorn_div_np, atol=1e-05) + np.testing.assert_allclose( + emp_sinkhorn_div, sinkhorn_div, atol=1e-05) # cf conv emp sinkhorn + + ot.bregman.empirical_sinkhorn_divergence(X_sb, X_tb, 1, a=ab, b=bb, log=True) + + +def test_stabilized_vs_sinkhorn_multidim(nx): + # test if stable version matches sinkhorn + # for multidimensional inputs + n = 100 + + # Gaussian distributions + a = ot.datasets.make_1D_gauss(n, m=20, s=5) # m= mean, s= std + b1 = ot.datasets.make_1D_gauss(n, m=60, s=8) + b2 = ot.datasets.make_1D_gauss(n, m=30, s=4) + + # creating matrix A containing all distributions + b = np.vstack((b1, b2)).T + + M = ot.utils.dist0(n) + M /= np.median(M) + epsilon = 0.1 + + ab, bb, M_nx = nx.from_numpy(a, b, M) + + G_np, _ = ot.bregman.sinkhorn(a, b, M, reg=epsilon, method="sinkhorn", log=True) + G, log = ot.bregman.sinkhorn(ab, bb, M_nx, reg=epsilon, + method="sinkhorn_stabilized", + log=True) + G = nx.to_numpy(G) + G2, log2 = ot.bregman.sinkhorn(ab, bb, M_nx, epsilon, + method="sinkhorn", log=True) + G2 = nx.to_numpy(G2) + + np.testing.assert_allclose(G_np, G2) + np.testing.assert_allclose(G, G2) + + +def test_implemented_methods(): + IMPLEMENTED_METHODS = ['sinkhorn', 'sinkhorn_stabilized'] + ONLY_1D_methods = ['greenkhorn', 'sinkhorn_epsilon_scaling'] + NOT_VALID_TOKENS = ['foo'] + # test generalized sinkhorn for unbalanced OT barycenter + n = 3 + rng = np.random.RandomState(42) + + x = rng.randn(n, 2) + a = ot.utils.unif(n) + + # make dists unbalanced + b = ot.utils.unif(n) + A = rng.rand(n, 2) + A /= A.sum(0, keepdims=True) + M = ot.dist(x, x) + epsilon = 1.0 + + for method in IMPLEMENTED_METHODS: + ot.bregman.sinkhorn(a, b, M, epsilon, method=method) + ot.bregman.sinkhorn2(a, b, M, epsilon, method=method) + ot.bregman.barycenter(A, M, reg=epsilon, method=method) + with pytest.raises(ValueError): + for method in set(NOT_VALID_TOKENS): + ot.bregman.sinkhorn(a, b, M, epsilon, method=method) + ot.bregman.sinkhorn2(a, b, M, epsilon, method=method) + ot.bregman.barycenter(A, M, reg=epsilon, method=method) + for method in ONLY_1D_methods: + ot.bregman.sinkhorn(a, b, M, epsilon, method=method) + with pytest.raises(ValueError): + ot.bregman.sinkhorn2(a, b, M, epsilon, method=method) + + +@pytest.skip_backend('tf') +@pytest.skip_backend("cupy") +@pytest.skip_backend("jax") +@pytest.mark.filterwarnings("ignore:Bottleneck") +def test_screenkhorn(nx): + # test screenkhorn + rng = np.random.RandomState(0) + n = 100 + a = ot.unif(n) + b = ot.unif(n) + + x = rng.randn(n, 2) + M = ot.dist(x, x) + + ab, bb, M_nx = nx.from_numpy(a, b, M) + + # sinkhorn + G_sink = nx.to_numpy(ot.sinkhorn(ab, bb, M_nx, 1e-1)) + # screenkhorn + G_screen = nx.to_numpy(ot.bregman.screenkhorn(ab, bb, M_nx, 1e-1, uniform=True, verbose=True)) + # check marginals + np.testing.assert_allclose(G_sink.sum(0), G_screen.sum(0), atol=1e-02) + np.testing.assert_allclose(G_sink.sum(1), G_screen.sum(1), atol=1e-02) + + +def test_convolutional_barycenter_non_square(nx): + # test for image with height not equal width + A = np.ones((2, 2, 3)) / (2 * 3) + A_nx = nx.from_numpy(A) + + b_np = ot.bregman.convolutional_barycenter2d(A, 1e-03) + b = nx.to_numpy(ot.bregman.convolutional_barycenter2d(A_nx, 1e-03)) + + np.testing.assert_allclose(np.ones((2, 3)) / (2 * 3), b, atol=1e-02) + np.testing.assert_allclose(np.ones((2, 3)) / (2 * 3), b, atol=1e-02) + np.testing.assert_allclose(b, b_np) From 59c1457021b23938de1fae10ea2fa834c140fbba Mon Sep 17 00:00:00 2001 From: eddardd Date: Thu, 21 Jul 2022 10:47:41 +0200 Subject: [PATCH 09/14] Adding section on Sinkhorn barycenter to the example --- .../plot_free_support_barycenter.py | 28 +++++++++++++++++-- 1 file changed, 25 insertions(+), 3 deletions(-) diff --git a/examples/barycenters/plot_free_support_barycenter.py b/examples/barycenters/plot_free_support_barycenter.py index 226dfeb43..f4a13dd79 100644 --- a/examples/barycenters/plot_free_support_barycenter.py +++ b/examples/barycenters/plot_free_support_barycenter.py @@ -4,13 +4,14 @@ 2D free support Wasserstein barycenters of distributions ======================================================== -Illustration of 2D Wasserstein barycenters if distributions are weighted +Illustration of 2D Wasserstein and Sinkhorn barycenters if distributions are weighted sum of diracs. """ # Authors: Vivien Seguy # Rémi Flamary +# Eduardo Fernandes Montesuma # # License: MIT License @@ -48,7 +49,7 @@ # %% -# Compute free support barycenter +# Compute free support Wasserstein barycenter # ------------------------------- k = 200 # number of Diracs of the barycenter @@ -58,7 +59,28 @@ X = ot.lp.free_support_barycenter(measures_locations, measures_weights, X_init, b) # %% -# Plot the barycenter +# Plot the Wasserstein barycenter +# --------- + +pl.figure(2, (8, 3)) +pl.scatter(x1[:, 0], x1[:, 1], alpha=0.5) +pl.scatter(x2[:, 0], x2[:, 1], alpha=0.5) +pl.scatter(X[:, 0], X[:, 1], s=b * 1000, marker='s', label='2-Wasserstein barycenter') +pl.title('Data measures and their barycenter') +pl.legend(loc="lower right") +pl.show() + +# %% +# Compute free support Sinkhorn barycenter + +k = 200 # number of Diracs of the barycenter +X_init = np.random.normal(0., 1., (k, d)) # initial Dirac locations +b = np.ones((k,)) / k # weights of the barycenter (it will not be optimized, only the locations are optimized) + +X = ot.bregman.free_support_sinkhorn_barycenter(measures_locations, measures_weights, X_init, 20, b, numItermax=15) + +# %% +# Plot the Wasserstein barycenter # --------- pl.figure(2, (8, 3)) From eb691c66f01370aa76fbf41c457a5d2aa01a3ca4 Mon Sep 17 00:00:00 2001 From: eddardd Date: Thu, 21 Jul 2022 10:48:18 +0200 Subject: [PATCH 10/14] Changing distributions for the Sinkhorn barycenter example --- .../plot_free_support_sinkhorn_barycenter.py | 39 +++++++------------ 1 file changed, 15 insertions(+), 24 deletions(-) diff --git a/examples/barycenters/plot_free_support_sinkhorn_barycenter.py b/examples/barycenters/plot_free_support_sinkhorn_barycenter.py index 996a53f1e..ebe1f3b75 100644 --- a/examples/barycenters/plot_free_support_sinkhorn_barycenter.py +++ b/examples/barycenters/plot_free_support_sinkhorn_barycenter.py @@ -14,7 +14,6 @@ import numpy as np import matplotlib.pyplot as plt -from sklearn.datasets import make_moons import ot # %% @@ -29,26 +28,18 @@ # Generate Data # ------------- +X1 = np.random.randn(200, 2) +X2 = 2 * np.concatenate([ + np.concatenate([- np.ones([50, 1]), np.linspace(-1, 1, 50)[:, None]], axis=1), + np.concatenate([np.linspace(-1, 1, 50)[:, None], np.ones([50, 1])], axis=1), + np.concatenate([np.ones([50, 1]), np.linspace(1, -1, 50)[:, None]], axis=1), + np.concatenate([np.linspace(1, -1, 50)[:, None], - np.ones([50, 1])], axis=1), +], axis=0) +X3 = np.random.randn(200, 2) +X3 = 2 * (X3 / np.linalg.norm(X3, axis=1)[:, None]) +X4 = np.random.multivariate_normal(np.array([0, 0]), np.array([[1., 0.5], [0.5, 1.]]), size=200) -def get_rotation(angle): - """Returns a rotation matrix for angle given in degrees""" - R = np.array([ - [np.cos((angle / 180) * np.pi), -np.sin((angle / 180) * np.pi)], - [np.sin((angle / 180) * np.pi), np.cos((angle / 180) * np.pi)] - ]) - return R - - -R2, R3, R4 = get_rotation(15), get_rotation(30), get_rotation(45) - -X1, _ = make_moons(n_samples=300, noise=1e-1) -a1 = ot.utils.unif(X1.shape[0], type_as=X1) -X2 = np.dot(X1, R2) -a2 = ot.utils.unif(X1.shape[0], type_as=X2) -X3 = np.dot(X1, R3) -a3 = ot.utils.unif(X1.shape[0], type_as=X3) -X4 = np.dot(X1, R4) -a4 = ot.utils.unif(X1.shape[0], type_as=X4) +a1, a2, a3, a4 = ot.unif(len(X1)), ot.unif(len(X1)), ot.unif(len(X1)), ot.unif(len(X1)) # %% # Inspect generated distributions @@ -63,19 +54,19 @@ def get_rotation(angle): axes[0].set_xlim([-3, 3]) axes[0].set_ylim([-3, 3]) -axes[0].set_title('Rotation: 0') +axes[0].set_title('Distribution 1') axes[1].set_xlim([-3, 3]) axes[1].set_ylim([-3, 3]) -axes[1].set_title('Rotation: 15') +axes[1].set_title('Distribution 2') axes[2].set_xlim([-3, 3]) axes[2].set_ylim([-3, 3]) -axes[2].set_title('Rotation: 30') +axes[2].set_title('Distribution 3') axes[3].set_xlim([-3, 3]) axes[3].set_ylim([-3, 3]) -axes[3].set_title('Rotation: 45') +axes[3].set_title('Distribution 4') plt.tight_layout() plt.show() From 551ac04b56695111c0ccf50d5134fff98d4c3e22 Mon Sep 17 00:00:00 2001 From: eddardd Date: Fri, 22 Jul 2022 17:53:05 +0200 Subject: [PATCH 11/14] Removing file that should not be on the last commit --- test/tmp.py | 982 ---------------------------------------------------- 1 file changed, 982 deletions(-) delete mode 100644 test/tmp.py diff --git a/test/tmp.py b/test/tmp.py deleted file mode 100644 index e128ea225..000000000 --- a/test/tmp.py +++ /dev/null @@ -1,982 +0,0 @@ -"""Tests for module bregman on OT with bregman projections """ - -# Author: Remi Flamary -# Kilian Fatras -# Quang Huy Tran -# Eduardo Fernandes Montesuma -# -# License: MIT License - -from itertools import product - -import numpy as np -import pytest - -import ot -from ot.backend import torch, tf - - -@pytest.mark.parametrize("verbose, warn", product([True, False], [True, False])) -def test_sinkhorn(verbose, warn): - # test sinkhorn - n = 100 - rng = np.random.RandomState(0) - - x = rng.randn(n, 2) - u = ot.utils.unif(n) - - M = ot.dist(x, x) - - G = ot.sinkhorn(u, u, M, 1, stopThr=1e-10, verbose=verbose, warn=warn) - - # check constraints - np.testing.assert_allclose( - u, G.sum(1), atol=1e-05) # cf convergence sinkhorn - np.testing.assert_allclose( - u, G.sum(0), atol=1e-05) # cf convergence sinkhorn - - with pytest.warns(UserWarning): - ot.sinkhorn(u, u, M, 1, stopThr=0, numItermax=1) - - -@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized", - "sinkhorn_epsilon_scaling", - "greenkhorn", - "sinkhorn_log"]) -def test_convergence_warning(method): - # test sinkhorn - n = 100 - a1 = ot.datasets.make_1D_gauss(n, m=30, s=10) - a2 = ot.datasets.make_1D_gauss(n, m=40, s=10) - A = np.asarray([a1, a2]).T - M = ot.utils.dist0(n) - - with pytest.warns(UserWarning): - ot.sinkhorn(a1, a2, M, 1., method=method, stopThr=0, numItermax=1) - - if method in ["sinkhorn", "sinkhorn_stabilized", "sinkhorn_log"]: - with pytest.warns(UserWarning): - ot.barycenter(A, M, 1, method=method, stopThr=0, numItermax=1) - with pytest.warns(UserWarning): - ot.sinkhorn2(a1, a2, M, 1, method=method, stopThr=0, numItermax=1) - - -def test_not_implemented_method(): - # test sinkhorn - w = 10 - n = w ** 2 - rng = np.random.RandomState(42) - A_img = rng.rand(2, w, w) - A_flat = A_img.reshape(n, 2) - a1, a2 = A_flat.T - M_flat = ot.utils.dist0(n) - not_implemented = "new_method" - reg = 0.01 - with pytest.raises(ValueError): - ot.sinkhorn(a1, a2, M_flat, reg, method=not_implemented) - with pytest.raises(ValueError): - ot.sinkhorn2(a1, a2, M_flat, reg, method=not_implemented) - with pytest.raises(ValueError): - ot.barycenter(A_flat, M_flat, reg, method=not_implemented) - with pytest.raises(ValueError): - ot.bregman.barycenter_debiased(A_flat, M_flat, reg, - method=not_implemented) - with pytest.raises(ValueError): - ot.bregman.convolutional_barycenter2d(A_img, reg, - method=not_implemented) - with pytest.raises(ValueError): - ot.bregman.convolutional_barycenter2d_debiased(A_img, reg, - method=not_implemented) - - -@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"]) -def test_nan_warning(method): - # test sinkhorn - n = 100 - a1 = ot.datasets.make_1D_gauss(n, m=30, s=10) - a2 = ot.datasets.make_1D_gauss(n, m=40, s=10) - - M = ot.utils.dist0(n) - reg = 0 - with pytest.warns(UserWarning): - # warn set to False to avoid catching a convergence warning instead - ot.sinkhorn(a1, a2, M, reg, method=method, warn=False) - - -def test_sinkhorn_stabilization(): - # test sinkhorn - n = 100 - a1 = ot.datasets.make_1D_gauss(n, m=30, s=10) - a2 = ot.datasets.make_1D_gauss(n, m=40, s=10) - M = ot.utils.dist0(n) - reg = 1e-5 - loss1 = ot.sinkhorn2(a1, a2, M, reg, method="sinkhorn_log") - loss2 = ot.sinkhorn2(a1, a2, M, reg, tau=1, method="sinkhorn_stabilized") - np.testing.assert_allclose( - loss1, loss2, atol=1e-06) # cf convergence sinkhorn - - -@pytest.mark.parametrize("method, verbose, warn", - product(["sinkhorn", "sinkhorn_stabilized", - "sinkhorn_log"], - [True, False], [True, False])) -def test_sinkhorn_multi_b(method, verbose, warn): - # test sinkhorn - n = 10 - rng = np.random.RandomState(0) - - x = rng.randn(n, 2) - u = ot.utils.unif(n) - - b = rng.rand(n, 3) - b = b / np.sum(b, 0, keepdims=True) - - M = ot.dist(x, x) - - loss0, log = ot.sinkhorn(u, b, M, .1, method=method, stopThr=1e-10, - log=True) - - loss = [ot.sinkhorn2(u, b[:, k], M, .1, method=method, stopThr=1e-10, - verbose=verbose, warn=warn) for k in range(3)] - # check constraints - np.testing.assert_allclose( - loss0, loss, atol=1e-4) # cf convergence sinkhorn - - -def test_sinkhorn_backends(nx): - n_samples = 100 - n_features = 2 - rng = np.random.RandomState(0) - - x = rng.randn(n_samples, n_features) - y = rng.randn(n_samples, n_features) - a = ot.utils.unif(n_samples) - - M = ot.dist(x, y) - - G = ot.sinkhorn(a, a, M, 1) - - ab, M_nx = nx.from_numpy(a, M) - - Gb = ot.sinkhorn(ab, ab, M_nx, 1) - - np.allclose(G, nx.to_numpy(Gb)) - - -def test_sinkhorn2_backends(nx): - n_samples = 100 - n_features = 2 - rng = np.random.RandomState(0) - - x = rng.randn(n_samples, n_features) - y = rng.randn(n_samples, n_features) - a = ot.utils.unif(n_samples) - - M = ot.dist(x, y) - - G = ot.sinkhorn(a, a, M, 1) - - ab, M_nx = nx.from_numpy(a, M) - - Gb = ot.sinkhorn2(ab, ab, M_nx, 1) - - np.allclose(G, nx.to_numpy(Gb)) - - -def test_sinkhorn2_gradients(): - n_samples = 100 - n_features = 2 - rng = np.random.RandomState(0) - - x = rng.randn(n_samples, n_features) - y = rng.randn(n_samples, n_features) - a = ot.utils.unif(n_samples) - - M = ot.dist(x, y) - - if torch: - - a1 = torch.tensor(a, requires_grad=True) - b1 = torch.tensor(a, requires_grad=True) - M1 = torch.tensor(M, requires_grad=True) - - val = ot.sinkhorn2(a1, b1, M1, 1) - - val.backward() - - assert a1.shape == a1.grad.shape - assert b1.shape == b1.grad.shape - assert M1.shape == M1.grad.shape - - -def test_sinkhorn_empty(): - # test sinkhorn - n = 100 - rng = np.random.RandomState(0) - - x = rng.randn(n, 2) - u = ot.utils.unif(n) - - M = ot.dist(x, x) - - G, log = ot.sinkhorn([], [], M, 1, stopThr=1e-10, method="sinkhorn_log", - verbose=True, log=True) - # check constraints - np.testing.assert_allclose(u, G.sum(1), atol=1e-05) - np.testing.assert_allclose(u, G.sum(0), atol=1e-05) - - G, log = ot.sinkhorn([], [], M, 1, stopThr=1e-10, verbose=True, log=True) - # check constraints - np.testing.assert_allclose(u, G.sum(1), atol=1e-05) - np.testing.assert_allclose(u, G.sum(0), atol=1e-05) - - G, log = ot.sinkhorn([], [], M, 1, stopThr=1e-10, - method='sinkhorn_stabilized', verbose=True, log=True) - # check constraints - np.testing.assert_allclose(u, G.sum(1), atol=1e-05) - np.testing.assert_allclose(u, G.sum(0), atol=1e-05) - - G, log = ot.sinkhorn( - [], [], M, 1, stopThr=1e-10, method='sinkhorn_epsilon_scaling', - verbose=True, log=True) - # check constraints - np.testing.assert_allclose(u, G.sum(1), atol=1e-05) - np.testing.assert_allclose(u, G.sum(0), atol=1e-05) - - # test empty weights greenkhorn - ot.sinkhorn([], [], M, 1, method='greenkhorn', stopThr=1e-10, log=True) - - -@pytest.skip_backend('tf') -@pytest.skip_backend("jax") -def test_sinkhorn_variants(nx): - # test sinkhorn - n = 100 - rng = np.random.RandomState(0) - - x = rng.randn(n, 2) - u = ot.utils.unif(n) - - M = ot.dist(x, x) - - ub, M_nx = nx.from_numpy(u, M) - - G = ot.sinkhorn(u, u, M, 1, method='sinkhorn', stopThr=1e-10) - Gl = nx.to_numpy(ot.sinkhorn(ub, ub, M_nx, 1, method='sinkhorn_log', stopThr=1e-10)) - G0 = nx.to_numpy(ot.sinkhorn(ub, ub, M_nx, 1, method='sinkhorn', stopThr=1e-10)) - Gs = nx.to_numpy(ot.sinkhorn(ub, ub, M_nx, 1, method='sinkhorn_stabilized', stopThr=1e-10)) - Ges = nx.to_numpy(ot.sinkhorn( - ub, ub, M_nx, 1, method='sinkhorn_epsilon_scaling', stopThr=1e-10)) - G_green = nx.to_numpy(ot.sinkhorn(ub, ub, M_nx, 1, method='greenkhorn', stopThr=1e-10)) - - # check values - np.testing.assert_allclose(G, G0, atol=1e-05) - np.testing.assert_allclose(G, Gl, atol=1e-05) - np.testing.assert_allclose(G0, Gs, atol=1e-05) - np.testing.assert_allclose(G0, Ges, atol=1e-05) - np.testing.assert_allclose(G0, G_green, atol=1e-5) - - -@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized", - "sinkhorn_epsilon_scaling", - "greenkhorn", - "sinkhorn_log"]) -@pytest.skip_arg(("nx", "method"), ("tf", "sinkhorn_epsilon_scaling"), reason="tf does not support sinkhorn_epsilon_scaling", getter=str) -@pytest.skip_arg(("nx", "method"), ("tf", "greenkhorn"), reason="tf does not support greenkhorn", getter=str) -@pytest.skip_arg(("nx", "method"), ("jax", "sinkhorn_epsilon_scaling"), reason="jax does not support sinkhorn_epsilon_scaling", getter=str) -@pytest.skip_arg(("nx", "method"), ("jax", "greenkhorn"), reason="jax does not support greenkhorn", getter=str) -def test_sinkhorn_variants_dtype_device(nx, method): - n = 100 - - x = np.random.randn(n, 2) - u = ot.utils.unif(n) - - M = ot.dist(x, x) - - for tp in nx.__type_list__: - print(nx.dtype_device(tp)) - - ub, Mb = nx.from_numpy(u, M, type_as=tp) - - Gb = ot.sinkhorn(ub, ub, Mb, 1, method=method, stopThr=1e-10) - - nx.assert_same_dtype_device(Mb, Gb) - - -@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized", "sinkhorn_log"]) -def test_sinkhorn2_variants_dtype_device(nx, method): - n = 100 - - x = np.random.randn(n, 2) - u = ot.utils.unif(n) - - M = ot.dist(x, x) - - for tp in nx.__type_list__: - print(nx.dtype_device(tp)) - - ub, Mb = nx.from_numpy(u, M, type_as=tp) - - lossb = ot.sinkhorn2(ub, ub, Mb, 1, method=method, stopThr=1e-10) - - nx.assert_same_dtype_device(Mb, lossb) - - -@pytest.mark.skipif(not tf, reason="tf not installed") -@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized", "sinkhorn_log"]) -def test_sinkhorn2_variants_device_tf(method): - nx = ot.backend.TensorflowBackend() - n = 100 - x = np.random.randn(n, 2) - u = ot.utils.unif(n) - M = ot.dist(x, x) - - # Check that everything stays on the CPU - with tf.device("/CPU:0"): - ub, Mb = nx.from_numpy(u, M) - Gb = ot.sinkhorn(ub, ub, Mb, 1, method=method, stopThr=1e-10) - lossb = ot.sinkhorn2(ub, ub, Mb, 1, method=method, stopThr=1e-10) - nx.assert_same_dtype_device(Mb, Gb) - nx.assert_same_dtype_device(Mb, lossb) - - if len(tf.config.list_physical_devices('GPU')) > 0: - # Check that everything happens on the GPU - ub, Mb = nx.from_numpy(u, M) - Gb = ot.sinkhorn(ub, ub, Mb, 1, method=method, stopThr=1e-10) - lossb = ot.sinkhorn2(ub, ub, Mb, 1, method=method, stopThr=1e-10) - nx.assert_same_dtype_device(Mb, Gb) - nx.assert_same_dtype_device(Mb, lossb) - assert nx.dtype_device(Gb)[1].startswith("GPU") - - -@pytest.skip_backend('tf') -@pytest.skip_backend("jax") -def test_sinkhorn_variants_multi_b(nx): - # test sinkhorn - n = 50 - rng = np.random.RandomState(0) - - x = rng.randn(n, 2) - u = ot.utils.unif(n) - - b = rng.rand(n, 3) - b = b / np.sum(b, 0, keepdims=True) - - M = ot.dist(x, x) - - ub, bb, M_nx = nx.from_numpy(u, b, M) - - G = ot.sinkhorn(u, b, M, 1, method='sinkhorn', stopThr=1e-10) - Gl = nx.to_numpy(ot.sinkhorn(ub, bb, M_nx, 1, method='sinkhorn_log', stopThr=1e-10)) - G0 = nx.to_numpy(ot.sinkhorn(ub, bb, M_nx, 1, method='sinkhorn', stopThr=1e-10)) - Gs = nx.to_numpy(ot.sinkhorn(ub, bb, M_nx, 1, method='sinkhorn_stabilized', stopThr=1e-10)) - - # check values - np.testing.assert_allclose(G, G0, atol=1e-05) - np.testing.assert_allclose(G, Gl, atol=1e-05) - np.testing.assert_allclose(G0, Gs, atol=1e-05) - - -@pytest.skip_backend('tf') -@pytest.skip_backend("jax") -def test_sinkhorn2_variants_multi_b(nx): - # test sinkhorn - n = 50 - rng = np.random.RandomState(0) - - x = rng.randn(n, 2) - u = ot.utils.unif(n) - - b = rng.rand(n, 3) - b = b / np.sum(b, 0, keepdims=True) - - M = ot.dist(x, x) - - ub, bb, M_nx = nx.from_numpy(u, b, M) - - G = ot.sinkhorn2(u, b, M, 1, method='sinkhorn', stopThr=1e-10) - Gl = nx.to_numpy(ot.sinkhorn2(ub, bb, M_nx, 1, method='sinkhorn_log', stopThr=1e-10)) - G0 = nx.to_numpy(ot.sinkhorn2(ub, bb, M_nx, 1, method='sinkhorn', stopThr=1e-10)) - Gs = nx.to_numpy(ot.sinkhorn2(ub, bb, M_nx, 1, method='sinkhorn_stabilized', stopThr=1e-10)) - - # check values - np.testing.assert_allclose(G, G0, atol=1e-05) - np.testing.assert_allclose(G, Gl, atol=1e-05) - np.testing.assert_allclose(G0, Gs, atol=1e-05) - - -def test_sinkhorn_variants_log(): - # test sinkhorn - n = 50 - rng = np.random.RandomState(0) - - x = rng.randn(n, 2) - u = ot.utils.unif(n) - - M = ot.dist(x, x) - - G0, log0 = ot.sinkhorn(u, u, M, 1, method='sinkhorn', stopThr=1e-10, log=True) - Gl, logl = ot.sinkhorn(u, u, M, 1, method='sinkhorn_log', stopThr=1e-10, log=True) - Gs, logs = ot.sinkhorn(u, u, M, 1, method='sinkhorn_stabilized', stopThr=1e-10, log=True) - Ges, loges = ot.sinkhorn( - u, u, M, 1, method='sinkhorn_epsilon_scaling', stopThr=1e-10, log=True,) - G_green, loggreen = ot.sinkhorn(u, u, M, 1, method='greenkhorn', stopThr=1e-10, log=True) - - # check values - np.testing.assert_allclose(G0, Gs, atol=1e-05) - np.testing.assert_allclose(G0, Gl, atol=1e-05) - np.testing.assert_allclose(G0, Ges, atol=1e-05) - np.testing.assert_allclose(G0, G_green, atol=1e-5) - - -@pytest.mark.parametrize("verbose, warn", product([True, False], [True, False])) -def test_sinkhorn_variants_log_multib(verbose, warn): - # test sinkhorn - n = 50 - rng = np.random.RandomState(0) - - x = rng.randn(n, 2) - u = ot.utils.unif(n) - b = rng.rand(n, 3) - b = b / np.sum(b, 0, keepdims=True) - - M = ot.dist(x, x) - - G0, log0 = ot.sinkhorn(u, b, M, 1, method='sinkhorn', stopThr=1e-10, log=True) - Gl, logl = ot.sinkhorn(u, b, M, 1, method='sinkhorn_log', stopThr=1e-10, log=True, - verbose=verbose, warn=warn) - Gs, logs = ot.sinkhorn(u, b, M, 1, method='sinkhorn_stabilized', stopThr=1e-10, log=True, - verbose=verbose, warn=warn) - - # check values - np.testing.assert_allclose(G0, Gs, atol=1e-05) - np.testing.assert_allclose(G0, Gl, atol=1e-05) - - -@pytest.mark.parametrize("method, verbose, warn", - product(["sinkhorn", "sinkhorn_stabilized", "sinkhorn_log"], - [True, False], [True, False])) -def test_barycenter(nx, method, verbose, warn): - n_bins = 100 # nb bins - - # Gaussian distributions - a1 = ot.datasets.make_1D_gauss(n_bins, m=30, s=10) # m= mean, s= std - a2 = ot.datasets.make_1D_gauss(n_bins, m=40, s=10) - - # creating matrix A containing all distributions - A = np.vstack((a1, a2)).T - - # loss matrix + normalization - M = ot.utils.dist0(n_bins) - M /= M.max() - - alpha = 0.5 # 0<=alpha<=1 - weights = np.array([1 - alpha, alpha]) - - A_nx, M_nx, weights_nx = nx.from_numpy(A, M, weights) - reg = 1e-2 - - if nx.__name__ in ("jax", "tf") and method == "sinkhorn_log": - with pytest.raises(NotImplementedError): - ot.bregman.barycenter(A_nx, M_nx, reg, weights, method=method) - else: - # wasserstein - bary_wass_np = ot.bregman.barycenter(A, M, reg, weights, method=method, verbose=verbose, warn=warn) - bary_wass, _ = ot.bregman.barycenter(A_nx, M_nx, reg, weights_nx, method=method, log=True) - bary_wass = nx.to_numpy(bary_wass) - - np.testing.assert_allclose(1, np.sum(bary_wass)) - np.testing.assert_allclose(bary_wass, bary_wass_np) - - ot.bregman.barycenter(A_nx, M_nx, reg, log=True) - - -def test_free_support_sinkhorn_barycenter(): - measures_locations = [ - np.array([-1.]).reshape((1, 1)), # First dirac support - np.array([1.]).reshape((1, 1)) # Second dirac support - ] - - measures_weights = [ - np.array([1.]), # First dirac sample weights - np.array([1.]) # Second dirac sample weights - ] - - # Barycenter initialization - X_init = np.array([-12.]).reshape((1, 1)) - - # Obvious barycenter locations. Take a look on test_ot.py, test_free_support_barycenter - bar_locations = np.array([0.]).reshape((1, 1)) - - # Calculate free support barycenter w/ Sinkhorn algorithm. We set the entropic regularization - # term to 1, but this should be, in general, fine-tuned to the problem. - X = ot.bregman.free_support_sinkhorn_barycenter(measures_locations, measures_weights, X_init, reg=1) - - # Verifies if calculated barycenter matches ground-truth - np.testing.assert_allclose(X, bar_locations, rtol=1e-5, atol=1e-7) - - -@pytest.mark.parametrize("method, verbose, warn", - product(["sinkhorn", "sinkhorn_stabilized", "sinkhorn_log"], - [True, False], [True, False])) -def test_barycenter_assymetric_cost(nx, method, verbose, warn): - n_bins = 20 # nb bins - - # Gaussian distributions - A = ot.datasets.make_1D_gauss(n_bins, m=30, s=10) # m= mean, s= std - - # creating matrix A containing all distributions - A = A[:, None] - - # assymetric loss matrix + normalization - rng = np.random.RandomState(42) - M = rng.randn(n_bins, n_bins) ** 2 - M /= M.max() - - A_nx, M_nx = nx.from_numpy(A, M) - reg = 1e-2 - - if nx.__name__ in ("jax", "tf") and method == "sinkhorn_log": - with pytest.raises(NotImplementedError): - ot.bregman.barycenter(A_nx, M_nx, reg, method=method) - else: - # wasserstein - bary_wass_np = ot.bregman.barycenter(A, M, reg, method=method, verbose=verbose, warn=warn) - bary_wass, _ = ot.bregman.barycenter(A_nx, M_nx, reg, method=method, log=True) - bary_wass = nx.to_numpy(bary_wass) - - np.testing.assert_allclose(1, np.sum(bary_wass)) - np.testing.assert_allclose(bary_wass, bary_wass_np) - - ot.bregman.barycenter(A_nx, M_nx, reg, log=True) - - -@pytest.mark.parametrize("method, verbose, warn", - product(["sinkhorn", "sinkhorn_log"], - [True, False], [True, False])) -def test_barycenter_debiased(nx, method, verbose, warn): - n_bins = 100 # nb bins - - # Gaussian distributions - a1 = ot.datasets.make_1D_gauss(n_bins, m=30, s=10) # m= mean, s= std - a2 = ot.datasets.make_1D_gauss(n_bins, m=40, s=10) - - # creating matrix A containing all distributions - A = np.vstack((a1, a2)).T - - # loss matrix + normalization - M = ot.utils.dist0(n_bins) - M /= M.max() - - alpha = 0.5 # 0<=alpha<=1 - weights = np.array([1 - alpha, alpha]) - - A_nx, M_nx, weights_nx = nx.from_numpy(A, M, weights) - - # wasserstein - reg = 1e-2 - if nx.__name__ in ("jax", "tf") and method == "sinkhorn_log": - with pytest.raises(NotImplementedError): - ot.bregman.barycenter_debiased(A_nx, M_nx, reg, weights, method=method) - else: - bary_wass_np = ot.bregman.barycenter_debiased(A, M, reg, weights, method=method, - verbose=verbose, warn=warn) - bary_wass, _ = ot.bregman.barycenter_debiased(A_nx, M_nx, reg, weights_nx, method=method, log=True) - bary_wass = nx.to_numpy(bary_wass) - - np.testing.assert_allclose(1, np.sum(bary_wass), atol=1e-3) - np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-5) - - ot.bregman.barycenter_debiased(A_nx, M_nx, reg, log=True, verbose=False) - - -@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_log"]) -def test_convergence_warning_barycenters(method): - w = 10 - n_bins = w ** 2 # nb bins - - # Gaussian distributions - a1 = ot.datasets.make_1D_gauss(n_bins, m=30, s=10) # m= mean, s= std - a2 = ot.datasets.make_1D_gauss(n_bins, m=40, s=10) - - # creating matrix A containing all distributions - A = np.vstack((a1, a2)).T - A_img = A.reshape(2, w, w) - A_img /= A_img.sum((1, 2))[:, None, None] - - # loss matrix + normalization - M = ot.utils.dist0(n_bins) - M /= M.max() - - alpha = 0.5 # 0<=alpha<=1 - weights = np.array([1 - alpha, alpha]) - reg = 0.1 - with pytest.warns(UserWarning): - ot.bregman.barycenter_debiased(A, M, reg, weights, method=method, numItermax=1) - with pytest.warns(UserWarning): - ot.bregman.barycenter(A, M, reg, weights, method=method, numItermax=1) - with pytest.warns(UserWarning): - ot.bregman.convolutional_barycenter2d(A_img, reg, weights, - method=method, numItermax=1) - with pytest.warns(UserWarning): - ot.bregman.convolutional_barycenter2d_debiased(A_img, reg, weights, - method=method, numItermax=1) - - -def test_barycenter_stabilization(nx): - n_bins = 100 # nb bins - - # Gaussian distributions - a1 = ot.datasets.make_1D_gauss(n_bins, m=30, s=10) # m= mean, s= std - a2 = ot.datasets.make_1D_gauss(n_bins, m=40, s=10) - - # creating matrix A containing all distributions - A = np.vstack((a1, a2)).T - - # loss matrix + normalization - M = ot.utils.dist0(n_bins) - M /= M.max() - - alpha = 0.5 # 0<=alpha<=1 - weights = np.array([1 - alpha, alpha]) - - A_nx, M_nx, weights_b = nx.from_numpy(A, M, weights) - - # wasserstein - reg = 1e-2 - bar_np = ot.bregman.barycenter(A, M, reg, weights, method="sinkhorn", stopThr=1e-8, verbose=True) - bar_stable = nx.to_numpy(ot.bregman.barycenter( - A_nx, M_nx, reg, weights_b, method="sinkhorn_stabilized", - stopThr=1e-8, verbose=True - )) - bar = nx.to_numpy(ot.bregman.barycenter( - A_nx, M_nx, reg, weights_b, method="sinkhorn", - stopThr=1e-8, verbose=True - )) - np.testing.assert_allclose(bar, bar_stable) - np.testing.assert_allclose(bar, bar_np) - - -@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_log"]) -def test_wasserstein_bary_2d(nx, method): - size = 20 # size of a square image - a1 = np.random.rand(size, size) - a1 += a1.min() - a1 = a1 / np.sum(a1) - a2 = np.random.rand(size, size) - a2 += a2.min() - a2 = a2 / np.sum(a2) - # creating matrix A containing all distributions - A = np.zeros((2, size, size)) - A[0, :, :] = a1 - A[1, :, :] = a2 - - A_nx = nx.from_numpy(A) - - # wasserstein - reg = 1e-2 - if nx.__name__ in ("jax", "tf") and method == "sinkhorn_log": - with pytest.raises(NotImplementedError): - ot.bregman.convolutional_barycenter2d(A_nx, reg, method=method) - else: - bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d(A, reg, method=method, verbose=True, log=True) - bary_wass = nx.to_numpy(ot.bregman.convolutional_barycenter2d(A_nx, reg, method=method)) - - np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3) - np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3) - - # help in checking if log and verbose do not bug the function - ot.bregman.convolutional_barycenter2d(A, reg, log=True, verbose=True) - - -@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_log"]) -def test_wasserstein_bary_2d_debiased(nx, method): - size = 20 # size of a square image - a1 = np.random.rand(size, size) - a1 += a1.min() - a1 = a1 / np.sum(a1) - a2 = np.random.rand(size, size) - a2 += a2.min() - a2 = a2 / np.sum(a2) - # creating matrix A containing all distributions - A = np.zeros((2, size, size)) - A[0, :, :] = a1 - A[1, :, :] = a2 - - A_nx = nx.from_numpy(A) - - # wasserstein - reg = 1e-2 - if nx.__name__ in ("jax", "tf") and method == "sinkhorn_log": - with pytest.raises(NotImplementedError): - ot.bregman.convolutional_barycenter2d_debiased(A_nx, reg, method=method) - else: - bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d_debiased(A, reg, method=method, verbose=True, log=True) - bary_wass = nx.to_numpy(ot.bregman.convolutional_barycenter2d_debiased(A_nx, reg, method=method)) - - np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3) - np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3) - - # help in checking if log and verbose do not bug the function - ot.bregman.convolutional_barycenter2d(A, reg, log=True, verbose=True) - - -def test_unmix(nx): - n_bins = 50 # nb bins - - # Gaussian distributions - a1 = ot.datasets.make_1D_gauss(n_bins, m=20, s=10) # m= mean, s= std - a2 = ot.datasets.make_1D_gauss(n_bins, m=40, s=10) - - a = ot.datasets.make_1D_gauss(n_bins, m=30, s=10) - - # creating matrix A containing all distributions - D = np.vstack((a1, a2)).T - - # loss matrix + normalization - M = ot.utils.dist0(n_bins) - M /= M.max() - - M0 = ot.utils.dist0(2) - M0 /= M0.max() - h0 = ot.unif(2) - - ab, Db, M_nx, M0b, h0b = nx.from_numpy(a, D, M, M0, h0) - - # wasserstein - reg = 1e-3 - um_np = ot.bregman.unmix(a, D, M, M0, h0, reg, 1, alpha=0.01) - um = nx.to_numpy(ot.bregman.unmix(ab, Db, M_nx, M0b, h0b, reg, 1, alpha=0.01)) - - np.testing.assert_allclose(1, np.sum(um), rtol=1e-03, atol=1e-03) - np.testing.assert_allclose([0.5, 0.5], um, rtol=1e-03, atol=1e-03) - np.testing.assert_allclose(um, um_np) - - ot.bregman.unmix(ab, Db, M_nx, M0b, h0b, reg, - 1, alpha=0.01, log=True, verbose=True) - - -def test_empirical_sinkhorn(nx): - # test sinkhorn - n = 10 - a = ot.unif(n) - b = ot.unif(n) - - X_s = np.reshape(1.0 * np.arange(n), (n, 1)) - X_t = np.reshape(1.0 * np.arange(0, n), (n, 1)) - M = ot.dist(X_s, X_t) - M_m = ot.dist(X_s, X_t, metric='euclidean') - - ab, bb, X_sb, X_tb, M_nx, M_mb = nx.from_numpy(a, b, X_s, X_t, M, M_m) - - G_sqe = nx.to_numpy(ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1)) - sinkhorn_sqe = nx.to_numpy(ot.sinkhorn(ab, bb, M_nx, 1)) - - G_log, log_es = ot.bregman.empirical_sinkhorn(X_sb, X_tb, 0.1, log=True) - G_log = nx.to_numpy(G_log) - sinkhorn_log, log_s = ot.sinkhorn(ab, bb, M_nx, 0.1, log=True) - sinkhorn_log = nx.to_numpy(sinkhorn_log) - - G_m = nx.to_numpy(ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1, metric='euclidean')) - sinkhorn_m = nx.to_numpy(ot.sinkhorn(ab, bb, M_mb, 1)) - - loss_emp_sinkhorn = nx.to_numpy(ot.bregman.empirical_sinkhorn2(X_sb, X_tb, 1)) - loss_sinkhorn = nx.to_numpy(ot.sinkhorn2(ab, bb, M_nx, 1)) - - # check constraints - np.testing.assert_allclose( - sinkhorn_sqe.sum(1), G_sqe.sum(1), atol=1e-05) # metric sqeuclidian - np.testing.assert_allclose( - sinkhorn_sqe.sum(0), G_sqe.sum(0), atol=1e-05) # metric sqeuclidian - np.testing.assert_allclose( - sinkhorn_log.sum(1), G_log.sum(1), atol=1e-05) # log - np.testing.assert_allclose( - sinkhorn_log.sum(0), G_log.sum(0), atol=1e-05) # log - np.testing.assert_allclose( - sinkhorn_m.sum(1), G_m.sum(1), atol=1e-05) # metric euclidian - np.testing.assert_allclose( - sinkhorn_m.sum(0), G_m.sum(0), atol=1e-05) # metric euclidian - np.testing.assert_allclose(loss_emp_sinkhorn, loss_sinkhorn, atol=1e-05) - - -def test_lazy_empirical_sinkhorn(nx): - # test sinkhorn - n = 10 - a = ot.unif(n) - b = ot.unif(n) - numIterMax = 1000 - - X_s = np.reshape(np.arange(n, dtype=np.float64), (n, 1)) - X_t = np.reshape(np.arange(0, n, dtype=np.float64), (n, 1)) - M = ot.dist(X_s, X_t) - M_m = ot.dist(X_s, X_t, metric='euclidean') - - ab, bb, X_sb, X_tb, M_nx, M_mb = nx.from_numpy(a, b, X_s, X_t, M, M_m) - - f, g = ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1, numIterMax=numIterMax, isLazy=True, batchSize=(1, 3), verbose=True) - f, g = nx.to_numpy(f), nx.to_numpy(g) - G_sqe = np.exp(f[:, None] + g[None, :] - M / 1) - sinkhorn_sqe = nx.to_numpy(ot.sinkhorn(ab, bb, M_nx, 1)) - - f, g, log_es = ot.bregman.empirical_sinkhorn(X_sb, X_tb, 0.1, numIterMax=numIterMax, isLazy=True, batchSize=1, log=True) - f, g = nx.to_numpy(f), nx.to_numpy(g) - G_log = np.exp(f[:, None] + g[None, :] - M / 0.1) - sinkhorn_log, log_s = ot.sinkhorn(ab, bb, M_nx, 0.1, log=True) - sinkhorn_log = nx.to_numpy(sinkhorn_log) - - f, g = ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1, metric='euclidean', numIterMax=numIterMax, isLazy=True, batchSize=1) - f, g = nx.to_numpy(f), nx.to_numpy(g) - G_m = np.exp(f[:, None] + g[None, :] - M_m / 1) - sinkhorn_m = nx.to_numpy(ot.sinkhorn(ab, bb, M_mb, 1)) - - loss_emp_sinkhorn, log = ot.bregman.empirical_sinkhorn2(X_sb, X_tb, 1, numIterMax=numIterMax, isLazy=True, batchSize=1, log=True) - loss_emp_sinkhorn = nx.to_numpy(loss_emp_sinkhorn) - loss_sinkhorn = nx.to_numpy(ot.sinkhorn2(ab, bb, M_nx, 1)) - - # check constraints - np.testing.assert_allclose( - sinkhorn_sqe.sum(1), G_sqe.sum(1), atol=1e-05) # metric sqeuclidian - np.testing.assert_allclose( - sinkhorn_sqe.sum(0), G_sqe.sum(0), atol=1e-05) # metric sqeuclidian - np.testing.assert_allclose( - sinkhorn_log.sum(1), G_log.sum(1), atol=1e-05) # log - np.testing.assert_allclose( - sinkhorn_log.sum(0), G_log.sum(0), atol=1e-05) # log - np.testing.assert_allclose( - sinkhorn_m.sum(1), G_m.sum(1), atol=1e-05) # metric euclidian - np.testing.assert_allclose( - sinkhorn_m.sum(0), G_m.sum(0), atol=1e-05) # metric euclidian - np.testing.assert_allclose(loss_emp_sinkhorn, loss_sinkhorn, atol=1e-05) - - -def test_empirical_sinkhorn_divergence(nx): - # Test sinkhorn divergence - n = 10 - a = np.linspace(1, n, n) - a /= a.sum() - b = ot.unif(n) - X_s = np.reshape(np.arange(n, dtype=np.float64), (n, 1)) - X_t = np.reshape(np.arange(0, n * 2, 2, dtype=np.float64), (n, 1)) - M = ot.dist(X_s, X_t) - M_s = ot.dist(X_s, X_s) - M_t = ot.dist(X_t, X_t) - - ab, bb, X_sb, X_tb, M_nx, M_sb, M_tb = nx.from_numpy(a, b, X_s, X_t, M, M_s, M_t) - - emp_sinkhorn_div = nx.to_numpy(ot.bregman.empirical_sinkhorn_divergence(X_sb, X_tb, 1, a=ab, b=bb)) - sinkhorn_div = nx.to_numpy( - ot.sinkhorn2(ab, bb, M_nx, 1) - - 1 / 2 * ot.sinkhorn2(ab, ab, M_sb, 1) - - 1 / 2 * ot.sinkhorn2(bb, bb, M_tb, 1) - ) - emp_sinkhorn_div_np = ot.bregman.empirical_sinkhorn_divergence(X_s, X_t, 1, a=a, b=b) - - # check constraints - np.testing.assert_allclose(emp_sinkhorn_div, emp_sinkhorn_div_np, atol=1e-05) - np.testing.assert_allclose( - emp_sinkhorn_div, sinkhorn_div, atol=1e-05) # cf conv emp sinkhorn - - ot.bregman.empirical_sinkhorn_divergence(X_sb, X_tb, 1, a=ab, b=bb, log=True) - - -def test_stabilized_vs_sinkhorn_multidim(nx): - # test if stable version matches sinkhorn - # for multidimensional inputs - n = 100 - - # Gaussian distributions - a = ot.datasets.make_1D_gauss(n, m=20, s=5) # m= mean, s= std - b1 = ot.datasets.make_1D_gauss(n, m=60, s=8) - b2 = ot.datasets.make_1D_gauss(n, m=30, s=4) - - # creating matrix A containing all distributions - b = np.vstack((b1, b2)).T - - M = ot.utils.dist0(n) - M /= np.median(M) - epsilon = 0.1 - - ab, bb, M_nx = nx.from_numpy(a, b, M) - - G_np, _ = ot.bregman.sinkhorn(a, b, M, reg=epsilon, method="sinkhorn", log=True) - G, log = ot.bregman.sinkhorn(ab, bb, M_nx, reg=epsilon, - method="sinkhorn_stabilized", - log=True) - G = nx.to_numpy(G) - G2, log2 = ot.bregman.sinkhorn(ab, bb, M_nx, epsilon, - method="sinkhorn", log=True) - G2 = nx.to_numpy(G2) - - np.testing.assert_allclose(G_np, G2) - np.testing.assert_allclose(G, G2) - - -def test_implemented_methods(): - IMPLEMENTED_METHODS = ['sinkhorn', 'sinkhorn_stabilized'] - ONLY_1D_methods = ['greenkhorn', 'sinkhorn_epsilon_scaling'] - NOT_VALID_TOKENS = ['foo'] - # test generalized sinkhorn for unbalanced OT barycenter - n = 3 - rng = np.random.RandomState(42) - - x = rng.randn(n, 2) - a = ot.utils.unif(n) - - # make dists unbalanced - b = ot.utils.unif(n) - A = rng.rand(n, 2) - A /= A.sum(0, keepdims=True) - M = ot.dist(x, x) - epsilon = 1.0 - - for method in IMPLEMENTED_METHODS: - ot.bregman.sinkhorn(a, b, M, epsilon, method=method) - ot.bregman.sinkhorn2(a, b, M, epsilon, method=method) - ot.bregman.barycenter(A, M, reg=epsilon, method=method) - with pytest.raises(ValueError): - for method in set(NOT_VALID_TOKENS): - ot.bregman.sinkhorn(a, b, M, epsilon, method=method) - ot.bregman.sinkhorn2(a, b, M, epsilon, method=method) - ot.bregman.barycenter(A, M, reg=epsilon, method=method) - for method in ONLY_1D_methods: - ot.bregman.sinkhorn(a, b, M, epsilon, method=method) - with pytest.raises(ValueError): - ot.bregman.sinkhorn2(a, b, M, epsilon, method=method) - - -@pytest.skip_backend('tf') -@pytest.skip_backend("cupy") -@pytest.skip_backend("jax") -@pytest.mark.filterwarnings("ignore:Bottleneck") -def test_screenkhorn(nx): - # test screenkhorn - rng = np.random.RandomState(0) - n = 100 - a = ot.unif(n) - b = ot.unif(n) - - x = rng.randn(n, 2) - M = ot.dist(x, x) - - ab, bb, M_nx = nx.from_numpy(a, b, M) - - # sinkhorn - G_sink = nx.to_numpy(ot.sinkhorn(ab, bb, M_nx, 1e-1)) - # screenkhorn - G_screen = nx.to_numpy(ot.bregman.screenkhorn(ab, bb, M_nx, 1e-1, uniform=True, verbose=True)) - # check marginals - np.testing.assert_allclose(G_sink.sum(0), G_screen.sum(0), atol=1e-02) - np.testing.assert_allclose(G_sink.sum(1), G_screen.sum(1), atol=1e-02) - - -def test_convolutional_barycenter_non_square(nx): - # test for image with height not equal width - A = np.ones((2, 2, 3)) / (2 * 3) - A_nx = nx.from_numpy(A) - - b_np = ot.bregman.convolutional_barycenter2d(A, 1e-03) - b = nx.to_numpy(ot.bregman.convolutional_barycenter2d(A_nx, 1e-03)) - - np.testing.assert_allclose(np.ones((2, 3)) / (2 * 3), b, atol=1e-02) - np.testing.assert_allclose(np.ones((2, 3)) / (2 * 3), b, atol=1e-02) - np.testing.assert_allclose(b, b_np) From 1dc2a7c30a34e1255711cd38394905d9db523868 Mon Sep 17 00:00:00 2001 From: eddardd Date: Fri, 22 Jul 2022 17:55:38 +0200 Subject: [PATCH 12/14] Adding PR number to REALEASES.md --- RELEASES.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/RELEASES.md b/RELEASES.md index 7efda04e7..14d11c42e 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -5,7 +5,7 @@ #### New features - Added Generalized Wasserstein Barycenter solver + example (PR #372), fixed graphical details on the example (PR #376) -- Added Free Support Sinkhorn Barycenter + example +- Added Free Support Sinkhorn Barycenter + example (PR #387) #### Closed issues From 8c258e5682df9d69003dc28cc8601128a7e0cb92 Mon Sep 17 00:00:00 2001 From: eddardd Date: Mon, 25 Jul 2022 10:18:17 +0200 Subject: [PATCH 13/14] Adding new contributors --- CONTRIBUTORS.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index c535c0991..5091bd2c0 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -39,6 +39,7 @@ The contributors to this library are: * [Cédric Vincent-Cuaz](https://github.com/cedricvincentcuaz) (Graph Dictionary Learning) * [Eloi Tanguy](https://github.com/eloitanguy) (Generalized Wasserstein Barycenters) * [Camille Le Coz](https://www.linkedin.com/in/camille-le-coz-8593b91a1/) (EMD2 debug) +* [Eduardo Fernandes Montesuma](https://eddardd.github.io/my-personal-blog/) ## Acknowledgments From 0f2fd2435f009253e7530f161331aeecc9b4c335 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Mon, 25 Jul 2022 16:18:09 +0200 Subject: [PATCH 14/14] Update CONTRIBUTORS.md --- CONTRIBUTORS.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index 5091bd2c0..0524151c2 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -39,7 +39,7 @@ The contributors to this library are: * [Cédric Vincent-Cuaz](https://github.com/cedricvincentcuaz) (Graph Dictionary Learning) * [Eloi Tanguy](https://github.com/eloitanguy) (Generalized Wasserstein Barycenters) * [Camille Le Coz](https://www.linkedin.com/in/camille-le-coz-8593b91a1/) (EMD2 debug) -* [Eduardo Fernandes Montesuma](https://eddardd.github.io/my-personal-blog/) +* [Eduardo Fernandes Montesuma](https://eddardd.github.io/my-personal-blog/) (Free support sinkhorn barycenter) ## Acknowledgments