Skip to content

Commit 034daa0

Browse files
authored
Merge branch 'master' into dev
2 parents 371e3e7 + 8aed2d7 commit 034daa0

12 files changed

+848
-73
lines changed

CONTRIBUTORS.md

+3-5
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,10 @@
22

33
## Creators and Maintainers
44

5-
This toolbox has been created by
5+
This toolbox has been created by [Rémi Flamary](https://remi.flamary.com/)
6+
and [Nicolas Courty](http://people.irisa.fr/Nicolas.Courty/).
67

7-
* [Rémi Flamary](https://remi.flamary.com/)
8-
* [Nicolas Courty](http://people.irisa.fr/Nicolas.Courty/)
9-
10-
It is currently maintained by
8+
It is currently maintained by :
119

1210
* [Rémi Flamary](https://remi.flamary.com/)
1311
* [Cédric Vincent-Cuaz](https://cedricvincentcuaz.github.io/)

README.md

+2-7
Original file line numberDiff line numberDiff line change
@@ -204,12 +204,9 @@ The examples folder contain several examples and use case for the library. The f
204204

205205
## Acknowledgements
206206

207-
This toolbox has been created by
207+
This toolbox has been created by [Rémi Flamary](https://remi.flamary.com/) and [Nicolas Courty](http://people.irisa.fr/Nicolas.Courty/).
208208

209-
* [Rémi Flamary](https://remi.flamary.com/)
210-
* [Nicolas Courty](http://people.irisa.fr/Nicolas.Courty/)
211-
212-
It is currently maintained by
209+
It is currently maintained by :
213210

214211
* [Rémi Flamary](https://remi.flamary.com/)
215212
* [Cédric Vincent-Cuaz](https://cedricvincentcuaz.github.io/)
@@ -220,8 +217,6 @@ POT has benefited from the financing or manpower from the following partners:
220217

221218
<img src="https://pythonot.github.io/master/_static/images/logo_anr.jpg" alt="ANR" style="height:60px;"/><img src="https://pythonot.github.io/master/_static/images/logo_cnrs.jpg" alt="CNRS" style="height:60px;"/><img src="https://pythonot.github.io/master/_static/images/logo_3ia.jpg" alt="3IA" style="height:60px;"/><img src="https://pythonot.github.io/master/_static/images/logo_hiparis.png" alt="Hi!PARIS" style="height:60px;"/>
222219

223-
224-
225220
## Contributions and code of conduct
226221

227222
Every contribution is welcome and should respect the [contribution guidelines](https://pythonot.github.io/master/contributing.html). Each member of the project is expected to follow the [code of conduct](https://pythonot.github.io/master/code_of_conduct.html).

RELEASES.md

+2
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
- Added `ot.gaussian.bures_barycenter_gradient_descent` (PR #680)
1818
- Added `ot.gaussian.bures_wasserstein_distance` (PR #680)
1919
- `ot.gaussian.bures_wasserstein_distance` can be batched (PR #680)
20+
- Backend implementation of `ot.dist` for (PR #701)
21+
- Updated documentation Quickstart guide and User guide with new API (PR #726)
2022

2123
#### Closed issues
2224
- Fixed `ot.mapping` solvers which depended on deprecated `cvxpy` `ECOS` solver (PR #692, Issue #668)

docs/source/conf.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,7 @@ def __getattr__(cls, name):
347347
}
348348

349349
sphinx_gallery_conf = {
350-
"examples_dirs": ["../../examples", "../../examples/da"],
350+
"examples_dirs": ["../../examples"],
351351
"gallery_dirs": "auto_examples",
352352
"filename_pattern": "plot_", # (?!barycenter_fgw)
353353
"nested_sections": False,

docs/source/index.rst

+3-2
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,10 @@ Contents
1717
:maxdepth: 1
1818

1919
self
20-
quickstart
21-
all
20+
auto_examples/plot_quickstart_guide
2221
auto_examples/index
22+
user_guide
23+
all
2324
releases
2425
contributors
2526
contributing

docs/source/quickstart.rst renamed to docs/source/user_guide.rst

+35-35
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11

2-
Quick start guide
3-
=================
2+
User guide
3+
==========
44

55
In the following we provide some pointers about which functions and classes
66
to use for different problems related to optimal transport (OT) and machine
@@ -136,12 +136,12 @@ instance the memory cost for an OT problem is always :math:`\mathcal{O}(n^2)` in
136136
memory because the cost matrix has to be computed. The exact solver in of time
137137
complexity :math:`\mathcal{O}(n^3\log(n))` and the Sinkhorn solver has been
138138
proven to be nearly :math:`\mathcal{O}(n^2)` which is still too complex for very
139-
large scale solvers.
139+
large scale solvers. For all the generic solvers we need to compute the cost
140+
matrix and the OT matrix of memory size :math:`\mathcal{O}(n^2)` which can be
141+
prohibitive for very large scale problems.
140142

141-
142-
If you need to solve OT with large number of samples, we recommend to use
143-
entropic regularization and memory efficient implementation of Sinkhorn as
144-
proposed in `GeomLoss <https://www.kernel-operations.io/geomloss/>`_. This
143+
If you need to solve OT with large number of samples, we provide "lazy" memory efficient implementation of Sinkhorn in pure
144+
python and using `GeomLoss <https://www.kernel-operations.io/geomloss/>`_. This
145145
implementation is compatible with Pytorch and can handle large number of
146146
samples. Another approach to estimate the Wasserstein distance for very large
147147
number of sample is to use the trick from `Wasserstein GAN
@@ -193,15 +193,19 @@ that will return the optimal transport matrix :math:`\gamma^*`:
193193
194194
# a and b are 1D histograms (sum to 1 and positive)
195195
# M is the ground cost matrix
196+
197+
# unified API
198+
T = ot.solve(M, a, b).plan # exact linear program
199+
200+
# classical API
196201
T = ot.emd(a, b, M) # exact linear program
197202
198203
The method implemented for solving the OT problem is the network simplex. It is
199204
implemented in C from [1]_. It has a complexity of :math:`O(n^3)` but the
200205
solver is quite efficient and uses sparsity of the solution.
201206

202207

203-
204-
.. minigallery:: ot.emd
208+
.. minigallery:: ot.emd, ot.solve
205209
:add-heading: Examples of use for :any:`ot.emd`
206210
:heading-level: "
207211

@@ -226,7 +230,12 @@ It can computed from an already estimated OT matrix with
226230
227231
# a and b are 1D histograms (sum to 1 and positive)
228232
# M is the ground cost matrix
229-
W = ot.emd2(a, b, M) # Wasserstein distance / EMD value
233+
234+
# Wasserstein distance / EMD value with unified API
235+
W = ot.solve(M, a, b, return_matrix=False).value
236+
237+
# with classical API
238+
W = ot.emd2(a, b, M)
230239
231240
Note that the well known `Wasserstein distance
232241
<https://en.wikipedia.org/wiki/Wasserstein_metric>`_ between distributions a and
@@ -246,7 +255,7 @@ the :math:`W_1` Wasserstein distance can be done directly with :any:`ot.emd2`
246255
when providing :code:`M = ot.dist(xs, xt, metric='euclidean')` to use the Euclidean
247256
distance.
248257

249-
.. minigallery:: ot.emd2
258+
.. minigallery:: ot.emd2, ot.solve
250259
:add-heading: Examples of use for :any:`ot.emd2`
251260
:heading-level: "
252261

@@ -274,6 +283,10 @@ distributions. In the case when the finite sample dataset is supposed Gaussian,
274283
we provide :any:`ot.gaussian.bures_wasserstein_mapping` that returns the parameters for the
275284
Monge mapping.
276285

286+
All those special cases are accessible with the unified API of POT through the
287+
function :any:`ot.solve_sample` with the parameter :code:`method` that allows to
288+
choose the method used to solve the problem (with :code:`method='1D'` or :code:`method='gaussian'`).
289+
277290

278291
Regularized Optimal Transport
279292
-----------------------------
@@ -330,13 +343,15 @@ The Sinkhorn-Knopp algorithm is implemented in :any:`ot.sinkhorn` and
330343
linear term. Note that the regularization parameter :math:`\lambda` in the
331344
equation above is given to those functions with the parameter :code:`reg`.
332345

333-
>>> import ot
334-
>>> a = [.5, .5]
335-
>>> b = [.5, .5]
336-
>>> M = [[0., 1.], [1., 0.]]
337-
>>> ot.sinkhorn(a, b, M, 1)
338-
array([[ 0.36552929, 0.13447071],
339-
[ 0.13447071, 0.36552929]])
346+
.. code:: python
347+
348+
# unified API
349+
P = ot.solve(M, a, b, reg=1).plan # OT Sinkhorn matrix
350+
loss = ot.solve(M, a, b, reg=1).value # OT Sinkhorn value
351+
352+
# classical API
353+
P = ot.sinkhorn(a, b, M, reg=1) # OT Sinkhorn matrix
354+
loss = ot.sinkhorn2(a, b, M, reg=1) # OT Sinkhorn value
340355
341356
More details about the algorithms used are given in the following note.
342357

@@ -406,13 +421,10 @@ implementations are not optimized for speed but provide a robust implementation
406421
of algorithms in [18]_ [19]_.
407422

408423

409-
.. minigallery:: ot.sinkhorn
410-
:add-heading: Examples of use for :any:`ot.sinkhorn`
424+
.. minigallery:: ot.sinkhorn ot.sinkhorn2
425+
:add-heading: Examples of use for Sinkhorn algorithm
411426
:heading-level: "
412427

413-
.. minigallery:: ot.sinkhorn2
414-
:add-heading: Examples of use for :any:`ot.sinkhorn2`
415-
:heading-level: "
416428

417429

418430
Other regularizations
@@ -969,18 +981,6 @@ For instance, to disable TensorFlow, set `export POT_BACKEND_DISABLE_TENSORFLOW=
969981
It's important to note that the `numpy` backend cannot be disabled.
970982

971983

972-
List of compatible modules
973-
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
974-
975-
This list will get longer for new releases and will hopefully disappear when POT
976-
become fully implemented with the backend.
977-
978-
- :any:`ot.bregman`
979-
- :any:`ot.gromov` (some functions use CPU only solvers with copy overhead)
980-
- :any:`ot.optim` (some functions use CPU only solvers with copy overhead)
981-
- :any:`ot.sliced`
982-
- :any:`ot.utils` (partial)
983-
984984

985985
FAQ
986986
---

examples/plot_OT_2D_samples.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@
6565

6666
# %% EMD
6767

68-
G0 = ot.emd(a, b, M)
68+
G0 = ot.solve(M, a, b).plan
6969

7070
pl.figure(3)
7171
pl.imshow(G0, interpolation="nearest")

0 commit comments

Comments
 (0)