1
1
2
- Quick start guide
3
- =================
2
+ User guide
3
+ ==========
4
4
5
5
In the following we provide some pointers about which functions and classes
6
6
to use for different problems related to optimal transport (OT) and machine
@@ -136,12 +136,12 @@ instance the memory cost for an OT problem is always :math:`\mathcal{O}(n^2)` in
136
136
memory because the cost matrix has to be computed. The exact solver in of time
137
137
complexity :math: `\mathcal {O}(n^3 \log (n))` and the Sinkhorn solver has been
138
138
proven to be nearly :math: `\mathcal {O}(n^2 )` which is still too complex for very
139
- large scale solvers.
139
+ large scale solvers. For all the generic solvers we need to compute the cost
140
+ matrix and the OT matrix of memory size :math: `\mathcal {O}(n^2 )` which can be
141
+ prohibitive for very large scale problems.
140
142
141
-
142
- If you need to solve OT with large number of samples, we recommend to use
143
- entropic regularization and memory efficient implementation of Sinkhorn as
144
- proposed in `GeomLoss <https://www.kernel-operations.io/geomloss/ >`_. This
143
+ If you need to solve OT with large number of samples, we provide "lazy" memory efficient implementation of Sinkhorn in pure
144
+ python and using `GeomLoss <https://www.kernel-operations.io/geomloss/ >`_. This
145
145
implementation is compatible with Pytorch and can handle large number of
146
146
samples. Another approach to estimate the Wasserstein distance for very large
147
147
number of sample is to use the trick from `Wasserstein GAN
@@ -193,15 +193,19 @@ that will return the optimal transport matrix :math:`\gamma^*`:
193
193
194
194
# a and b are 1D histograms (sum to 1 and positive)
195
195
# M is the ground cost matrix
196
+
197
+ # unified API
198
+ T = ot.solve(M, a, b).plan # exact linear program
199
+
200
+ # classical API
196
201
T = ot.emd(a, b, M) # exact linear program
197
202
198
203
The method implemented for solving the OT problem is the network simplex. It is
199
204
implemented in C from [1 ]_. It has a complexity of :math: `O(n^3 )` but the
200
205
solver is quite efficient and uses sparsity of the solution.
201
206
202
207
203
-
204
- .. minigallery :: ot.emd
208
+ .. minigallery :: ot.emd, ot.solve
205
209
:add-heading: Examples of use for :any: `ot.emd `
206
210
:heading-level: "
207
211
@@ -226,7 +230,12 @@ It can computed from an already estimated OT matrix with
226
230
227
231
# a and b are 1D histograms (sum to 1 and positive)
228
232
# M is the ground cost matrix
229
- W = ot.emd2(a, b, M) # Wasserstein distance / EMD value
233
+
234
+ # Wasserstein distance / EMD value with unified API
235
+ W = ot.solve(M, a, b, return_matrix = False ).value
236
+
237
+ # with classical API
238
+ W = ot.emd2(a, b, M)
230
239
231
240
Note that the well known `Wasserstein distance
232
241
<https://en.wikipedia.org/wiki/Wasserstein_metric> `_ between distributions a and
@@ -246,7 +255,7 @@ the :math:`W_1` Wasserstein distance can be done directly with :any:`ot.emd2`
246
255
when providing :code: `M = ot.dist(xs, xt, metric='euclidean') ` to use the Euclidean
247
256
distance.
248
257
249
- .. minigallery :: ot.emd2
258
+ .. minigallery :: ot.emd2, ot.solve
250
259
:add-heading: Examples of use for :any: `ot.emd2 `
251
260
:heading-level: "
252
261
@@ -274,6 +283,10 @@ distributions. In the case when the finite sample dataset is supposed Gaussian,
274
283
we provide :any: `ot.gaussian.bures_wasserstein_mapping ` that returns the parameters for the
275
284
Monge mapping.
276
285
286
+ All those special cases are accessible with the unified API of POT through the
287
+ function :any: `ot.solve_sample ` with the parameter :code: `method ` that allows to
288
+ choose the method used to solve the problem (with :code: `method='1D' ` or :code: `method='gaussian' `).
289
+
277
290
278
291
Regularized Optimal Transport
279
292
-----------------------------
@@ -330,13 +343,15 @@ The Sinkhorn-Knopp algorithm is implemented in :any:`ot.sinkhorn` and
330
343
linear term. Note that the regularization parameter :math: `\lambda ` in the
331
344
equation above is given to those functions with the parameter :code: `reg `.
332
345
333
- >>> import ot
334
- >>> a = [.5 , .5 ]
335
- >>> b = [.5 , .5 ]
336
- >>> M = [[0 ., 1 .], [1 ., 0 .]]
337
- >>> ot.sinkhorn(a, b, M, 1 )
338
- array([[ 0.36552929, 0.13447071],
339
- [ 0.13447071, 0.36552929]])
346
+ .. code :: python
347
+
348
+ # unified API
349
+ P = ot.solve(M, a, b, reg = 1 ).plan # OT Sinkhorn matrix
350
+ loss = ot.solve(M, a, b, reg = 1 ).value # OT Sinkhorn value
351
+
352
+ # classical API
353
+ P = ot.sinkhorn(a, b, M, reg = 1 ) # OT Sinkhorn matrix
354
+ loss = ot.sinkhorn2(a, b, M, reg = 1 ) # OT Sinkhorn value
340
355
341
356
More details about the algorithms used are given in the following note.
342
357
@@ -406,13 +421,10 @@ implementations are not optimized for speed but provide a robust implementation
406
421
of algorithms in [18 ]_ [19 ]_.
407
422
408
423
409
- .. minigallery :: ot.sinkhorn
410
- :add-heading: Examples of use for :any: ` ot.sinkhorn `
424
+ .. minigallery :: ot.sinkhorn ot.sinkhorn2
425
+ :add-heading: Examples of use for Sinkhorn algorithm
411
426
:heading-level: "
412
427
413
- .. minigallery :: ot.sinkhorn2
414
- :add-heading: Examples of use for :any: `ot.sinkhorn2 `
415
- :heading-level: "
416
428
417
429
418
430
Other regularizations
@@ -969,18 +981,6 @@ For instance, to disable TensorFlow, set `export POT_BACKEND_DISABLE_TENSORFLOW=
969
981
It's important to note that the `numpy ` backend cannot be disabled.
970
982
971
983
972
- List of compatible modules
973
- ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
974
-
975
- This list will get longer for new releases and will hopefully disappear when POT
976
- become fully implemented with the backend.
977
-
978
- - :any: `ot.bregman `
979
- - :any: `ot.gromov ` (some functions use CPU only solvers with copy overhead)
980
- - :any: `ot.optim ` (some functions use CPU only solvers with copy overhead)
981
- - :any: `ot.sliced `
982
- - :any: `ot.utils ` (partial)
983
-
984
984
985
985
FAQ
986
986
---
0 commit comments