Skip to content

Commit 695a8a8

Browse files
Circle CICircle CI
Circle CI
authored and
Circle CI
committed
CircleCI update of dev docs (2525).
1 parent d979453 commit 695a8a8

File tree

259 files changed

+728883
-727946
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

259 files changed

+728883
-727946
lines changed
Binary file not shown.
Binary file not shown.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {},
6+
"source": [
7+
"\n# Optimal Transport solvers comparison\n\nThis example illustrates the solutions returns for diffrent variants of exact,\nregularized and unbalanced OT solvers.\n"
8+
]
9+
},
10+
{
11+
"cell_type": "code",
12+
"execution_count": null,
13+
"metadata": {
14+
"collapsed": false
15+
},
16+
"outputs": [],
17+
"source": [
18+
"# Author: Remi Flamary <[email protected]>\n#\n# License: MIT License\n# sphinx_gallery_thumbnail_number = 3"
19+
]
20+
},
21+
{
22+
"cell_type": "code",
23+
"execution_count": null,
24+
"metadata": {
25+
"collapsed": false
26+
},
27+
"outputs": [],
28+
"source": [
29+
"import numpy as np\nimport matplotlib.pylab as pl\nimport ot\nimport ot.plot\nfrom ot.datasets import make_1D_gauss as gauss"
30+
]
31+
},
32+
{
33+
"cell_type": "markdown",
34+
"metadata": {},
35+
"source": [
36+
"## Generate data\n\n"
37+
]
38+
},
39+
{
40+
"cell_type": "code",
41+
"execution_count": null,
42+
"metadata": {
43+
"collapsed": false
44+
},
45+
"outputs": [],
46+
"source": [
47+
"n = 50 # nb bins\n\n# bin positions\nx = np.arange(n, dtype=np.float64)\n\n# Gaussian distributions\na = 0.6 * gauss(n, m=15, s=5) + 0.4 * gauss(n, m=35, s=5) # m= mean, s= std\nb = gauss(n, m=25, s=5)\n\n# loss matrix\nM = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)))\nM /= M.max()"
48+
]
49+
},
50+
{
51+
"cell_type": "markdown",
52+
"metadata": {},
53+
"source": [
54+
"## Plot distributions and loss matrix\n\n"
55+
]
56+
},
57+
{
58+
"cell_type": "code",
59+
"execution_count": null,
60+
"metadata": {
61+
"collapsed": false
62+
},
63+
"outputs": [],
64+
"source": [
65+
"pl.figure(1, figsize=(6.4, 3))\npl.plot(x, a, 'b', label='Source distribution')\npl.plot(x, b, 'r', label='Target distribution')\npl.legend()"
66+
]
67+
},
68+
{
69+
"cell_type": "code",
70+
"execution_count": null,
71+
"metadata": {
72+
"collapsed": false
73+
},
74+
"outputs": [],
75+
"source": [
76+
"pl.figure(2, figsize=(5, 5))\not.plot.plot1D_mat(a, b, M, 'Cost matrix M')"
77+
]
78+
},
79+
{
80+
"cell_type": "markdown",
81+
"metadata": {},
82+
"source": [
83+
"## Define Group lasso regularization and gradient\nThe groups are the first and second half of the columns of G\n\n"
84+
]
85+
},
86+
{
87+
"cell_type": "code",
88+
"execution_count": null,
89+
"metadata": {
90+
"collapsed": false
91+
},
92+
"outputs": [],
93+
"source": [
94+
"def reg_gl(G): # group lasso + small l2 reg\n G1 = G[:n // 2, :]**2\n G2 = G[n // 2:, :]**2\n gl1 = np.sum(np.sqrt(np.sum(G1, 0)))\n gl2 = np.sum(np.sqrt(np.sum(G2, 0)))\n return gl1 + gl2 + 0.1 * np.sum(G**2)\n\n\ndef grad_gl(G): # gradient of group lasso + small l2 reg\n G1 = G[:n // 2, :]\n G2 = G[n // 2:, :]\n gl1 = G1 / np.sqrt(np.sum(G1**2, 0, keepdims=True) + 1e-8)\n gl2 = G2 / np.sqrt(np.sum(G2**2, 0, keepdims=True) + 1e-8)\n return np.concatenate((gl1, gl2), axis=0) + 0.2 * G\n\n\nreg_type_gl = (reg_gl, grad_gl)"
95+
]
96+
},
97+
{
98+
"cell_type": "markdown",
99+
"metadata": {},
100+
"source": [
101+
"## Set up parameters for solvers and solve\n\n"
102+
]
103+
},
104+
{
105+
"cell_type": "code",
106+
"execution_count": null,
107+
"metadata": {
108+
"collapsed": false
109+
},
110+
"outputs": [],
111+
"source": [
112+
"lst_regs = [\"No Reg.\", \"Entropic\", \"L2\", \"Group Lasso + L2\"]\nlst_unbalanced = [\"Balanced\", \"Unbalanced KL\", 'Unbalanced L2', 'Unb. TV (Partial)'] # [\"Balanced\", \"Unb. KL\", \"Unb. L2\", \"Unb L1 (partial)\"]\n\nlst_solvers = [ # name, param for ot.solve function\n # balanced OT\n ('Exact OT', dict()),\n ('Entropic Reg. OT', dict(reg=0.005)),\n ('L2 Reg OT', dict(reg=1, reg_type='l2')),\n ('Group Lasso Reg. OT', dict(reg=0.1, reg_type=reg_type_gl)),\n\n\n # unbalanced OT KL\n ('Unbalanced KL No Reg.', dict(unbalanced=0.005)),\n ('Unbalanced KL wit KL Reg.', dict(reg=0.0005, unbalanced=0.005, unbalanced_type='kl', reg_type='kl')),\n ('Unbalanced KL with L2 Reg.', dict(reg=0.5, reg_type='l2', unbalanced=0.005, unbalanced_type='kl')),\n ('Unbalanced KL with Group Lasso Reg.', dict(reg=0.1, reg_type=reg_type_gl, unbalanced=0.05, unbalanced_type='kl')),\n\n # unbalanced OT L2\n ('Unbalanced L2 No Reg.', dict(unbalanced=0.5, unbalanced_type='l2')),\n ('Unbalanced L2 with KL Reg.', dict(reg=0.001, unbalanced=0.2, unbalanced_type='l2')),\n ('Unbalanced L2 with L2 Reg.', dict(reg=0.1, reg_type='l2', unbalanced=0.2, unbalanced_type='l2')),\n ('Unbalanced L2 with Group Lasso Reg.', dict(reg=0.05, reg_type=reg_type_gl, unbalanced=0.7, unbalanced_type='l2')),\n\n # unbalanced OT TV\n ('Unbalanced TV No Reg.', dict(unbalanced=0.1, unbalanced_type='tv')),\n ('Unbalanced TV with KL Reg.', dict(reg=0.001, unbalanced=0.01, unbalanced_type='tv')),\n ('Unbalanced TV with L2 Reg.', dict(reg=0.1, reg_type='l2', unbalanced=0.01, unbalanced_type='tv')),\n ('Unbalanced TV with Group Lasso Reg.', dict(reg=0.02, reg_type=reg_type_gl, unbalanced=0.01, unbalanced_type='tv')),\n\n]\n\nlst_plans = []\nfor (name, param) in lst_solvers:\n G = ot.solve(M, a, b, **param).plan\n lst_plans.append(G)"
113+
]
114+
},
115+
{
116+
"cell_type": "markdown",
117+
"metadata": {},
118+
"source": [
119+
"## Plot plans\n\n"
120+
]
121+
},
122+
{
123+
"cell_type": "code",
124+
"execution_count": null,
125+
"metadata": {
126+
"collapsed": false
127+
},
128+
"outputs": [],
129+
"source": [
130+
"pl.figure(3, figsize=(9, 9))\n\nfor i, bname in enumerate(lst_unbalanced):\n for j, rname in enumerate(lst_regs):\n pl.subplot(len(lst_unbalanced), len(lst_regs), i * len(lst_regs) + j + 1)\n\n plan = lst_plans[i * len(lst_regs) + j]\n m2 = plan.sum(0)\n m1 = plan.sum(1)\n m1, m2 = m1 / a.max(), m2 / b.max()\n pl.imshow(plan, cmap='Greys')\n pl.plot(x, m2 * 10, 'r')\n pl.plot(m1 * 10, x, 'b')\n pl.plot(x, b / b.max() * 10, 'r', alpha=0.3)\n pl.plot(a / a.max() * 10, x, 'b', alpha=0.3)\n #pl.axis('off')\n pl.tick_params(left=False, right=False, labelleft=False,\n labelbottom=False, bottom=False)\n if i == 0:\n pl.title(rname)\n if j == 0:\n pl.ylabel(bname, fontsize=14)"
131+
]
132+
}
133+
],
134+
"metadata": {
135+
"kernelspec": {
136+
"display_name": "Python 3",
137+
"language": "python",
138+
"name": "python3"
139+
},
140+
"language_info": {
141+
"codemirror_mode": {
142+
"name": "ipython",
143+
"version": 3
144+
},
145+
"file_extension": ".py",
146+
"mimetype": "text/x-python",
147+
"name": "python",
148+
"nbconvert_exporter": "python",
149+
"pygments_lexer": "ipython3",
150+
"version": "3.10.14"
151+
}
152+
},
153+
"nbformat": 4,
154+
"nbformat_minor": 0
155+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
======================================
4+
Optimal Transport solvers comparison
5+
======================================
6+
7+
This example illustrates the solutions returns for diffrent variants of exact,
8+
regularized and unbalanced OT solvers.
9+
"""
10+
11+
# Author: Remi Flamary <[email protected]>
12+
#
13+
# License: MIT License
14+
# sphinx_gallery_thumbnail_number = 3
15+
16+
#%%
17+
18+
import numpy as np
19+
import matplotlib.pylab as pl
20+
import ot
21+
import ot.plot
22+
from ot.datasets import make_1D_gauss as gauss
23+
24+
##############################################################################
25+
# Generate data
26+
# -------------
27+
28+
29+
#%% parameters
30+
31+
n = 50 # nb bins
32+
33+
# bin positions
34+
x = np.arange(n, dtype=np.float64)
35+
36+
# Gaussian distributions
37+
a = 0.6 * gauss(n, m=15, s=5) + 0.4 * gauss(n, m=35, s=5) # m= mean, s= std
38+
b = gauss(n, m=25, s=5)
39+
40+
# loss matrix
41+
M = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)))
42+
M /= M.max()
43+
44+
45+
##############################################################################
46+
# Plot distributions and loss matrix
47+
# ----------------------------------
48+
49+
#%% plot the distributions
50+
51+
pl.figure(1, figsize=(6.4, 3))
52+
pl.plot(x, a, 'b', label='Source distribution')
53+
pl.plot(x, b, 'r', label='Target distribution')
54+
pl.legend()
55+
56+
#%% plot distributions and loss matrix
57+
58+
pl.figure(2, figsize=(5, 5))
59+
ot.plot.plot1D_mat(a, b, M, 'Cost matrix M')
60+
61+
##############################################################################
62+
# Define Group lasso regularization and gradient
63+
# ------------------------------------------------
64+
# The groups are the first and second half of the columns of G
65+
66+
67+
def reg_gl(G): # group lasso + small l2 reg
68+
G1 = G[:n // 2, :]**2
69+
G2 = G[n // 2:, :]**2
70+
gl1 = np.sum(np.sqrt(np.sum(G1, 0)))
71+
gl2 = np.sum(np.sqrt(np.sum(G2, 0)))
72+
return gl1 + gl2 + 0.1 * np.sum(G**2)
73+
74+
75+
def grad_gl(G): # gradient of group lasso + small l2 reg
76+
G1 = G[:n // 2, :]
77+
G2 = G[n // 2:, :]
78+
gl1 = G1 / np.sqrt(np.sum(G1**2, 0, keepdims=True) + 1e-8)
79+
gl2 = G2 / np.sqrt(np.sum(G2**2, 0, keepdims=True) + 1e-8)
80+
return np.concatenate((gl1, gl2), axis=0) + 0.2 * G
81+
82+
83+
reg_type_gl = (reg_gl, grad_gl)
84+
85+
# %%
86+
# Set up parameters for solvers and solve
87+
# ---------------------------------------
88+
89+
lst_regs = ["No Reg.", "Entropic", "L2", "Group Lasso + L2"]
90+
lst_unbalanced = ["Balanced", "Unbalanced KL", 'Unbalanced L2', 'Unb. TV (Partial)'] # ["Balanced", "Unb. KL", "Unb. L2", "Unb L1 (partial)"]
91+
92+
lst_solvers = [ # name, param for ot.solve function
93+
# balanced OT
94+
('Exact OT', dict()),
95+
('Entropic Reg. OT', dict(reg=0.005)),
96+
('L2 Reg OT', dict(reg=1, reg_type='l2')),
97+
('Group Lasso Reg. OT', dict(reg=0.1, reg_type=reg_type_gl)),
98+
99+
100+
# unbalanced OT KL
101+
('Unbalanced KL No Reg.', dict(unbalanced=0.005)),
102+
('Unbalanced KL wit KL Reg.', dict(reg=0.0005, unbalanced=0.005, unbalanced_type='kl', reg_type='kl')),
103+
('Unbalanced KL with L2 Reg.', dict(reg=0.5, reg_type='l2', unbalanced=0.005, unbalanced_type='kl')),
104+
('Unbalanced KL with Group Lasso Reg.', dict(reg=0.1, reg_type=reg_type_gl, unbalanced=0.05, unbalanced_type='kl')),
105+
106+
# unbalanced OT L2
107+
('Unbalanced L2 No Reg.', dict(unbalanced=0.5, unbalanced_type='l2')),
108+
('Unbalanced L2 with KL Reg.', dict(reg=0.001, unbalanced=0.2, unbalanced_type='l2')),
109+
('Unbalanced L2 with L2 Reg.', dict(reg=0.1, reg_type='l2', unbalanced=0.2, unbalanced_type='l2')),
110+
('Unbalanced L2 with Group Lasso Reg.', dict(reg=0.05, reg_type=reg_type_gl, unbalanced=0.7, unbalanced_type='l2')),
111+
112+
# unbalanced OT TV
113+
('Unbalanced TV No Reg.', dict(unbalanced=0.1, unbalanced_type='tv')),
114+
('Unbalanced TV with KL Reg.', dict(reg=0.001, unbalanced=0.01, unbalanced_type='tv')),
115+
('Unbalanced TV with L2 Reg.', dict(reg=0.1, reg_type='l2', unbalanced=0.01, unbalanced_type='tv')),
116+
('Unbalanced TV with Group Lasso Reg.', dict(reg=0.02, reg_type=reg_type_gl, unbalanced=0.01, unbalanced_type='tv')),
117+
118+
]
119+
120+
lst_plans = []
121+
for (name, param) in lst_solvers:
122+
G = ot.solve(M, a, b, **param).plan
123+
lst_plans.append(G)
124+
125+
##############################################################################
126+
# Plot plans
127+
# ----------
128+
129+
pl.figure(3, figsize=(9, 9))
130+
131+
for i, bname in enumerate(lst_unbalanced):
132+
for j, rname in enumerate(lst_regs):
133+
pl.subplot(len(lst_unbalanced), len(lst_regs), i * len(lst_regs) + j + 1)
134+
135+
plan = lst_plans[i * len(lst_regs) + j]
136+
m2 = plan.sum(0)
137+
m1 = plan.sum(1)
138+
m1, m2 = m1 / a.max(), m2 / b.max()
139+
pl.imshow(plan, cmap='Greys')
140+
pl.plot(x, m2 * 10, 'r')
141+
pl.plot(m1 * 10, x, 'b')
142+
pl.plot(x, b / b.max() * 10, 'r', alpha=0.3)
143+
pl.plot(a / a.max() * 10, x, 'b', alpha=0.3)
144+
#pl.axis('off')
145+
pl.tick_params(left=False, right=False, labelleft=False,
146+
labelbottom=False, bottom=False)
147+
if i == 0:
148+
pl.title(rname)
149+
if j == 0:
150+
pl.ylabel(bname, fontsize=14)
-686 Bytes
-233 Bytes
-258 Bytes
-259 Bytes
123 Bytes
-58 Bytes
142 Bytes

master/_modules/index.html

+3-3
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
<head>
44
<meta charset="utf-8" />
55
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
6-
<title>Overview: module code &mdash; POT Python Optimal Transport 0.9.3dev documentation</title>
6+
<title>Overview: module code &mdash; POT Python Optimal Transport 0.9.4dev documentation</title>
77
<link rel="stylesheet" type="text/css" href="../_static/pygments.css?v=80d5e7a1" />
88
<link rel="stylesheet" type="text/css" href="../_static/css/theme.css?v=19f00094" />
99
<link rel="stylesheet" type="text/css" href="../_static/sg_gallery.css?v=61a4c737" />
@@ -18,7 +18,7 @@
1818

1919
<script src="../_static/jquery.js?v=5d32c60e"></script>
2020
<script src="../_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
21-
<script src="../_static/documentation_options.js?v=8099e02f"></script>
21+
<script src="../_static/documentation_options.js?v=fccb5469"></script>
2222
<script src="../_static/doctools.js?v=9a2dae69"></script>
2323
<script src="../_static/sphinx_highlight.js?v=dc90522c"></script>
2424
<script src="../_static/js/theme.js"></script>
@@ -39,7 +39,7 @@
3939
<img src="../_static/logo_dark.svg" class="logo" alt="Logo"/>
4040
</a>
4141
<div class="version">
42-
0.9.3dev
42+
0.9.4dev
4343
</div>
4444
<div role="search">
4545
<form id="rtd-search-form" class="wy-form" action="../search.html" method="get">

master/_modules/ot/backend.html

+3-3
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
<head>
44
<meta charset="utf-8" />
55
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
6-
<title>ot.backend &mdash; POT Python Optimal Transport 0.9.3dev documentation</title>
6+
<title>ot.backend &mdash; POT Python Optimal Transport 0.9.4dev documentation</title>
77
<link rel="stylesheet" type="text/css" href="../../_static/pygments.css?v=80d5e7a1" />
88
<link rel="stylesheet" type="text/css" href="../../_static/css/theme.css?v=19f00094" />
99
<link rel="stylesheet" type="text/css" href="../../_static/sg_gallery.css?v=61a4c737" />
@@ -18,7 +18,7 @@
1818

1919
<script src="../../_static/jquery.js?v=5d32c60e"></script>
2020
<script src="../../_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
21-
<script src="../../_static/documentation_options.js?v=8099e02f"></script>
21+
<script src="../../_static/documentation_options.js?v=fccb5469"></script>
2222
<script src="../../_static/doctools.js?v=9a2dae69"></script>
2323
<script src="../../_static/sphinx_highlight.js?v=dc90522c"></script>
2424
<script src="../../_static/js/theme.js"></script>
@@ -39,7 +39,7 @@
3939
<img src="../../_static/logo_dark.svg" class="logo" alt="Logo"/>
4040
</a>
4141
<div class="version">
42-
0.9.3dev
42+
0.9.4dev
4343
</div>
4444
<div role="search">
4545
<form id="rtd-search-form" class="wy-form" action="../../search.html" method="get">

master/_modules/ot/bregman/_barycenter.html

+3-3
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
<head>
44
<meta charset="utf-8" />
55
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
6-
<title>ot.bregman._barycenter &mdash; POT Python Optimal Transport 0.9.3dev documentation</title>
6+
<title>ot.bregman._barycenter &mdash; POT Python Optimal Transport 0.9.4dev documentation</title>
77
<link rel="stylesheet" type="text/css" href="../../../_static/pygments.css?v=80d5e7a1" />
88
<link rel="stylesheet" type="text/css" href="../../../_static/css/theme.css?v=19f00094" />
99
<link rel="stylesheet" type="text/css" href="../../../_static/sg_gallery.css?v=61a4c737" />
@@ -18,7 +18,7 @@
1818

1919
<script src="../../../_static/jquery.js?v=5d32c60e"></script>
2020
<script src="../../../_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
21-
<script src="../../../_static/documentation_options.js?v=8099e02f"></script>
21+
<script src="../../../_static/documentation_options.js?v=fccb5469"></script>
2222
<script src="../../../_static/doctools.js?v=9a2dae69"></script>
2323
<script src="../../../_static/sphinx_highlight.js?v=dc90522c"></script>
2424
<script src="../../../_static/js/theme.js"></script>
@@ -39,7 +39,7 @@
3939
<img src="../../../_static/logo_dark.svg" class="logo" alt="Logo"/>
4040
</a>
4141
<div class="version">
42-
0.9.3dev
42+
0.9.4dev
4343
</div>
4444
<div role="search">
4545
<form id="rtd-search-form" class="wy-form" action="../../../search.html" method="get">

0 commit comments

Comments
 (0)