Skip to content

sinkhorn divergence appears to be calculated incorrectly #383

@jacksonloper

Description

@jacksonloper

Describe the bug

As far as I can tell, ot.bregman.empirical_sinkhorn_divergence returns only the transport costs and ignores the regularization costs. The documentation says it returns both. This was also mentioned in #255, but the concern in that issue was whether it matches some papers (and there have been a lot of papers!).

My concern here is that the documentation of ot.bregman.empirical_sinkhorn_divergence gives a precise formula for what it does, but then the function itself appears to compute something quite different.

To Reproduce

Steps to reproduce the behavior:

  1. Run code below.
  2. Feel confused.

Code sample

import ot
import numpy as np
import scipy as sp

# setup problem
ptsA=np.r_[0:5:10j][:,None]
ptsB=np.r_[1:3:20j][:,None]
eps=1

# get distance matrices
C1=sp.spatial.distance.cdist(ptsA,ptsB)**2
C2=sp.spatial.distance.cdist(ptsA,ptsA)**2
C3=sp.spatial.distance.cdist(ptsB,ptsB)**2

# get transport plans
pot_plan1=ot.bregman.empirical_sinkhorn(ptsA,ptsB,eps)
pot_plan2=ot.bregman.empirical_sinkhorn(ptsA,ptsA,eps)
pot_plan3=ot.bregman.empirical_sinkhorn(ptsB,ptsB,eps)

# compute transport costs for sinkhorn divergence
transport_costs=np.sum(C1*pot_plan1)-.5*np.sum(C2*pot_plan2)-.5*np.sum(C3*pot_plan3)

# compute entropic costs for sinkhorn divergence
entropic_costs = np.sum(pot_plan1*np.log(pot_plan1))-.5*np.sum(pot_plan2*np.log(pot_plan2))-.5*np.sum(pot_plan3*np.log(pot_plan3))

# print results
print('transport costs'.rjust(30),transport_costs)
print('entropic costs'.rjust(30),entropic_costs)
print('sinkhorn divergence'.rjust(30),transport_costs+entropic_costs)

# compare with results form ot
print('result from ot'.rjust(30),ot.bregman.empirical_sinkhorn_divergence(ptsA,ptsB,eps))

Expected behavior

I expect the result from ot.bregman.empirical_sinkhorn_divergence to be the same as the sinkhorn divergence as I have calculated it. Instead, it seems to be identical to only the transport-cost portion of the divergence.

Environment (please complete the following information):

  • OS (e.g. MacOS, Windows, Linux): Ubuntu
  • Python version: 3.9.7
  • How was POT installed (source, pip, conda): pip
Linux-5.10.47-linuxkit-x86_64-with-glibc2.31
Python 3.9.7 | packaged by conda-forge | (default, Sep 29 2021, 19:20:46) 
[GCC 9.4.0]
NumPy 1.20.3
SciPy 1.7.2
POT 0.8.2

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions