Skip to content

[WIP] Spectral-Grassmann OT#792

Open
thibaut-germain wants to merge 16 commits intoPythonOT:masterfrom
thibaut-germain:sgot
Open

[WIP] Spectral-Grassmann OT#792
thibaut-germain wants to merge 16 commits intoPythonOT:masterfrom
thibaut-germain:sgot

Conversation

@thibaut-germain
Copy link

Types of changes

Adding sgot file in the ot folder.

Motivation and context / Related issue

Keep track of SGOT implementation in POT.

How has this been tested (if it applies)

Not tested yet.

PR checklist

  • I have read the CONTRIBUTING document.
  • The documentation is up-to-date with the changes I made (check build artifacts).
  • [] All tests passed, and additional code has been covered with new tests.
  • I have added the PR and Issue fix to the RELEASES.md file.

@rflamary rflamary changed the title Sgot [WIP] Spactral-Gromov OT Feb 9, 2026
@rflamary rflamary changed the title [WIP] Spactral-Gromov OT [WIP] Spectral-Grassman OT Feb 9, 2026
@rflamary rflamary changed the title [WIP] Spectral-Grassman OT [WIP] Spectral-Grassmann OT Feb 9, 2026
@codecov
Copy link

codecov bot commented Feb 11, 2026

Codecov Report

❌ Patch coverage is 7.17131% with 233 lines in your changes missing coverage. Please review.
✅ Project coverage is 95.77%. Comparing base (e164e78) to head (3f10111).
⚠️ Report is 2 commits behind head on master.

Additional details and impacted files
@@            Coverage Diff             @@
##           master     #792      +/-   ##
==========================================
- Coverage   96.77%   95.77%   -1.00%     
==========================================
  Files         107      108       +1     
  Lines       22342    22621     +279     
==========================================
+ Hits        21622    21666      +44     
- Misses        720      955     +235     
🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Collaborator

@rflamary rflamary left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello @osheasienna and @thibaut-germain this is a nice first step.

Here are below a few comments that we can discuss together

ot/sgot.py Outdated
return C


def metric(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def metric(
def sgot_metric(

ot/sgot.py Outdated
return prod ** (q / 2)


def ot_plan(C, Ws=None, Wt=None, nx=None):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this function is not needed, this is two lines and the ormalization wrt ws and wt are not oK because it rcan retrun very weird things

ot/sgot.py Outdated
### SPECTRAL-GRASSMANNIAN WASSERSTEIN METRIC ###
#####################################################################################################################################
#####################################################################################################################################
def cost(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def cost(
def sgot_cost_matrix(

ot/sgot.py Outdated
imag_scale=1.0,
nx=None,
):
"""Compute the SGOT cost matrix between two spectral decompositions.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

recall here the equation with eta and define with math teh different acceptable metrics

ot/sgot.py Outdated
raise ValueError(f"cost() expects Dt to be 1D (n,), got shape {Dt.shape}")
lam2 = Dt

lam1 = nx.astype(lam1, "complex128")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is that necessary? seems overkill to add a function to the backend for that . When and why does it fails?

logits_s = rng.randn(r)
logits_t = rng.randn(r)

Ws = np.exp(logits_s)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

simpler and return only positive values

Suggested change
Ws = np.exp(logits_s)
Ws = rng.rand(r)

"""Create test_cost for each trial: sweep over HPs and run cost()."""
grassmann_types = ["geodesic", "chordal", "procrustes", "martin"]
n_trials = 10
for _ in range(n_trials):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need for trials ;)

def test_hyperparameter_sweep():
grassmann_types = ["geodesic", "chordal", "procrustes", "martin"]

for _ in range(10):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here

This new release adds support for sparse cost matrices and a new lazy EMD solver that computes distances on-the-fly from coordinates, reducing memory usage from O(n×m) to O(n+m). Both implementations are backend-agnostic and preserve gradient computation for automatic differentiation.

#### New features
- Add lazy EMD solver with on-the-fly distance computation from coordinates (PR #788)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add feature here

RELEASES.md Outdated
## Upcomming 0.9.7.post1

#### New features
The next release will add cost functions between linear operators following [A Spectral-Grassmann Wasserstein metric for operator representations of dynamical systems](https://arxiv.org/pdf/2509.24920).
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move this text to the new feature of 0.9.7.dev0 this is what we are working on. Also add a line in the Itemize with the PR number

Copy link
Collaborator

@rflamary rflamary left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few comments from talking together

ot/sgot.py Outdated
if grassman_metric == "procrustes":
return 2.0 * (1.0 - delta)
if grassman_metric == "martin":
return -nx.log(nx.clip(delta**2, eps, 1e300))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove upper threshold

ot/sgot.py Outdated
C_grass = _grassmann_distance_squared(delta, grassman_metric=grassman_metric, nx=nx)

C2 = eta * C_lambda + (1.0 - eta) * C_grass
C = C2 ** (p / 2.0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
C = C2 ** (p / 2.0)
C = nx.real(C2) ** (p / 2.0)

ot/sgot.py Outdated
q=1,
r=2,
grassman_metric="chordal",
real_scale=1.0,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets call this eigen_scaling and set it to None by default

nx=None,
):
"""Compute the SGOT metric between two spectral decompositions.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add equation that illustrate p q and r

import numpy as np
import pytest

from ot.backend import get_backend
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
from ot.backend import get_backend
from ot.backend import get_backend, torch, jax

rng = np.random.RandomState(0)


def rand_complex(shape):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def rand_complex(shape):
def rand_complex(shape,rng):

return real + 1j * imag


def random_atoms(d=8, r=4):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def random_atoms(d=8, r=4):
def random_atoms(d=8, r=4,seed=42):



@pytest.mark.parametrize("backend_name", ["numpy", "torch", "jax"])
def test_cost_backend_consistency(backend_name):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def test_cost_backend_consistency(backend_name):
def test_cost_backend_consistency(nx):

# ---------------------------------------------------------------------


def test_hyperparameter_sweep_cost(nx):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def test_hyperparameter_sweep_cost(nx):
def test_hyperparameter_sweep_cost(nx,grassmann_types,p,q,r,eta):

ot/sgot.py Outdated
Ws = Ws / nx.sum(Ws)
Wt = Wt / nx.sum(Wt)

P = ot.emd2(Ws, Wt, nx.real(C))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

emd2 retruns directly obj no need to compute it again below

else:
real_scale, imag_scale = eigen_scaling[0], eigen_scaling[1]

Dsn = nx.real(Ds) * real_scale + 1j * nx.imag(Ds) * imag_scale
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Dsn = nx.real(Ds) * real_scale + 1j * nx.imag(Ds) * imag_scale
C_real = nx.real(Dsn)[:,None] - nx.real(Dtn)[None,:]
C_real = C_real**2
C_imag = nx.imag(Dsn)[:,None] - nx.imag(Dtn)[None,:]
C_imag = C_imag**2
prod = C_real + C_imag
return prod ** (q / 2)

A_norm: array-like, shape (d, n)
Column-normalized array.
"""
nrm = nx.sqrt(nx.sum(A * nx.conj(A), axis=0, keepdims=True))
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
nrm = nx.sqrt(nx.sum(A * nx.conj(A), axis=0, keepdims=True))
nrm = nx.norm(A, axis=0, keepdims=True)

You can replace it with the function nx.norm which manages the case of complex number

return delta


def _grassmann_distance_squared(delta, grassman_metric="chordal", nx=None, eps=1e-300):
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Epsilon is too small for the machine precision, you can set it to 1e-12 for instance.

if nx is None:
nx = get_backend(delta)

delta = nx.clip(delta, 0.0, 1.0)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If delta is not in [0,1] it should raise an error, this is an issue in the computation of delta outside of this function.

### SPECTRAL-GRASSMANNIAN WASSERSTEIN METRIC ###
#####################################################################################################################################
#####################################################################################################################################
def sgot_cost_matrix(
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should add eps in the definition of the function as this parameters appears in downstream functions. Keep the same epsilon for all functions.

Copy link
Author

@thibaut-germain thibaut-germain left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thibaut's review

# information-geometric interpretation in Germain et al. (2025).
delta2 = nx.maximum(delta**2, eps)
return -nx.log(delta2)
raise ValueError(f"Unknown grassman_metric: {grassman_metric}")
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this function the power q should also be a parameter:
for any distance you can set:
result = square_ditance(delta)
then
return nx.real(result)**(q/2)
Set by default q to the same value as for eigenvalue cost

C_lambda = eigenvalue_cost_matrix(Ds, Dt, q=q, eigen_scaling=eigen_scaling, nx=nx)

delta = _delta_matrix_1d(Rs, Ls, Rt, Lt, nx=nx)
C_grass = _grassmann_distance_squared(delta, grassman_metric=grassman_metric, nx=nx)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the power parameter q should also affect the Grassmann cost

C_grass = _grassmann_distance_squared(delta, grassman_metric=grassman_metric, nx=nx)

C2 = eta * C_lambda + (1.0 - eta) * C_grass
C = nx.real(C2) ** (p / 2.0)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you your cost function already return a real no need for nx.real here

return C


def _validate_sgot_metric_inputs(Ds, Dt):
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can add verifications you wrote in line 272-290 in this function and also add verifications than source and target have the same shapes.

)


def sgot_metric(
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You will need to add eps also in this function.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants