From 0327b5c14ddc68510fedce34d78b9155200cb8b5 Mon Sep 17 00:00:00 2001 From: x12hengyu Date: Tue, 4 Apr 2023 00:11:03 -0500 Subject: [PATCH 01/28] add demd.py to ot, add plot_demd_*.py to examples, updated init.py in ot, build failed need to fix --- examples/demd/plot_demd_1d.py | 94 ++++++++ examples/demd/plot_demd_gradient_minimize.py | 118 ++++++++++ ot/__init__.py | 4 +- ot/demd.py | 232 +++++++++++++++++++ 4 files changed, 447 insertions(+), 1 deletion(-) create mode 100644 examples/demd/plot_demd_1d.py create mode 100644 examples/demd/plot_demd_gradient_minimize.py create mode 100644 ot/demd.py diff --git a/examples/demd/plot_demd_1d.py b/examples/demd/plot_demd_1d.py new file mode 100644 index 000000000..4cd7ce562 --- /dev/null +++ b/examples/demd/plot_demd_1d.py @@ -0,0 +1,94 @@ +# -*- coding: utf-8 -*- +r""" +================================================================================= +1D Wasserstein barycenter: LP Barycenter vs DEMD +================================================================================= + +Compare the speed of 1D Wasserstein barycenter between LP and DEMD. +""" + +# Author: Ronak Mehta +# Xizheng Yu +# +# License: MIT License + +import numpy as np +import matplotlib.pyplot as pl +import ot +from demd import demd + +# %% +# Define 1d Barycenter Function and Compare Function +# -------------------------------------------------- + + +def lp_1d_bary(data, M, n, d): + A = np.vstack(data).T + + alpha = 1.0 # /d # 0<=alpha<=1 + weights = np.array(d*[alpha]) + + bary, bary_log = ot.lp.barycenter( + A, M, weights, solver='interior-point', verbose=False, log=True) + + return bary_log['fun'], bary + + +def compare_all(data, M, n, d): + print('IP LP Iterations:') + ot.tic() + lp_bary, lp_obj = lp_1d_bary(np.vstack(data), M, n, d) + lp_time = ot.toc('') + print('Obj\t: ', lp_bary) + print('Time\t: ', lp_time) + + print('') + print('D-EMD Algorithm:') + ot.tic() + demd_obj = demd(np.vstack(data), n, d) + demd_time = ot.toc('') + print('Obj\t: ', demd_obj) + print('Time\t: ', demd_time) + return lp_time, demd_time + +# %% +# 2 Random Dists with Increasing Bins +# ----------------------------------- + +def random2d(n=4): + print('*'*10) + d = 2 + # Gaussian distributions + a1 = ot.datasets.make_1D_gauss(n, m=20, s=5) # m= mean, s= std + a2 = ot.datasets.make_1D_gauss(n, m=60, s=8) + print(a1) + print(a2) + x = np.arange(n, dtype=np.float64).reshape((n, 1)) + M = ot.utils.dist(x, metric='minkowski') + lp_time, demd_time = compare_all([a1, a2], M, n, d) + print('*'*10, '\n') + return lp_time, demd_time + + +def increasing_bins(): + lp_times, demd_times = [], [] + ns = [5, 10, 20, 50, 100] + for n in ns: + lp_time, demd_time = random2d(n=n) + lp_times.append(lp_time) + demd_times.append(demd_time) + return ns, lp_times, demd_times + + +ns, lp_times, demd_times = increasing_bins() + + +# %% +# Plot and Compare data +# --------------------- +pl.plot(ns, lp_times, 'o', label="LP Barycenter") +pl.plot(ns, demd_times, 'o', label="DEMD") +# pl.yscale('log') +pl.ylabel('Time Per Epoch (Seconds)') +pl.xlabel('Number of Distributions') +pl.legend() \ No newline at end of file diff --git a/examples/demd/plot_demd_gradient_minimize.py b/examples/demd/plot_demd_gradient_minimize.py new file mode 100644 index 000000000..8f375a9d4 --- /dev/null +++ b/examples/demd/plot_demd_gradient_minimize.py @@ -0,0 +1,118 @@ +# -*- coding: utf-8 -*- +r""" +================================================================================= +DEMD vs LP Gradient Decent without Pytorch +================================================================================= + + +""" + +# Author: Ronak Mehta +# Xizheng Yu +# +# License: MIT License + +import io +import sys +import numpy as np +import matplotlib.pyplot as pl +import ot +from demd import demd, demd_minimize + +# %% +# Define function to get random (n, d) data +# ------------------------------------------- +def getData(n, d, dist='skewedGauss'): + print(f'Data: {d} Random Dists with {n} Bins ***') + + x = np.arange(n, dtype=np.float64).reshape((n, 1)) + M = ot.utils.dist(x, metric='minkowski') + + data = [] + for i in range(d): + # m = 100*np.random.rand(1) + m = n*(0.5*np.random.rand(1))*float(np.random.randint(2)+1) + if dist == 'skewedGauss': + a = ot.datasets.make_1D_gauss(n, m=m, s=5) + elif dist == 'uniform': + a = np.random.rand(n) + a = a / sum(a) + else: + print('unknown dist') + data.append(a) + + return data, M + +# %% +# Gradient Decent +# --------------- + +# %% parameters and data +n = 50 # nb bins +d = 7 + +vecsize = n*d + +# data, M = getData(n, d, 'uniform') +data, M = getData(n, d, 'skewedGauss') +data = np.vstack(data) + +# %% demd +# Redirect the standard output to a string buffer +old_stdout = sys.stdout +sys.stdout = output_buffer = io.StringIO() + +x = demd_minimize(demd, data, d, n, vecsize, niters=3000, lr=0.00001) + +# after minimization, any distribution can be used as a estimate of barycenter +bary = x[0] + +sys.stdout = old_stdout +output = output_buffer.getvalue() + +rows = output.strip().split("\n") +demd_loss = [float(row.split()[-3]) for row in rows[1:]] + +print(output) + +# %% lp barycenter +def lp_1d_bary(data, M, n, d): + + A = np.vstack(data).T + + alpha = 1.0 #/d # 0<=alpha<=1 + weights = np.array(d*[alpha]) + + bary, bary_log = ot.lp.barycenter(A, M, weights, solver='interior-point', verbose=True, log=True) + + return bary_log['fun'], bary + +# Redirect the standard output to a string buffer +old_stdout = sys.stdout +sys.stdout = output_buffer = io.StringIO() + +obj, lp_bary = lp_1d_bary(data, M, n, d) + +# Restore the standard output and get value +sys.stdout = old_stdout +output = output_buffer.getvalue() + +rows = output.strip().split("\n") +lp_loss = [float(row.split()[-1]) for row in rows[1:-3]] + +print(output) + + +#%% +# Compare the loss between DEMD and LP Barycenter +# --------- +# The barycenter approach does not minize the distance between +# the distributions, while our DEMD does. +index = [*range(0, len(demd_loss))] + +pl.plot(index, demd_loss, label = "DEMD") +pl.plot(index, lp_loss[:len(demd_loss)], label="LP") +pl.yscale('log') +pl.ylabel('Loss') +pl.xlabel('Epochs') +pl.legend() \ No newline at end of file diff --git a/ot/__init__.py b/ot/__init__.py index 45d5cfa44..47ccd37b1 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -52,6 +52,7 @@ from .weak import weak_optimal_transport from .factored import factored_optimal_transport from .solvers import solve +from .demd import (greedy_primal_dual, demd, demd_minimize) # utils functions from .utils import dist, unif, tic, toc, toq @@ -69,4 +70,5 @@ 'factored_optimal_transport', 'solve', 'smooth', 'stochastic', 'unbalanced', 'partial', 'regpath', 'solvers', 'binary_search_circle', 'wasserstein_circle', - 'semidiscrete_wasserstein2_unif_circle', 'sliced_wasserstein_sphere_unif'] + 'semidiscrete_wasserstein2_unif_circle', 'sliced_wasserstein_sphere_unif', + 'greedy_primal_dual', 'demd', 'demd_minimize'] diff --git a/ot/demd.py b/ot/demd.py new file mode 100644 index 000000000..c52a26144 --- /dev/null +++ b/ot/demd.py @@ -0,0 +1,232 @@ +# -*- coding: utf-8 -*- +""" +DEMD solvers for optimal transport +""" + +# Author: Ronak Mehta +# Xizheng Yu +# +# License: MIT License + +import numpy as np +from .backend import get_backend + +def greedy_primal_dual(aa, verbose=False): + r""" + The greedy algorithm that solves both primal and dual generalized Earth + mover’s programs. + + The algorithm accepts $d$ distributions (i.e., histograms) :math:`p_{1}, + \ldots, p_{d} \in \mathbb{R}_{+}^{n}` with :math:`e^{\prime} p_{j}=1` + for all :math:`j \in[d]`. Although the algorithm states that all + histograms have the same number of bins, the algorithm can be easily + adapted to accept as inputs :math:`p_{i} \in \mathbb{R}_{+}^{n_{i}}$ + with $n_{i} \neq n_{j}`. + + Parameters + ---------- + aa : list of numpy arrays + The input arrays defining the optimization problem. They must have the + same shape. + verbose : bool, optional + If True, print debugging information during the execution of the + algorithm. Default is False. + + Returns + ------- + dict : dic + A dictionary containing the solution of the primal-dual problem: + - 'x': a dictionary that maps tuples of indices to the corresponding + primal variables. The tuples are the indices of the entries that are + set to their minimum value during the algorithm. + - 'primal objective': a float, the value of the objective function + evaluated at the solution. + - 'dual': a list of numpy arrays, the dual variables corresponding to + the input arrays. The i-th element of the list is the dual variable + corresponding to the i-th dimension of the input arrays. + - 'dual objective': a float, the value of the dual objective function + evaluated at the solution. + + References + ---------- + .. [1] Jeffery Kline. Properties of the d-dimensional earth mover’s + problem. Discrete Applied Mathematics, 265: 128–141, 2019. + + Examples + -------- + >>> import numpy as np + >>> aa = [np.array([[1, 2], [3, 4]]), np.array([[5, 6], [7, 8]])] + >>> result = greedy_primal_dual(aa) + >>> result['primal objective'] + -12 + """ + # function body here + + def OBJ(i): + return max(i) - min(i) + + # print(f"aa type is: {type(aa)}") + nx = get_backend(aa) + + AA = [nx.copy(_) for _ in aa] + + dims = tuple([len(_) for _ in AA]) + xx = {} + dual = [nx.zeros(d) for d in dims] + + idx = [0, ] * len(AA) + obj = 0 + if verbose: + print('i minval oldidx\t\tobj\t\tvals') + while all([i < _ for _, i in zip(dims, idx)]): + vals = [v[i] for v, i in zip(AA, idx)] + minval = min(vals) + i = vals.index(minval) + xx[tuple(idx)] = minval + obj += (OBJ(idx)) * minval + for v, j in zip(AA, idx): + v[j] -= minval + oldidx = nx.copy(idx) + idx[i] += 1 + if idx[i] < dims[i]: + dual[i][idx[i]] += OBJ(idx) - OBJ(oldidx) + dual[i][idx[i]-1] + if verbose: + print(i, minval, oldidx, obj, '\t', vals) + + # the above terminates when any entry in idx equals the corresponding + # value in dims this leaves other dimensions incomplete; the remaining + # terms of the dual solution must be filled-in + for _, i in enumerate(idx): + try: + dual[_][i:] = dual[_][i] + except Exception: + pass + + dualobj = sum([_.dot(_d) for _, _d in zip(aa, dual)]) + + return {'x': xx, 'primal objective': obj, + 'dual': dual, 'dual objective': dualobj} + + +def demd(x, d, n, return_dual_vars=False): + r""" + Solver of our proposed method: d−Dimensional Earch Mover’s Distance (DEMD). + + Parameters + ---------- + x : numpy array, shape (d * n, ) + The input vector containing coordinates of n points in d dimensions. + d : int + The number of dimensions of the points. + n : int + The number of points. + return_dual_vars : bool, optional + If True, also return the dual variables and the dual objective value of + the DEMD problem. Default is False. + + Returns + ------- + primal_obj : float + the value of the primal objective function evaluated at the solution. + dual_vars : numpy array, shape (d, n-1), optional + the values of the dual variables corresponding to the input points. + The i-th column of the array corresponds to the i-th point. + dual_obj : float, optional + the value of the dual objective function evaluated at the solution. + + References + ---------- + .. [1] Ronak Mehta, Jeffery Kline, Vishnu Suresh Lokhande, Glenn Fung, & + Vikas Singh (2023). Efficient Discrete Multi Marginal Optimal + Transport Regularization. In The Eleventh International + Conference on Learning Representations. + + """ + + # function body here + nx = get_backend(x) + log = greedy_primal_dual(x) + + if return_dual_vars: + dual = log['dual'] + return_dual = np.array(dual) + dualobj = log['dual objective'] + return log['primal objective'], return_dual, dualobj + else: + return log['primal objective'] + + +def demd_minimize(f, x, d, n, vecsize, niters=100, lr=0.1, print_rate=100): + r""" + Minimize a DEMD function using gradient descent. + + Parameters + ---------- + f : callable + The objective function to minimize. This function must take as input + a matrix x of shape (d, n) and return a scalar value representing + the objective function evaluated at x. It may also return a matrix of + shape (d, n) representing the gradient of the objective function + with respect to x, and/or any other dual variables needed for the + optimization algorithm. The signature of this function should be: + `f(x, d, n, return_dual_vars=False) -> float` + or + `f(x, d, n, return_dual_vars=True) -> (float, ndarray, ...)` + x : ndarray, shape (d, n) + The initial point for the optimization algorithm. + d : int + The number of rows in the matrix x. + n : int + The number of columns in the matrix x. + vecsize : int + The size of the vectors that make up the columns of x. + niters : int, optional (default=100) + The maximum number of iterations for the optimization algorithm. + lr : float, optional (default=0.1) + The learning rate (step size) for the optimization algorithm. + print_rate : int, optional (default=100) + The rate at which to print the objective value and gradient norm + during the optimization algorithm. + + Returns + ------- + list of ndarrays, each of shape (n,) + The optimal solution as a list of n vectors, each of length vecsize. + """ + + # function body here + nx = get_backend(x) + + def dualIter(f, x, d, n, vecsize, lr): + funcval, grad, _ = f(x, d, n, return_dual_vars=True) + xnew = nx.reshape(x, (d, n)) - grad * lr + return funcval, xnew, grad + + def renormalize(x, d, n, vecsize): + x = nx.reshape(x, (d, n)) + for i in range(x.shape[0]): + if min(x[i, :]) < 0: + x[i, :] -= min(x[i, :]) + x[i, :] /= nx.sum(x[i, :]) + return x + + def listify(x): + return [x[i, :] for i in range(x.shape[0])] + + # print(f"x type is {type(x)}") + funcval, _, grad = dualIter(f, x, d, n, vecsize, lr) + gn = nx.norm(grad) + + print(f'Inital:\t\tObj:\t{funcval:.4f}\tGradNorm:\t{gn:.4f}') + + for i in range(niters): + + x = renormalize(x, d, n, vecsize) + funcval, x, grad = dualIter(f, x, d, n, vecsize, lr) + gn = nx.norm(grad) + + if i % print_rate == 0: + print(f'Iter {i:2.0f}:\tObj:\t{funcval:.4f}\tGradNorm:\t{gn:.4f}') + + x = renormalize(x, d, n, vecsize) + return listify(nx.reshape(x, (d, n))) \ No newline at end of file From 27878b750a59b347f50e7ffba4d4e95698f59cfa Mon Sep 17 00:00:00 2001 From: x12hengyu Date: Wed, 5 Apr 2023 16:07:02 -0500 Subject: [PATCH 02/28] update REAMDME.md with citation to iclr23 paper and example link --- README.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/README.md b/README.md index 2a81e95ab..535255b40 100644 --- a/README.md +++ b/README.md @@ -43,6 +43,7 @@ POT provides the following generic OT solvers (links to examples): * [Spherical Sliced Wasserstein](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_variance_ssw.html) [46] * [Graph Dictionary Learning solvers](https://pythonot.github.io/auto_examples/gromov/plot_gromov_wasserstein_dictionary_learning.html) [38]. * [Semi-relaxed (Fused) Gromov-Wasserstein divergences](https://pythonot.github.io/auto_examples/gromov/plot_semirelaxed_fgw.html) [48]. +* [Efficient Discrete Multi Marginal Optimal Transport Regularization](https://pythonot.github.io/auto_examples/others/plot_demd_gradient_minimize.html) [50]. * [Several backends](https://pythonot.github.io/quickstart.html#solving-ot-with-multiple-backends) for easy use of POT with [Pytorch](https://pytorch.org/)/[jax](https://github.com/google/jax)/[Numpy](https://numpy.org/)/[Cupy](https://cupy.dev/)/[Tensorflow](https://www.tensorflow.org/) arrays. POT provides the following Machine Learning related solvers: @@ -308,3 +309,5 @@ Dictionary Learning](https://arxiv.org/pdf/2102.06555.pdf), International Confer [48] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, Nicolas Courty (2022). [Semi-relaxed Gromov-Wasserstein divergence and applications on graphs](https://openreview.net/pdf?id=RShaMexjc-x). International Conference on Learning Representations (ICLR), 2022. [49] Redko, I., Vayer, T., Flamary, R., and Courty, N. (2020). [CO-Optimal Transport](https://proceedings.neurips.cc/paper/2020/file/cc384c68ad503482fb24e6d1e3b512ae-Paper.pdf). Advances in Neural Information Processing Systems, 33. + +[50] Ronak Mehta, Jeffery Kline, Vishnu Suresh Lokhande, Glenn Fung, & Vikas Singh (2023). [Efficient Discrete Multi Marginal Optimal Transport Regularization](https://openreview.net/forum?id=R98ZfMt-jE). In The Eleventh International Conference on Learning Representations.git \ No newline at end of file From 4a3a4f1def05251b547ca51632c44d30181bb986 Mon Sep 17 00:00:00 2001 From: x12hengyu Date: Wed, 5 Apr 2023 16:08:56 -0500 Subject: [PATCH 03/28] chaneg directory of examples, build successful --- examples/{demd => others}/plot_demd_1d.py | 6 +- .../plot_demd_gradient_minimize.py | 74 ++++++++++--------- 2 files changed, 43 insertions(+), 37 deletions(-) rename examples/{demd => others}/plot_demd_1d.py (96%) rename examples/{demd => others}/plot_demd_gradient_minimize.py (61%) diff --git a/examples/demd/plot_demd_1d.py b/examples/others/plot_demd_1d.py similarity index 96% rename from examples/demd/plot_demd_1d.py rename to examples/others/plot_demd_1d.py index 4cd7ce562..8dc48ca61 100644 --- a/examples/demd/plot_demd_1d.py +++ b/examples/others/plot_demd_1d.py @@ -15,7 +15,6 @@ import numpy as np import matplotlib.pyplot as pl import ot -from demd import demd # %% # Define 1d Barycenter Function and Compare Function @@ -45,7 +44,7 @@ def compare_all(data, M, n, d): print('') print('D-EMD Algorithm:') ot.tic() - demd_obj = demd(np.vstack(data), n, d) + demd_obj = ot.demd(np.vstack(data), n, d) demd_time = ot.toc('') print('Obj\t: ', demd_obj) print('Time\t: ', demd_time) @@ -55,6 +54,7 @@ def compare_all(data, M, n, d): # 2 Random Dists with Increasing Bins # ----------------------------------- + def random2d(n=4): print('*'*10) d = 2 @@ -91,4 +91,4 @@ def increasing_bins(): # pl.yscale('log') pl.ylabel('Time Per Epoch (Seconds)') pl.xlabel('Number of Distributions') -pl.legend() \ No newline at end of file +pl.legend() diff --git a/examples/demd/plot_demd_gradient_minimize.py b/examples/others/plot_demd_gradient_minimize.py similarity index 61% rename from examples/demd/plot_demd_gradient_minimize.py rename to examples/others/plot_demd_gradient_minimize.py index 8f375a9d4..d46bd6a29 100644 --- a/examples/demd/plot_demd_gradient_minimize.py +++ b/examples/others/plot_demd_gradient_minimize.py @@ -4,10 +4,10 @@ DEMD vs LP Gradient Decent without Pytorch ================================================================================= - +Compare the loss between LP and DEMD. """ -# Author: Ronak Mehta +# Author: Ronak Mehta # Xizheng Yu # # License: MIT License @@ -17,41 +17,43 @@ import numpy as np import matplotlib.pyplot as pl import ot -from demd import demd, demd_minimize # %% # Define function to get random (n, d) data # ------------------------------------------- + + def getData(n, d, dist='skewedGauss'): - print(f'Data: {d} Random Dists with {n} Bins ***') - - x = np.arange(n, dtype=np.float64).reshape((n, 1)) - M = ot.utils.dist(x, metric='minkowski') - - data = [] - for i in range(d): - # m = 100*np.random.rand(1) - m = n*(0.5*np.random.rand(1))*float(np.random.randint(2)+1) - if dist == 'skewedGauss': - a = ot.datasets.make_1D_gauss(n, m=m, s=5) - elif dist == 'uniform': - a = np.random.rand(n) - a = a / sum(a) - else: - print('unknown dist') - data.append(a) - - return data, M + print(f'Data: {d} Random Dists with {n} Bins ***') + + x = np.arange(n, dtype=np.float64).reshape((n, 1)) + M = ot.utils.dist(x, metric='minkowski') + + data = [] + for i in range(d): + # m = 100*np.random.rand(1) + m = n * (0.5 * np.random.rand(1)) * float(np.random.randint(2) + 1) + if dist == 'skewedGauss': + a = ot.datasets.make_1D_gauss(n, m=m, s=5) + elif dist == 'uniform': + a = np.random.rand(n) + a = a / sum(a) + else: + print('unknown dist') + data.append(a) + + return data, M # %% # Gradient Decent # --------------- + # %% parameters and data n = 50 # nb bins d = 7 -vecsize = n*d +vecsize = n * d # data, M = getData(n, d, 'uniform') data, M = getData(n, d, 'skewedGauss') @@ -62,7 +64,7 @@ def getData(n, d, dist='skewedGauss'): old_stdout = sys.stdout sys.stdout = output_buffer = io.StringIO() -x = demd_minimize(demd, data, d, n, vecsize, niters=3000, lr=0.00001) +x = ot.demd_minimize(ot.demd, data, d, n, vecsize, niters=3000, lr=0.00001) # after minimization, any distribution can be used as a estimate of barycenter bary = x[0] @@ -76,22 +78,26 @@ def getData(n, d, dist='skewedGauss'): print(output) # %% lp barycenter + + def lp_1d_bary(data, M, n, d): - A = np.vstack(data).T + A = np.vstack(data).T + + alpha = 1.0 # /d # 0<=alpha<=1 + weights = np.array(d * [alpha]) - alpha = 1.0 #/d # 0<=alpha<=1 - weights = np.array(d*[alpha]) + bary, bary_log = ot.lp.barycenter(A, M, weights, solver='interior-point', + verbose=True, log=True) + + return bary_log['fun'], bary - bary, bary_log = ot.lp.barycenter(A, M, weights, solver='interior-point', verbose=True, log=True) - - return bary_log['fun'], bary # Redirect the standard output to a string buffer old_stdout = sys.stdout sys.stdout = output_buffer = io.StringIO() -obj, lp_bary = lp_1d_bary(data, M, n, d) +obj, lp_bary = lp_1d_bary(data, M, n, d) # Restore the standard output and get value sys.stdout = old_stdout @@ -103,16 +109,16 @@ def lp_1d_bary(data, M, n, d): print(output) -#%% +# %% # Compare the loss between DEMD and LP Barycenter # --------- # The barycenter approach does not minize the distance between # the distributions, while our DEMD does. index = [*range(0, len(demd_loss))] -pl.plot(index, demd_loss, label = "DEMD") +pl.plot(index, demd_loss, label="DEMD") pl.plot(index, lp_loss[:len(demd_loss)], label="LP") pl.yscale('log') pl.ylabel('Loss') pl.xlabel('Epochs') -pl.legend() \ No newline at end of file +pl.legend() From 94e0f443206881fb6914eeece694619eb4f4c402 Mon Sep 17 00:00:00 2001 From: x12hengyu Date: Wed, 5 Apr 2023 16:09:24 -0500 Subject: [PATCH 04/28] fix small latex bug --- ot/demd.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ot/demd.py b/ot/demd.py index c52a26144..ba90311d0 100644 --- a/ot/demd.py +++ b/ot/demd.py @@ -20,8 +20,8 @@ def greedy_primal_dual(aa, verbose=False): \ldots, p_{d} \in \mathbb{R}_{+}^{n}` with :math:`e^{\prime} p_{j}=1` for all :math:`j \in[d]`. Although the algorithm states that all histograms have the same number of bins, the algorithm can be easily - adapted to accept as inputs :math:`p_{i} \in \mathbb{R}_{+}^{n_{i}}$ - with $n_{i} \neq n_{j}`. + adapted to accept as inputs :math:`p_{i} \in \mathbb{R}_{+}^{n_{i}}` + with :math:`n_{i} \neq n_{j}`. Parameters ---------- From 295751059ba688f456f132b75eba648b39bbfa51 Mon Sep 17 00:00:00 2001 From: x12hengyu Date: Wed, 5 Apr 2023 16:10:21 -0500 Subject: [PATCH 05/28] update all.rst, examples and demd have passed pep8 and pyflake --- docs/source/all.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/source/all.rst b/docs/source/all.rst index a9d7fe2bb..6b8bf8cfc 100644 --- a/docs/source/all.rst +++ b/docs/source/all.rst @@ -19,6 +19,7 @@ API and modules coot da datasets + demd dr factored gaussian From 708b756d94ba9026cc1551bda6a3c7234f6a1dd9 Mon Sep 17 00:00:00 2001 From: x12hengyu Date: Wed, 5 Apr 2023 17:12:52 -0500 Subject: [PATCH 06/28] add more detailed comments for examples --- examples/others/plot_demd_1d.py | 23 ++++++++++++++++--- .../others/plot_demd_gradient_minimize.py | 10 +++++++- 2 files changed, 29 insertions(+), 4 deletions(-) diff --git a/examples/others/plot_demd_1d.py b/examples/others/plot_demd_1d.py index 8dc48ca61..afaa06414 100644 --- a/examples/others/plot_demd_1d.py +++ b/examples/others/plot_demd_1d.py @@ -4,7 +4,14 @@ 1D Wasserstein barycenter: LP Barycenter vs DEMD ================================================================================= -Compare the speed of 1D Wasserstein barycenter between LP and DEMD. +Compares the performance of two methods for computing the 1D Wasserstein +barycenter: +1. Linear Programming (LP) method +2. Discrete Earth Mover's Distance (DEMD) method + +The comparison is performed by generating random Gaussian distributions with +increasing numbers of bins and measuring the computation time of each method. +The results are then plotted for visualization. """ # Author: Ronak Mehta @@ -19,6 +26,10 @@ # %% # Define 1d Barycenter Function and Compare Function # -------------------------------------------------- +# This section defines the functions `lp_1d_bary` and `compare_all`. The +# `lp_1d_bary` function computes the barycenter using the LP method. The +# `compare_all` function compares the LP method and DEMD method in terms of +# computation time and objective values. def lp_1d_bary(data, M, n, d): @@ -53,6 +64,8 @@ def compare_all(data, M, n, d): # %% # 2 Random Dists with Increasing Bins # ----------------------------------- +# Generates two random Gaussian distributions with increasing bin +# sizes and compares the LP and DEMD methods def random2d(n=4): @@ -86,8 +99,12 @@ def increasing_bins(): # %% # Plot and Compare data # --------------------- -pl.plot(ns, lp_times, 'o', label="LP Barycenter") -pl.plot(ns, demd_times, 'o', label="DEMD") +# plots the computation times for the LP and DEMD methods for +# different bin sizes + + +pl.plot(ns, lp_times, 'o', linestyle="-", label="LP Barycenter") +pl.plot(ns, demd_times, 'o', linestyle="-", label="DEMD") # pl.yscale('log') pl.ylabel('Time Per Epoch (Seconds)') pl.xlabel('Number of Distributions') diff --git a/examples/others/plot_demd_gradient_minimize.py b/examples/others/plot_demd_gradient_minimize.py index d46bd6a29..c6fbfce8c 100644 --- a/examples/others/plot_demd_gradient_minimize.py +++ b/examples/others/plot_demd_gradient_minimize.py @@ -4,7 +4,9 @@ DEMD vs LP Gradient Decent without Pytorch ================================================================================= -Compare the loss between LP and DEMD. +Compare the loss between LP and DEMD. The comparison is performed using random +Gaussian or uniform distributions and calculating the loss for each method +during the optimization process. """ # Author: Ronak Mehta @@ -21,6 +23,8 @@ # %% # Define function to get random (n, d) data # ------------------------------------------- +# The following function generates random (n, d) data with either +# 'skewedGauss' or 'uniform' distributions def getData(n, d, dist='skewedGauss'): @@ -47,6 +51,8 @@ def getData(n, d, dist='skewedGauss'): # %% # Gradient Decent # --------------- +# The following section performs gradient descent optimization using +# the DEMD method # %% parameters and data @@ -78,6 +84,8 @@ def getData(n, d, dist='skewedGauss'): print(output) # %% lp barycenter +# ---------------- +# The following section computes 1D Wasserstein barycenter using the LP method def lp_1d_bary(data, M, n, d): From 7c813a7e802c418efaffd6ff6f3e61de9a84e72b Mon Sep 17 00:00:00 2001 From: x12hengyu Date: Wed, 5 Apr 2023 17:59:49 -0500 Subject: [PATCH 07/28] TODO: test module for demd, wrong demd index after build --- README.md | 4 +++- ot/demd.py | 5 +++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 535255b40..3ee378442 100644 --- a/README.md +++ b/README.md @@ -310,4 +310,6 @@ Dictionary Learning](https://arxiv.org/pdf/2102.06555.pdf), International Confer [49] Redko, I., Vayer, T., Flamary, R., and Courty, N. (2020). [CO-Optimal Transport](https://proceedings.neurips.cc/paper/2020/file/cc384c68ad503482fb24e6d1e3b512ae-Paper.pdf). Advances in Neural Information Processing Systems, 33. -[50] Ronak Mehta, Jeffery Kline, Vishnu Suresh Lokhande, Glenn Fung, & Vikas Singh (2023). [Efficient Discrete Multi Marginal Optimal Transport Regularization](https://openreview.net/forum?id=R98ZfMt-jE). In The Eleventh International Conference on Learning Representations.git \ No newline at end of file +[50] Ronak Mehta, Jeffery Kline, Vishnu Suresh Lokhande, Glenn Fung, & Vikas Singh (2023). [Efficient Discrete Multi Marginal Optimal Transport Regularization](https://openreview.net/forum?id=R98ZfMt-jE). In The Eleventh International Conference on Learning Representations.git + +[51] Jeffery Kline. [Properties of the d-dimensional earth mover’s problem](https://www.sciencedirect.com/science/article/pii/S0166218X19301441). Discrete Applied Mathematics, 265: 128–141, 2019. \ No newline at end of file diff --git a/ot/demd.py b/ot/demd.py index ba90311d0..875898325 100644 --- a/ot/demd.py +++ b/ot/demd.py @@ -11,6 +11,7 @@ import numpy as np from .backend import get_backend + def greedy_primal_dual(aa, verbose=False): r""" The greedy algorithm that solves both primal and dual generalized Earth @@ -49,7 +50,7 @@ def greedy_primal_dual(aa, verbose=False): References ---------- - .. [1] Jeffery Kline. Properties of the d-dimensional earth mover’s + .. [51] Jeffery Kline. Properties of the d-dimensional earth mover’s problem. Discrete Applied Mathematics, 265: 128–141, 2019. Examples @@ -136,7 +137,7 @@ def demd(x, d, n, return_dual_vars=False): References ---------- - .. [1] Ronak Mehta, Jeffery Kline, Vishnu Suresh Lokhande, Glenn Fung, & + .. [50] Ronak Mehta, Jeffery Kline, Vishnu Suresh Lokhande, Glenn Fung, & Vikas Singh (2023). Efficient Discrete Multi Marginal Optimal Transport Regularization. In The Eleventh International Conference on Learning Representations. From 707152c6cf3028e19b3faa58a414b4037165984f Mon Sep 17 00:00:00 2001 From: x12hengyu Date: Fri, 7 Apr 2023 22:26:22 -0500 Subject: [PATCH 08/28] add test module --- test/test_demd.py | 64 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 64 insertions(+) create mode 100644 test/test_demd.py diff --git a/test/test_demd.py b/test/test_demd.py new file mode 100644 index 000000000..2e4978990 --- /dev/null +++ b/test/test_demd.py @@ -0,0 +1,64 @@ +"""Tests for ot.demd module """ + +# Author: Ronak Mehta +# Xizheng Yu +# +# License: MIT License + +import numpy as np +import ot +import pytest + + +def create_test_data(): + np.random.seed(1234) + d = 2 + n = 4 + a1 = ot.datasets.make_1D_gauss(n, m=20, s=5) + a2 = ot.datasets.make_1D_gauss(n, m=60, s=8) + aa = np.vstack([a1, a2]) + x = np.arange(n, dtype=np.float64).reshape((n, 1)) + return aa, x, d, n + + +def test_greedy_primal_dual(): + # test greedy_primal_dual object calculation + aa, _, _, _ = create_test_data() + result = ot.greedy_primal_dual(aa) + expected_primal_obj = 0.13667759626298503 + np.testing.assert_allclose(result['primal objective'], + expected_primal_obj, + rtol=1e-7, + err_msg="Test failed: \ + Expected different primal objective value") + + +def test_demd(): + # test one demd iteration result + aa, _, d, n = create_test_data() + primal_obj = ot.demd(aa, n, d) + expected_primal_obj = 0.13667759626298503 + np.testing.assert_allclose(primal_obj, + expected_primal_obj, + rtol=1e-7, + err_msg="Test failed: \ + Expected different primal objective value") + + +def test_demd_minimize(): + # test demd_minimize result + aa, _, d, n = create_test_data() + niters = 10 + result = ot.demd_minimize(ot.demd, aa, d, n, 2, niters, 0.001, 5) + + expected_obj = np.array([[0.05553516, 0.13082618, 0.27327479, 0.54036388], + [0.04185365, 0.09570724, 0.24384705, 0.61859206]]) + + assert len(result) == d, "Test failed: Expected a list of length n" + for i in range(d): + np.testing.assert_allclose(result[i], + expected_obj[i], + atol=1e-7, + rtol=1e-7, + err_msg="Test failed: \ + Expected vectors of all zeros") From 81ab727e7fd4c4f1f2a61b7b056695ecff0e3f09 Mon Sep 17 00:00:00 2001 From: x12hengyu Date: Fri, 7 Apr 2023 22:42:03 -0500 Subject: [PATCH 09/28] add contributors --- CONTRIBUTORS.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index 6b356537c..0f1d040df 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -42,6 +42,8 @@ The contributors to this library are: * [Eduardo Fernandes Montesuma](https://eddardd.github.io/my-personal-blog/) (Free support sinkhorn barycenter) * [Theo Gnassounou](https://github.com/tgnassou) (OT between Gaussian distributions) * [Clément Bonet](https://clbonet.github.io) (Wassertstein on circle, Spherical Sliced-Wasserstein) +* [Ronak Mehta](https://ronakrm.github.io) (Efficient Discrete Multi Marginal Optimal Transport Regularization) +* [Xizheng Yu](https://github.com/x12hengyu) (Efficient Discrete Multi Marginal Optimal Transport Regularization) ## Acknowledgments From 706d6a52e4431b280deb92111640f6f16e171608 Mon Sep 17 00:00:00 2001 From: x12hengyu Date: Fri, 7 Apr 2023 22:50:57 -0500 Subject: [PATCH 10/28] pass pyflake checks, pass pep8 --- ot/demd.py | 16 ++++++++-------- test/test_demd.py | 1 - 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/ot/demd.py b/ot/demd.py index 875898325..7d6bcff39 100644 --- a/ot/demd.py +++ b/ot/demd.py @@ -3,7 +3,7 @@ DEMD solvers for optimal transport """ -# Author: Ronak Mehta +# Author: Ronak Mehta # Xizheng Yu # # License: MIT License @@ -65,7 +65,7 @@ def greedy_primal_dual(aa, verbose=False): def OBJ(i): return max(i) - min(i) - + # print(f"aa type is: {type(aa)}") nx = get_backend(aa) @@ -134,7 +134,7 @@ def demd(x, d, n, return_dual_vars=False): The i-th column of the array corresponds to the i-th point. dual_obj : float, optional the value of the dual objective function evaluated at the solution. - + References ---------- .. [50] Ronak Mehta, Jeffery Kline, Vishnu Suresh Lokhande, Glenn Fung, & @@ -143,9 +143,9 @@ def demd(x, d, n, return_dual_vars=False): Conference on Learning Representations. """ - + # function body here - nx = get_backend(x) + # nx = get_backend(x) log = greedy_primal_dual(x) if return_dual_vars: @@ -194,10 +194,10 @@ def demd_minimize(f, x, d, n, vecsize, niters=100, lr=0.1, print_rate=100): list of ndarrays, each of shape (n,) The optimal solution as a list of n vectors, each of length vecsize. """ - + # function body here nx = get_backend(x) - + def dualIter(f, x, d, n, vecsize, lr): funcval, grad, _ = f(x, d, n, return_dual_vars=True) xnew = nx.reshape(x, (d, n)) - grad * lr @@ -230,4 +230,4 @@ def listify(x): print(f'Iter {i:2.0f}:\tObj:\t{funcval:.4f}\tGradNorm:\t{gn:.4f}') x = renormalize(x, d, n, vecsize) - return listify(nx.reshape(x, (d, n))) \ No newline at end of file + return listify(nx.reshape(x, (d, n))) diff --git a/test/test_demd.py b/test/test_demd.py index 2e4978990..642b449eb 100644 --- a/test/test_demd.py +++ b/test/test_demd.py @@ -7,7 +7,6 @@ import numpy as np import ot -import pytest def create_test_data(): From 4e6f693756e60d9a090ae2760eb09fd98fd0b869 Mon Sep 17 00:00:00 2001 From: x12hengyu Date: Sat, 8 Apr 2023 13:24:14 -0500 Subject: [PATCH 11/28] added the PR to the RELEASES.md file --- RELEASES.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/RELEASES.md b/RELEASES.md index e9789054d..50eb262fb 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -1,5 +1,10 @@ # Releases +## 0.9.1dev + +#### New features +- Added feature Efficient Discrete Multi Marginal Optimal Transport Regularization + examples (PR #454) + ## 0.9.0 This new release contains so many new features and bug fixes since 0.8.2 that we From 08bb9198f1d84efefec3b27feec2f719157a5c41 Mon Sep 17 00:00:00 2001 From: x12hengyu Date: Thu, 4 May 2023 01:18:47 -0500 Subject: [PATCH 12/28] temporal changes with logs --- examples/others/log.MD | 60 +++++++++++ ot/lp/dmmot.py | 238 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 298 insertions(+) create mode 100644 examples/others/log.MD create mode 100644 ot/lp/dmmot.py diff --git a/examples/others/log.MD b/examples/others/log.MD new file mode 100644 index 000000000..ee0877995 --- /dev/null +++ b/examples/others/log.MD @@ -0,0 +1,60 @@ +Large code changes: +- Merged `greedy_primal_dual` and `demd` to `ot.lp.discrete_mmot`, named `demd_minimize` to `ot.lp.discrete_mmot_converge`. +- Changed parameter names and return values to follow POT's API standard. +- Merge two examples together, removing time and loss compare (discuss using text discription), and add barycenter plots. +- Add more detailed mathematic proves and explanations in documentation. + +Comments resolved below: +- The new solvers should be in `ot.lp.discrete_emd` or something else more descriptive + - Fixed. Move our method under `ot.lp` and named as `discrete_mmot` stands for multi marginal optimal transport +- the example should not need to transpose the data. it means that the API for the implemented function is not good (it should retrun the smae thing as ot.lp.barycenter) + - For `ot.lp.barycenter`, input distributions A is defined `A (np.ndarray (d,n))`, but stacked data is shaped (n, d), that's why we added transpose here. We have moved transpose to initial generation of distributions A to avoid confusions. This is the same approach as other examples in POT such as [1D Wasserstein barycenter: exact LP vs entropic regularization](https://pythonot.github.io/auto_examples/barycenters/plot_barycenter_lp_vs_entropic.html#sphx-glr-auto-examples-barycenters-plot-barycenter-lp-vs-entropic-py) +- where is the barycener? Objective is nice but the barycenetr shoudl be rtruned + - `discrete_mmot` denotes a single iteration in the converging calculation. Since we are not defining a learning rate here, we cannot return updated all distributions. But we have updated distributions in each iteration with `A_new = nx.reshape(A, (d, n)) - grad * lr` in converging method, and returned final barycenters/distributions at the end. +- why only plot thr time? pleade also plot the barycenter. + - Added. +- Plot the data + - Fixed. +- what is all this? sphinx-gallery can print the output of the function properly (assuming it is using print) + - Removed. We were trying to compare the time efficency between two methods. Therefore I used stdout to retreive the loss and time data from the output. +- not clear what is the comparaison, the baryenter estimation? + - Our method converges all the distributions rather than calculating a single center for all distributions, the comparasion is intending to show two different methods works on the same distributions. +- the functiuon should reurn data already well formated + - Fixed. The function is removed following previous comments. +- No need for a function if you just call it oncen alaos visualize the data + - Removed. +- this s a weird name for a function, if it computes a barycenter then just call it `discrete_emd_barycenter` + - Fixed. Named as `discrete_mmot_converge` +- also do `bary, ~ = ot.demd_minimize(ot.demd, data, d, n, vecsize, niters=3000, lr=0.00001)` to avoid betting the barycenter later + - Fixed, added log as an additional return value following the POT API. +- you must stor eand return the loss instead of parsing the text output of the function... + - Fixed. But we only can get one final loss from log instead all losses through iterations. We will add comment comparasions for the time difference. +- same comment as above + - Fixed. Added transpose when create A. +- `weights=np.ones(d)/d` or `weights=ot.unif(d)` + - Fixed. +- if you solve the same problem, the loss should be the same or similar. Why is it so different. + - They are not the same problem. **EXPLANATION NEEDED** +- you need to plot both baryceneters. Are they the same? if not why? Maybe a discussion here explaining the method vs standard LP + - We added plots for both barycenters. They are not the same. **EXPLANATION NEEDED** +- add more detzails heren alaos the references papers + - **ADD PAPERS** +- $ is not compatible with sphinx, please use proper math notations. Also what is aa? what is its relation with the p discussed above. + - Removed $. And we merged `greedy_primal_dual` to `discrete_mmot` for better structure and naming. aa refered to A in that context. +- are those the input distributions? how many of them? + - Yes, these are input distributions. Changed to A to follow POT API. +- this function needs to foloow POT API. look at https://pythonot.github.io/all.html#ot.emd or https://pythonot.github.io/all.html#ot.emd2. + - Fixed. Changed `x` to `A`, `aa` to `A`. +- hare you should express the optimization problem that the functio solves with proper maths and noations (agai see ot.emd), define what is generalized EMD, If it is a multimarginal then please explain it + - **EXPLANATION NEEDED** +- the function should return an OT plan (following ot.emd API) and a log dictionary with other information. + - Followed this schema in `discrete_mmot_converge` +- the output of the function should be of the same type as the input of the function +- should use nx.sum + - Fixed, used `nx.sum` and `nx.dot`. +- desctibe mroe precisely what you are solkving. is it a barycenter? + - **EXPLANATION NEEDED** +- bad naming, at least discrete_emd or discrete_emd2 if teh function return emd without the plan . if you give the function an empirical distribution then you shoud also put it in the name. Finally emd is computed beteen twoi distibutions so why is there only on naumpy arrya here? + - Naming changed. We are solving a different problem than emd. **EXPLANATION NEEDED** +- those shapes can be infered from x they should not be passed as parameters + - Removed unnacessary parameters that can be inferred. \ No newline at end of file diff --git a/ot/lp/dmmot.py b/ot/lp/dmmot.py new file mode 100644 index 000000000..23351f44a --- /dev/null +++ b/ot/lp/dmmot.py @@ -0,0 +1,238 @@ +# -*- coding: utf-8 -*- +""" +DEMD solvers for optimal transport +""" + +# Author: Ronak Mehta +# Xizheng Yu +# +# License: MIT License + +import numpy as np +from ..backend import get_backend + +# M -> obj + +def greedy_primal_dual(A, verbose=False): + r""" + The greedy algorithm that solves both primal and dual generalized Earth + mover’s programs. + + The algorithm accepts :math:`d` distributions (i.e., histograms) :math:`p_{1}, + \ldots, p_{d} \in \mathbb{R}_{+}^{n}` with :math:`e^{\prime} p_{j}=1` + for all :math:`j \in[d]`. Although the algorithm states that all + histograms have the same number of bins, the algorithm can be easily + adapted to accept as inputs :math:`p_{i} \in \mathbb{R}_{+}^{n_{i}}` + with :math:`n_{i} \neq n_{j}`. + + Parameters + ---------- + A : list of numpy arrays -> nd array + The input arrays are list of distributions + + + Returns + ------- + dict : dic + A dictionary containing the solution of the primal-dual problem: + - 'x': a dictionary that maps tuples of indices to the corresponding + primal variables. The tuples are the indices of the entries that are + set to their minimum value during the algorithm. + - 'primal objective': a float, the value of the objective function + evaluated at the solution. + - 'dual': a list of numpy arrays, the dual variables corresponding to + the input arrays. The i-th element of the list is the dual variable + corresponding to the i-th dimension of the input arrays. + - 'dual objective': a float, the value of the dual objective function + evaluated at the solution. + + Examples + -------- + >>> import numpy as np + >>> A = [np.array([[1, 2], [3, 4]]), np.array([[5, 6], [7, 8]])] + >>> result = greedy_primal_dual(A) + >>> result['primal objective'] + -12 + """ + + pass + + +def discrete_mmot(A, verbose=False, log=False): + r""" + Solver of our proposed method: d−Dimensional Earch Mover’s Distance (DEMD). + + multi marginal optimal transport + + Parameters + ---------- + A : numpy array, shape (d * n, ) + The input vector containing coordinates of n points in d dimensions. + d : int + The number of dimensions of the points. + n : int + The number of points. + verbose : bool, optional + If True, print debugging information during the execution of the + algorithm. Default is False. + return_dual_vars : bool, optional + If True, also return the dual variables and the dual objective value of + the DEMD problem. Default is False. + + Returns + ------- + primal_obj : float + the value of the primal objective function evaluated at the solution. + dual_vars : numpy array, shape (d, n-1), optional + the values of the dual variables corresponding to the input points. + The i-th column of the array corresponds to the i-th point. + dual_obj : float, optional + the value of the dual objective function evaluated at the solution. + + References + ---------- + .. [50] Ronak Mehta, Jeffery Kline, Vishnu Suresh Lokhande, Glenn Fung, & + Vikas Singh (2023). Efficient Discrete Multi Marginal Optimal + Transport Regularization. In The Eleventh International + Conference on Learning Representations. + .. [51] Jeffery Kline. Properties of the d-dimensional earth mover’s + problem. Discrete Applied Mathematics, 265: 128–141, 2019. + """ + + def OBJ(i): + return max(i) - min(i) + + # print(f"A type is: {type(A)}") + nx = get_backend(A) + + AA = [nx.copy(_) for _ in A] + + dims = tuple([len(_) for _ in AA]) + xx = {} + dual = [nx.zeros(d) for d in dims] + + idx = [0, ] * len(AA) + obj = 0 + if verbose: + print('i minval oldidx\t\tobj\t\tvals') + while all([i < _ for _, i in zip(dims, idx)]): + vals = [v[i] for v, i in zip(AA, idx)] + minval = min(vals) + i = vals.index(minval) + xx[tuple(idx)] = minval + obj += (OBJ(idx)) * minval + for v, j in zip(AA, idx): + v[j] -= minval + oldidx = nx.copy(idx) + idx[i] += 1 + if idx[i] < dims[i]: + dual[i][idx[i]] += OBJ(idx) - OBJ(oldidx) + dual[i][idx[i]-1] + if verbose: + print(i, minval, oldidx, obj, '\t', vals) + + # the above terminates when any entry in idx equals the corresponding + # value in dims this leaves other dimensions incomplete; the remaining + # terms of the dual solution must be filled-in + for _, i in enumerate(idx): + try: + dual[_][i:] = dual[_][i] + except Exception: + pass + + dualobj = nx.sum([nx.dot(arr, dual_arr) for arr, dual_arr in zip(A, dual)]) + + log_dict = {'A': xx, + 'primal objective': obj, + 'dual': dual, + 'dual objective': dualobj} + + if log: + return obj, log_dict + else: + return obj + + # if return_dual_vars: + # dual = log['dual'] + # return_dual = np.array(dual) + # dualobj = log['dual objective'] + # return log['primal objective'], return_dual, log['dual objective'] + # else: + # return log['primal objective'], log + + +def discrete_mmot_converge(A, niters=100, lr=0.1, print_rate=100, log=False): + r""" + Minimize a DEMD function using gradient descent. + + Parameters + ---------- + f : callable + The objective function to minimize. This function must take as input + a matrix x of shape (d, n) and return a scalar value representing + the objective function evaluated at x. It may also return a matrix of + shape (d, n) representing the gradient of the objective function + with respect to x, and/or any other dual variables needed for the + optimization algorithm. The signature of this function should be: + `f(x, d, n, return_dual_vars=False) -> float` + or + `f(x, d, n, return_dual_vars=True) -> (float, ndarray, ...)` + A : ndarray, shape (d, n) + The initial point for the optimization algorithm. + niters : int, optional (default=100) + The maximum number of iterations for the optimization algorithm. + lr : float, optional (default=0.1) + The learning rate (step size) for the optimization algorithm. + print_rate : int, optional (default=100) + The rate at which to print the objective value and gradient norm + during the optimization algorithm. + + Returns + ------- + list of ndarrays, each of shape (n,) + The optimal solution as a list of n vectors, each of length vecsize. + """ + + # function body here + nx = get_backend(A) + d, n = A.shape + + def dualIter(A, lr): + funcval, log_dict = discrete_mmot(A, log=True) + grad = np.array(log_dict['dual']) + A_new = nx.reshape(A, (d, n)) - grad * lr + # A_new = A - grad * lr + return funcval, A_new, grad, log_dict + + def renormalize(A): + A = nx.reshape(A, (d, n)) + for i in range(A.shape[0]): + if min(A[i, :]) < 0: + A[i, :] -= min(A[i, :]) + A[i, :] /= nx.sum(A[i, :]) + return A + + def listify(A): + return [A[i, :] for i in range(A.shape[0])] + + funcval, _, grad, log_dict = dualIter(A, lr) + gn = nx.norm(grad) + + print(f'Inital:\t\tObj:\t{funcval:.4f}\tGradNorm:\t{gn:.4f}') + + for i in range(niters): + + A = renormalize(A) + funcval, A, grad, log_dict = dualIter(A, lr) + gn = nx.norm(grad) + + if i % print_rate == 0: + print(f'Iter {i:2.0f}:\tObj:\t{funcval:.4f}\tGradNorm:\t{gn:.4f}') + + A = renormalize(A) + a = listify(nx.reshape(A, (d, n))) + + if log: + return a, log_dict + else: + return a + From 29d16f4507aaad39e17d257c559fa62df5266747 Mon Sep 17 00:00:00 2001 From: x12hengyu Date: Sun, 7 May 2023 13:07:09 -0500 Subject: [PATCH 13/28] init changes --- docs/source/all.rst | 1 - ot/__init__.py | 4 +- ot/demd.py | 233 -------------------------------------------- ot/lp/__init__.py | 4 +- 4 files changed, 4 insertions(+), 238 deletions(-) delete mode 100644 ot/demd.py diff --git a/docs/source/all.rst b/docs/source/all.rst index 6b8bf8cfc..a9d7fe2bb 100644 --- a/docs/source/all.rst +++ b/docs/source/all.rst @@ -19,7 +19,6 @@ API and modules coot da datasets - demd dr factored gaussian diff --git a/ot/__init__.py b/ot/__init__.py index 5fcbeab1c..eb00551ad 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -51,7 +51,6 @@ from .weak import weak_optimal_transport from .factored import factored_optimal_transport from .solvers import solve -from .demd import (greedy_primal_dual, demd, demd_minimize) # utils functions from .utils import dist, unif, tic, toc, toq @@ -69,5 +68,4 @@ 'factored_optimal_transport', 'solve', 'smooth', 'stochastic', 'unbalanced', 'partial', 'regpath', 'solvers', 'binary_search_circle', 'wasserstein_circle', - 'semidiscrete_wasserstein2_unif_circle', 'sliced_wasserstein_sphere_unif', - 'greedy_primal_dual', 'demd', 'demd_minimize'] + 'semidiscrete_wasserstein2_unif_circle', 'sliced_wasserstein_sphere_unif'] diff --git a/ot/demd.py b/ot/demd.py deleted file mode 100644 index 7d6bcff39..000000000 --- a/ot/demd.py +++ /dev/null @@ -1,233 +0,0 @@ -# -*- coding: utf-8 -*- -""" -DEMD solvers for optimal transport -""" - -# Author: Ronak Mehta -# Xizheng Yu -# -# License: MIT License - -import numpy as np -from .backend import get_backend - - -def greedy_primal_dual(aa, verbose=False): - r""" - The greedy algorithm that solves both primal and dual generalized Earth - mover’s programs. - - The algorithm accepts $d$ distributions (i.e., histograms) :math:`p_{1}, - \ldots, p_{d} \in \mathbb{R}_{+}^{n}` with :math:`e^{\prime} p_{j}=1` - for all :math:`j \in[d]`. Although the algorithm states that all - histograms have the same number of bins, the algorithm can be easily - adapted to accept as inputs :math:`p_{i} \in \mathbb{R}_{+}^{n_{i}}` - with :math:`n_{i} \neq n_{j}`. - - Parameters - ---------- - aa : list of numpy arrays - The input arrays defining the optimization problem. They must have the - same shape. - verbose : bool, optional - If True, print debugging information during the execution of the - algorithm. Default is False. - - Returns - ------- - dict : dic - A dictionary containing the solution of the primal-dual problem: - - 'x': a dictionary that maps tuples of indices to the corresponding - primal variables. The tuples are the indices of the entries that are - set to their minimum value during the algorithm. - - 'primal objective': a float, the value of the objective function - evaluated at the solution. - - 'dual': a list of numpy arrays, the dual variables corresponding to - the input arrays. The i-th element of the list is the dual variable - corresponding to the i-th dimension of the input arrays. - - 'dual objective': a float, the value of the dual objective function - evaluated at the solution. - - References - ---------- - .. [51] Jeffery Kline. Properties of the d-dimensional earth mover’s - problem. Discrete Applied Mathematics, 265: 128–141, 2019. - - Examples - -------- - >>> import numpy as np - >>> aa = [np.array([[1, 2], [3, 4]]), np.array([[5, 6], [7, 8]])] - >>> result = greedy_primal_dual(aa) - >>> result['primal objective'] - -12 - """ - # function body here - - def OBJ(i): - return max(i) - min(i) - - # print(f"aa type is: {type(aa)}") - nx = get_backend(aa) - - AA = [nx.copy(_) for _ in aa] - - dims = tuple([len(_) for _ in AA]) - xx = {} - dual = [nx.zeros(d) for d in dims] - - idx = [0, ] * len(AA) - obj = 0 - if verbose: - print('i minval oldidx\t\tobj\t\tvals') - while all([i < _ for _, i in zip(dims, idx)]): - vals = [v[i] for v, i in zip(AA, idx)] - minval = min(vals) - i = vals.index(minval) - xx[tuple(idx)] = minval - obj += (OBJ(idx)) * minval - for v, j in zip(AA, idx): - v[j] -= minval - oldidx = nx.copy(idx) - idx[i] += 1 - if idx[i] < dims[i]: - dual[i][idx[i]] += OBJ(idx) - OBJ(oldidx) + dual[i][idx[i]-1] - if verbose: - print(i, minval, oldidx, obj, '\t', vals) - - # the above terminates when any entry in idx equals the corresponding - # value in dims this leaves other dimensions incomplete; the remaining - # terms of the dual solution must be filled-in - for _, i in enumerate(idx): - try: - dual[_][i:] = dual[_][i] - except Exception: - pass - - dualobj = sum([_.dot(_d) for _, _d in zip(aa, dual)]) - - return {'x': xx, 'primal objective': obj, - 'dual': dual, 'dual objective': dualobj} - - -def demd(x, d, n, return_dual_vars=False): - r""" - Solver of our proposed method: d−Dimensional Earch Mover’s Distance (DEMD). - - Parameters - ---------- - x : numpy array, shape (d * n, ) - The input vector containing coordinates of n points in d dimensions. - d : int - The number of dimensions of the points. - n : int - The number of points. - return_dual_vars : bool, optional - If True, also return the dual variables and the dual objective value of - the DEMD problem. Default is False. - - Returns - ------- - primal_obj : float - the value of the primal objective function evaluated at the solution. - dual_vars : numpy array, shape (d, n-1), optional - the values of the dual variables corresponding to the input points. - The i-th column of the array corresponds to the i-th point. - dual_obj : float, optional - the value of the dual objective function evaluated at the solution. - - References - ---------- - .. [50] Ronak Mehta, Jeffery Kline, Vishnu Suresh Lokhande, Glenn Fung, & - Vikas Singh (2023). Efficient Discrete Multi Marginal Optimal - Transport Regularization. In The Eleventh International - Conference on Learning Representations. - - """ - - # function body here - # nx = get_backend(x) - log = greedy_primal_dual(x) - - if return_dual_vars: - dual = log['dual'] - return_dual = np.array(dual) - dualobj = log['dual objective'] - return log['primal objective'], return_dual, dualobj - else: - return log['primal objective'] - - -def demd_minimize(f, x, d, n, vecsize, niters=100, lr=0.1, print_rate=100): - r""" - Minimize a DEMD function using gradient descent. - - Parameters - ---------- - f : callable - The objective function to minimize. This function must take as input - a matrix x of shape (d, n) and return a scalar value representing - the objective function evaluated at x. It may also return a matrix of - shape (d, n) representing the gradient of the objective function - with respect to x, and/or any other dual variables needed for the - optimization algorithm. The signature of this function should be: - `f(x, d, n, return_dual_vars=False) -> float` - or - `f(x, d, n, return_dual_vars=True) -> (float, ndarray, ...)` - x : ndarray, shape (d, n) - The initial point for the optimization algorithm. - d : int - The number of rows in the matrix x. - n : int - The number of columns in the matrix x. - vecsize : int - The size of the vectors that make up the columns of x. - niters : int, optional (default=100) - The maximum number of iterations for the optimization algorithm. - lr : float, optional (default=0.1) - The learning rate (step size) for the optimization algorithm. - print_rate : int, optional (default=100) - The rate at which to print the objective value and gradient norm - during the optimization algorithm. - - Returns - ------- - list of ndarrays, each of shape (n,) - The optimal solution as a list of n vectors, each of length vecsize. - """ - - # function body here - nx = get_backend(x) - - def dualIter(f, x, d, n, vecsize, lr): - funcval, grad, _ = f(x, d, n, return_dual_vars=True) - xnew = nx.reshape(x, (d, n)) - grad * lr - return funcval, xnew, grad - - def renormalize(x, d, n, vecsize): - x = nx.reshape(x, (d, n)) - for i in range(x.shape[0]): - if min(x[i, :]) < 0: - x[i, :] -= min(x[i, :]) - x[i, :] /= nx.sum(x[i, :]) - return x - - def listify(x): - return [x[i, :] for i in range(x.shape[0])] - - # print(f"x type is {type(x)}") - funcval, _, grad = dualIter(f, x, d, n, vecsize, lr) - gn = nx.norm(grad) - - print(f'Inital:\t\tObj:\t{funcval:.4f}\tGradNorm:\t{gn:.4f}') - - for i in range(niters): - - x = renormalize(x, d, n, vecsize) - funcval, x, grad = dualIter(f, x, d, n, vecsize, lr) - gn = nx.norm(grad) - - if i % print_rate == 0: - print(f'Iter {i:2.0f}:\tObj:\t{funcval:.4f}\tGradNorm:\t{gn:.4f}') - - x = renormalize(x, d, n, vecsize) - return listify(nx.reshape(x, (d, n))) diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index 2ff02ab72..7540c4ca2 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -17,6 +17,7 @@ from . import cvx from .cvx import barycenter +from .dmmot import * # import compiled emd from .emd_wrap import emd_c, check_result, emd_1d_sorted @@ -30,7 +31,8 @@ __all__ = ['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx', ' emd_1d_sorted', 'emd_1d', 'emd2_1d', 'wasserstein_1d', 'generalized_free_support_barycenter', - 'binary_search_circle', 'wasserstein_circle', 'semidiscrete_wasserstein2_unif_circle'] + 'binary_search_circle', 'wasserstein_circle', 'semidiscrete_wasserstein2_unif_circle', + 'discrete_mmot', 'discrete_mmot_converge'] def check_number_threads(numThreads): From 4226eee71b7de21e1fe4d99d86ae226f8fd5e87c Mon Sep 17 00:00:00 2001 From: x12hengyu Date: Sun, 7 May 2023 18:08:40 -0500 Subject: [PATCH 14/28] merge examples, demd -> lp.dmmot --- examples/others/log.MD | 60 ----- examples/others/plot_d-mmot.py | 170 +++++++++++++++ examples/others/plot_demd_1d.py | 111 ---------- .../others/plot_demd_gradient_minimize.py | 132 ----------- ot/lp/dmmot.py | 205 ++++++++++-------- 5 files changed, 279 insertions(+), 399 deletions(-) delete mode 100644 examples/others/log.MD create mode 100644 examples/others/plot_d-mmot.py delete mode 100644 examples/others/plot_demd_1d.py delete mode 100644 examples/others/plot_demd_gradient_minimize.py diff --git a/examples/others/log.MD b/examples/others/log.MD deleted file mode 100644 index ee0877995..000000000 --- a/examples/others/log.MD +++ /dev/null @@ -1,60 +0,0 @@ -Large code changes: -- Merged `greedy_primal_dual` and `demd` to `ot.lp.discrete_mmot`, named `demd_minimize` to `ot.lp.discrete_mmot_converge`. -- Changed parameter names and return values to follow POT's API standard. -- Merge two examples together, removing time and loss compare (discuss using text discription), and add barycenter plots. -- Add more detailed mathematic proves and explanations in documentation. - -Comments resolved below: -- The new solvers should be in `ot.lp.discrete_emd` or something else more descriptive - - Fixed. Move our method under `ot.lp` and named as `discrete_mmot` stands for multi marginal optimal transport -- the example should not need to transpose the data. it means that the API for the implemented function is not good (it should retrun the smae thing as ot.lp.barycenter) - - For `ot.lp.barycenter`, input distributions A is defined `A (np.ndarray (d,n))`, but stacked data is shaped (n, d), that's why we added transpose here. We have moved transpose to initial generation of distributions A to avoid confusions. This is the same approach as other examples in POT such as [1D Wasserstein barycenter: exact LP vs entropic regularization](https://pythonot.github.io/auto_examples/barycenters/plot_barycenter_lp_vs_entropic.html#sphx-glr-auto-examples-barycenters-plot-barycenter-lp-vs-entropic-py) -- where is the barycener? Objective is nice but the barycenetr shoudl be rtruned - - `discrete_mmot` denotes a single iteration in the converging calculation. Since we are not defining a learning rate here, we cannot return updated all distributions. But we have updated distributions in each iteration with `A_new = nx.reshape(A, (d, n)) - grad * lr` in converging method, and returned final barycenters/distributions at the end. -- why only plot thr time? pleade also plot the barycenter. - - Added. -- Plot the data - - Fixed. -- what is all this? sphinx-gallery can print the output of the function properly (assuming it is using print) - - Removed. We were trying to compare the time efficency between two methods. Therefore I used stdout to retreive the loss and time data from the output. -- not clear what is the comparaison, the baryenter estimation? - - Our method converges all the distributions rather than calculating a single center for all distributions, the comparasion is intending to show two different methods works on the same distributions. -- the functiuon should reurn data already well formated - - Fixed. The function is removed following previous comments. -- No need for a function if you just call it oncen alaos visualize the data - - Removed. -- this s a weird name for a function, if it computes a barycenter then just call it `discrete_emd_barycenter` - - Fixed. Named as `discrete_mmot_converge` -- also do `bary, ~ = ot.demd_minimize(ot.demd, data, d, n, vecsize, niters=3000, lr=0.00001)` to avoid betting the barycenter later - - Fixed, added log as an additional return value following the POT API. -- you must stor eand return the loss instead of parsing the text output of the function... - - Fixed. But we only can get one final loss from log instead all losses through iterations. We will add comment comparasions for the time difference. -- same comment as above - - Fixed. Added transpose when create A. -- `weights=np.ones(d)/d` or `weights=ot.unif(d)` - - Fixed. -- if you solve the same problem, the loss should be the same or similar. Why is it so different. - - They are not the same problem. **EXPLANATION NEEDED** -- you need to plot both baryceneters. Are they the same? if not why? Maybe a discussion here explaining the method vs standard LP - - We added plots for both barycenters. They are not the same. **EXPLANATION NEEDED** -- add more detzails heren alaos the references papers - - **ADD PAPERS** -- $ is not compatible with sphinx, please use proper math notations. Also what is aa? what is its relation with the p discussed above. - - Removed $. And we merged `greedy_primal_dual` to `discrete_mmot` for better structure and naming. aa refered to A in that context. -- are those the input distributions? how many of them? - - Yes, these are input distributions. Changed to A to follow POT API. -- this function needs to foloow POT API. look at https://pythonot.github.io/all.html#ot.emd or https://pythonot.github.io/all.html#ot.emd2. - - Fixed. Changed `x` to `A`, `aa` to `A`. -- hare you should express the optimization problem that the functio solves with proper maths and noations (agai see ot.emd), define what is generalized EMD, If it is a multimarginal then please explain it - - **EXPLANATION NEEDED** -- the function should return an OT plan (following ot.emd API) and a log dictionary with other information. - - Followed this schema in `discrete_mmot_converge` -- the output of the function should be of the same type as the input of the function -- should use nx.sum - - Fixed, used `nx.sum` and `nx.dot`. -- desctibe mroe precisely what you are solkving. is it a barycenter? - - **EXPLANATION NEEDED** -- bad naming, at least discrete_emd or discrete_emd2 if teh function return emd without the plan . if you give the function an empirical distribution then you shoud also put it in the name. Finally emd is computed beteen twoi distibutions so why is there only on naumpy arrya here? - - Naming changed. We are solving a different problem than emd. **EXPLANATION NEEDED** -- those shapes can be infered from x they should not be passed as parameters - - Removed unnacessary parameters that can be inferred. \ No newline at end of file diff --git a/examples/others/plot_d-mmot.py b/examples/others/plot_d-mmot.py new file mode 100644 index 000000000..2eb2b784a --- /dev/null +++ b/examples/others/plot_d-mmot.py @@ -0,0 +1,170 @@ +# -*- coding: utf-8 -*- +r""" +================================================================================= +d-MMOT vs LP Gradient Decent without Pytorch +================================================================================= + +Compare the loss convergence between LP and DEMD. The comparison is performed using random +Gaussian or uniform distributions and calculating the loss for each method +during the optimization process. +""" + +# Author: Ronak Mehta +# Xizheng Yu +# +# License: MIT License + +# %% +# 2 distributions +# ----- +import numpy as np +import matplotlib.pyplot as pl +import ot + +n = 100 +d = 2 +# Gaussian distributions +a1 = ot.datasets.make_1D_gauss(n, m=20, s=5) # m=mean, s=std +a2 = ot.datasets.make_1D_gauss(n, m=60, s=8) +A = np.vstack((a1, a2)).T +x = np.arange(n, dtype=np.float64) +# M = ot.utils.dist0(n) +# M /= M.max() +M = ot.utils.dist(x.reshape((n, 1)), metric='minkowski') + +pl.figure(1, figsize=(6.4, 3)) +pl.plot(x, a1, 'b', label='Source distribution') +pl.plot(x, a2, 'r', label='Target distribution') +pl.legend() + +# %% +# Run test +# ----- + +print('LP Iterations:') +ot.tic() +alpha = 1 # /d # 0<=alpha<=1 +weights = np.array(d * [alpha]) +lp_bary, lp_log = ot.lp.barycenter( + A, M, weights, solver='interior-point', verbose=False, log=True) +print('Time\t: ', ot.toc('')) +print('Obj\t: ', lp_log['fun']) + +print('') +print('Discrete MMOT Algorithm:') +ot.tic() +# dmmot_obj, log = ot.lp.discrete_mmot(A.T, n, d) +barys, log = ot.lp.discrete_mmot_converge(A.T, niters=3000, lr=0.000002, log=True) +dmmot_obj = log['primal objective'] +print('Time\t: ', ot.toc('')) +print('Obj\t: ', dmmot_obj) + +# %% +# Compare Barycenters in both methods +# --------- +pl.figure(1, figsize=(6.4, 3)) +for i in range(len(barys)): + if i == 0: + pl.plot(x, barys[i], 'g', label='Discrete MMOT') + else: + pl.plot(x, barys[i], 'g') +pl.plot(x, a1, 'b', label='Source distribution') +pl.plot(x, a2, 'r', label='Target distribution') +pl.title('Barycenters') +pl.legend() + +# %% +# Compare d-MMOOT with original distributions +# --------- +pl.figure(1, figsize=(6.4, 3)) +for i in range(len(barys)): + if i == 0: + pl.plot(x, barys[i], 'g', label='Discrete MMOT') + else: + pl.plot(x, barys[i], 'g') +# pl.plot(x, bary, 'g', label='Discrete MMOT') +pl.plot(x, lp_bary, 'b', label='LP Wasserstein') +pl.title('Barycenters') +pl.legend() + +# %% +# Define parameters, generate and plot distributions +# -------------------------------------------------- +# The following code generates random (n, d) data with in gauss +n = 50 # nb bins +d = 7 +vecsize = n * d + +data = [] +for i in range(d): + m = n * (0.5 * np.random.rand(1)) * float(np.random.randint(2) + 1) + a = ot.datasets.make_1D_gauss(n, m=m, s=5) + data.append(a) + +x = np.arange(n, dtype=np.float64) +M = ot.utils.dist(x.reshape((n, 1)), metric='minkowski') +A = np.vstack(data).T + +print(A.shape) + +pl.figure(1, figsize=(6.4, 3)) +for i in range(len(data)): + pl.plot(x, data[i]) + +pl.title('Distributions') +pl.legend() + +# %% +# Gradient Decent +# --------------- +# The following section performs gradient descent optimization using +# the DEMD method + +barys = ot.lp.discrete_mmot_converge(A.T, niters=9000, lr=0.00001) + +# after minimization, any distribution can be used as a estimate of barycenter +# bary = barys[0] + + +# %% lp barycenter +# ---------------- +# The following section computes 1D Wasserstein barycenter using the LP method +weights = ot.unif(d) +lp_bary, bary_log = ot.lp.barycenter(A, M, weights, solver='interior-point', p + verbose=True, log=True) + +# %% +# Compare Barycenters in both methods +# --------- +pl.figure(1, figsize=(6.4, 3)) +for i in range(len(barys)): + if i == 0: + pl.plot(x, barys[i], 'g', label='Discrete MMOT') + else: + pl.plot(x, barys[i], 'g') +# pl.plot(x, bary, 'g', label='Discrete MMOT') +pl.plot(x, lp_bary, 'b', label='LP Wasserstein') +pl.title('Barycenters') +pl.legend() + +# %% +# Compare d-MMOOT with original distributions +# --------- +pl.figure(1, figsize=(6.4, 3)) +for i in range(len(barys)): + if i == 0: + pl.plot(x, barys[i], 'g', label='Discrete MMOT') + else: + pl.plot(x, barys[i], 'g') +# pl.plot(x, bary, 'g', label='Discrete MMOT') +for i in range(len(data)): + pl.plot(x, data[i]) +pl.title('Barycenters') +pl.legend() + + +# %% +# Compare the loss between DEMD and LP Barycenter +# --------- +# The barycenter approach does not minize the distance between +# the distributions, while our DEMD does. diff --git a/examples/others/plot_demd_1d.py b/examples/others/plot_demd_1d.py deleted file mode 100644 index afaa06414..000000000 --- a/examples/others/plot_demd_1d.py +++ /dev/null @@ -1,111 +0,0 @@ -# -*- coding: utf-8 -*- -r""" -================================================================================= -1D Wasserstein barycenter: LP Barycenter vs DEMD -================================================================================= - -Compares the performance of two methods for computing the 1D Wasserstein -barycenter: -1. Linear Programming (LP) method -2. Discrete Earth Mover's Distance (DEMD) method - -The comparison is performed by generating random Gaussian distributions with -increasing numbers of bins and measuring the computation time of each method. -The results are then plotted for visualization. -""" - -# Author: Ronak Mehta -# Xizheng Yu -# -# License: MIT License - -import numpy as np -import matplotlib.pyplot as pl -import ot - -# %% -# Define 1d Barycenter Function and Compare Function -# -------------------------------------------------- -# This section defines the functions `lp_1d_bary` and `compare_all`. The -# `lp_1d_bary` function computes the barycenter using the LP method. The -# `compare_all` function compares the LP method and DEMD method in terms of -# computation time and objective values. - - -def lp_1d_bary(data, M, n, d): - A = np.vstack(data).T - - alpha = 1.0 # /d # 0<=alpha<=1 - weights = np.array(d*[alpha]) - - bary, bary_log = ot.lp.barycenter( - A, M, weights, solver='interior-point', verbose=False, log=True) - - return bary_log['fun'], bary - - -def compare_all(data, M, n, d): - print('IP LP Iterations:') - ot.tic() - lp_bary, lp_obj = lp_1d_bary(np.vstack(data), M, n, d) - lp_time = ot.toc('') - print('Obj\t: ', lp_bary) - print('Time\t: ', lp_time) - - print('') - print('D-EMD Algorithm:') - ot.tic() - demd_obj = ot.demd(np.vstack(data), n, d) - demd_time = ot.toc('') - print('Obj\t: ', demd_obj) - print('Time\t: ', demd_time) - return lp_time, demd_time - -# %% -# 2 Random Dists with Increasing Bins -# ----------------------------------- -# Generates two random Gaussian distributions with increasing bin -# sizes and compares the LP and DEMD methods - - -def random2d(n=4): - print('*'*10) - d = 2 - # Gaussian distributions - a1 = ot.datasets.make_1D_gauss(n, m=20, s=5) # m= mean, s= std - a2 = ot.datasets.make_1D_gauss(n, m=60, s=8) - print(a1) - print(a2) - x = np.arange(n, dtype=np.float64).reshape((n, 1)) - M = ot.utils.dist(x, metric='minkowski') - lp_time, demd_time = compare_all([a1, a2], M, n, d) - print('*'*10, '\n') - return lp_time, demd_time - - -def increasing_bins(): - lp_times, demd_times = [], [] - ns = [5, 10, 20, 50, 100] - for n in ns: - lp_time, demd_time = random2d(n=n) - lp_times.append(lp_time) - demd_times.append(demd_time) - return ns, lp_times, demd_times - - -ns, lp_times, demd_times = increasing_bins() - - -# %% -# Plot and Compare data -# --------------------- -# plots the computation times for the LP and DEMD methods for -# different bin sizes - - -pl.plot(ns, lp_times, 'o', linestyle="-", label="LP Barycenter") -pl.plot(ns, demd_times, 'o', linestyle="-", label="DEMD") -# pl.yscale('log') -pl.ylabel('Time Per Epoch (Seconds)') -pl.xlabel('Number of Distributions') -pl.legend() diff --git a/examples/others/plot_demd_gradient_minimize.py b/examples/others/plot_demd_gradient_minimize.py deleted file mode 100644 index c6fbfce8c..000000000 --- a/examples/others/plot_demd_gradient_minimize.py +++ /dev/null @@ -1,132 +0,0 @@ -# -*- coding: utf-8 -*- -r""" -================================================================================= -DEMD vs LP Gradient Decent without Pytorch -================================================================================= - -Compare the loss between LP and DEMD. The comparison is performed using random -Gaussian or uniform distributions and calculating the loss for each method -during the optimization process. -""" - -# Author: Ronak Mehta -# Xizheng Yu -# -# License: MIT License - -import io -import sys -import numpy as np -import matplotlib.pyplot as pl -import ot - -# %% -# Define function to get random (n, d) data -# ------------------------------------------- -# The following function generates random (n, d) data with either -# 'skewedGauss' or 'uniform' distributions - - -def getData(n, d, dist='skewedGauss'): - print(f'Data: {d} Random Dists with {n} Bins ***') - - x = np.arange(n, dtype=np.float64).reshape((n, 1)) - M = ot.utils.dist(x, metric='minkowski') - - data = [] - for i in range(d): - # m = 100*np.random.rand(1) - m = n * (0.5 * np.random.rand(1)) * float(np.random.randint(2) + 1) - if dist == 'skewedGauss': - a = ot.datasets.make_1D_gauss(n, m=m, s=5) - elif dist == 'uniform': - a = np.random.rand(n) - a = a / sum(a) - else: - print('unknown dist') - data.append(a) - - return data, M - -# %% -# Gradient Decent -# --------------- -# The following section performs gradient descent optimization using -# the DEMD method - - -# %% parameters and data -n = 50 # nb bins -d = 7 - -vecsize = n * d - -# data, M = getData(n, d, 'uniform') -data, M = getData(n, d, 'skewedGauss') -data = np.vstack(data) - -# %% demd -# Redirect the standard output to a string buffer -old_stdout = sys.stdout -sys.stdout = output_buffer = io.StringIO() - -x = ot.demd_minimize(ot.demd, data, d, n, vecsize, niters=3000, lr=0.00001) - -# after minimization, any distribution can be used as a estimate of barycenter -bary = x[0] - -sys.stdout = old_stdout -output = output_buffer.getvalue() - -rows = output.strip().split("\n") -demd_loss = [float(row.split()[-3]) for row in rows[1:]] - -print(output) - -# %% lp barycenter -# ---------------- -# The following section computes 1D Wasserstein barycenter using the LP method - - -def lp_1d_bary(data, M, n, d): - - A = np.vstack(data).T - - alpha = 1.0 # /d # 0<=alpha<=1 - weights = np.array(d * [alpha]) - - bary, bary_log = ot.lp.barycenter(A, M, weights, solver='interior-point', - verbose=True, log=True) - - return bary_log['fun'], bary - - -# Redirect the standard output to a string buffer -old_stdout = sys.stdout -sys.stdout = output_buffer = io.StringIO() - -obj, lp_bary = lp_1d_bary(data, M, n, d) - -# Restore the standard output and get value -sys.stdout = old_stdout -output = output_buffer.getvalue() - -rows = output.strip().split("\n") -lp_loss = [float(row.split()[-1]) for row in rows[1:-3]] - -print(output) - - -# %% -# Compare the loss between DEMD and LP Barycenter -# --------- -# The barycenter approach does not minize the distance between -# the distributions, while our DEMD does. -index = [*range(0, len(demd_loss))] - -pl.plot(index, demd_loss, label="DEMD") -pl.plot(index, lp_loss[:len(demd_loss)], label="LP") -pl.yscale('log') -pl.ylabel('Loss') -pl.xlabel('Epochs') -pl.legend() diff --git a/ot/lp/dmmot.py b/ot/lp/dmmot.py index 23351f44a..03a62457f 100644 --- a/ot/lp/dmmot.py +++ b/ot/lp/dmmot.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- """ -DEMD solvers for optimal transport +d-MMOT solvers for optimal transport """ # Author: Ronak Mehta @@ -13,81 +13,50 @@ # M -> obj -def greedy_primal_dual(A, verbose=False): +def discrete_mmot(A, verbose=False, log=False): r""" - The greedy algorithm that solves both primal and dual generalized Earth - mover’s programs. - - The algorithm accepts :math:`d` distributions (i.e., histograms) :math:`p_{1}, - \ldots, p_{d} \in \mathbb{R}_{+}^{n}` with :math:`e^{\prime} p_{j}=1` - for all :math:`j \in[d]`. Although the algorithm states that all - histograms have the same number of bins, the algorithm can be easily - adapted to accept as inputs :math:`p_{i} \in \mathbb{R}_{+}^{n_{i}}` - with :math:`n_{i} \neq n_{j}`. - - Parameters - ---------- - A : list of numpy arrays -> nd array - The input arrays are list of distributions + Compute the discrete multi marginal optimal transport of distributions A + The algorithm solves both primal and dual d-MMOT programs + + The algorithm accepts :math:`d` distributions (i.e., histograms) + :math:`p_{1}, \ldots, p_{d} \in \mathbb{R}_{+}^{n}` with :math:`e^{\prime} + p_{j}=1` for all :math:`j \in[d]`. Although the algorithm states that all + histograms have the same number of bins, the algorithm can be easily + adapted to accept as inputs :math:`p_{i} \in \mathbb{R}_{+}^{n_{i}}` + with :math:`n_{i} \neq n_{j}` [50]. - Returns - ------- - dict : dic - A dictionary containing the solution of the primal-dual problem: - - 'x': a dictionary that maps tuples of indices to the corresponding - primal variables. The tuples are the indices of the entries that are - set to their minimum value during the algorithm. - - 'primal objective': a float, the value of the objective function - evaluated at the solution. - - 'dual': a list of numpy arrays, the dual variables corresponding to - the input arrays. The i-th element of the list is the dual variable - corresponding to the i-th dimension of the input arrays. - - 'dual objective': a float, the value of the dual objective function - evaluated at the solution. - - Examples - -------- - >>> import numpy as np - >>> A = [np.array([[1, 2], [3, 4]]), np.array([[5, 6], [7, 8]])] - >>> result = greedy_primal_dual(A) - >>> result['primal objective'] - -12 - """ - - pass - - -def discrete_mmot(A, verbose=False, log=False): - r""" - Solver of our proposed method: d−Dimensional Earch Mover’s Distance (DEMD). - multi marginal optimal transport + The function solves the following optimization problem[51]: + + Parameters ---------- - A : numpy array, shape (d * n, ) - The input vector containing coordinates of n points in d dimensions. - d : int - The number of dimensions of the points. - n : int - The number of points. + A : nx.ndarray, shape (d * n, ) + The input ndarray containing distributions of n bins in d dimensions. verbose : bool, optional - If True, print debugging information during the execution of the - algorithm. Default is False. - return_dual_vars : bool, optional - If True, also return the dual variables and the dual objective value of - the DEMD problem. Default is False. + If True, print debugging information during execution. Default=False. + log : bool, optional + If True, record log. Default is False. Returns ------- - primal_obj : float + obj : float the value of the primal objective function evaluated at the solution. - dual_vars : numpy array, shape (d, n-1), optional - the values of the dual variables corresponding to the input points. - The i-th column of the array corresponds to the i-th point. - dual_obj : float, optional - the value of the dual objective function evaluated at the solution. + log : dict + A dictionary containing the log of the discrete mmot problem: + - 'A': a dictionary that maps tuples of indices to the corresponding + primal variables. The tuples are the indices of the entries that are + set to their minimum value during the algorithm. + - 'primal objective': a float, the value of the objective function + evaluated at the solution. + - 'dual': a list of arrays, the dual variables corresponding to + the input arrays. The i-th element of the list is the dual variable + corresponding to the i-th dimension of the input arrays. + - 'dual objective': a float, the value of the dual objective function + evaluated at the solution. + References ---------- @@ -95,14 +64,15 @@ def discrete_mmot(A, verbose=False, log=False): Vikas Singh (2023). Efficient Discrete Multi Marginal Optimal Transport Regularization. In The Eleventh International Conference on Learning Representations. - .. [51] Jeffery Kline. Properties of the d-dimensional earth mover’s - problem. Discrete Applied Mathematics, 265: 128–141, 2019. + .. [51] Jeffery Kline. Properties of the d-dimensional earth mover's + problem. Discrete Applied Mathematics, 265: 128-141, 2019. + .. [52] Leonid V Kantorovich. On the translocation of masses. Dokl. Akad. + Nauk SSSR, 37:227-229, 1942. """ def OBJ(i): return max(i) - min(i) - # print(f"A type is: {type(A)}") nx = get_backend(A) AA = [nx.copy(_) for _ in A] @@ -113,8 +83,10 @@ def OBJ(i): idx = [0, ] * len(AA) obj = 0 + if verbose: print('i minval oldidx\t\tobj\t\tvals') + while all([i < _ for _, i in zip(dims, idx)]): vals = [v[i] for v, i in zip(AA, idx)] minval = min(vals) @@ -151,33 +123,69 @@ def OBJ(i): else: return obj - # if return_dual_vars: - # dual = log['dual'] - # return_dual = np.array(dual) - # dualobj = log['dual objective'] - # return log['primal objective'], return_dual, log['dual objective'] - # else: - # return log['primal objective'], log - -def discrete_mmot_converge(A, niters=100, lr=0.1, print_rate=100, log=False): - r""" - Minimize a DEMD function using gradient descent. +def discrete_mmot_converge( + A, niters=100, lr=0.1, print_rate=100, verbose=False, log=False): + r"""Compute a d-MMOT problem using gradient descent. + + Discrete Multi-Marginal Optimal Transport (d-MMOT): Let :math:`p_1, \ldots, + p_d\in\mathbb{R}^n_{+}` be discrete probability distributions. Let + :math:`C_d : \mathbb{R}^{n^{d}}\rightarrow \mathbb{R}_{+}`. + The discrete multi-marginal optimal transport problem (d-MMOT) can be + written as: + + .. math:: + \underset{{X \in \mathbb{R}^{n\times \cdots \times n}}}{\textrm{min}} + \quad C_d(X) \quad \textrm{s.t.}\quad X_i = p_i,\ (\forall i\in [d]), + + where :math:`X_i \in \mathbb{R}^n` is the :math:`i`-th marginal of :math: + `X \in \mathbb{R}^{n\times \cdots \times n}=\mathbb{R}^{n^{d}}`. + + Following the original formulation (Kantorovich 1942), we will restrict the + cost function :math:`C_d(\cdot)` to the linear map, :math:`C_d(X) := + \langle c, X \rangle_{\otimes}`, where :math:`c \in \mathbb{R}_{+}^{n\times + \cdots \times n}` is nonnegative. Here, the d-MMOT is the LP, + + .. math:: + \begin{align}\begin{aligned} + \underset{x\in\mathbb{R}^{n^{d}}_{+}} {\textrm{min}} + \sum_{i_1,\ldots,i_d} c(i_1,\ldots, i_d)\, x(i_1,\ldots,i_d) \quad + \textrm{s.t.} + \sum_{i_2,\ldots,i_d} x(i_1,\ldots,i_d) &= p_1(i_i), + (\forall i_1\in[n])\\ + \qquad\vdots\\ + \sum_{i_1,\ldots,i_{d-1}} x(i_1,\ldots,i_d) &= p_{d}(i_{d}), + (\forall i_d\in[n]). + \end{aligned} + \end{align} + + The dual linear program of the d-MMOT problem is: + + .. math:: + \underset{z_j\in\mathbb{R}^n, j\in[d]}{\textrm{maximize}}\qquad\sum_{j} + p_j'z_j\qquad \textrm{subject to}\qquad z_{1}(i_1)+\cdots+z_{d}(i_{d}) + \leq c(i_1,\ldots,i_{d}), + + + where the indices in the constraints include all :math:`i_j\in[n]`, :math: + `j\in[d]`. Denote by :math:`\phi(p_1,\ldots,p_d)`, the optimal objective + value of the LP in d-MMOT problem. Let :math:`z^*` be an optimal solution + to the dual program. Then, + + .. math:: + \begin{align} + \nabla \phi(p_1,\ldots,p_{d}) &= z^*, + ~~\text{and for any $t\in \mathbb{R}$,}~~ + \phi(p_1,p_2,\ldots,p_{d}) = \sum_{j}p_j' + (z_j^* + t\, \eta), \nonumber \\ + \text{where } \eta &:= (z_1^{*}(n)\,e, z^*_1(n)\,e, \cdots, + z^*_{d}(n)\,e) + \end{align} Parameters ---------- - f : callable - The objective function to minimize. This function must take as input - a matrix x of shape (d, n) and return a scalar value representing - the objective function evaluated at x. It may also return a matrix of - shape (d, n) representing the gradient of the objective function - with respect to x, and/or any other dual variables needed for the - optimization algorithm. The signature of this function should be: - `f(x, d, n, return_dual_vars=False) -> float` - or - `f(x, d, n, return_dual_vars=True) -> (float, ndarray, ...)` - A : ndarray, shape (d, n) - The initial point for the optimization algorithm. + A : nx.ndarray, shape (d, n) + The input ndarray containing distributions of n bins in d dimensions. niters : int, optional (default=100) The maximum number of iterations for the optimization algorithm. lr : float, optional (default=0.1) @@ -185,11 +193,18 @@ def discrete_mmot_converge(A, niters=100, lr=0.1, print_rate=100, log=False): print_rate : int, optional (default=100) The rate at which to print the objective value and gradient norm during the optimization algorithm. + verbose : bool, optional + If True, print debugging information during execution. Default=False. + log : bool, optional + If True, record log. Default is False. Returns ------- - list of ndarrays, each of shape (n,) - The optimal solution as a list of n vectors, each of length vecsize. + a : list of ndarrays, each of shape (n,) + The optimal solution as a list of n approximate barycenters, each of + length vecsize. + log : dict + log dictionary return only if log==True in parameters """ # function body here @@ -197,10 +212,9 @@ def discrete_mmot_converge(A, niters=100, lr=0.1, print_rate=100, log=False): d, n = A.shape def dualIter(A, lr): - funcval, log_dict = discrete_mmot(A, log=True) + funcval, log_dict = discrete_mmot(A, verbose=verbose, log=True) grad = np.array(log_dict['dual']) A_new = nx.reshape(A, (d, n)) - grad * lr - # A_new = A - grad * lr return funcval, A_new, grad, log_dict def renormalize(A): @@ -234,5 +248,4 @@ def listify(A): if log: return a, log_dict else: - return a - + return a \ No newline at end of file From 3c7ab34f1054bba94acf85759cbc8479be84628e Mon Sep 17 00:00:00 2001 From: Ronak Date: Wed, 17 May 2023 14:08:47 -0400 Subject: [PATCH 15/28] bug fix in plot_dmmot, some commenting/documenting edits --- examples/others/plot_d-mmot.py | 2 +- ot/lp/dmmot.py | 54 +++++++++++++++++++++------------- 2 files changed, 34 insertions(+), 22 deletions(-) diff --git a/examples/others/plot_d-mmot.py b/examples/others/plot_d-mmot.py index 2eb2b784a..e3f02459f 100644 --- a/examples/others/plot_d-mmot.py +++ b/examples/others/plot_d-mmot.py @@ -130,7 +130,7 @@ # ---------------- # The following section computes 1D Wasserstein barycenter using the LP method weights = ot.unif(d) -lp_bary, bary_log = ot.lp.barycenter(A, M, weights, solver='interior-point', p +lp_bary, bary_log = ot.lp.barycenter(A, M, weights, solver='interior-point', verbose=True, log=True) # %% diff --git a/ot/lp/dmmot.py b/ot/lp/dmmot.py index 03a62457f..8be4ae17e 100644 --- a/ot/lp/dmmot.py +++ b/ot/lp/dmmot.py @@ -15,9 +15,12 @@ def discrete_mmot(A, verbose=False, log=False): r""" - Compute the discrete multi marginal optimal transport of distributions A + Compute the discrete multi-marginal optimal transport of distributions A. - The algorithm solves both primal and dual d-MMOT programs + The algorithm solves both primal and dual d-MMOT programs concurrently to + produce the optimal transport plan as well as the total (minimal) cost. + The cost is a generalized Monge cost, and the solution is independent of + which Monge cost is desired. The algorithm accepts :math:`d` distributions (i.e., histograms) :math:`p_{1}, \ldots, p_{d} \in \mathbb{R}_{+}^{n}` with :math:`e^{\prime} @@ -25,10 +28,21 @@ def discrete_mmot(A, verbose=False, log=False): histograms have the same number of bins, the algorithm can be easily adapted to accept as inputs :math:`p_{i} \in \mathbb{R}_{+}^{n_{i}}` with :math:`n_{i} \neq n_{j}` [50]. - The function solves the following optimization problem[51]: - + + .. math:: + \begin{align}\begin{aligned} + \underset{x\in\mathbb{R}^{n^{d}}_{+}} {\textrm{min}} + \sum_{i_1,\ldots,i_d} c(i_1,\ldots, i_d)\, x(i_1,\ldots,i_d) \quad + \textrm{s.t.} + \sum_{i_2,\ldots,i_d} x(i_1,\ldots,i_d) &= p_1(i_i), + (\forall i_1\in[n])\\ + \qquad\vdots\\ + \sum_{i_1,\ldots,i_{d-1}} x(i_1,\ldots,i_d) &= p_{d}(i_{d}), + (\forall i_d\in[n]). + \end{aligned} + \end{align} Parameters @@ -68,6 +82,10 @@ def discrete_mmot(A, verbose=False, log=False): problem. Discrete Applied Mathematics, 265: 128-141, 2019. .. [52] Leonid V Kantorovich. On the translocation of masses. Dokl. Akad. Nauk SSSR, 37:227-229, 1942. + + See Also + -------- + ot.lp.discrete_mmot_converge : Minimized the d-Dimensional Earth Mover's Distance (d-MMOT) """ def OBJ(i): @@ -126,25 +144,11 @@ def OBJ(i): def discrete_mmot_converge( A, niters=100, lr=0.1, print_rate=100, verbose=False, log=False): - r"""Compute a d-MMOT problem using gradient descent. + r"""Minimize the d-dimensional EMD using gradient descent. Discrete Multi-Marginal Optimal Transport (d-MMOT): Let :math:`p_1, \ldots, - p_d\in\mathbb{R}^n_{+}` be discrete probability distributions. Let - :math:`C_d : \mathbb{R}^{n^{d}}\rightarrow \mathbb{R}_{+}`. - The discrete multi-marginal optimal transport problem (d-MMOT) can be - written as: - - .. math:: - \underset{{X \in \mathbb{R}^{n\times \cdots \times n}}}{\textrm{min}} - \quad C_d(X) \quad \textrm{s.t.}\quad X_i = p_i,\ (\forall i\in [d]), - - where :math:`X_i \in \mathbb{R}^n` is the :math:`i`-th marginal of :math: - `X \in \mathbb{R}^{n\times \cdots \times n}=\mathbb{R}^{n^{d}}`. - - Following the original formulation (Kantorovich 1942), we will restrict the - cost function :math:`C_d(\cdot)` to the linear map, :math:`C_d(X) := - \langle c, X \rangle_{\otimes}`, where :math:`c \in \mathbb{R}_{+}^{n\times - \cdots \times n}` is nonnegative. Here, the d-MMOT is the LP, + p_d\in\mathbb{R}^n_{+}` be discrete probability distributions. Here, + the d-MMOT is the LP, .. math:: \begin{align}\begin{aligned} @@ -182,6 +186,10 @@ def discrete_mmot_converge( z^*_{d}(n)\,e) \end{align} + Using these dual variables naturally provided by the algorithm in + ot.lp.discrete_mmot, gradient steps move each input distribution + to minimize their d-mmot distance. + Parameters ---------- A : nx.ndarray, shape (d, n) @@ -205,6 +213,10 @@ def discrete_mmot_converge( length vecsize. log : dict log dictionary return only if log==True in parameters + + See Also + -------- + ot.lp.discrete_mmot : d-Dimensional Earth Mover's Solver """ # function body here From 7452379941ad6f8349ffe28ce5cfa607769ab3c9 Mon Sep 17 00:00:00 2001 From: Ronak Date: Wed, 17 May 2023 14:51:10 -0400 Subject: [PATCH 16/28] dmmot example cleanup, some comments/plotting edits --- examples/others/plot_d-mmot.py | 107 +++++++++++++++++---------------- 1 file changed, 55 insertions(+), 52 deletions(-) diff --git a/examples/others/plot_d-mmot.py b/examples/others/plot_d-mmot.py index e3f02459f..61a3d6327 100644 --- a/examples/others/plot_d-mmot.py +++ b/examples/others/plot_d-mmot.py @@ -1,12 +1,14 @@ # -*- coding: utf-8 -*- r""" ================================================================================= -d-MMOT vs LP Gradient Decent without Pytorch +Computing d-dimensional Barycenters via d-MMOT ================================================================================= -Compare the loss convergence between LP and DEMD. The comparison is performed using random -Gaussian or uniform distributions and calculating the loss for each method -during the optimization process. +When the cost is discretized (Monge), the d-MMOT solver can more quickly compute and +minimize the distance between many distributions without the need for intermediate +barycenter computations. This example compares the time to identify, +and the quality of, solutions for the d-MMOT problem using a primal/dual algorithm +and classical LP barycenter approaches. """ # Author: Ronak Mehta @@ -15,12 +17,14 @@ # License: MIT License # %% -# 2 distributions +# Generating 2 distributions # ----- import numpy as np import matplotlib.pyplot as pl import ot +np.random.seed(0) + n = 100 d = 2 # Gaussian distributions @@ -38,8 +42,10 @@ pl.legend() # %% -# Run test +# Minimize the distances among distributions, identify the Barycenter # ----- +# The objective being minimized is different for both methods, so the objective values +# cannot be compared. print('LP Iterations:') ot.tic() @@ -65,32 +71,34 @@ pl.figure(1, figsize=(6.4, 3)) for i in range(len(barys)): if i == 0: - pl.plot(x, barys[i], 'g', label='Discrete MMOT') + pl.plot(x, barys[i], 'g-*', label='Discrete MMOT') else: - pl.plot(x, barys[i], 'g') + continue + #pl.plot(x, barys[i], 'g-*') +pl.plot(x, lp_bary, 'k-', label='LP Barycenter') pl.plot(x, a1, 'b', label='Source distribution') pl.plot(x, a2, 'r', label='Target distribution') pl.title('Barycenters') pl.legend() -# %% -# Compare d-MMOOT with original distributions -# --------- -pl.figure(1, figsize=(6.4, 3)) -for i in range(len(barys)): - if i == 0: - pl.plot(x, barys[i], 'g', label='Discrete MMOT') - else: - pl.plot(x, barys[i], 'g') -# pl.plot(x, bary, 'g', label='Discrete MMOT') -pl.plot(x, lp_bary, 'b', label='LP Wasserstein') -pl.title('Barycenters') -pl.legend() +# # %% +# # Compare d-MMOT with original distributions +# # --------- +# pl.figure(1, figsize=(6.4, 3)) +# for i in range(len(barys)): +# if i == 0: +# pl.plot(x, barys[i], 'g', label='Discrete MMOT') +# else: +# pl.plot(x, barys[i], 'g') +# # pl.plot(x, bary, 'g', label='Discrete MMOT') +# pl.plot(x, lp_bary, 'b', label='LP Wasserstein') +# pl.title('Barycenters') +# pl.legend() # %% -# Define parameters, generate and plot distributions +# More than 2 distributions # -------------------------------------------------- -# The following code generates random (n, d) data with in gauss +# Generate 7 pseudorandom gaussian distributions with 50 bins. n = 50 # nb bins d = 7 vecsize = n * d @@ -115,20 +123,20 @@ pl.legend() # %% -# Gradient Decent +# Minimizing Distances Among Many Distributions # --------------- -# The following section performs gradient descent optimization using -# the DEMD method +# The objective being minimized is different for both methods, so the objective values +# cannot be compared. -barys = ot.lp.discrete_mmot_converge(A.T, niters=9000, lr=0.00001) +# Perform gradient descent optimization using +# the d-MMOT method. -# after minimization, any distribution can be used as a estimate of barycenter -# bary = barys[0] +barys = ot.lp.discrete_mmot_converge(A.T, niters=9000, lr=0.00001) +# after minimization, any distribution can be used as a estimate of barycenter. +bary = barys[0] -# %% lp barycenter -# ---------------- -# The following section computes 1D Wasserstein barycenter using the LP method +# Compute 1D Wasserstein barycenter using the LP method weights = ot.unif(d) lp_bary, bary_log = ot.lp.barycenter(A, M, weights, solver='interior-point', verbose=True, log=True) @@ -137,34 +145,29 @@ # Compare Barycenters in both methods # --------- pl.figure(1, figsize=(6.4, 3)) -for i in range(len(barys)): - if i == 0: - pl.plot(x, barys[i], 'g', label='Discrete MMOT') - else: - pl.plot(x, barys[i], 'g') -# pl.plot(x, bary, 'g', label='Discrete MMOT') -pl.plot(x, lp_bary, 'b', label='LP Wasserstein') +# for i in range(len(barys)): +# if i == 0: +# pl.plot(x, barys[i], 'g', label='Discrete MMOT') +# else: +# pl.plot(x, barys[i], 'g') +pl.plot(x, bary, 'g-*', label='Discrete MMOT') +pl.plot(x, lp_bary, 'k-', label='LP Wasserstein') pl.title('Barycenters') pl.legend() # %% -# Compare d-MMOOT with original distributions +# Compare with original distributions # --------- pl.figure(1, figsize=(6.4, 3)) +for i in range(len(data)): + pl.plot(x, data[i]) for i in range(len(barys)): if i == 0: - pl.plot(x, barys[i], 'g', label='Discrete MMOT') + pl.plot(x, barys[i], 'g-*', label='Discrete MMOT') else: - pl.plot(x, barys[i], 'g') + continue + #pl.plot(x, barys[i], 'g') +pl.plot(x, lp_bary, 'k-', label='LP Wasserstein') # pl.plot(x, bary, 'g', label='Discrete MMOT') -for i in range(len(data)): - pl.plot(x, data[i]) pl.title('Barycenters') -pl.legend() - - -# %% -# Compare the loss between DEMD and LP Barycenter -# --------- -# The barycenter approach does not minize the distance between -# the distributions, while our DEMD does. +pl.legend() \ No newline at end of file From 9c360bbdfc956f33b2101ab2e063b962ca88e2e9 Mon Sep 17 00:00:00 2001 From: x12hengyu Date: Wed, 31 May 2023 21:55:49 -0500 Subject: [PATCH 17/28] add dist_monge method --- examples/others/plot_d-mmot.py | 15 ++++++++------- ot/lp/dmmot.py | 33 ++++++++++++++++++++++++++++++++- 2 files changed, 40 insertions(+), 8 deletions(-) diff --git a/examples/others/plot_d-mmot.py b/examples/others/plot_d-mmot.py index 2eb2b784a..9a90db1fe 100644 --- a/examples/others/plot_d-mmot.py +++ b/examples/others/plot_d-mmot.py @@ -1,12 +1,12 @@ # -*- coding: utf-8 -*- r""" -================================================================================= +=============================================================================== d-MMOT vs LP Gradient Decent without Pytorch -================================================================================= +=============================================================================== -Compare the loss convergence between LP and DEMD. The comparison is performed using random -Gaussian or uniform distributions and calculating the loss for each method -during the optimization process. +Compare the loss convergence between LP and DEMD. The comparison is performed +using random Gaussian or uniform distributions and calculating the loss for +each method during the optimization process. """ # Author: Ronak Mehta @@ -59,6 +59,7 @@ print('Time\t: ', ot.toc('')) print('Obj\t: ', dmmot_obj) + # %% # Compare Barycenters in both methods # --------- @@ -130,8 +131,8 @@ # ---------------- # The following section computes 1D Wasserstein barycenter using the LP method weights = ot.unif(d) -lp_bary, bary_log = ot.lp.barycenter(A, M, weights, solver='interior-point', p - verbose=True, log=True) +lp_bary, bary_log = ot.lp.barycenter(A, M, weights, solver='interior-point', + verbose=True, log=True) # %% # Compare Barycenters in both methods diff --git a/ot/lp/dmmot.py b/ot/lp/dmmot.py index 03a62457f..b8c9d8927 100644 --- a/ot/lp/dmmot.py +++ b/ot/lp/dmmot.py @@ -11,7 +11,38 @@ import numpy as np from ..backend import get_backend -# M -> obj +def dist_monge(i): + r""" + A tensor :math:c is Monge if for all valid :math:i_1, \ldots i_d and + :math:j_1, \ldots, j_d, + + .. math:: + c(s_1, \ldots, s_d) + c(t_1, \ldots t_d) \leq c(i_1, \ldots i_d) + + c(j_1, \ldots, j_d) + + where :math:s_k = \min(i_k, j_k) and :math:t_k = \max(i_k, j_k). + + Our focus is on a specific cost, which is known to be Monge: + + .. math:: + c(i_1,i_2,\ldots,i_d) = \max{i_k:k\in[d]} - \min{i_k:k\in[d]}. + + When :math:d=2, this cost reduces to :math:c(i_1,i_2)=|i_1-i_2|, + which agrees with the classical EMD cost. This choice of :math:c is called + the generalized EMD cost. + + Parameters + ---------- + i : list + The list for which the generalized EMD cost is to be computed. + + Returns + ------- + cost : numeric value + The generalized EMD cost of the tensor. + """ + + return max(i) - min(i) def discrete_mmot(A, verbose=False, log=False): r""" From 697036d1facf7b0ef353674f48d5eff88fedb005 Mon Sep 17 00:00:00 2001 From: x12hengyu Date: Thu, 1 Jun 2023 10:28:51 -0500 Subject: [PATCH 18/28] all dmmot methods takes (n, d) shape A as input (follows POT style) --- examples/others/plot_d-mmot.py | 53 ++++++++++------------------------ ot/lp/dmmot.py | 41 +++++++++++++------------- 2 files changed, 35 insertions(+), 59 deletions(-) diff --git a/examples/others/plot_d-mmot.py b/examples/others/plot_d-mmot.py index 5dd174096..ba230efc2 100644 --- a/examples/others/plot_d-mmot.py +++ b/examples/others/plot_d-mmot.py @@ -1,14 +1,14 @@ # -*- coding: utf-8 -*- r""" -================================================================================= +=============================================================================== Computing d-dimensional Barycenters via d-MMOT -================================================================================= +=============================================================================== -When the cost is discretized (Monge), the d-MMOT solver can more quickly compute and -minimize the distance between many distributions without the need for intermediate -barycenter computations. This example compares the time to identify, -and the quality of, solutions for the d-MMOT problem using a primal/dual algorithm -and classical LP barycenter approaches. +When the cost is discretized (Monge), the d-MMOT solver can more quickly +compute and minimize the distance between many distributions without the need +for intermediate barycenter computations. This example compares the time to +identify, and the quality of, solutions for the d-MMOT problem using a +primal/dual algorithm and classical LP barycenter approaches. """ # Author: Ronak Mehta @@ -32,8 +32,6 @@ a2 = ot.datasets.make_1D_gauss(n, m=60, s=8) A = np.vstack((a1, a2)).T x = np.arange(n, dtype=np.float64) -# M = ot.utils.dist0(n) -# M /= M.max() M = ot.utils.dist(x.reshape((n, 1)), metric='minkowski') pl.figure(1, figsize=(6.4, 3)) @@ -44,8 +42,8 @@ # %% # Minimize the distances among distributions, identify the Barycenter # ----- -# The objective being minimized is different for both methods, so the objective values -# cannot be compared. +# The objective being minimized is different for both methods, so the objective +# values cannot be compared. print('LP Iterations:') ot.tic() @@ -60,7 +58,7 @@ print('Discrete MMOT Algorithm:') ot.tic() # dmmot_obj, log = ot.lp.discrete_mmot(A.T, n, d) -barys, log = ot.lp.discrete_mmot_converge(A.T, niters=3000, lr=0.000002, log=True) +barys, log = ot.lp.discrete_mmot_converge(A, niters=3000, lr=0.000002, log=True) dmmot_obj = log['primal objective'] print('Time\t: ', ot.toc('')) print('Obj\t: ', dmmot_obj) @@ -82,20 +80,6 @@ pl.title('Barycenters') pl.legend() -# # %% -# # Compare d-MMOT with original distributions -# # --------- -# pl.figure(1, figsize=(6.4, 3)) -# for i in range(len(barys)): -# if i == 0: -# pl.plot(x, barys[i], 'g', label='Discrete MMOT') -# else: -# pl.plot(x, barys[i], 'g') -# # pl.plot(x, bary, 'g', label='Discrete MMOT') -# pl.plot(x, lp_bary, 'b', label='LP Wasserstein') -# pl.title('Barycenters') -# pl.legend() - # %% # More than 2 distributions # -------------------------------------------------- @@ -126,13 +110,11 @@ # %% # Minimizing Distances Among Many Distributions # --------------- -# The objective being minimized is different for both methods, so the objective values -# cannot be compared. - -# Perform gradient descent optimization using -# the d-MMOT method. +# The objective being minimized is different for both methods, so the objective +# values cannot be compared. -barys = ot.lp.discrete_mmot_converge(A.T, niters=9000, lr=0.00001) +# Perform gradient descent optimization using the d-MMOT method. +barys = ot.lp.discrete_mmot_converge(A, niters=9000, lr=0.00001) # after minimization, any distribution can be used as a estimate of barycenter. bary = barys[0] @@ -146,11 +128,6 @@ # Compare Barycenters in both methods # --------- pl.figure(1, figsize=(6.4, 3)) -# for i in range(len(barys)): -# if i == 0: -# pl.plot(x, barys[i], 'g', label='Discrete MMOT') -# else: -# pl.plot(x, barys[i], 'g') pl.plot(x, bary, 'g-*', label='Discrete MMOT') pl.plot(x, lp_bary, 'k-', label='LP Wasserstein') pl.title('Barycenters') @@ -171,4 +148,4 @@ pl.plot(x, lp_bary, 'k-', label='LP Wasserstein') # pl.plot(x, bary, 'g', label='Discrete MMOT') pl.title('Barycenters') -pl.legend() \ No newline at end of file +pl.legend() diff --git a/ot/lp/dmmot.py b/ot/lp/dmmot.py index 171a79609..0bc1b7706 100644 --- a/ot/lp/dmmot.py +++ b/ot/lp/dmmot.py @@ -78,7 +78,7 @@ def discrete_mmot(A, verbose=False, log=False): Parameters ---------- - A : nx.ndarray, shape (d * n, ) + A : nx.ndarray, shape (n * d, ) The input ndarray containing distributions of n bins in d dimensions. verbose : bool, optional If True, print debugging information during execution. Default=False. @@ -116,15 +116,14 @@ def discrete_mmot(A, verbose=False, log=False): See Also -------- - ot.lp.discrete_mmot_converge : Minimized the d-Dimensional Earth Mover's Distance (d-MMOT) + ot.lp.discrete_mmot_converge : Minimized the d-Dimensional Earth Mover's + Distance (d-MMOT) """ - - def OBJ(i): - return max(i) - min(i) nx = get_backend(A) - AA = [nx.copy(_) for _ in A] + # AA = [nx.copy(_) for _ in A] + AA = [nx.copy(A[:, j]) for j in range(A.shape[1])] dims = tuple([len(_) for _ in AA]) xx = {} @@ -141,13 +140,13 @@ def OBJ(i): minval = min(vals) i = vals.index(minval) xx[tuple(idx)] = minval - obj += (OBJ(idx)) * minval + obj += (dist_monge(idx)) * minval for v, j in zip(AA, idx): v[j] -= minval oldidx = nx.copy(idx) idx[i] += 1 if idx[i] < dims[i]: - dual[i][idx[i]] += OBJ(idx) - OBJ(oldidx) + dual[i][idx[i]-1] + dual[i][idx[i]] += dist_monge(idx) - dist_monge(oldidx) + dual[i][idx[i]-1] if verbose: print(i, minval, oldidx, obj, '\t', vals) @@ -160,7 +159,7 @@ def OBJ(i): except Exception: pass - dualobj = nx.sum([nx.dot(arr, dual_arr) for arr, dual_arr in zip(A, dual)]) + dualobj = nx.sum([nx.dot(A[:, i], arr) for i, arr in enumerate(dual)]) log_dict = {'A': xx, 'primal objective': obj, @@ -223,7 +222,7 @@ def discrete_mmot_converge( Parameters ---------- - A : nx.ndarray, shape (d, n) + A : nx.ndarray, shape (n, d) The input ndarray containing distributions of n bins in d dimensions. niters : int, optional (default=100) The maximum number of iterations for the optimization algorithm. @@ -252,25 +251,25 @@ def discrete_mmot_converge( # function body here nx = get_backend(A) - d, n = A.shape + n, d = A.shape def dualIter(A, lr): funcval, log_dict = discrete_mmot(A, verbose=verbose, log=True) - grad = np.array(log_dict['dual']) - A_new = nx.reshape(A, (d, n)) - grad * lr + grad = np.column_stack(log_dict['dual']) + A_new = nx.reshape(A, (n, d)) - grad * lr return funcval, A_new, grad, log_dict def renormalize(A): - A = nx.reshape(A, (d, n)) - for i in range(A.shape[0]): - if min(A[i, :]) < 0: - A[i, :] -= min(A[i, :]) - A[i, :] /= nx.sum(A[i, :]) + A = nx.reshape(A, (n, d)) + for i in range(A.shape[1]): + if min(A[:, i]) < 0: + A[:, i] -= min(A[:, i]) + A[:, i] /= nx.sum(A[:, i]) return A def listify(A): - return [A[i, :] for i in range(A.shape[0])] - + return [A[:, i] for i in range(A.shape[1])] + funcval, _, grad, log_dict = dualIter(A, lr) gn = nx.norm(grad) @@ -286,7 +285,7 @@ def listify(A): print(f'Iter {i:2.0f}:\tObj:\t{funcval:.4f}\tGradNorm:\t{gn:.4f}') A = renormalize(A) - a = listify(nx.reshape(A, (d, n))) + a = listify(A) if log: return a, log_dict From 70326a62ef4549a0605c3c71c63349f132e6bd59 Mon Sep 17 00:00:00 2001 From: x12hengyu Date: Thu, 1 Jun 2023 12:08:39 -0500 Subject: [PATCH 19/28] passed pep8 and pyflake checks --- examples/others/plot_d-mmot.py | 13 ++-- ot/lp/dmmot.py | 130 ++++++++++++++++++++------------- 2 files changed, 84 insertions(+), 59 deletions(-) diff --git a/examples/others/plot_d-mmot.py b/examples/others/plot_d-mmot.py index ba230efc2..f3dcc1f3e 100644 --- a/examples/others/plot_d-mmot.py +++ b/examples/others/plot_d-mmot.py @@ -58,7 +58,8 @@ print('Discrete MMOT Algorithm:') ot.tic() # dmmot_obj, log = ot.lp.discrete_mmot(A.T, n, d) -barys, log = ot.lp.discrete_mmot_converge(A, niters=3000, lr=0.000002, log=True) +barys, log = ot.lp.discrete_mmot_converge( + A, niters=3000, lr=0.000002, log=True) dmmot_obj = log['primal objective'] print('Time\t: ', ot.toc('')) print('Obj\t: ', dmmot_obj) @@ -73,7 +74,7 @@ pl.plot(x, barys[i], 'g-*', label='Discrete MMOT') else: continue - #pl.plot(x, barys[i], 'g-*') + # pl.plot(x, barys[i], 'g-*') pl.plot(x, lp_bary, 'k-', label='LP Barycenter') pl.plot(x, a1, 'b', label='Source distribution') pl.plot(x, a2, 'r', label='Target distribution') @@ -93,13 +94,11 @@ m = n * (0.5 * np.random.rand(1)) * float(np.random.randint(2) + 1) a = ot.datasets.make_1D_gauss(n, m=m, s=5) data.append(a) - + x = np.arange(n, dtype=np.float64) M = ot.utils.dist(x.reshape((n, 1)), metric='minkowski') A = np.vstack(data).T -print(A.shape) - pl.figure(1, figsize=(6.4, 3)) for i in range(len(data)): pl.plot(x, data[i]) @@ -122,7 +121,7 @@ # Compute 1D Wasserstein barycenter using the LP method weights = ot.unif(d) lp_bary, bary_log = ot.lp.barycenter(A, M, weights, solver='interior-point', - verbose=True, log=True) + verbose=True, log=True) # %% # Compare Barycenters in both methods @@ -144,7 +143,7 @@ pl.plot(x, barys[i], 'g-*', label='Discrete MMOT') else: continue - #pl.plot(x, barys[i], 'g') + # pl.plot(x, barys[i], 'g') pl.plot(x, lp_bary, 'k-', label='LP Wasserstein') # pl.plot(x, bary, 'g', label='Discrete MMOT') pl.title('Barycenters') diff --git a/ot/lp/dmmot.py b/ot/lp/dmmot.py index 0bc1b7706..e73af2d5c 100644 --- a/ot/lp/dmmot.py +++ b/ot/lp/dmmot.py @@ -11,74 +11,87 @@ import numpy as np from ..backend import get_backend + def dist_monge(i): r""" - A tensor :math:c is Monge if for all valid :math:i_1, \ldots i_d and + A tensor :math:c is Monge if for all valid :math:i_1, \ldots i_d and :math:j_1, \ldots, j_d, - + .. math:: - c(s_1, \ldots, s_d) + c(t_1, \ldots t_d) \leq c(i_1, \ldots i_d) + + c(s_1, \ldots, s_d) + c(t_1, \ldots t_d) \leq c(i_1, \ldots i_d) + c(j_1, \ldots, j_d) - + where :math:s_k = \min(i_k, j_k) and :math:t_k = \max(i_k, j_k). Our focus is on a specific cost, which is known to be Monge: - + .. math:: c(i_1,i_2,\ldots,i_d) = \max{i_k:k\in[d]} - \min{i_k:k\in[d]}. - - When :math:d=2, this cost reduces to :math:c(i_1,i_2)=|i_1-i_2|, + + When :math:d=2, this cost reduces to :math:c(i_1,i_2)=|i_1-i_2|, which agrees with the classical EMD cost. This choice of :math:c is called the generalized EMD cost. - + Parameters ---------- i : list The list for which the generalized EMD cost is to be computed. - + Returns ------- cost : numeric value The generalized EMD cost of the tensor. + + References + ---------- + .. [51] Jeffery Kline. Properties of the d-dimensional earth mover's + problem. Discrete Applied Mathematics, 265: 128-141, 2019. + .. [53] Wolfgang W. Bein, Peter Brucker, James K. Park, and Pramod K. + Pathak. A monge property for the d- dimensional transportation problem. + Discrete Applied Mathematics, 58(2):97-109, 1995. ISSN 0166-218X. doi: + https://doi.org/10.1016/0166-218X(93)E0121-E. URL + https://www.sciencedirect.com/ science/article/pii/0166218X93E0121E. + Workshop on Discrete Algoritms. """ - + return max(i) - min(i) + def discrete_mmot(A, verbose=False, log=False): r""" Compute the discrete multi-marginal optimal transport of distributions A. - + The algorithm solves both primal and dual d-MMOT programs concurrently to produce the optimal transport plan as well as the total (minimal) cost. The cost is a generalized Monge cost, and the solution is independent of which Monge cost is desired. - - The algorithm accepts :math:`d` distributions (i.e., histograms) - :math:`p_{1}, \ldots, p_{d} \in \mathbb{R}_{+}^{n}` with :math:`e^{\prime} - p_{j}=1` for all :math:`j \in[d]`. Although the algorithm states that all - histograms have the same number of bins, the algorithm can be easily - adapted to accept as inputs :math:`p_{i} \in \mathbb{R}_{+}^{n_{i}}` + + The algorithm accepts :math:`d` distributions (i.e., histograms) + :math:`p_{1}, \ldots, p_{d} \in \mathbb{R}_{+}^{n}` with :math:`e^{\prime} + p_{j}=1` for all :math:`j \in[d]`. Although the algorithm states that all + histograms have the same number of bins, the algorithm can be easily + adapted to accept as inputs :math:`p_{i} \in \mathbb{R}_{+}^{n_{i}}` with :math:`n_{i} \neq n_{j}` [50]. - + The function solves the following optimization problem[51]: .. math:: \begin{align}\begin{aligned} \underset{x\in\mathbb{R}^{n^{d}}_{+}} {\textrm{min}} - \sum_{i_1,\ldots,i_d} c(i_1,\ldots, i_d)\, x(i_1,\ldots,i_d) \quad + \sum_{i_1,\ldots,i_d} c(i_1,\ldots, i_d)\, x(i_1,\ldots,i_d) \quad \textrm{s.t.} - \sum_{i_2,\ldots,i_d} x(i_1,\ldots,i_d) &= p_1(i_i), + \sum_{i_2,\ldots,i_d} x(i_1,\ldots,i_d) &= p_1(i_i), (\forall i_1\in[n])\\ \qquad\vdots\\ - \sum_{i_1,\ldots,i_{d-1}} x(i_1,\ldots,i_d) &= p_{d}(i_{d}), + \sum_{i_1,\ldots,i_{d-1}} x(i_1,\ldots,i_d) &= p_{d}(i_{d}), (\forall i_d\in[n]). \end{aligned} - \end{align} + \end{align} Parameters ---------- - A : nx.ndarray, shape (n * d, ) + A : nx.ndarray, shape (dim, n_hists) The input ndarray containing distributions of n bins in d dimensions. verbose : bool, optional If True, print debugging information during execution. Default=False. @@ -113,10 +126,10 @@ def discrete_mmot(A, verbose=False, log=False): problem. Discrete Applied Mathematics, 265: 128-141, 2019. .. [52] Leonid V Kantorovich. On the translocation of masses. Dokl. Akad. Nauk SSSR, 37:227-229, 1942. - + See Also -------- - ot.lp.discrete_mmot_converge : Minimized the d-Dimensional Earth Mover's + ot.lp.discrete_mmot_converge : Minimized the d-Dimensional Earth Mover's Distance (d-MMOT) """ @@ -131,10 +144,10 @@ def discrete_mmot(A, verbose=False, log=False): idx = [0, ] * len(AA) obj = 0 - + if verbose: print('i minval oldidx\t\tobj\t\tvals') - + while all([i < _ for _, i in zip(dims, idx)]): vals = [v[i] for v, i in zip(AA, idx)] minval = min(vals) @@ -146,7 +159,8 @@ def discrete_mmot(A, verbose=False, log=False): oldidx = nx.copy(idx) idx[i] += 1 if idx[i] < dims[i]: - dual[i][idx[i]] += dist_monge(idx) - dist_monge(oldidx) + dual[i][idx[i]-1] + temp = dist_monge(idx) - dist_monge(oldidx) + dual[i][idx[i] - 1] + dual[i][idx[i]] += temp if verbose: print(i, minval, oldidx, obj, '\t', vals) @@ -161,11 +175,11 @@ def discrete_mmot(A, verbose=False, log=False): dualobj = nx.sum([nx.dot(A[:, i], arr) for i, arr in enumerate(dual)]) - log_dict = {'A': xx, - 'primal objective': obj, - 'dual': dual, - 'dual objective': dualobj} - + log_dict = {'A': xx, + 'primal objective': obj, + 'dual': dual, + 'dual objective': dualobj} + if log: return obj, log_dict else: @@ -173,42 +187,42 @@ def discrete_mmot(A, verbose=False, log=False): def discrete_mmot_converge( - A, niters=100, lr=0.1, print_rate=100, verbose=False, log=False): + A, niters=100, lr=0.1, print_rate=100, verbose=False, log=False): r"""Minimize the d-dimensional EMD using gradient descent. - + Discrete Multi-Marginal Optimal Transport (d-MMOT): Let :math:`p_1, \ldots, p_d\in\mathbb{R}^n_{+}` be discrete probability distributions. Here, the d-MMOT is the LP, - + .. math:: \begin{align}\begin{aligned} \underset{x\in\mathbb{R}^{n^{d}}_{+}} {\textrm{min}} - \sum_{i_1,\ldots,i_d} c(i_1,\ldots, i_d)\, x(i_1,\ldots,i_d) \quad + \sum_{i_1,\ldots,i_d} c(i_1,\ldots, i_d)\, x(i_1,\ldots,i_d) \quad \textrm{s.t.} - \sum_{i_2,\ldots,i_d} x(i_1,\ldots,i_d) &= p_1(i_i), + \sum_{i_2,\ldots,i_d} x(i_1,\ldots,i_d) &= p_1(i_i), (\forall i_1\in[n])\\ \qquad\vdots\\ - \sum_{i_1,\ldots,i_{d-1}} x(i_1,\ldots,i_d) &= p_{d}(i_{d}), + \sum_{i_1,\ldots,i_{d-1}} x(i_1,\ldots,i_d) &= p_{d}(i_{d}), (\forall i_d\in[n]). \end{aligned} \end{align} - + The dual linear program of the d-MMOT problem is: - + .. math:: \underset{z_j\in\mathbb{R}^n, j\in[d]}{\textrm{maximize}}\qquad\sum_{j} p_j'z_j\qquad \textrm{subject to}\qquad z_{1}(i_1)+\cdots+z_{d}(i_{d}) \leq c(i_1,\ldots,i_{d}), - - + + where the indices in the constraints include all :math:`i_j\in[n]`, :math: - `j\in[d]`. Denote by :math:`\phi(p_1,\ldots,p_d)`, the optimal objective - value of the LP in d-MMOT problem. Let :math:`z^*` be an optimal solution + `j\in[d]`. Denote by :math:`\phi(p_1,\ldots,p_d)`, the optimal objective + value of the LP in d-MMOT problem. Let :math:`z^*` be an optimal solution to the dual program. Then, .. math:: \begin{align} - \nabla \phi(p_1,\ldots,p_{d}) &= z^*, + \nabla \phi(p_1,\ldots,p_{d}) &= z^*, ~~\text{and for any $t\in \mathbb{R}$,}~~ \phi(p_1,p_2,\ldots,p_{d}) = \sum_{j}p_j' (z_j^* + t\, \eta), \nonumber \\ @@ -222,7 +236,7 @@ def discrete_mmot_converge( Parameters ---------- - A : nx.ndarray, shape (n, d) + A : nx.ndarray, shape (dim, n_hists) The input ndarray containing distributions of n bins in d dimensions. niters : int, optional (default=100) The maximum number of iterations for the optimization algorithm. @@ -239,11 +253,23 @@ def discrete_mmot_converge( Returns ------- a : list of ndarrays, each of shape (n,) - The optimal solution as a list of n approximate barycenters, each of + The optimal solution as a list of n approximate barycenters, each of length vecsize. log : dict log dictionary return only if log==True in parameters + References + ---------- + .. [50] Ronak Mehta, Jeffery Kline, Vishnu Suresh Lokhande, Glenn Fung, & + Vikas Singh (2023). Efficient Discrete Multi Marginal Optimal + Transport Regularization. In The Eleventh International + Conference on Learning Representations. + .. [54] Olvi L Mangasarian and RR Meyer. Nonlinear perturbation of linear + programs. SIAM Journal on Control and Optimization, 17(6):745-752, 1979 + .. [55] Michael C Ferris and Olvi L Mangasarian. Finite perturbation of + convex programs. Applied Mathematics and Optimization, 23(1):263-273, + 1991. + See Also -------- ot.lp.discrete_mmot : d-Dimensional Earth Mover's Solver @@ -251,7 +277,7 @@ def discrete_mmot_converge( # function body here nx = get_backend(A) - n, d = A.shape + n, d = A.shape # n is dim, d is n_hists def dualIter(A, lr): funcval, log_dict = discrete_mmot(A, verbose=verbose, log=True) @@ -269,7 +295,7 @@ def renormalize(A): def listify(A): return [A[:, i] for i in range(A.shape[1])] - + funcval, _, grad, log_dict = dualIter(A, lr) gn = nx.norm(grad) @@ -286,8 +312,8 @@ def listify(A): A = renormalize(A) a = listify(A) - + if log: return a, log_dict else: - return a \ No newline at end of file + return a From 6de193cd32054836208d30994c3b15d68799dad9 Mon Sep 17 00:00:00 2001 From: x12hengyu Date: Mon, 12 Jun 2023 11:19:20 -0500 Subject: [PATCH 20/28] resolve test fail issue --- test/{test_demd.py => test_dmmot.py} | 37 ++++++++++------------------ 1 file changed, 13 insertions(+), 24 deletions(-) rename test/{test_demd.py => test_dmmot.py} (59%) diff --git a/test/test_demd.py b/test/test_dmmot.py similarity index 59% rename from test/test_demd.py rename to test/test_dmmot.py index 642b449eb..fe482843a 100644 --- a/test/test_demd.py +++ b/test/test_dmmot.py @@ -1,4 +1,4 @@ -"""Tests for ot.demd module """ +"""Tests for ot.lp.dmmot module """ # Author: Ronak Mehta # Xizheng Yu @@ -15,27 +15,15 @@ def create_test_data(): n = 4 a1 = ot.datasets.make_1D_gauss(n, m=20, s=5) a2 = ot.datasets.make_1D_gauss(n, m=60, s=8) - aa = np.vstack([a1, a2]) + A = np.vstack([a1, a2]) x = np.arange(n, dtype=np.float64).reshape((n, 1)) - return aa, x, d, n + return A.T, x -def test_greedy_primal_dual(): - # test greedy_primal_dual object calculation - aa, _, _, _ = create_test_data() - result = ot.greedy_primal_dual(aa) - expected_primal_obj = 0.13667759626298503 - np.testing.assert_allclose(result['primal objective'], - expected_primal_obj, - rtol=1e-7, - err_msg="Test failed: \ - Expected different primal objective value") - - -def test_demd(): - # test one demd iteration result - aa, _, d, n = create_test_data() - primal_obj = ot.demd(aa, n, d) +def test_discrete_mmot(): + # test one discrete_mmot iteration result + A, _ = create_test_data() + primal_obj = ot.lp.discrete_mmot(A) expected_primal_obj = 0.13667759626298503 np.testing.assert_allclose(primal_obj, expected_primal_obj, @@ -44,12 +32,13 @@ def test_demd(): Expected different primal objective value") -def test_demd_minimize(): - # test demd_minimize result - aa, _, d, n = create_test_data() +def test_discrete_mmot_converge(): + # test discrete_mmot_converge result + A, _ = create_test_data() + d = 2 niters = 10 - result = ot.demd_minimize(ot.demd, aa, d, n, 2, niters, 0.001, 5) - + result = ot.lp.discrete_mmot_converge(A, niters, 0.001, 5) + expected_obj = np.array([[0.05553516, 0.13082618, 0.27327479, 0.54036388], [0.04185365, 0.09570724, 0.24384705, 0.61859206]]) From e98c7eeeec28faa23ce03a37a2bc7a95a4b4ec1e Mon Sep 17 00:00:00 2001 From: x12hengyu Date: Tue, 13 Jun 2023 09:53:56 -0500 Subject: [PATCH 21/28] fix pep8 error --- test/test_dmmot.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/test/test_dmmot.py b/test/test_dmmot.py index fe482843a..59977cb2d 100644 --- a/test/test_dmmot.py +++ b/test/test_dmmot.py @@ -11,7 +11,6 @@ def create_test_data(): np.random.seed(1234) - d = 2 n = 4 a1 = ot.datasets.make_1D_gauss(n, m=20, s=5) a2 = ot.datasets.make_1D_gauss(n, m=60, s=8) @@ -38,7 +37,7 @@ def test_discrete_mmot_converge(): d = 2 niters = 10 result = ot.lp.discrete_mmot_converge(A, niters, 0.001, 5) - + expected_obj = np.array([[0.05553516, 0.13082618, 0.27327479, 0.54036388], [0.04185365, 0.09570724, 0.24384705, 0.61859206]]) From 7339e8a18251113235f81bc0e34aee56fa90abb3 Mon Sep 17 00:00:00 2001 From: x12hengyu Date: Wed, 5 Jul 2023 00:00:55 -0500 Subject: [PATCH 22/28] resolve issues from last review, pyflake and pep8 checked --- examples/others/plot_d-mmot.py | 46 ++++++++++------- ot/lp/dmmot.py | 90 +++++++++++++++++++--------------- test/test_dmmot.py | 44 ++++++++++++----- 3 files changed, 113 insertions(+), 67 deletions(-) diff --git a/examples/others/plot_d-mmot.py b/examples/others/plot_d-mmot.py index f3dcc1f3e..b7a840dbd 100644 --- a/examples/others/plot_d-mmot.py +++ b/examples/others/plot_d-mmot.py @@ -47,19 +47,20 @@ print('LP Iterations:') ot.tic() -alpha = 1 # /d # 0<=alpha<=1 -weights = np.array(d * [alpha]) -lp_bary, lp_log = ot.lp.barycenter( - A, M, weights, solver='interior-point', verbose=False, log=True) -print('Time\t: ', ot.toc('')) -print('Obj\t: ', lp_log['fun']) +# alpha = 1 # /d # 0<=alpha<=1 +# weights = np.array(d * [alpha]) +weights = np.ones(d)/d +l2_bary = A.dot(weights) +# lp_bary, lp_log = ot.lp.barycenter( +# A, M, weights, solver='interior-point', verbose=False, log=True) +# print('Time\t: ', ot.toc('')) +# print('Obj\t: ', lp_log['fun']) print('') print('Discrete MMOT Algorithm:') ot.tic() -# dmmot_obj, log = ot.lp.discrete_mmot(A.T, n, d) -barys, log = ot.lp.discrete_mmot_converge( - A, niters=3000, lr=0.000002, log=True) +barys, log = ot.lp.dmmot_monge_1dgrid_optimize( + A, niters=4000, lr=0.000002, log=True) dmmot_obj = log['primal objective'] print('Time\t: ', ot.toc('')) print('Obj\t: ', dmmot_obj) @@ -75,7 +76,8 @@ else: continue # pl.plot(x, barys[i], 'g-*') -pl.plot(x, lp_bary, 'k-', label='LP Barycenter') +# pl.plot(x, lp_bary, 'k-', label='LP Barycenter') +pl.plot(x, l2_bary, 'k', label='L2 Barycenter') pl.plot(x, a1, 'b', label='Source distribution') pl.plot(x, a2, 'r', label='Target distribution') pl.title('Barycenters') @@ -113,25 +115,34 @@ # values cannot be compared. # Perform gradient descent optimization using the d-MMOT method. -barys = ot.lp.discrete_mmot_converge(A, niters=9000, lr=0.00001) +barys = ot.lp.dmmot_monge_1dgrid_optimize(A, niters=9000, lr=0.00001) # after minimization, any distribution can be used as a estimate of barycenter. bary = barys[0] -# Compute 1D Wasserstein barycenter using the LP method +# Compute 1D Wasserstein barycenter using the L2/LP method weights = ot.unif(d) -lp_bary, bary_log = ot.lp.barycenter(A, M, weights, solver='interior-point', - verbose=True, log=True) +l2_bary = A.dot(weights) +# lp_bary, bary_log = ot.lp.barycenter(A, M, weights, solver='interior-point', +# verbose=True, log=True) # %% # Compare Barycenters in both methods # --------- pl.figure(1, figsize=(6.4, 3)) pl.plot(x, bary, 'g-*', label='Discrete MMOT') -pl.plot(x, lp_bary, 'k-', label='LP Wasserstein') +pl.plot(x, l2_bary, 'k', label='L2 Barycenter') +# pl.plot(x, lp_bary, 'k-', label='LP Wasserstein') pl.title('Barycenters') pl.legend() +# %% +# Compare all converged distributions +# --------- +pl.figure(1, figsize=(6.4, 3)) +for i in range(len(barys)): + pl.plot(x, barys[i], 'g', label='Discrete MMOT') + # %% # Compare with original distributions # --------- @@ -144,7 +155,8 @@ else: continue # pl.plot(x, barys[i], 'g') -pl.plot(x, lp_bary, 'k-', label='LP Wasserstein') -# pl.plot(x, bary, 'g', label='Discrete MMOT') +pl.plot(x, l2_bary, 'k', label='L2 Barycenter') +# pl.plot(x, lp_bary, 'k-', label='LP Wasserstein') pl.title('Barycenters') pl.legend() +# %% diff --git a/ot/lp/dmmot.py b/ot/lp/dmmot.py index e73af2d5c..6e08e90bd 100644 --- a/ot/lp/dmmot.py +++ b/ot/lp/dmmot.py @@ -12,7 +12,7 @@ from ..backend import get_backend -def dist_monge(i): +def dist_monge_max_min(i): r""" A tensor :math:c is Monge if for all valid :math:i_1, \ldots i_d and :math:j_1, \ldots, j_d, @@ -35,12 +35,12 @@ def dist_monge(i): Parameters ---------- i : list - The list for which the generalized EMD cost is to be computed. + The list of integer indexes. Returns ------- cost : numeric value - The generalized EMD cost of the tensor. + The ground cost (generalized EMD cost) of the tensor. References ---------- @@ -57,33 +57,36 @@ def dist_monge(i): return max(i) - min(i) -def discrete_mmot(A, verbose=False, log=False): +def dmmot_monge_1dgrid_loss(A, verbose=False, log=False): r""" Compute the discrete multi-marginal optimal transport of distributions A. + This function operates on distributions whose supports are real numbers on + the real line. + The algorithm solves both primal and dual d-MMOT programs concurrently to produce the optimal transport plan as well as the total (minimal) cost. - The cost is a generalized Monge cost, and the solution is independent of + The cost is a ground cost, and the solution is independent of which Monge cost is desired. The algorithm accepts :math:`d` distributions (i.e., histograms) - :math:`p_{1}, \ldots, p_{d} \in \mathbb{R}_{+}^{n}` with :math:`e^{\prime} - p_{j}=1` for all :math:`j \in[d]`. Although the algorithm states that all + :math:`a_{1}, \ldots, a_{d} \in \mathbb{R}_{+}^{n}` with :math:`e^{\prime} + a_{j}=1` for all :math:`j \in[d]`. Although the algorithm states that all histograms have the same number of bins, the algorithm can be easily - adapted to accept as inputs :math:`p_{i} \in \mathbb{R}_{+}^{n_{i}}` + adapted to accept as inputs :math:`a_{i} \in \mathbb{R}_{+}^{n_{i}}` with :math:`n_{i} \neq n_{j}` [50]. The function solves the following optimization problem[51]: .. math:: \begin{align}\begin{aligned} - \underset{x\in\mathbb{R}^{n^{d}}_{+}} {\textrm{min}} - \sum_{i_1,\ldots,i_d} c(i_1,\ldots, i_d)\, x(i_1,\ldots,i_d) \quad - \textrm{s.t.} - \sum_{i_2,\ldots,i_d} x(i_1,\ldots,i_d) &= p_1(i_i), + \underset{\gamma\in\mathbb{R}^{n^{d}}_{+}} {\textrm{min}} + \sum_{i_1,\ldots,i_d} c(i_1,\ldots, i_d)\, \gamma(i_1,\ldots,i_d) + \quad \textrm{s.t.} + \sum_{i_2,\ldots,i_d} \gamma(i_1,\ldots,i_d) &= a_1(i_i), (\forall i_1\in[n])\\ \qquad\vdots\\ - \sum_{i_1,\ldots,i_{d-1}} x(i_1,\ldots,i_d) &= p_{d}(i_{d}), + \sum_{i_1,\ldots,i_{d-1}} \gamma(i_1,\ldots,i_d) &= a_{d}(i_{d}), (\forall i_d\in[n]). \end{aligned} \end{align} @@ -129,18 +132,18 @@ def discrete_mmot(A, verbose=False, log=False): See Also -------- - ot.lp.discrete_mmot_converge : Minimized the d-Dimensional Earth Mover's - Distance (d-MMOT) + ot.lp.dmmot_monge_1dgrid_optimize : Optimize the d-Dimensional Earth + Mover's Distance (d-MMOT) """ nx = get_backend(A) + A = nx.to_numpy(A) - # AA = [nx.copy(_) for _ in A] - AA = [nx.copy(A[:, j]) for j in range(A.shape[1])] + AA = [np.copy(A[:, j]) for j in range(A.shape[1])] dims = tuple([len(_) for _ in AA]) xx = {} - dual = [nx.zeros(d) for d in dims] + dual = [np.zeros(d) for d in dims] idx = [0, ] * len(AA) obj = 0 @@ -153,13 +156,16 @@ def discrete_mmot(A, verbose=False, log=False): minval = min(vals) i = vals.index(minval) xx[tuple(idx)] = minval - obj += (dist_monge(idx)) * minval + obj += (dist_monge_max_min(idx)) * minval for v, j in zip(AA, idx): v[j] -= minval - oldidx = nx.copy(idx) + # oldidx = nx.copy(idx) + oldidx = idx.copy() idx[i] += 1 if idx[i] < dims[i]: - temp = dist_monge(idx) - dist_monge(oldidx) + dual[i][idx[i] - 1] + temp = (dist_monge_max_min(idx) - + dist_monge_max_min(oldidx) + + dual[i][idx[i] - 1]) dual[i][idx[i]] += temp if verbose: print(i, minval, oldidx, obj, '\t', vals) @@ -173,25 +179,29 @@ def discrete_mmot(A, verbose=False, log=False): except Exception: pass - dualobj = nx.sum([nx.dot(A[:, i], arr) for i, arr in enumerate(dual)]) + dualobj = sum([np.dot(A[:, i], arr) for i, arr in enumerate(dual)]) + obj = nx.from_numpy(obj) log_dict = {'A': xx, 'primal objective': obj, 'dual': dual, 'dual objective': dualobj} + # define forward/backward relations for pytorch + obj = nx.set_gradients(obj, (nx.from_numpy(A)), (dual)) + if log: return obj, log_dict else: return obj -def discrete_mmot_converge( +def dmmot_monge_1dgrid_optimize( A, niters=100, lr=0.1, print_rate=100, verbose=False, log=False): r"""Minimize the d-dimensional EMD using gradient descent. - Discrete Multi-Marginal Optimal Transport (d-MMOT): Let :math:`p_1, \ldots, - p_d\in\mathbb{R}^n_{+}` be discrete probability distributions. Here, + Discrete Multi-Marginal Optimal Transport (d-MMOT): Let :math:`a_1, \ldots, + a_d\in\mathbb{R}^n_{+}` be discrete probability distributions. Here, the d-MMOT is the LP, .. math:: @@ -199,10 +209,10 @@ def discrete_mmot_converge( \underset{x\in\mathbb{R}^{n^{d}}_{+}} {\textrm{min}} \sum_{i_1,\ldots,i_d} c(i_1,\ldots, i_d)\, x(i_1,\ldots,i_d) \quad \textrm{s.t.} - \sum_{i_2,\ldots,i_d} x(i_1,\ldots,i_d) &= p_1(i_i), + \sum_{i_2,\ldots,i_d} x(i_1,\ldots,i_d) &= a_1(i_i), (\forall i_1\in[n])\\ \qquad\vdots\\ - \sum_{i_1,\ldots,i_{d-1}} x(i_1,\ldots,i_d) &= p_{d}(i_{d}), + \sum_{i_1,\ldots,i_{d-1}} x(i_1,\ldots,i_d) &= a_{d}(i_{d}), (\forall i_d\in[n]). \end{aligned} \end{align} @@ -211,27 +221,27 @@ def discrete_mmot_converge( .. math:: \underset{z_j\in\mathbb{R}^n, j\in[d]}{\textrm{maximize}}\qquad\sum_{j} - p_j'z_j\qquad \textrm{subject to}\qquad z_{1}(i_1)+\cdots+z_{d}(i_{d}) + a_j'z_j\qquad \textrm{subject to}\qquad z_{1}(i_1)+\cdots+z_{d}(i_{d}) \leq c(i_1,\ldots,i_{d}), where the indices in the constraints include all :math:`i_j\in[n]`, :math: - `j\in[d]`. Denote by :math:`\phi(p_1,\ldots,p_d)`, the optimal objective + `j\in[d]`. Denote by :math:`\phi(a_1,\ldots,a_d)`, the optimal objective value of the LP in d-MMOT problem. Let :math:`z^*` be an optimal solution to the dual program. Then, .. math:: \begin{align} - \nabla \phi(p_1,\ldots,p_{d}) &= z^*, + \nabla \phi(a_1,\ldots,a_{d}) &= z^*, ~~\text{and for any $t\in \mathbb{R}$,}~~ - \phi(p_1,p_2,\ldots,p_{d}) = \sum_{j}p_j' + \phi(a_1,a_2,\ldots,a_{d}) = \sum_{j}a_j' (z_j^* + t\, \eta), \nonumber \\ \text{where } \eta &:= (z_1^{*}(n)\,e, z^*_1(n)\,e, \cdots, z^*_{d}(n)\,e) \end{align} Using these dual variables naturally provided by the algorithm in - ot.lp.discrete_mmot, gradient steps move each input distribution + ot.lp.dmmot_monge_1dgrid_loss, gradient steps move each input distribution to minimize their d-mmot distance. Parameters @@ -272,32 +282,34 @@ def discrete_mmot_converge( See Also -------- - ot.lp.discrete_mmot : d-Dimensional Earth Mover's Solver + ot.lp.dmmot_monge_1dgrid_loss: d-Dimensional Earth Mover's Solver """ # function body here nx = get_backend(A) + A = nx.to_numpy(A) n, d = A.shape # n is dim, d is n_hists def dualIter(A, lr): - funcval, log_dict = discrete_mmot(A, verbose=verbose, log=True) + funcval, log_dict = dmmot_monge_1dgrid_loss( + A, verbose=verbose, log=True) grad = np.column_stack(log_dict['dual']) - A_new = nx.reshape(A, (n, d)) - grad * lr + A_new = np.reshape(A, (n, d)) - grad * lr return funcval, A_new, grad, log_dict def renormalize(A): - A = nx.reshape(A, (n, d)) + A = np.reshape(A, (n, d)) for i in range(A.shape[1]): if min(A[:, i]) < 0: A[:, i] -= min(A[:, i]) - A[:, i] /= nx.sum(A[:, i]) + A[:, i] /= np.sum(A[:, i]) return A def listify(A): return [A[:, i] for i in range(A.shape[1])] funcval, _, grad, log_dict = dualIter(A, lr) - gn = nx.norm(grad) + gn = np.linalg.norm(grad) print(f'Inital:\t\tObj:\t{funcval:.4f}\tGradNorm:\t{gn:.4f}') @@ -305,7 +317,7 @@ def listify(A): A = renormalize(A) funcval, A, grad, log_dict = dualIter(A, lr) - gn = nx.norm(grad) + gn = np.linalg.norm(grad) if i % print_rate == 0: print(f'Iter {i:2.0f}:\tObj:\t{funcval:.4f}\tGradNorm:\t{gn:.4f}') diff --git a/test/test_dmmot.py b/test/test_dmmot.py index 59977cb2d..915078eaa 100644 --- a/test/test_dmmot.py +++ b/test/test_dmmot.py @@ -9,34 +9,56 @@ import ot -def create_test_data(): +def create_test_data(nx): np.random.seed(1234) n = 4 a1 = ot.datasets.make_1D_gauss(n, m=20, s=5) a2 = ot.datasets.make_1D_gauss(n, m=60, s=8) - A = np.vstack([a1, a2]) + A = np.vstack([a1, a2]).T x = np.arange(n, dtype=np.float64).reshape((n, 1)) - return A.T, x + A, x = nx.from_numpy(A, x) + return A, x -def test_discrete_mmot(): - # test one discrete_mmot iteration result - A, _ = create_test_data() - primal_obj = ot.lp.discrete_mmot(A) +def test_dmmot_monge_1dgrid_loss(nx): + A, x = create_test_data(nx) + + # Compute loss using dmmot_monge_1dgrid_loss + primal_obj = ot.lp.dmmot_monge_1dgrid_loss(A) + primal_obj = nx.to_numpy(primal_obj) expected_primal_obj = 0.13667759626298503 + np.testing.assert_allclose(primal_obj, expected_primal_obj, rtol=1e-7, err_msg="Test failed: \ - Expected different primal objective value") + Expected different primal objective value") + + # Compute loss using exact OT solver with absolute ground metric + A, x = nx.to_numpy(A, x) + M = ot.utils.dist(x, metric='cityblock') # absolute ground metric + bary, _ = ot.barycenter(A, M, 1e-2, weights=None, verbose=False, log=True) + ot_obj = 0.0 + for x in A.T: + # deal with C-contiguous error from tensorflow backend (not sure why) + x = np.ascontiguousarray(x) + # compute loss + _, log = ot.lp.emd(x, np.array(bary/np.sum(bary)), M, log=True) + ot_obj += log['cost'] + + np.testing.assert_allclose(primal_obj, + ot_obj, + rtol=1e-7, + err_msg="Test failed: \ + Expected different primal objective value") -def test_discrete_mmot_converge(): +def test_dmmot_monge_1dgrid_optimize(nx): # test discrete_mmot_converge result - A, _ = create_test_data() + A, _ = create_test_data(nx) d = 2 niters = 10 - result = ot.lp.discrete_mmot_converge(A, niters, 0.001, 5) + result = ot.lp.dmmot_monge_1dgrid_optimize(A, niters, 0.001, 5) expected_obj = np.array([[0.05553516, 0.13082618, 0.27327479, 0.54036388], [0.04185365, 0.09570724, 0.24384705, 0.61859206]]) From fd444b7840417603036a8efc7ca16e1b724d5ab4 Mon Sep 17 00:00:00 2001 From: x12hengyu Date: Fri, 7 Jul 2023 16:02:28 -0500 Subject: [PATCH 23/28] add lr decay --- examples/others/plot_d-mmot.py | 18 ++++++++++-------- ot/lp/dmmot.py | 30 +++++++++++++++++++++--------- test/test_dmmot.py | 13 ++++++++----- 3 files changed, 39 insertions(+), 22 deletions(-) diff --git a/examples/others/plot_d-mmot.py b/examples/others/plot_d-mmot.py index b7a840dbd..5ed923bd2 100644 --- a/examples/others/plot_d-mmot.py +++ b/examples/others/plot_d-mmot.py @@ -45,12 +45,13 @@ # The objective being minimized is different for both methods, so the objective # values cannot be compared. -print('LP Iterations:') -ot.tic() +# L2 Iteration +weights = np.ones(d) / d +l2_bary = A.dot(weights) + +# print('LP Iterations:') # alpha = 1 # /d # 0<=alpha<=1 # weights = np.array(d * [alpha]) -weights = np.ones(d)/d -l2_bary = A.dot(weights) # lp_bary, lp_log = ot.lp.barycenter( # A, M, weights, solver='interior-point', verbose=False, log=True) # print('Time\t: ', ot.toc('')) @@ -59,8 +60,8 @@ print('') print('Discrete MMOT Algorithm:') ot.tic() -barys, log = ot.lp.dmmot_monge_1dgrid_optimize( - A, niters=4000, lr=0.000002, log=True) +barys, log = ot.lp.dmmot_monge_ddgrid_optimize( + A, niters=4000, lr_init=1e-5, lr_decay=0.997, log=True) dmmot_obj = log['primal objective'] print('Time\t: ', ot.toc('')) print('Obj\t: ', dmmot_obj) @@ -68,7 +69,7 @@ # %% # Compare Barycenters in both methods -# --------- +# ----- pl.figure(1, figsize=(6.4, 3)) for i in range(len(barys)): if i == 0: @@ -115,7 +116,8 @@ # values cannot be compared. # Perform gradient descent optimization using the d-MMOT method. -barys = ot.lp.dmmot_monge_1dgrid_optimize(A, niters=9000, lr=0.00001) +barys = ot.lp.dmmot_monge_ddgrid_optimize( + A, niters=3000, lr_init=1e-4, lr_decay=0.997) # after minimization, any distribution can be used as a estimate of barycenter. bary = barys[0] diff --git a/ot/lp/dmmot.py b/ot/lp/dmmot.py index 6e08e90bd..a4d56a02f 100644 --- a/ot/lp/dmmot.py +++ b/ot/lp/dmmot.py @@ -57,7 +57,7 @@ def dist_monge_max_min(i): return max(i) - min(i) -def dmmot_monge_1dgrid_loss(A, verbose=False, log=False): +def dmmot_monge_ddgrid_loss(A, verbose=False, log=False): r""" Compute the discrete multi-marginal optimal transport of distributions A. @@ -132,7 +132,7 @@ def dmmot_monge_1dgrid_loss(A, verbose=False, log=False): See Also -------- - ot.lp.dmmot_monge_1dgrid_optimize : Optimize the d-Dimensional Earth + ot.lp.dmmot_monge_ddgrid_optimize : Optimize the d-Dimensional Earth Mover's Distance (d-MMOT) """ @@ -196,8 +196,14 @@ def dmmot_monge_1dgrid_loss(A, verbose=False, log=False): return obj -def dmmot_monge_1dgrid_optimize( - A, niters=100, lr=0.1, print_rate=100, verbose=False, log=False): +def dmmot_monge_ddgrid_optimize( + A, + niters=100, + lr_init=1e-5, + lr_decay=0.995, + print_rate=100, + verbose=False, + log=False): r"""Minimize the d-dimensional EMD using gradient descent. Discrete Multi-Marginal Optimal Transport (d-MMOT): Let :math:`a_1, \ldots, @@ -241,7 +247,7 @@ def dmmot_monge_1dgrid_optimize( \end{align} Using these dual variables naturally provided by the algorithm in - ot.lp.dmmot_monge_1dgrid_loss, gradient steps move each input distribution + ot.lp.dmmot_monge_ddgrid_loss, gradient steps move each input distribution to minimize their d-mmot distance. Parameters @@ -250,8 +256,10 @@ def dmmot_monge_1dgrid_optimize( The input ndarray containing distributions of n bins in d dimensions. niters : int, optional (default=100) The maximum number of iterations for the optimization algorithm. - lr : float, optional (default=0.1) - The learning rate (step size) for the optimization algorithm. + lr_init : float, optional (default=1e-5) + The initial learning rate (step size) for the optimization algorithm. + lr_decay : float, optional (default=0.995) + The learning rate decay rate in each iteration. print_rate : int, optional (default=100) The rate at which to print the objective value and gradient norm during the optimization algorithm. @@ -282,7 +290,7 @@ def dmmot_monge_1dgrid_optimize( See Also -------- - ot.lp.dmmot_monge_1dgrid_loss: d-Dimensional Earth Mover's Solver + ot.lp.dmmot_monge_ddgrid_loss: d-Dimensional Earth Mover's Solver """ # function body here @@ -291,7 +299,7 @@ def dmmot_monge_1dgrid_optimize( n, d = A.shape # n is dim, d is n_hists def dualIter(A, lr): - funcval, log_dict = dmmot_monge_1dgrid_loss( + funcval, log_dict = dmmot_monge_ddgrid_loss( A, verbose=verbose, log=True) grad = np.column_stack(log_dict['dual']) A_new = np.reshape(A, (n, d)) - grad * lr @@ -308,6 +316,8 @@ def renormalize(A): def listify(A): return [A[:, i] for i in range(A.shape[1])] + lr = lr_init + funcval, _, grad, log_dict = dualIter(A, lr) gn = np.linalg.norm(grad) @@ -322,6 +332,8 @@ def listify(A): if i % print_rate == 0: print(f'Iter {i:2.0f}:\tObj:\t{funcval:.4f}\tGradNorm:\t{gn:.4f}') + lr *= lr_decay + A = renormalize(A) a = listify(A) diff --git a/test/test_dmmot.py b/test/test_dmmot.py index 915078eaa..e62e77400 100644 --- a/test/test_dmmot.py +++ b/test/test_dmmot.py @@ -20,11 +20,11 @@ def create_test_data(nx): return A, x -def test_dmmot_monge_1dgrid_loss(nx): +def test_dmmot_monge_ddgrid_loss(nx): A, x = create_test_data(nx) - # Compute loss using dmmot_monge_1dgrid_loss - primal_obj = ot.lp.dmmot_monge_1dgrid_loss(A) + # Compute loss using dmmot_monge_ddgrid_loss + primal_obj = ot.lp.dmmot_monge_ddgrid_loss(A) primal_obj = nx.to_numpy(primal_obj) expected_primal_obj = 0.13667759626298503 @@ -53,12 +53,15 @@ def test_dmmot_monge_1dgrid_loss(nx): Expected different primal objective value") -def test_dmmot_monge_1dgrid_optimize(nx): +def test_dmmot_monge_ddgrid_optimize(nx): # test discrete_mmot_converge result A, _ = create_test_data(nx) d = 2 niters = 10 - result = ot.lp.dmmot_monge_1dgrid_optimize(A, niters, 0.001, 5) + result = ot.lp.dmmot_monge_ddgrid_optimize(A, + niters, + lr_init=1e-3, + lr_decay=1) expected_obj = np.array([[0.05553516, 0.13082618, 0.27327479, 0.54036388], [0.04185365, 0.09570724, 0.24384705, 0.61859206]]) From f531b9eb35ecf3ddbc5377b72e6eb91b76dd81fa Mon Sep 17 00:00:00 2001 From: x12hengyu Date: Wed, 26 Jul 2023 16:26:14 -0500 Subject: [PATCH 24/28] add more examples, ground cost options, test for uniqueness --- .../others/{plot_d-mmot.py => plot_dmmot.py} | 44 +++-- examples/others/plot_dmmot_cost.py | 160 ++++++++++++++++++ ot/lp/dmmot.py | 67 ++++++-- test/test_dmmot.py | 12 +- 4 files changed, 242 insertions(+), 41 deletions(-) rename examples/others/{plot_d-mmot.py => plot_dmmot.py} (79%) create mode 100644 examples/others/plot_dmmot_cost.py diff --git a/examples/others/plot_d-mmot.py b/examples/others/plot_dmmot.py similarity index 79% rename from examples/others/plot_d-mmot.py rename to examples/others/plot_dmmot.py index 5ed923bd2..1548ba470 100644 --- a/examples/others/plot_d-mmot.py +++ b/examples/others/plot_dmmot.py @@ -49,24 +49,22 @@ weights = np.ones(d) / d l2_bary = A.dot(weights) -# print('LP Iterations:') -# alpha = 1 # /d # 0<=alpha<=1 -# weights = np.array(d * [alpha]) -# lp_bary, lp_log = ot.lp.barycenter( -# A, M, weights, solver='interior-point', verbose=False, log=True) -# print('Time\t: ', ot.toc('')) -# print('Obj\t: ', lp_log['fun']) +print('LP Iterations:') +weights = np.ones(d) / d +lp_bary, lp_log = ot.lp.barycenter( + A, M, weights, solver='interior-point', verbose=False, log=True) +print('Time\t: ', ot.toc('')) +print('Obj\t: ', lp_log['fun']) print('') print('Discrete MMOT Algorithm:') ot.tic() -barys, log = ot.lp.dmmot_monge_ddgrid_optimize( +barys, log = ot.lp.dmmot_monge_1dgrid_optimize( A, niters=4000, lr_init=1e-5, lr_decay=0.997, log=True) dmmot_obj = log['primal objective'] print('Time\t: ', ot.toc('')) print('Obj\t: ', dmmot_obj) - # %% # Compare Barycenters in both methods # ----- @@ -77,13 +75,14 @@ else: continue # pl.plot(x, barys[i], 'g-*') -# pl.plot(x, lp_bary, 'k-', label='LP Barycenter') -pl.plot(x, l2_bary, 'k', label='L2 Barycenter') +pl.plot(x, lp_bary, label='LP Barycenter') +pl.plot(x, l2_bary, label='L2 Barycenter') pl.plot(x, a1, 'b', label='Source distribution') pl.plot(x, a2, 'r', label='Target distribution') -pl.title('Barycenters') +pl.title('Monge Cost: Barycenters from LP Solver and dmmot solver') pl.legend() + # %% # More than 2 distributions # -------------------------------------------------- @@ -116,7 +115,7 @@ # values cannot be compared. # Perform gradient descent optimization using the d-MMOT method. -barys = ot.lp.dmmot_monge_ddgrid_optimize( +barys = ot.lp.dmmot_monge_1dgrid_optimize( A, niters=3000, lr_init=1e-4, lr_decay=0.997) # after minimization, any distribution can be used as a estimate of barycenter. @@ -125,8 +124,8 @@ # Compute 1D Wasserstein barycenter using the L2/LP method weights = ot.unif(d) l2_bary = A.dot(weights) -# lp_bary, bary_log = ot.lp.barycenter(A, M, weights, solver='interior-point', -# verbose=True, log=True) +lp_bary, bary_log = ot.lp.barycenter(A, M, weights, solver='interior-point', + verbose=False, log=True) # %% # Compare Barycenters in both methods @@ -134,17 +133,10 @@ pl.figure(1, figsize=(6.4, 3)) pl.plot(x, bary, 'g-*', label='Discrete MMOT') pl.plot(x, l2_bary, 'k', label='L2 Barycenter') -# pl.plot(x, lp_bary, 'k-', label='LP Wasserstein') +pl.plot(x, lp_bary, 'k-', label='LP Wasserstein') pl.title('Barycenters') pl.legend() -# %% -# Compare all converged distributions -# --------- -pl.figure(1, figsize=(6.4, 3)) -for i in range(len(barys)): - pl.plot(x, barys[i], 'g', label='Discrete MMOT') - # %% # Compare with original distributions # --------- @@ -157,8 +149,10 @@ else: continue # pl.plot(x, barys[i], 'g') -pl.plot(x, l2_bary, 'k', label='L2 Barycenter') -# pl.plot(x, lp_bary, 'k-', label='LP Wasserstein') +pl.plot(x, l2_bary, 'k^', label='L2') +pl.plot(x, lp_bary, 'o', color='grey', label='LP') pl.title('Barycenters') pl.legend() +pl.show() + # %% diff --git a/examples/others/plot_dmmot_cost.py b/examples/others/plot_dmmot_cost.py new file mode 100644 index 000000000..f1873d486 --- /dev/null +++ b/examples/others/plot_dmmot_cost.py @@ -0,0 +1,160 @@ +# -*- coding: utf-8 -*- +r""" +=============================================================================== +Comparation of LP, dMMOT solvers under different Monge Matrics +=============================================================================== + +We also provided uniqueness test betweeen Entropic Regularization Barycenter, +LP Barycenter, and dMMOT Barycenter. +""" + +# Author: Xizheng Yu +# +# License: MIT License + +# %% +# Generating distributions and functions setup +# ----- +import numpy as np +import ot +import matplotlib.pyplot as pl + +n = 100 # number of bins +d = 2 + + +def monge_cost_matrix(matric): + MM = np.zeros((n, n)) + for i in range(n): + for j in range(n): + MM[i, j] = ot.lp.dmmot.ground_cost([i, j], matric) + return MM + +labels = ['monge', 'monge_mean', "monge_square", + 'monge_sqrt', 'monge_log', 'monge_exp'] +Ms = [monge_cost_matrix(label) for label in labels] + +a1 = ot.datasets.make_1D_gauss(n, m=20, s=5) +a2 = ot.datasets.make_1D_gauss(n, m=60, s=8) + +A = np.vstack((a1, a2)).T +x = np.arange(n, dtype=np.float64) +weights = np.ones(d) / d + +l2_bary = A.dot(weights) + +pl.figure(1, figsize=(6.4, 3)) +pl.plot(x, a1, 'b', label='Source distribution') +pl.plot(x, a2, 'r', label='Target distribution') +pl.legend() + + +# %% +# Minimize using monge costs in LP solver and dmmot solver +# ----- +lp_barys = [] + +for M, label in zip(Ms, labels): + lp_bary_temp, _ = ot.lp.barycenter( + A, M, weights, solver='interior-point', verbose=False, log=True) + lp_barys.append(lp_bary_temp) + +barys = ot.lp.dmmot_monge_1dgrid_optimize( + A, niters=4000, lr_init=1e-5, lr_decay=0.997) +barys_mean = ot.lp.dmmot_monge_1dgrid_optimize( + A, niters=3000, lr_init=1e-5, lr_decay=0.999, metric="monge_mean") +barys_square = ot.lp.dmmot_monge_1dgrid_optimize( + A, niters=3000, lr_init=1e-6, lr_decay=0.999, metric="monge_square") +barys_sqrt = ot.lp.dmmot_monge_1dgrid_optimize( + A, niters=3000, lr_init=1e-5, lr_decay=0.999, metric="monge_sqrt") +barys_log = ot.lp.dmmot_monge_1dgrid_optimize( + A, niters=3000, lr_init=1e-5, lr_decay=0.9995, metric="monge_log") +barys_exp = ot.lp.dmmot_monge_1dgrid_optimize( + A, niters=3000, lr_init=1e-3, lr_decay=0.999, metric="monge_exp") + +dmmot_barys = [barys[0], barys_mean[0], barys_square[0], barys_sqrt[0], + barys_log[0], barys_exp[0]] + +# %% +# Compare Barycenters with different monge costs in LP Solver +# ----- +fig, axes = pl.subplots(2, 3, figsize=(6.4, 3)) +axes = axes.ravel() + +for i in range(6): # iterate over each subplot + axes[i].plot(x, a1, 'b', label='Source distribution') + axes[i].plot(x, a2, 'r', label='Target distribution') + axes[i].plot(x, lp_barys[i], 'g-', label=labels[i]) + axes[i].set_title(labels[i]) + axes[i].set_xticklabels([]) + axes[i].set_yticklabels([]) + +fig.suptitle('LP Solver: Barycenters with Different Monge Costs') +pl.tight_layout() +pl.show() + +# %% +# Compare Barycenters with different monge costs +# ----- +fig, axes = pl.subplots(2, 3, figsize=(6.4, 3)) +axes = axes.ravel() + +for i in range(6): + axes[i].plot(x, a1, 'b', label='Source distribution') + axes[i].plot(x, a2, 'r', label='Target distribution') + axes[i].plot(x, dmmot_barys[i], 'g-', label=labels[i]) + axes[i].plot(x, l2_bary, 'k', label='L2 Bary') + axes[i].set_title(labels[i]) + axes[i].set_xticklabels([]) + axes[i].set_yticklabels([]) + +fig.suptitle('dmmot Solver: Barycenters with Different Monge Costs') +pl.tight_layout() +pl.show() + + +# %% +# Compare Barycenters with different monge costs +# ----- +pl.figure(1, figsize=(6.4, 3)) + +pl.plot(x, barys[0], label='Monge') +pl.plot(x, barys_mean[0], label='Monge Mean') +pl.plot(x, barys_square[0], label='Monge Square') +pl.plot(x, barys_sqrt[0], label='Monge Sqrt') +pl.plot(x, barys_log[0], label='Monge Log') +pl.plot(x, barys_exp[0], label='Monge Exp') + +pl.plot(x, l2_bary, 'k', label='L2 Bary') +# pl.plot(x, a1, 'b', label='Source') +# pl.plot(x, a2, 'r', label='Target') +pl.title('Barycenters of Different Monge Costs') +pl.legend() + + +# %% +# Uniqueness Test betweeen Entropic Regularization Barycenter, LP Barycenter, +# dMMOT Barycenter +# ----- +def obj(A, M, bary): + tmp = 0.0 + for x in A.T: + _, log = ot.lp.emd( + x, np.array(bary / np.sum(bary)), M, log=True) + tmp += log['cost'] + return tmp + + +def entropy_reg(A, M): + # Entropic Regularization Barycenter + bary_wass, _ = ot.bregman.barycenter(A, M, 1e-2, weights, log=True) + return bary_wass + +print("\t\tReg\t\tLP\t\tdMMOT\t") +for M, dmmot_bary, lp_bary, label in zip(Ms, dmmot_barys, lp_barys, labels): + M /= M.max() + bary_wass = entropy_reg(A, M) + print(f'{label}\t', f'{obj(A, M, bary_wass):.7f}\t', + f'{obj(A, M, lp_bary):.7f}\t', f'{obj(A, M, dmmot_bary):.7f}') + +# %% diff --git a/ot/lp/dmmot.py b/ot/lp/dmmot.py index a4d56a02f..ddef55818 100644 --- a/ot/lp/dmmot.py +++ b/ot/lp/dmmot.py @@ -12,6 +12,46 @@ from ..backend import get_backend +def ground_cost(i, metric="monge"): + r""" + Calculate cost based on selected cost function. + + Parameters + ---------- + i : list + The list of integer indexes. + metric : str, optional, (default="monge") + The cost function to use. Options: "monge", "monge_square", + "monge_sqrt", "monge_log", "monge_exp", "monge_mean". + + Returns + ------- + cost : numeric value + The ground cost of the tensor. + + See Also + -------- + ot.lp.dist_monge_max_min : Monge Cost. + """ + + if metric == "monge": + return dist_monge_max_min(i) + elif metric == "monge_square": + return dist_monge_max_min(i) ** 2 + elif metric == "monge_sqrt": + return np.sqrt(dist_monge_max_min(i)) + elif metric == "monge_log": + return np.log(dist_monge_max_min(i) + 1) + elif metric == "monge_exp": + # numerical instability + scaling_factor = 0.01 + return np.exp(scaling_factor * dist_monge_max_min(i)) + elif metric == "monge_mean": + return np.mean(dist_monge_max_min(i)) + else: + raise ValueError(f"Unknown cost function: {metric}") + + def dist_monge_max_min(i): r""" A tensor :math:c is Monge if for all valid :math:i_1, \ldots i_d and @@ -57,7 +97,7 @@ def dist_monge_max_min(i): return max(i) - min(i) -def dmmot_monge_ddgrid_loss(A, verbose=False, log=False): +def dmmot_monge_1dgrid_loss(A, metric='monge', verbose=False, log=False): r""" Compute the discrete multi-marginal optimal transport of distributions A. @@ -96,6 +136,9 @@ def dmmot_monge_ddgrid_loss(A, verbose=False, log=False): ---------- A : nx.ndarray, shape (dim, n_hists) The input ndarray containing distributions of n bins in d dimensions. + metric : str, optional, (default="monge") + The cost function to use. Options: "monge", "monge_square", + "monge_sqrt", "monge_log", "monge_exp", "monge_mean". verbose : bool, optional If True, print debugging information during execution. Default=False. log : bool, optional @@ -132,7 +175,7 @@ def dmmot_monge_ddgrid_loss(A, verbose=False, log=False): See Also -------- - ot.lp.dmmot_monge_ddgrid_optimize : Optimize the d-Dimensional Earth + ot.lp.dmmot_monge_1dgrid_optimize : Optimize the d-Dimensional Earth Mover's Distance (d-MMOT) """ @@ -156,15 +199,15 @@ def dmmot_monge_ddgrid_loss(A, verbose=False, log=False): minval = min(vals) i = vals.index(minval) xx[tuple(idx)] = minval - obj += (dist_monge_max_min(idx)) * minval + obj += (ground_cost(idx, metric)) * minval for v, j in zip(AA, idx): v[j] -= minval # oldidx = nx.copy(idx) oldidx = idx.copy() idx[i] += 1 if idx[i] < dims[i]: - temp = (dist_monge_max_min(idx) - - dist_monge_max_min(oldidx) + + temp = (ground_cost(idx, metric) - + ground_cost(oldidx, metric) + dual[i][idx[i] - 1]) dual[i][idx[i]] += temp if verbose: @@ -196,12 +239,13 @@ def dmmot_monge_ddgrid_loss(A, verbose=False, log=False): return obj -def dmmot_monge_ddgrid_optimize( +def dmmot_monge_1dgrid_optimize( A, niters=100, lr_init=1e-5, lr_decay=0.995, print_rate=100, + metric='monge', verbose=False, log=False): r"""Minimize the d-dimensional EMD using gradient descent. @@ -247,7 +291,7 @@ def dmmot_monge_ddgrid_optimize( \end{align} Using these dual variables naturally provided by the algorithm in - ot.lp.dmmot_monge_ddgrid_loss, gradient steps move each input distribution + ot.lp.dmmot_monge_1dgrid_loss, gradient steps move each input distribution to minimize their d-mmot distance. Parameters @@ -263,6 +307,9 @@ def dmmot_monge_ddgrid_optimize( print_rate : int, optional (default=100) The rate at which to print the objective value and gradient norm during the optimization algorithm. + metric : str, optional, (default="monge") + The cost function to use. Options: "monge", "monge_square", + "monge_sqrt", "monge_log", "monge_exp", "monge_mean". verbose : bool, optional If True, print debugging information during execution. Default=False. log : bool, optional @@ -290,7 +337,7 @@ def dmmot_monge_ddgrid_optimize( See Also -------- - ot.lp.dmmot_monge_ddgrid_loss: d-Dimensional Earth Mover's Solver + ot.lp.dmmot_monge_1dgrid_loss: d-Dimensional Earth Mover's Solver """ # function body here @@ -299,8 +346,8 @@ def dmmot_monge_ddgrid_optimize( n, d = A.shape # n is dim, d is n_hists def dualIter(A, lr): - funcval, log_dict = dmmot_monge_ddgrid_loss( - A, verbose=verbose, log=True) + funcval, log_dict = dmmot_monge_1dgrid_loss( + A, metric, verbose=verbose, log=True) grad = np.column_stack(log_dict['dual']) A_new = np.reshape(A, (n, d)) - grad * lr return funcval, A_new, grad, log_dict diff --git a/test/test_dmmot.py b/test/test_dmmot.py index e62e77400..fa8dc6b89 100644 --- a/test/test_dmmot.py +++ b/test/test_dmmot.py @@ -20,11 +20,11 @@ def create_test_data(nx): return A, x -def test_dmmot_monge_ddgrid_loss(nx): +def test_dmmot_monge_1dgrid_loss(nx): A, x = create_test_data(nx) - # Compute loss using dmmot_monge_ddgrid_loss - primal_obj = ot.lp.dmmot_monge_ddgrid_loss(A) + # Compute loss using dmmot_monge_1dgrid_loss + primal_obj = ot.lp.dmmot_monge_1dgrid_loss(A) primal_obj = nx.to_numpy(primal_obj) expected_primal_obj = 0.13667759626298503 @@ -43,7 +43,7 @@ def test_dmmot_monge_ddgrid_loss(nx): # deal with C-contiguous error from tensorflow backend (not sure why) x = np.ascontiguousarray(x) # compute loss - _, log = ot.lp.emd(x, np.array(bary/np.sum(bary)), M, log=True) + _, log = ot.lp.emd(x, np.array(bary / np.sum(bary)), M, log=True) ot_obj += log['cost'] np.testing.assert_allclose(primal_obj, @@ -53,12 +53,12 @@ def test_dmmot_monge_ddgrid_loss(nx): Expected different primal objective value") -def test_dmmot_monge_ddgrid_optimize(nx): +def test_dmmot_monge_1dgrid_optimize(nx): # test discrete_mmot_converge result A, _ = create_test_data(nx) d = 2 niters = 10 - result = ot.lp.dmmot_monge_ddgrid_optimize(A, + result = ot.lp.dmmot_monge_1dgrid_optimize(A, niters, lr_init=1e-3, lr_decay=1) From b3cb89660387ef2fbd276e5dc0dd744d0ba3d689 Mon Sep 17 00:00:00 2001 From: x12hengyu Date: Fri, 28 Jul 2023 12:56:20 -0500 Subject: [PATCH 25/28] remove additional experiment setting, not needed in this PR --- examples/others/plot_dmmot_cost.py | 160 ----------------------------- ot/lp/dmmot.py | 74 +++---------- 2 files changed, 13 insertions(+), 221 deletions(-) delete mode 100644 examples/others/plot_dmmot_cost.py diff --git a/examples/others/plot_dmmot_cost.py b/examples/others/plot_dmmot_cost.py deleted file mode 100644 index f1873d486..000000000 --- a/examples/others/plot_dmmot_cost.py +++ /dev/null @@ -1,160 +0,0 @@ -# -*- coding: utf-8 -*- -r""" -=============================================================================== -Comparation of LP, dMMOT solvers under different Monge Matrics -=============================================================================== - -We also provided uniqueness test betweeen Entropic Regularization Barycenter, -LP Barycenter, and dMMOT Barycenter. -""" - -# Author: Xizheng Yu -# -# License: MIT License - -# %% -# Generating distributions and functions setup -# ----- -import numpy as np -import ot -import matplotlib.pyplot as pl - -n = 100 # number of bins -d = 2 - - -def monge_cost_matrix(matric): - MM = np.zeros((n, n)) - for i in range(n): - for j in range(n): - MM[i, j] = ot.lp.dmmot.ground_cost([i, j], matric) - return MM - -labels = ['monge', 'monge_mean', "monge_square", - 'monge_sqrt', 'monge_log', 'monge_exp'] -Ms = [monge_cost_matrix(label) for label in labels] - -a1 = ot.datasets.make_1D_gauss(n, m=20, s=5) -a2 = ot.datasets.make_1D_gauss(n, m=60, s=8) - -A = np.vstack((a1, a2)).T -x = np.arange(n, dtype=np.float64) -weights = np.ones(d) / d - -l2_bary = A.dot(weights) - -pl.figure(1, figsize=(6.4, 3)) -pl.plot(x, a1, 'b', label='Source distribution') -pl.plot(x, a2, 'r', label='Target distribution') -pl.legend() - - -# %% -# Minimize using monge costs in LP solver and dmmot solver -# ----- -lp_barys = [] - -for M, label in zip(Ms, labels): - lp_bary_temp, _ = ot.lp.barycenter( - A, M, weights, solver='interior-point', verbose=False, log=True) - lp_barys.append(lp_bary_temp) - -barys = ot.lp.dmmot_monge_1dgrid_optimize( - A, niters=4000, lr_init=1e-5, lr_decay=0.997) -barys_mean = ot.lp.dmmot_monge_1dgrid_optimize( - A, niters=3000, lr_init=1e-5, lr_decay=0.999, metric="monge_mean") -barys_square = ot.lp.dmmot_monge_1dgrid_optimize( - A, niters=3000, lr_init=1e-6, lr_decay=0.999, metric="monge_square") -barys_sqrt = ot.lp.dmmot_monge_1dgrid_optimize( - A, niters=3000, lr_init=1e-5, lr_decay=0.999, metric="monge_sqrt") -barys_log = ot.lp.dmmot_monge_1dgrid_optimize( - A, niters=3000, lr_init=1e-5, lr_decay=0.9995, metric="monge_log") -barys_exp = ot.lp.dmmot_monge_1dgrid_optimize( - A, niters=3000, lr_init=1e-3, lr_decay=0.999, metric="monge_exp") - -dmmot_barys = [barys[0], barys_mean[0], barys_square[0], barys_sqrt[0], - barys_log[0], barys_exp[0]] - -# %% -# Compare Barycenters with different monge costs in LP Solver -# ----- -fig, axes = pl.subplots(2, 3, figsize=(6.4, 3)) -axes = axes.ravel() - -for i in range(6): # iterate over each subplot - axes[i].plot(x, a1, 'b', label='Source distribution') - axes[i].plot(x, a2, 'r', label='Target distribution') - axes[i].plot(x, lp_barys[i], 'g-', label=labels[i]) - axes[i].set_title(labels[i]) - axes[i].set_xticklabels([]) - axes[i].set_yticklabels([]) - -fig.suptitle('LP Solver: Barycenters with Different Monge Costs') -pl.tight_layout() -pl.show() - -# %% -# Compare Barycenters with different monge costs -# ----- -fig, axes = pl.subplots(2, 3, figsize=(6.4, 3)) -axes = axes.ravel() - -for i in range(6): - axes[i].plot(x, a1, 'b', label='Source distribution') - axes[i].plot(x, a2, 'r', label='Target distribution') - axes[i].plot(x, dmmot_barys[i], 'g-', label=labels[i]) - axes[i].plot(x, l2_bary, 'k', label='L2 Bary') - axes[i].set_title(labels[i]) - axes[i].set_xticklabels([]) - axes[i].set_yticklabels([]) - -fig.suptitle('dmmot Solver: Barycenters with Different Monge Costs') -pl.tight_layout() -pl.show() - - -# %% -# Compare Barycenters with different monge costs -# ----- -pl.figure(1, figsize=(6.4, 3)) - -pl.plot(x, barys[0], label='Monge') -pl.plot(x, barys_mean[0], label='Monge Mean') -pl.plot(x, barys_square[0], label='Monge Square') -pl.plot(x, barys_sqrt[0], label='Monge Sqrt') -pl.plot(x, barys_log[0], label='Monge Log') -pl.plot(x, barys_exp[0], label='Monge Exp') - -pl.plot(x, l2_bary, 'k', label='L2 Bary') -# pl.plot(x, a1, 'b', label='Source') -# pl.plot(x, a2, 'r', label='Target') -pl.title('Barycenters of Different Monge Costs') -pl.legend() - - -# %% -# Uniqueness Test betweeen Entropic Regularization Barycenter, LP Barycenter, -# dMMOT Barycenter -# ----- -def obj(A, M, bary): - tmp = 0.0 - for x in A.T: - _, log = ot.lp.emd( - x, np.array(bary / np.sum(bary)), M, log=True) - tmp += log['cost'] - return tmp - - -def entropy_reg(A, M): - # Entropic Regularization Barycenter - bary_wass, _ = ot.bregman.barycenter(A, M, 1e-2, weights, log=True) - return bary_wass - -print("\t\tReg\t\tLP\t\tdMMOT\t") -for M, dmmot_bary, lp_bary, label in zip(Ms, dmmot_barys, lp_barys, labels): - M /= M.max() - bary_wass = entropy_reg(A, M) - print(f'{label}\t', f'{obj(A, M, bary_wass):.7f}\t', - f'{obj(A, M, lp_bary):.7f}\t', f'{obj(A, M, dmmot_bary):.7f}') - -# %% diff --git a/ot/lp/dmmot.py b/ot/lp/dmmot.py index ddef55818..2e102f1d9 100644 --- a/ot/lp/dmmot.py +++ b/ot/lp/dmmot.py @@ -11,47 +11,6 @@ import numpy as np from ..backend import get_backend - -def ground_cost(i, metric="monge"): - r""" - Calculate cost based on selected cost function. - - Parameters - ---------- - i : list - The list of integer indexes. - metric : str, optional, (default="monge") - The cost function to use. Options: "monge", "monge_square", - "monge_sqrt", "monge_log", "monge_exp", "monge_mean". - - Returns - ------- - cost : numeric value - The ground cost of the tensor. - - See Also - -------- - ot.lp.dist_monge_max_min : Monge Cost. - """ - - if metric == "monge": - return dist_monge_max_min(i) - elif metric == "monge_square": - return dist_monge_max_min(i) ** 2 - elif metric == "monge_sqrt": - return np.sqrt(dist_monge_max_min(i)) - elif metric == "monge_log": - return np.log(dist_monge_max_min(i) + 1) - elif metric == "monge_exp": - # numerical instability - scaling_factor = 0.01 - return np.exp(scaling_factor * dist_monge_max_min(i)) - elif metric == "monge_mean": - return np.mean(dist_monge_max_min(i)) - else: - raise ValueError(f"Unknown cost function: {metric}") - - def dist_monge_max_min(i): r""" A tensor :math:c is Monge if for all valid :math:i_1, \ldots i_d and @@ -84,9 +43,9 @@ def dist_monge_max_min(i): References ---------- - .. [51] Jeffery Kline. Properties of the d-dimensional earth mover's + .. [56] Jeffery Kline. Properties of the d-dimensional earth mover's problem. Discrete Applied Mathematics, 265: 128-141, 2019. - .. [53] Wolfgang W. Bein, Peter Brucker, James K. Park, and Pramod K. + .. [57] Wolfgang W. Bein, Peter Brucker, James K. Park, and Pramod K. Pathak. A monge property for the d- dimensional transportation problem. Discrete Applied Mathematics, 58(2):97-109, 1995. ISSN 0166-218X. doi: https://doi.org/10.1016/0166-218X(93)E0121-E. URL @@ -97,7 +56,7 @@ def dist_monge_max_min(i): return max(i) - min(i) -def dmmot_monge_1dgrid_loss(A, metric='monge', verbose=False, log=False): +def dmmot_monge_1dgrid_loss(A, verbose=False, log=False): r""" Compute the discrete multi-marginal optimal transport of distributions A. @@ -136,9 +95,6 @@ def dmmot_monge_1dgrid_loss(A, metric='monge', verbose=False, log=False): ---------- A : nx.ndarray, shape (dim, n_hists) The input ndarray containing distributions of n bins in d dimensions. - metric : str, optional, (default="monge") - The cost function to use. Options: "monge", "monge_square", - "monge_sqrt", "monge_log", "monge_exp", "monge_mean". verbose : bool, optional If True, print debugging information during execution. Default=False. log : bool, optional @@ -164,13 +120,13 @@ def dmmot_monge_1dgrid_loss(A, metric='monge', verbose=False, log=False): References ---------- - .. [50] Ronak Mehta, Jeffery Kline, Vishnu Suresh Lokhande, Glenn Fung, & + .. [55] Ronak Mehta, Jeffery Kline, Vishnu Suresh Lokhande, Glenn Fung, & Vikas Singh (2023). Efficient Discrete Multi Marginal Optimal Transport Regularization. In The Eleventh International Conference on Learning Representations. - .. [51] Jeffery Kline. Properties of the d-dimensional earth mover's + .. [56] Jeffery Kline. Properties of the d-dimensional earth mover's problem. Discrete Applied Mathematics, 265: 128-141, 2019. - .. [52] Leonid V Kantorovich. On the translocation of masses. Dokl. Akad. + .. [58] Leonid V Kantorovich. On the translocation of masses. Dokl. Akad. Nauk SSSR, 37:227-229, 1942. See Also @@ -199,15 +155,15 @@ def dmmot_monge_1dgrid_loss(A, metric='monge', verbose=False, log=False): minval = min(vals) i = vals.index(minval) xx[tuple(idx)] = minval - obj += (ground_cost(idx, metric)) * minval + obj += (dist_monge_max_min(idx)) * minval for v, j in zip(AA, idx): v[j] -= minval # oldidx = nx.copy(idx) oldidx = idx.copy() idx[i] += 1 if idx[i] < dims[i]: - temp = (ground_cost(idx, metric) - - ground_cost(oldidx, metric) + + temp = (dist_monge_max_min(idx) - + dist_monge_max_min(oldidx) + dual[i][idx[i] - 1]) dual[i][idx[i]] += temp if verbose: @@ -245,7 +201,6 @@ def dmmot_monge_1dgrid_optimize( lr_init=1e-5, lr_decay=0.995, print_rate=100, - metric='monge', verbose=False, log=False): r"""Minimize the d-dimensional EMD using gradient descent. @@ -307,9 +262,6 @@ def dmmot_monge_1dgrid_optimize( print_rate : int, optional (default=100) The rate at which to print the objective value and gradient norm during the optimization algorithm. - metric : str, optional, (default="monge") - The cost function to use. Options: "monge", "monge_square", - "monge_sqrt", "monge_log", "monge_exp", "monge_mean". verbose : bool, optional If True, print debugging information during execution. Default=False. log : bool, optional @@ -325,13 +277,13 @@ def dmmot_monge_1dgrid_optimize( References ---------- - .. [50] Ronak Mehta, Jeffery Kline, Vishnu Suresh Lokhande, Glenn Fung, & + .. [55] Ronak Mehta, Jeffery Kline, Vishnu Suresh Lokhande, Glenn Fung, & Vikas Singh (2023). Efficient Discrete Multi Marginal Optimal Transport Regularization. In The Eleventh International Conference on Learning Representations. - .. [54] Olvi L Mangasarian and RR Meyer. Nonlinear perturbation of linear + .. [60] Olvi L Mangasarian and RR Meyer. Nonlinear perturbation of linear programs. SIAM Journal on Control and Optimization, 17(6):745-752, 1979 - .. [55] Michael C Ferris and Olvi L Mangasarian. Finite perturbation of + .. [59] Michael C Ferris and Olvi L Mangasarian. Finite perturbation of convex programs. Applied Mathematics and Optimization, 23(1):263-273, 1991. @@ -347,7 +299,7 @@ def dmmot_monge_1dgrid_optimize( def dualIter(A, lr): funcval, log_dict = dmmot_monge_1dgrid_loss( - A, metric, verbose=verbose, log=True) + A, verbose=verbose, log=True) grad = np.column_stack(log_dict['dual']) A_new = np.reshape(A, (n, d)) - grad * lr return funcval, A_new, grad, log_dict From 2d22fc9ed1501e8986b24b0f1fbf047ab7c0ffc0 Mon Sep 17 00:00:00 2001 From: x12hengyu Date: Sat, 29 Jul 2023 15:55:20 -0500 Subject: [PATCH 26/28] fixed line 14 1 blank line --- ot/lp/dmmot.py | 1 + 1 file changed, 1 insertion(+) diff --git a/ot/lp/dmmot.py b/ot/lp/dmmot.py index 2e102f1d9..6e2099e0e 100644 --- a/ot/lp/dmmot.py +++ b/ot/lp/dmmot.py @@ -11,6 +11,7 @@ import numpy as np from ..backend import get_backend + def dist_monge_max_min(i): r""" A tensor :math:c is Monge if for all valid :math:i_1, \ldots i_d and From a7bde6642fe038fc50ebdd328322db3f496080ed Mon Sep 17 00:00:00 2001 From: x12hengyu Date: Wed, 2 Aug 2023 14:32:09 -0500 Subject: [PATCH 27/28] fix gradient computation link --- ot/lp/dmmot.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/ot/lp/dmmot.py b/ot/lp/dmmot.py index 6e2099e0e..bc12d4d49 100644 --- a/ot/lp/dmmot.py +++ b/ot/lp/dmmot.py @@ -47,7 +47,7 @@ def dist_monge_max_min(i): .. [56] Jeffery Kline. Properties of the d-dimensional earth mover's problem. Discrete Applied Mathematics, 265: 128-141, 2019. .. [57] Wolfgang W. Bein, Peter Brucker, James K. Park, and Pramod K. - Pathak. A monge property for the d- dimensional transportation problem. + Pathak. A monge property for the d-dimensional transportation problem. Discrete Applied Mathematics, 58(2):97-109, 1995. ISSN 0166-218X. doi: https://doi.org/10.1016/0166-218X(93)E0121-E. URL https://www.sciencedirect.com/ science/article/pii/0166218X93E0121E. @@ -137,6 +137,7 @@ def dmmot_monge_1dgrid_loss(A, verbose=False, log=False): """ nx = get_backend(A) + A_copy = nx.copy(A) A = nx.to_numpy(A) AA = [np.copy(A[:, j]) for j in range(A.shape[1])] @@ -188,7 +189,7 @@ def dmmot_monge_1dgrid_loss(A, verbose=False, log=False): 'dual objective': dualobj} # define forward/backward relations for pytorch - obj = nx.set_gradients(obj, (nx.from_numpy(A)), (dual)) + obj = nx.set_gradients(obj, (A_copy), (dual)) if log: return obj, log_dict From 24a69c0d8f6b33dc684d899be44f6863e0a32b8c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Thu, 3 Aug 2023 08:56:19 +0200 Subject: [PATCH 28/28] Update ot/lp/dmmot.py Store input variable instead of copying it --- ot/lp/dmmot.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ot/lp/dmmot.py b/ot/lp/dmmot.py index bc12d4d49..8576c3c61 100644 --- a/ot/lp/dmmot.py +++ b/ot/lp/dmmot.py @@ -137,7 +137,7 @@ def dmmot_monge_1dgrid_loss(A, verbose=False, log=False): """ nx = get_backend(A) - A_copy = nx.copy(A) + A_copy = A A = nx.to_numpy(A) AA = [np.copy(A[:, j]) for j in range(A.shape[1])]