[WIP] Spectral-Grassmann OT#792
Conversation
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## master #792 +/- ##
==========================================
- Coverage 96.77% 95.77% -1.00%
==========================================
Files 107 108 +1
Lines 22342 22621 +279
==========================================
+ Hits 21622 21666 +44
- Misses 720 955 +235 🚀 New features to boost your workflow:
|
rflamary
left a comment
There was a problem hiding this comment.
Hello @osheasienna and @thibaut-germain this is a nice first step.
Here are below a few comments that we can discuss together
ot/sgot.py
Outdated
| return C | ||
|
|
||
|
|
||
| def metric( |
There was a problem hiding this comment.
| def metric( | |
| def sgot_metric( |
ot/sgot.py
Outdated
| return prod ** (q / 2) | ||
|
|
||
|
|
||
| def ot_plan(C, Ws=None, Wt=None, nx=None): |
There was a problem hiding this comment.
this function is not needed, this is two lines and the ormalization wrt ws and wt are not oK because it rcan retrun very weird things
ot/sgot.py
Outdated
| ### SPECTRAL-GRASSMANNIAN WASSERSTEIN METRIC ### | ||
| ##################################################################################################################################### | ||
| ##################################################################################################################################### | ||
| def cost( |
There was a problem hiding this comment.
| def cost( | |
| def sgot_cost_matrix( |
ot/sgot.py
Outdated
| imag_scale=1.0, | ||
| nx=None, | ||
| ): | ||
| """Compute the SGOT cost matrix between two spectral decompositions. |
There was a problem hiding this comment.
recall here the equation with eta and define with math teh different acceptable metrics
ot/sgot.py
Outdated
| raise ValueError(f"cost() expects Dt to be 1D (n,), got shape {Dt.shape}") | ||
| lam2 = Dt | ||
|
|
||
| lam1 = nx.astype(lam1, "complex128") |
There was a problem hiding this comment.
is that necessary? seems overkill to add a function to the backend for that . When and why does it fails?
test/test_sgot.py
Outdated
| logits_s = rng.randn(r) | ||
| logits_t = rng.randn(r) | ||
|
|
||
| Ws = np.exp(logits_s) |
There was a problem hiding this comment.
simpler and return only positive values
| Ws = np.exp(logits_s) | |
| Ws = rng.rand(r) |
test/test_sgot.py
Outdated
| """Create test_cost for each trial: sweep over HPs and run cost().""" | ||
| grassmann_types = ["geodesic", "chordal", "procrustes", "martin"] | ||
| n_trials = 10 | ||
| for _ in range(n_trials): |
test/test_sgot.py
Outdated
| def test_hyperparameter_sweep(): | ||
| grassmann_types = ["geodesic", "chordal", "procrustes", "martin"] | ||
|
|
||
| for _ in range(10): |
| This new release adds support for sparse cost matrices and a new lazy EMD solver that computes distances on-the-fly from coordinates, reducing memory usage from O(n×m) to O(n+m). Both implementations are backend-agnostic and preserve gradient computation for automatic differentiation. | ||
|
|
||
| #### New features | ||
| - Add lazy EMD solver with on-the-fly distance computation from coordinates (PR #788) |
RELEASES.md
Outdated
| ## Upcomming 0.9.7.post1 | ||
|
|
||
| #### New features | ||
| The next release will add cost functions between linear operators following [A Spectral-Grassmann Wasserstein metric for operator representations of dynamical systems](https://arxiv.org/pdf/2509.24920). |
There was a problem hiding this comment.
move this text to the new feature of 0.9.7.dev0 this is what we are working on. Also add a line in the Itemize with the PR number
rflamary
left a comment
There was a problem hiding this comment.
A few comments from talking together
ot/sgot.py
Outdated
| if grassman_metric == "procrustes": | ||
| return 2.0 * (1.0 - delta) | ||
| if grassman_metric == "martin": | ||
| return -nx.log(nx.clip(delta**2, eps, 1e300)) |
ot/sgot.py
Outdated
| C_grass = _grassmann_distance_squared(delta, grassman_metric=grassman_metric, nx=nx) | ||
|
|
||
| C2 = eta * C_lambda + (1.0 - eta) * C_grass | ||
| C = C2 ** (p / 2.0) |
There was a problem hiding this comment.
| C = C2 ** (p / 2.0) | |
| C = nx.real(C2) ** (p / 2.0) |
ot/sgot.py
Outdated
| q=1, | ||
| r=2, | ||
| grassman_metric="chordal", | ||
| real_scale=1.0, |
There was a problem hiding this comment.
lets call this eigen_scaling and set it to None by default
| nx=None, | ||
| ): | ||
| """Compute the SGOT metric between two spectral decompositions. | ||
|
|
There was a problem hiding this comment.
add equation that illustrate p q and r
test/test_sgot.py
Outdated
| import numpy as np | ||
| import pytest | ||
|
|
||
| from ot.backend import get_backend |
There was a problem hiding this comment.
| from ot.backend import get_backend | |
| from ot.backend import get_backend, torch, jax |
test/test_sgot.py
Outdated
| rng = np.random.RandomState(0) | ||
|
|
||
|
|
||
| def rand_complex(shape): |
There was a problem hiding this comment.
| def rand_complex(shape): | |
| def rand_complex(shape,rng): |
test/test_sgot.py
Outdated
| return real + 1j * imag | ||
|
|
||
|
|
||
| def random_atoms(d=8, r=4): |
There was a problem hiding this comment.
| def random_atoms(d=8, r=4): | |
| def random_atoms(d=8, r=4,seed=42): |
test/test_sgot.py
Outdated
|
|
||
|
|
||
| @pytest.mark.parametrize("backend_name", ["numpy", "torch", "jax"]) | ||
| def test_cost_backend_consistency(backend_name): |
There was a problem hiding this comment.
| def test_cost_backend_consistency(backend_name): | |
| def test_cost_backend_consistency(nx): |
test/test_sgot.py
Outdated
| # --------------------------------------------------------------------- | ||
|
|
||
|
|
||
| def test_hyperparameter_sweep_cost(nx): |
There was a problem hiding this comment.
| def test_hyperparameter_sweep_cost(nx): | |
| def test_hyperparameter_sweep_cost(nx,grassmann_types,p,q,r,eta): |
ot/sgot.py
Outdated
| Ws = Ws / nx.sum(Ws) | ||
| Wt = Wt / nx.sum(Wt) | ||
|
|
||
| P = ot.emd2(Ws, Wt, nx.real(C)) |
There was a problem hiding this comment.
emd2 retruns directly obj no need to compute it again below
| else: | ||
| real_scale, imag_scale = eigen_scaling[0], eigen_scaling[1] | ||
|
|
||
| Dsn = nx.real(Ds) * real_scale + 1j * nx.imag(Ds) * imag_scale |
There was a problem hiding this comment.
| Dsn = nx.real(Ds) * real_scale + 1j * nx.imag(Ds) * imag_scale | |
| C_real = nx.real(Dsn)[:,None] - nx.real(Dtn)[None,:] | |
| C_real = C_real**2 | |
| C_imag = nx.imag(Dsn)[:,None] - nx.imag(Dtn)[None,:] | |
| C_imag = C_imag**2 | |
| prod = C_real + C_imag | |
| return prod ** (q / 2) |
| A_norm: array-like, shape (d, n) | ||
| Column-normalized array. | ||
| """ | ||
| nrm = nx.sqrt(nx.sum(A * nx.conj(A), axis=0, keepdims=True)) |
There was a problem hiding this comment.
| nrm = nx.sqrt(nx.sum(A * nx.conj(A), axis=0, keepdims=True)) | |
| nrm = nx.norm(A, axis=0, keepdims=True) |
You can replace it with the function nx.norm which manages the case of complex number
| return delta | ||
|
|
||
|
|
||
| def _grassmann_distance_squared(delta, grassman_metric="chordal", nx=None, eps=1e-300): |
There was a problem hiding this comment.
Epsilon is too small for the machine precision, you can set it to 1e-12 for instance.
| if nx is None: | ||
| nx = get_backend(delta) | ||
|
|
||
| delta = nx.clip(delta, 0.0, 1.0) |
There was a problem hiding this comment.
If delta is not in [0,1] it should raise an error, this is an issue in the computation of delta outside of this function.
| ### SPECTRAL-GRASSMANNIAN WASSERSTEIN METRIC ### | ||
| ##################################################################################################################################### | ||
| ##################################################################################################################################### | ||
| def sgot_cost_matrix( |
There was a problem hiding this comment.
You should add eps in the definition of the function as this parameters appears in downstream functions. Keep the same epsilon for all functions.
| # information-geometric interpretation in Germain et al. (2025). | ||
| delta2 = nx.maximum(delta**2, eps) | ||
| return -nx.log(delta2) | ||
| raise ValueError(f"Unknown grassman_metric: {grassman_metric}") |
There was a problem hiding this comment.
In this function the power q should also be a parameter:
for any distance you can set:
result = square_ditance(delta)
then
return nx.real(result)**(q/2)
Set by default q to the same value as for eigenvalue cost
| C_lambda = eigenvalue_cost_matrix(Ds, Dt, q=q, eigen_scaling=eigen_scaling, nx=nx) | ||
|
|
||
| delta = _delta_matrix_1d(Rs, Ls, Rt, Lt, nx=nx) | ||
| C_grass = _grassmann_distance_squared(delta, grassman_metric=grassman_metric, nx=nx) |
There was a problem hiding this comment.
the power parameter q should also affect the Grassmann cost
| C_grass = _grassmann_distance_squared(delta, grassman_metric=grassman_metric, nx=nx) | ||
|
|
||
| C2 = eta * C_lambda + (1.0 - eta) * C_grass | ||
| C = nx.real(C2) ** (p / 2.0) |
There was a problem hiding this comment.
you your cost function already return a real no need for nx.real here
| return C | ||
|
|
||
|
|
||
| def _validate_sgot_metric_inputs(Ds, Dt): |
There was a problem hiding this comment.
You can add verifications you wrote in line 272-290 in this function and also add verifications than source and target have the same shapes.
| ) | ||
|
|
||
|
|
||
| def sgot_metric( |
There was a problem hiding this comment.
You will need to add eps also in this function.
Types of changes
Adding sgot file in the ot folder.
Motivation and context / Related issue
Keep track of SGOT implementation in POT.
How has this been tested (if it applies)
Not tested yet.
PR checklist