diff --git a/RELEASES.md b/RELEASES.md index c7e3f598b..106042af2 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -1,12 +1,14 @@ # Releases -## 0.9.3dev +## 0.9.4dev #### New features -+ `ot.gromov._gw.solve_gromov_linesearch` now has an argument to specifify if the matrices are symmetric in which case the computation can be done faster. ++ `ot.gromov._gw.solve_gromov_linesearch` now has an argument to specify if the matrices are symmetric in which case the computation can be done faster (PR #607). ++ Continuous entropic mapping (PR #613) ++ New general unbalanced solvers for `ot.solve` and BFGS solver and illustrative example (PR #620) ++ Add gradient computation with envelope theorem to sinkhorn solver of `ot.solve` with `grad='envelope'` (PR #605). #### Closed issues -- Fixed an issue with cost correction for mismatched labels in `ot.da.BaseTransport` fit methods. This fix addresses the original issue introduced PR #587 (PR #593) - Fix gpu compatibility of sr(F)GW solvers when `G0 is not None`(PR #596) - Fix doc and example for lowrank sinkhorn (PR #601) - Fix issue with empty weights for `ot.emd2` (PR #606, Issue #534) @@ -14,6 +16,14 @@ - Fix same sign error for sr(F)GW conditional gradient solvers (PR #611) - Split `test/test_gromov.py` into `test/gromov/` (PR #619) +## 0.9.3 +*January 2024* + + +#### Closed issues +- Fixed an issue with cost correction for mismatched labels in `ot.da.BaseTransport` fit methods. This fix addresses the original issue introduced PR #587 (PR #593) + + ## 0.9.2 *December 2023* diff --git a/examples/plot_solve_variants.py b/examples/plot_solve_variants.py new file mode 100644 index 000000000..82f892a52 --- /dev/null +++ b/examples/plot_solve_variants.py @@ -0,0 +1,150 @@ +# -*- coding: utf-8 -*- +""" +====================================== +Optimal Transport solvers comparison +====================================== + +This example illustrates the solutions returns for diffrent variants of exact, +regularized and unbalanced OT solvers. +""" + +# Author: Remi Flamary +# +# License: MIT License +# sphinx_gallery_thumbnail_number = 3 + +#%% + +import numpy as np +import matplotlib.pylab as pl +import ot +import ot.plot +from ot.datasets import make_1D_gauss as gauss + +############################################################################## +# Generate data +# ------------- + + +#%% parameters + +n = 50 # nb bins + +# bin positions +x = np.arange(n, dtype=np.float64) + +# Gaussian distributions +a = 0.6 * gauss(n, m=15, s=5) + 0.4 * gauss(n, m=35, s=5) # m= mean, s= std +b = gauss(n, m=25, s=5) + +# loss matrix +M = ot.dist(x.reshape((n, 1)), x.reshape((n, 1))) +M /= M.max() + + +############################################################################## +# Plot distributions and loss matrix +# ---------------------------------- + +#%% plot the distributions + +pl.figure(1, figsize=(6.4, 3)) +pl.plot(x, a, 'b', label='Source distribution') +pl.plot(x, b, 'r', label='Target distribution') +pl.legend() + +#%% plot distributions and loss matrix + +pl.figure(2, figsize=(5, 5)) +ot.plot.plot1D_mat(a, b, M, 'Cost matrix M') + +############################################################################## +# Define Group lasso regularization and gradient +# ------------------------------------------------ +# The groups are the first and second half of the columns of G + + +def reg_gl(G): # group lasso + small l2 reg + G1 = G[:n // 2, :]**2 + G2 = G[n // 2:, :]**2 + gl1 = np.sum(np.sqrt(np.sum(G1, 0))) + gl2 = np.sum(np.sqrt(np.sum(G2, 0))) + return gl1 + gl2 + 0.1 * np.sum(G**2) + + +def grad_gl(G): # gradient of group lasso + small l2 reg + G1 = G[:n // 2, :] + G2 = G[n // 2:, :] + gl1 = G1 / np.sqrt(np.sum(G1**2, 0, keepdims=True) + 1e-8) + gl2 = G2 / np.sqrt(np.sum(G2**2, 0, keepdims=True) + 1e-8) + return np.concatenate((gl1, gl2), axis=0) + 0.2 * G + + +reg_type_gl = (reg_gl, grad_gl) + +# %% +# Set up parameters for solvers and solve +# --------------------------------------- + +lst_regs = ["No Reg.", "Entropic", "L2", "Group Lasso + L2"] +lst_unbalanced = ["Balanced", "Unbalanced KL", 'Unbalanced L2', 'Unb. TV (Partial)'] # ["Balanced", "Unb. KL", "Unb. L2", "Unb L1 (partial)"] + +lst_solvers = [ # name, param for ot.solve function + # balanced OT + ('Exact OT', dict()), + ('Entropic Reg. OT', dict(reg=0.005)), + ('L2 Reg OT', dict(reg=1, reg_type='l2')), + ('Group Lasso Reg. OT', dict(reg=0.1, reg_type=reg_type_gl)), + + + # unbalanced OT KL + ('Unbalanced KL No Reg.', dict(unbalanced=0.005)), + ('Unbalanced KL wit KL Reg.', dict(reg=0.0005, unbalanced=0.005, unbalanced_type='kl', reg_type='kl')), + ('Unbalanced KL with L2 Reg.', dict(reg=0.5, reg_type='l2', unbalanced=0.005, unbalanced_type='kl')), + ('Unbalanced KL with Group Lasso Reg.', dict(reg=0.1, reg_type=reg_type_gl, unbalanced=0.05, unbalanced_type='kl')), + + # unbalanced OT L2 + ('Unbalanced L2 No Reg.', dict(unbalanced=0.5, unbalanced_type='l2')), + ('Unbalanced L2 with KL Reg.', dict(reg=0.001, unbalanced=0.2, unbalanced_type='l2')), + ('Unbalanced L2 with L2 Reg.', dict(reg=0.1, reg_type='l2', unbalanced=0.2, unbalanced_type='l2')), + ('Unbalanced L2 with Group Lasso Reg.', dict(reg=0.05, reg_type=reg_type_gl, unbalanced=0.7, unbalanced_type='l2')), + + # unbalanced OT TV + ('Unbalanced TV No Reg.', dict(unbalanced=0.1, unbalanced_type='tv')), + ('Unbalanced TV with KL Reg.', dict(reg=0.001, unbalanced=0.01, unbalanced_type='tv')), + ('Unbalanced TV with L2 Reg.', dict(reg=0.1, reg_type='l2', unbalanced=0.01, unbalanced_type='tv')), + ('Unbalanced TV with Group Lasso Reg.', dict(reg=0.02, reg_type=reg_type_gl, unbalanced=0.01, unbalanced_type='tv')), + +] + +lst_plans = [] +for (name, param) in lst_solvers: + G = ot.solve(M, a, b, **param).plan + lst_plans.append(G) + +############################################################################## +# Plot plans +# ---------- + +pl.figure(3, figsize=(9, 9)) + +for i, bname in enumerate(lst_unbalanced): + for j, rname in enumerate(lst_regs): + pl.subplot(len(lst_unbalanced), len(lst_regs), i * len(lst_regs) + j + 1) + + plan = lst_plans[i * len(lst_regs) + j] + m2 = plan.sum(0) + m1 = plan.sum(1) + m1, m2 = m1 / a.max(), m2 / b.max() + pl.imshow(plan, cmap='Greys') + pl.plot(x, m2 * 10, 'r') + pl.plot(m1 * 10, x, 'b') + pl.plot(x, b / b.max() * 10, 'r', alpha=0.3) + pl.plot(a / a.max() * 10, x, 'b', alpha=0.3) + #pl.axis('off') + pl.tick_params(left=False, right=False, labelleft=False, + labelbottom=False, bottom=False) + if i == 0: + pl.title(rname) + if j == 0: + pl.ylabel(bname, fontsize=14) diff --git a/ot/__init__.py b/ot/__init__.py index 1c10efafd..609f9ff37 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -58,7 +58,7 @@ # utils functions from .utils import dist, unif, tic, toc, toq -__version__ = "0.9.3dev" +__version__ = "0.9.4dev" __all__ = ['emd', 'emd2', 'emd_1d', 'sinkhorn', 'sinkhorn2', 'utils', 'datasets', 'bregman', 'lp', 'tic', 'toc', 'toq', 'gromov', diff --git a/ot/solvers.py b/ot/solvers.py index de817d7f7..95165ea11 100644 --- a/ot/solvers.py +++ b/ot/solvers.py @@ -23,6 +23,7 @@ from .gaussian import empirical_bures_wasserstein_distance from .factored import factored_optimal_transport from .lowrank import lowrank_sinkhorn +from .optim import cg lst_method_lazy = ['1d', 'gaussian', 'lowrank', 'factored', 'geomloss', 'geomloss_auto', 'geomloss_tensorized', 'geomloss_online', 'geomloss_multiscale'] @@ -57,13 +58,15 @@ def solve(M, a=None, b=None, reg=None, reg_type="KL", unbalanced=None, Regularization weight :math:`\lambda_r`, by default None (no reg., exact OT) reg_type : str, optional - Type of regularization :math:`R` either "KL", "L2", "entropy", by default "KL" + Type of regularization :math:`R` either "KL", "L2", "entropy", + by default "KL". a tuple of functions can be provided for general + solver (see :any:`cg`). This is only used when ``reg!=None``. unbalanced : float, optional Unbalanced penalization weight :math:`\lambda_u`, by default None (balanced OT) unbalanced_type : str, optional Type of unbalanced penalization function :math:`U` either "KL", "L2", - "TV", by default "KL" + "TV", by default "KL". method : str, optional Method for solving the problem when multiple algorithms are available, default None for automatic selection. @@ -80,10 +83,10 @@ def solve(M, a=None, b=None, reg=None, reg_type="KL", unbalanced=None, verbose : bool, optional Print information in the solver, by default False grad : str, optional - Type of gradient computation, either or 'autodiff' or 'implicit' used only for + Type of gradient computation, either or 'autodiff' or 'envelope' used only for Sinkhorn solver. By default 'autodiff' provides gradients wrt all outputs (`plan, value, value_linear`) but with important memory cost. - 'implicit' provides gradients only for `value` and and other outputs are + 'envelope' provides gradients only for `value` and and other outputs are detached. This is useful for memory saving when only the value is needed. Returns @@ -140,13 +143,13 @@ def solve(M, a=None, b=None, reg=None, reg_type="KL", unbalanced=None, # or for original Sinkhorn paper formulation [2] res = ot.solve(M, a, b, reg=1.0, reg_type='entropy') - # Use implicit differentiation for memory saving - res = ot.solve(M, a, b, reg=1.0, grad='implicit') # M, a, b are torch tensors + # Use envelope theorem differentiation for memory saving + res = ot.solve(M, a, b, reg=1.0, grad='envelope') # M, a, b are torch tensors res.value.backward() # only the value is differentiable Note that by default the Sinkhorn solver uses automatic differentiation to compute the gradients of the values and plan. This can be changed with the - `grad` parameter. The `implicit` mode computes the implicit gradients only + `grad` parameter. The `envelope` mode computes the gradients only for the value and the other outputs are detached. This is useful for memory saving when only the gradient of value is needed. @@ -311,9 +314,22 @@ def solve(M, a=None, b=None, reg=None, reg_type="KL", unbalanced=None, if unbalanced is None: # Balanced regularized OT - if reg_type.lower() in ['entropy', 'kl']: + if isinstance(reg_type, tuple): # general solver + + if max_iter is None: + max_iter = 1000 + if tol is None: + tol = 1e-9 + + plan, log = cg(a, b, M, reg=reg, f=reg_type[0], df=reg_type[1], numItermax=max_iter, stopThr=tol, log=True, verbose=verbose, G0=plan_init) + + value_linear = nx.sum(M * plan) + value = log['loss'][-1] + potentials = (log['u'], log['v']) + + elif reg_type.lower() in ['entropy', 'kl']: - if grad == 'implicit': # if implicit then detach the input + if grad == 'envelope': # if envelope then detach the input M0, a0, b0 = M, a, b M, a, b = nx.detach(M, a, b) @@ -336,7 +352,7 @@ def solve(M, a=None, b=None, reg=None, reg_type="KL", unbalanced=None, potentials = (log['log_u'], log['log_v']) - if grad == 'implicit': # set the gradient at convergence + if grad == 'envelope': # set the gradient at convergence value = nx.set_gradients(value, (M0, a0, b0), (plan, reg * (potentials[0] - potentials[0].mean()), reg * (potentials[1] - potentials[1].mean()))) @@ -359,7 +375,7 @@ def solve(M, a=None, b=None, reg=None, reg_type="KL", unbalanced=None, else: # unbalanced AND regularized OT - if reg_type.lower() in ['kl'] and unbalanced_type.lower() == 'kl': + if not isinstance(reg_type, tuple) and reg_type.lower() in ['kl'] and unbalanced_type.lower() == 'kl': if max_iter is None: max_iter = 1000 @@ -374,14 +390,16 @@ def solve(M, a=None, b=None, reg=None, reg_type="KL", unbalanced=None, potentials = (log['logu'], log['logv']) - elif reg_type.lower() in ['kl', 'l2', 'entropy'] and unbalanced_type.lower() in ['kl', 'l2']: + elif (isinstance(reg_type, tuple) or reg_type.lower() in ['kl', 'l2', 'entropy']) and unbalanced_type.lower() in ['kl', 'l2', 'tv']: if max_iter is None: max_iter = 1000 if tol is None: tol = 1e-12 + if isinstance(reg_type, str): + reg_type = reg_type.lower() - plan, log = lbfgsb_unbalanced(a, b, M, reg=reg, reg_m=unbalanced, reg_div=reg_type.lower(), regm_div=unbalanced_type.lower(), numItermax=max_iter, stopThr=tol, verbose=verbose, log=True) + plan, log = lbfgsb_unbalanced(a, b, M, reg=reg, reg_m=unbalanced, reg_div=reg_type, regm_div=unbalanced_type.lower(), numItermax=max_iter, stopThr=tol, verbose=verbose, log=True, G0=plan_init) value_linear = nx.sum(M * plan) @@ -962,10 +980,10 @@ def solve_sample(X_a, X_b, a=None, b=None, metric='sqeuclidean', reg=None, reg_t verbose : bool, optional Print information in the solver, by default False grad : str, optional - Type of gradient computation, either or 'autodiff' or 'implicit' used only for + Type of gradient computation, either or 'autodiff' or 'envelope' used only for Sinkhorn solver. By default 'autodiff' provides gradients wrt all outputs (`plan, value, value_linear`) but with important memory cost. - 'implicit' provides gradients only for `value` and and other outputs are + 'envelope' provides gradients only for `value` and and other outputs are detached. This is useful for memory saving when only the value is needed. Returns @@ -1034,13 +1052,13 @@ def solve_sample(X_a, X_b, a=None, b=None, metric='sqeuclidean', reg=None, reg_t # lazy OT plan lazy_plan = res.lazy_plan - # Use implicit differentiation for memory saving - res = ot.solve_sample(xa, xb, a, b, reg=1.0, grad='implicit') + # Use envelope theorem differentiation for memory saving + res = ot.solve_sample(xa, xb, a, b, reg=1.0, grad='envelope') res.value.backward() # only the value is differentiable Note that by default the Sinkhorn solver uses automatic differentiation to compute the gradients of the values and plan. This can be changed with the - `grad` parameter. The `implicit` mode computes the implicit gradients only + `grad` parameter. The `envelope` mode computes the gradients only for the value and the other outputs are detached. This is useful for memory saving when only the gradient of value is needed. diff --git a/ot/unbalanced.py b/ot/unbalanced.py index 73667b324..c39888a31 100644 --- a/ot/unbalanced.py +++ b/ot/unbalanced.py @@ -1432,6 +1432,9 @@ def grad_entropy(G): elif reg_div == 'entropy': reg_fun = reg_entropy grad_reg_fun = grad_entropy + elif isinstance(reg_div, tuple): + reg_fun = reg_div[0] + grad_reg_fun = reg_div[1] else: reg_fun = reg_l2 grad_reg_fun = grad_l2 @@ -1451,9 +1454,20 @@ def grad_marg_kl(G): return reg_m1 * np.outer(np.log(G.sum(1) / a + 1e-16), np.ones(n)) + \ reg_m2 * np.outer(np.ones(m), np.log(G.sum(0) / b + 1e-16)) + def marg_tv(G): + return reg_m1 * np.sum(np.abs(G.sum(1) - a)) + \ + reg_m2 * np.sum(np.abs(G.sum(0) - b)) + + def grad_marg_tv(G): + return reg_m1 * np.outer(np.sign(G.sum(1) - a), np.ones(n)) + \ + reg_m2 * np.outer(np.ones(m), np.sign(G.sum(0) - b)) + if regm_div == 'kl': regm_fun = marg_kl grad_regm_fun = grad_marg_kl + elif regm_div == 'tv': + regm_fun = marg_tv + grad_regm_fun = grad_marg_tv else: regm_fun = marg_l2 grad_regm_fun = grad_marg_l2 @@ -1518,7 +1532,10 @@ def lbfgsb_unbalanced(a, b, M, reg, reg_m, c=None, reg_div='kl', regm_div='kl', reg_div: string, optional Divergence used for regularization. Can take three values: 'entropy' (negative entropy), or - 'kl' (Kullback-Leibler) or 'l2' (quadratic). + 'kl' (Kullback-Leibler) or 'l2' (quadratic) or a tuple + of two calable functions returning the reg term and its derivative. + Note that the callable functions should be able to handle numpy arrays + and not tesors from the backend regm_div: string, optional Divergence to quantify the difference between the marginals. Can take two values: 'kl' (Kullback-Leibler) or 'l2' (quadratic) @@ -1574,6 +1591,23 @@ def lbfgsb_unbalanced(a, b, M, reg, reg_m, c=None, reg_div='kl', regm_div='kl', G0 = np.zeros(M.shape) if G0 is None else nx.to_numpy(G0) c = a[:, None] * b[None, :] if c is None else nx.to_numpy(c) + # wrap the callable function to handle numpy arrays + if isinstance(reg_div, tuple): + f0, df0 = reg_div + try: + f0(G0) + df0(G0) + except BaseException: + warnings.warn("The callable functions should be able to handle numpy arrays, wrapper ar added to handle this which comes with overhead") + + def f(x): + return nx.to_numpy(f0(nx.from_numpy(x, type_as=M0))) + + def df(x): + return nx.to_numpy(df0(nx.from_numpy(x, type_as=M0))) + + reg_div = (f, df) + reg_m1, reg_m2 = get_parameter_pair(reg_m) _func = _get_loss_unbalanced(a, b, c, M, reg, reg_m1, reg_m2, reg_div, regm_div) diff --git a/test/test_solvers.py b/test/test_solvers.py index 168b111e4..16e6df295 100644 --- a/test/test_solvers.py +++ b/test/test_solvers.py @@ -14,8 +14,9 @@ from ot.bregman import geomloss from ot.backend import torch + lst_reg = [None, 1] -lst_reg_type = ['KL', 'entropy', 'L2'] +lst_reg_type = ['KL', 'entropy', 'L2', 'tuple'] lst_unbalanced = [None, 0.9] lst_unbalanced_type = ['KL', 'L2', 'TV'] @@ -109,7 +110,7 @@ def test_solve(nx): @pytest.mark.skipif(not torch, reason="torch no installed") -def test_solve_implicit(): +def test_solve_envelope(): n_samples_s = 10 n_samples_t = 7 @@ -126,7 +127,7 @@ def test_solve_implicit(): b = torch.tensor(b, requires_grad=True) M = torch.tensor(M, requires_grad=True) - sol0 = ot.solve(M, a, b, reg=10, grad='implicit') + sol0 = ot.solve(M, a, b, reg=10, grad='envelope') sol0.value.backward() gM0 = M.grad.clone() @@ -166,6 +167,15 @@ def test_solve_grid(nx, reg, reg_type, unbalanced, unbalanced_type): try: + if reg_type == 'tuple': + def f(G): + return np.sum(G**2) + + def df(G): + return 2 * G + + reg_type = (f, df) + # solve unif weights sol0 = ot.solve(M, reg=reg, reg_type=reg_type, unbalanced=unbalanced, unbalanced_type=unbalanced_type) @@ -176,9 +186,20 @@ def test_solve_grid(nx, reg, reg_type, unbalanced, unbalanced_type): # solve in backend ab, bb, Mb = nx.from_numpy(a, b, M) - solb = ot.solve(M, a, b, reg=reg, reg_type=reg_type, unbalanced=unbalanced, unbalanced_type=unbalanced_type) + + if isinstance(reg_type, tuple): + def f(G): + return nx.sum(G**2) + + def df(G): + return 2 * G + + reg_type = (f, df) + + solb = ot.solve(Mb, ab, bb, reg=reg, reg_type=reg_type, unbalanced=unbalanced, unbalanced_type=unbalanced_type) assert_allclose_sol(sol, solb) + except NotImplementedError: pytest.skip("Not implemented") @@ -201,10 +222,6 @@ def test_solve_not_implemented(nx): with pytest.raises(NotImplementedError): ot.solve(M, unbalanced=1.0, unbalanced_type='cryptic divergence') - # pairs of incompatible divergences - with pytest.raises(NotImplementedError): - ot.solve(M, reg=1.0, reg_type='kl', unbalanced=1.0, unbalanced_type='tv') - def test_solve_gromov(nx):