Skip to content

Commit b9ed7b1

Browse files
[MRG] Make partial_wasserstein, partial_wasserstein2 and entropic_partial_wasserstein work with backend (#449)
* add test of partial_wasserstein with torch tensors * WIP: differentiable ot.partial.partial_wasserstein * change test of torch partial * make partial_wasserstein2 work with torch * test backward through ot.partial.partial_wasserstein2 * add test of entropic_partial_wasserstein with torch tensors * make entropic_partial_wasserstein work with torch tensors * add test of backward through entropic_partial_wasserstein * rm unused import * test partial_wasserstein with all backends * tests of partial fcts: check if torch is available * partial: check if marginals are empty arrays * add tests when marginals are empty arrays and/or m=None * add PR to RELEASES.md --------- Co-authored-by: Antoine Collas <[email protected]> Co-authored-by: Rémi Flamary <[email protected]>
1 parent c48cd76 commit b9ed7b1

File tree

3 files changed

+148
-62
lines changed

3 files changed

+148
-62
lines changed

RELEASES.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
- Added Generalized Wasserstein Barycenter solver + example (PR #372), fixed graphical details on the example (PR #376)
1515
- Added Free Support Sinkhorn Barycenter + example (PR #387)
1616
- New API for OT solver using function `ot.solve` (PR #388)
17-
- Backend version of `ot.partial` and `ot.smooth` (PR #388)
17+
- Backend version of `ot.partial` and `ot.smooth` (PR #388 and #449)
1818
- Added argument for warmstart of dual potentials in Sinkhorn-based methods in `ot.bregman` (PR #437)
1919
- Add parameters method in `ot.da.SinkhornTransport` (PR #440)
2020
- `ot.dr` now uses the new Pymanopt API and POT is compatible with current Pymanopt (PR #443)

ot/partial.py

+55-40
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def partial_wasserstein_lagrange(a, b, M, reg_m=None, nb_dummies=1, log=False,
120120

121121
nx = get_backend(a, b, M)
122122

123-
if nx.sum(a) > 1 or nx.sum(b) > 1:
123+
if nx.sum(a) > 1 + 1e-15 or nx.sum(b) > 1 + 1e-15: # 1e-15 for numerical errors
124124
raise ValueError("Problem infeasible. Check that a and b are in the "
125125
"simplex")
126126

@@ -270,36 +270,43 @@ def partial_wasserstein(a, b, M, m=None, nb_dummies=1, log=False, **kwargs):
270270

271271
nx = get_backend(a, b, M)
272272

273+
dim_a, dim_b = M.shape
274+
if len(a) == 0:
275+
a = nx.ones(dim_a, type_as=a) / dim_a
276+
if len(b) == 0:
277+
b = nx.ones(dim_b, type_as=b) / dim_b
278+
273279
if m is None:
274280
return partial_wasserstein_lagrange(a, b, M, log=log, **kwargs)
275281
elif m < 0:
276282
raise ValueError("Problem infeasible. Parameter m should be greater"
277283
" than 0.")
278-
elif m > nx.min((nx.sum(a), nx.sum(b))):
284+
elif m > nx.min(nx.stack((nx.sum(a), nx.sum(b)))):
279285
raise ValueError("Problem infeasible. Parameter m should lower or"
280286
" equal than min(|a|_1, |b|_1).")
281287

282-
a0, b0, M0 = a, b, M
283-
# convert to humpy
284-
a, b, M = nx.to_numpy(a, b, M)
285-
286-
b_extended = np.append(b, [(np.sum(a) - m) / nb_dummies] * nb_dummies)
287-
a_extended = np.append(a, [(np.sum(b) - m) / nb_dummies] * nb_dummies)
288-
M_extended = np.zeros((len(a_extended), len(b_extended)))
289-
M_extended[-nb_dummies:, -nb_dummies:] = np.max(M) * 2
290-
M_extended[:len(a), :len(b)] = M
288+
b_extension = nx.ones(nb_dummies, type_as=b) * (nx.sum(a) - m) / nb_dummies
289+
b_extended = nx.concatenate((b, b_extension))
290+
a_extension = nx.ones(nb_dummies, type_as=a) * (nx.sum(b) - m) / nb_dummies
291+
a_extended = nx.concatenate((a, a_extension))
292+
M_extension = nx.ones((nb_dummies, nb_dummies), type_as=M) * nx.max(M) * 2
293+
M_extended = nx.concatenate(
294+
(nx.concatenate((M, nx.zeros((M.shape[0], M_extension.shape[1]))), axis=1),
295+
nx.concatenate((nx.zeros((M_extension.shape[0], M.shape[1])), M_extension), axis=1)),
296+
axis=0
297+
)
291298

292299
gamma, log_emd = emd(a_extended, b_extended, M_extended, log=True,
293300
**kwargs)
294301

295-
gamma = nx.from_numpy(gamma[:len(a), :len(b)], type_as=M)
302+
gamma = gamma[:len(a), :len(b)]
296303

297304
if log_emd['warning'] is not None:
298305
raise ValueError("Error in the EMD resolution: try to increase the"
299306
" number of dummy points")
300-
log_emd['partial_w_dist'] = nx.sum(M0 * gamma)
301-
log_emd['u'] = nx.from_numpy(log_emd['u'][:len(a)], type_as=a0)
302-
log_emd['v'] = nx.from_numpy(log_emd['v'][:len(b)], type_as=b0)
307+
log_emd['partial_w_dist'] = nx.sum(M * gamma)
308+
log_emd['u'] = log_emd['u'][:len(a)]
309+
log_emd['v'] = log_emd['v'][:len(b)]
303310

304311
if log:
305312
return gamma, log_emd
@@ -389,14 +396,18 @@ def partial_wasserstein2(a, b, M, m=None, nb_dummies=1, log=False, **kwargs):
389396
NeurIPS.
390397
"""
391398

399+
a, b, M = list_to_array(a, b, M)
400+
401+
nx = get_backend(a, b, M)
402+
392403
partial_gw, log_w = partial_wasserstein(a, b, M, m, nb_dummies, log=True,
393404
**kwargs)
394405
log_w['T'] = partial_gw
395406

396407
if log:
397-
return np.sum(partial_gw * M), log_w
408+
return nx.sum(partial_gw * M), log_w
398409
else:
399-
return np.sum(partial_gw * M)
410+
return nx.sum(partial_gw * M)
400411

401412

402413
def gwgrad_partial(C1, C2, T):
@@ -838,60 +849,64 @@ def entropic_partial_wasserstein(a, b, M, reg, m=None, numItermax=1000,
838849
ot.partial.partial_wasserstein: exact Partial Wasserstein
839850
"""
840851

841-
a = np.asarray(a, dtype=np.float64)
842-
b = np.asarray(b, dtype=np.float64)
843-
M = np.asarray(M, dtype=np.float64)
852+
a, b, M = list_to_array(a, b, M)
853+
854+
nx = get_backend(a, b, M)
844855

845856
dim_a, dim_b = M.shape
846-
dx = np.ones(dim_a, dtype=np.float64)
847-
dy = np.ones(dim_b, dtype=np.float64)
857+
dx = nx.ones(dim_a, type_as=a)
858+
dy = nx.ones(dim_b, type_as=b)
848859

849860
if len(a) == 0:
850-
a = np.ones(dim_a, dtype=np.float64) / dim_a
861+
a = nx.ones(dim_a, type_as=a) / dim_a
851862
if len(b) == 0:
852-
b = np.ones(dim_b, dtype=np.float64) / dim_b
863+
b = nx.ones(dim_b, type_as=b) / dim_b
853864

854865
if m is None:
855-
m = np.min((np.sum(a), np.sum(b))) * 1.0
866+
m = nx.min(nx.stack((nx.sum(a), nx.sum(b)))) * 1.0
856867
if m < 0:
857868
raise ValueError("Problem infeasible. Parameter m should be greater"
858869
" than 0.")
859-
if m > np.min((np.sum(a), np.sum(b))):
870+
if m > nx.min(nx.stack((nx.sum(a), nx.sum(b)))):
860871
raise ValueError("Problem infeasible. Parameter m should lower or"
861872
" equal than min(|a|_1, |b|_1).")
862873

863874
log_e = {'err': []}
864875

865-
# Next 3 lines equivalent to K=np.exp(-M/reg), but faster to compute
866-
K = np.empty(M.shape, dtype=M.dtype)
867-
np.divide(M, -reg, out=K)
868-
np.exp(K, out=K)
869-
np.multiply(K, m / np.sum(K), out=K)
876+
if type(a) == type(b) == type(M) == np.ndarray:
877+
# Next 3 lines equivalent to K=nx.exp(-M/reg), but faster to compute
878+
K = np.empty(M.shape, dtype=M.dtype)
879+
np.divide(M, -reg, out=K)
880+
np.exp(K, out=K)
881+
np.multiply(K, m / np.sum(K), out=K)
882+
else:
883+
K = nx.exp(-M / reg)
884+
K = K * m / nx.sum(K)
870885

871886
err, cpt = 1, 0
872-
q1 = np.ones(K.shape)
873-
q2 = np.ones(K.shape)
874-
q3 = np.ones(K.shape)
887+
q1 = nx.ones(K.shape, type_as=K)
888+
q2 = nx.ones(K.shape, type_as=K)
889+
q3 = nx.ones(K.shape, type_as=K)
875890

876891
while (err > stopThr and cpt < numItermax):
877892
Kprev = K
878893
K = K * q1
879-
K1 = np.dot(np.diag(np.minimum(a / np.sum(K, axis=1), dx)), K)
894+
K1 = nx.dot(nx.diag(nx.minimum(a / nx.sum(K, axis=1), dx)), K)
880895
q1 = q1 * Kprev / K1
881896
K1prev = K1
882897
K1 = K1 * q2
883-
K2 = np.dot(K1, np.diag(np.minimum(b / np.sum(K1, axis=0), dy)))
898+
K2 = nx.dot(K1, nx.diag(nx.minimum(b / nx.sum(K1, axis=0), dy)))
884899
q2 = q2 * K1prev / K2
885900
K2prev = K2
886901
K2 = K2 * q3
887-
K = K2 * (m / np.sum(K2))
902+
K = K2 * (m / nx.sum(K2))
888903
q3 = q3 * K2prev / K
889904

890-
if np.any(np.isnan(K)) or np.any(np.isinf(K)):
905+
if nx.any(nx.isnan(K)) or nx.any(nx.isinf(K)):
891906
print('Warning: numerical errors at iteration', cpt)
892907
break
893908
if cpt % 10 == 0:
894-
err = np.linalg.norm(Kprev - K)
909+
err = nx.norm(Kprev - K)
895910
if log:
896911
log_e['err'].append(err)
897912
if verbose:
@@ -901,7 +916,7 @@ def entropic_partial_wasserstein(a, b, M, reg, m=None, numItermax=1000,
901916
print('{:5d}|{:8e}|'.format(cpt, err))
902917

903918
cpt = cpt + 1
904-
log_e['partial_w_dist'] = np.sum(M * K)
919+
log_e['partial_w_dist'] = nx.sum(M * K)
905920
if log:
906921
return K, log_e
907922
else:

test/test_partial.py

+92-21
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import numpy as np
99
import scipy as sp
1010
import ot
11+
from ot.backend import to_numpy, torch
1112
import pytest
1213

1314

@@ -82,7 +83,7 @@ def test_partial_wasserstein_lagrange():
8283
w0, log0 = ot.partial.partial_wasserstein_lagrange(p, q, M, 100, log=True)
8384

8485

85-
def test_partial_wasserstein():
86+
def test_partial_wasserstein(nx):
8687

8788
n_samples = 20 # nb samples (gaussian)
8889
n_noise = 20 # nb of samples (noise)
@@ -102,25 +103,20 @@ def test_partial_wasserstein():
102103

103104
m = 0.5
104105

106+
p, q, M = nx.from_numpy(p, q, M)
107+
105108
w0, log0 = ot.partial.partial_wasserstein(p, q, M, m=m, log=True)
106-
w, log = ot.partial.entropic_partial_wasserstein(p, q, M, reg=1, m=m,
107-
log=True, verbose=True)
109+
w, log = ot.partial.entropic_partial_wasserstein(p, q, M, reg=1, m=m, log=True, verbose=True)
108110

109111
# check constraints
110-
np.testing.assert_equal(
111-
w0.sum(1) - p <= 1e-5, [True] * len(p)) # cf convergence wasserstein
112-
np.testing.assert_equal(
113-
w0.sum(0) - q <= 1e-5, [True] * len(q)) # cf convergence wasserstein
114-
np.testing.assert_equal(
115-
w.sum(1) - p <= 1e-5, [True] * len(p)) # cf convergence wasserstein
116-
np.testing.assert_equal(
117-
w.sum(0) - q <= 1e-5, [True] * len(q)) # cf convergence wasserstein
112+
np.testing.assert_equal(to_numpy(nx.sum(w0, axis=1) - p) <= 1e-5, [True] * len(p))
113+
np.testing.assert_equal(to_numpy(nx.sum(w0, axis=0) - q) <= 1e-5, [True] * len(q))
114+
np.testing.assert_equal(to_numpy(nx.sum(w0, axis=1) - p) <= 1e-5, [True] * len(p))
115+
np.testing.assert_equal(to_numpy(nx.sum(w0, axis=0) - q) <= 1e-5, [True] * len(q))
118116

119117
# check transported mass
120-
np.testing.assert_allclose(
121-
np.sum(w0), m, atol=1e-04)
122-
np.testing.assert_allclose(
123-
np.sum(w), m, atol=1e-04)
118+
np.testing.assert_allclose(np.sum(to_numpy(w0)), m, atol=1e-04)
119+
np.testing.assert_allclose(np.sum(to_numpy(w)), m, atol=1e-04)
124120

125121
w0, log0 = ot.partial.partial_wasserstein2(p, q, M, m=m, log=True)
126122
w0_val = ot.partial.partial_wasserstein2(p, q, M, m=m, log=False)
@@ -130,12 +126,87 @@ def test_partial_wasserstein():
130126
np.testing.assert_allclose(w0, w0_val, atol=1e-1, rtol=1e-1)
131127

132128
# check constraints
133-
np.testing.assert_equal(
134-
G.sum(1) - p <= 1e-5, [True] * len(p)) # cf convergence wasserstein
135-
np.testing.assert_equal(
136-
G.sum(0) - q <= 1e-5, [True] * len(q)) # cf convergence wasserstein
137-
np.testing.assert_allclose(
138-
np.sum(G), m, atol=1e-04)
129+
np.testing.assert_equal(to_numpy(nx.sum(G, axis=1) - p) <= 1e-5, [True] * len(p))
130+
np.testing.assert_equal(to_numpy(nx.sum(G, axis=0) - q) <= 1e-5, [True] * len(q))
131+
np.testing.assert_allclose(np.sum(to_numpy(G)), m, atol=1e-04)
132+
133+
empty_array = nx.zeros(0, type_as=M)
134+
w = ot.partial.partial_wasserstein(empty_array, empty_array, M=M, m=None)
135+
136+
# check constraints
137+
np.testing.assert_equal(to_numpy(nx.sum(w, axis=1) - p) <= 1e-5, [True] * len(p))
138+
np.testing.assert_equal(to_numpy(nx.sum(w, axis=0) - q) <= 1e-5, [True] * len(q))
139+
np.testing.assert_equal(to_numpy(nx.sum(w, axis=1) - p) <= 1e-5, [True] * len(p))
140+
np.testing.assert_equal(to_numpy(nx.sum(w, axis=0) - q) <= 1e-5, [True] * len(q))
141+
142+
# check transported mass
143+
np.testing.assert_allclose(np.sum(to_numpy(w)), 1, atol=1e-04)
144+
145+
w0 = ot.partial.entropic_partial_wasserstein(empty_array, empty_array, M=M, reg=10, m=None)
146+
147+
# check constraints
148+
np.testing.assert_equal(to_numpy(nx.sum(w0, axis=1) - p) <= 1e-5, [True] * len(p))
149+
np.testing.assert_equal(to_numpy(nx.sum(w0, axis=0) - q) <= 1e-5, [True] * len(q))
150+
np.testing.assert_equal(to_numpy(nx.sum(w0, axis=1) - p) <= 1e-5, [True] * len(p))
151+
np.testing.assert_equal(to_numpy(nx.sum(w0, axis=0) - q) <= 1e-5, [True] * len(q))
152+
153+
# check transported mass
154+
np.testing.assert_allclose(np.sum(to_numpy(w0)), 1, atol=1e-04)
155+
156+
157+
def test_partial_wasserstein2_gradient():
158+
if torch:
159+
n_samples = 40
160+
161+
mu = np.array([0, 0])
162+
cov = np.array([[1, 0], [0, 2]])
163+
164+
xs = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov)
165+
xt = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov)
166+
167+
M = torch.tensor(ot.dist(xs, xt), requires_grad=True, dtype=torch.float64)
168+
169+
p = torch.tensor(ot.unif(n_samples), dtype=torch.float64)
170+
q = torch.tensor(ot.unif(n_samples), dtype=torch.float64)
171+
172+
m = 0.5
173+
174+
w, log = ot.partial.partial_wasserstein2(p, q, M, m=m, log=True)
175+
176+
w.backward()
177+
178+
assert M.grad is not None
179+
assert M.grad.shape == M.shape
180+
181+
182+
def test_entropic_partial_wasserstein_gradient():
183+
if torch:
184+
n_samples = 40
185+
186+
mu = np.array([0, 0])
187+
cov = np.array([[1, 0], [0, 2]])
188+
189+
xs = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov)
190+
xt = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov)
191+
192+
M = torch.tensor(ot.dist(xs, xt), requires_grad=True, dtype=torch.float64)
193+
194+
p = torch.tensor(ot.unif(n_samples), requires_grad=True, dtype=torch.float64)
195+
q = torch.tensor(ot.unif(n_samples), requires_grad=True, dtype=torch.float64)
196+
197+
m = 0.5
198+
reg = 1
199+
200+
_, log = ot.partial.entropic_partial_wasserstein(p, q, M, m=m, reg=reg, log=True)
201+
202+
log['partial_w_dist'].backward()
203+
204+
assert M.grad is not None
205+
assert p.grad is not None
206+
assert q.grad is not None
207+
assert M.grad.shape == M.shape
208+
assert p.grad.shape == p.shape
209+
assert q.grad.shape == q.shape
139210

140211

141212
def test_partial_gromov_wasserstein():

0 commit comments

Comments
 (0)