-
Notifications
You must be signed in to change notification settings - Fork 541
Description
Describe the bug
As far as I can tell, ot.bregman.empirical_sinkhorn_divergence returns only the transport costs and ignores the regularization costs. The documentation says it returns both. This was also mentioned in #255, but the concern in that issue was whether it matches some papers (and there have been a lot of papers!).
My concern here is that the documentation of ot.bregman.empirical_sinkhorn_divergence gives a precise formula for what it does, but then the function itself appears to compute something quite different.
To Reproduce
Steps to reproduce the behavior:
- Run code below.
- Feel confused.
Code sample
import ot
import numpy as np
import scipy as sp
# setup problem
ptsA=np.r_[0:5:10j][:,None]
ptsB=np.r_[1:3:20j][:,None]
eps=1
# get distance matrices
C1=sp.spatial.distance.cdist(ptsA,ptsB)**2
C2=sp.spatial.distance.cdist(ptsA,ptsA)**2
C3=sp.spatial.distance.cdist(ptsB,ptsB)**2
# get transport plans
pot_plan1=ot.bregman.empirical_sinkhorn(ptsA,ptsB,eps)
pot_plan2=ot.bregman.empirical_sinkhorn(ptsA,ptsA,eps)
pot_plan3=ot.bregman.empirical_sinkhorn(ptsB,ptsB,eps)
# compute transport costs for sinkhorn divergence
transport_costs=np.sum(C1*pot_plan1)-.5*np.sum(C2*pot_plan2)-.5*np.sum(C3*pot_plan3)
# compute entropic costs for sinkhorn divergence
entropic_costs = np.sum(pot_plan1*np.log(pot_plan1))-.5*np.sum(pot_plan2*np.log(pot_plan2))-.5*np.sum(pot_plan3*np.log(pot_plan3))
# print results
print('transport costs'.rjust(30),transport_costs)
print('entropic costs'.rjust(30),entropic_costs)
print('sinkhorn divergence'.rjust(30),transport_costs+entropic_costs)
# compare with results form ot
print('result from ot'.rjust(30),ot.bregman.empirical_sinkhorn_divergence(ptsA,ptsB,eps))
Expected behavior
I expect the result from ot.bregman.empirical_sinkhorn_divergence to be the same as the sinkhorn divergence as I have calculated it. Instead, it seems to be identical to only the transport-cost portion of the divergence.
Environment (please complete the following information):
- OS (e.g. MacOS, Windows, Linux): Ubuntu
- Python version: 3.9.7
- How was POT installed (source,
pip,conda): pip
Linux-5.10.47-linuxkit-x86_64-with-glibc2.31
Python 3.9.7 | packaged by conda-forge | (default, Sep 29 2021, 19:20:46)
[GCC 9.4.0]
NumPy 1.20.3
SciPy 1.7.2
POT 0.8.2