Skip to content

Commit 2601b7f

Browse files
Circle CICircle CI
Circle CI
authored and
Circle CI
committed
CircleCI update of dev docs (2555).
1 parent 8fd4226 commit 2601b7f

File tree

195 files changed

+733710
-731647
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

195 files changed

+733710
-731647
lines changed
Binary file not shown.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
========================================
4+
Low rank Gromov-Wasterstein between samples
5+
========================================
6+
7+
Comparaison between entropic Gromov-Wasserstein and Low Rank Gromov Wasserstein [67]
8+
on two curves in 2D and 3D, both sampled with 200 points.
9+
10+
The squared Euclidean distance is considered as the ground cost for both samples.
11+
12+
[67] Scetbon, M., Peyré, G. & Cuturi, M. (2022).
13+
"Linear-Time GromovWasserstein Distances using Low Rank Couplings and Costs".
14+
In International Conference on Machine Learning (ICML), 2022.
15+
"""
16+
17+
# Author: Laurène David <[email protected]>
18+
#
19+
# License: MIT License
20+
#
21+
# sphinx_gallery_thumbnail_number = 3
22+
23+
#%%
24+
import numpy as np
25+
import matplotlib.pylab as pl
26+
import ot.plot
27+
import time
28+
29+
##############################################################################
30+
# Generate data
31+
# -------------
32+
33+
#%% parameters
34+
n_samples = 200
35+
36+
# Generate 2D and 3D curves
37+
theta = np.linspace(-4 * np.pi, 4 * np.pi, n_samples)
38+
z = np.linspace(1, 2, n_samples)
39+
r = z**2 + 1
40+
x = r * np.sin(theta)
41+
y = r * np.cos(theta)
42+
43+
# Source and target distribution
44+
X = np.concatenate([x.reshape(-1, 1), z.reshape(-1, 1)], axis=1)
45+
Y = np.concatenate([x.reshape(-1, 1), y.reshape(-1, 1), z.reshape(-1, 1)], axis=1)
46+
47+
48+
##############################################################################
49+
# Plot data
50+
# ------------
51+
52+
#%%
53+
# Plot the source and target samples
54+
fig = pl.figure(1, figsize=(10, 4))
55+
56+
ax = fig.add_subplot(121)
57+
ax.plot(X[:, 0], X[:, 1], color="blue", linewidth=6)
58+
ax.tick_params(left=False, right=False, labelleft=False,
59+
labelbottom=False, bottom=False)
60+
ax.set_title("2D curve (source)")
61+
62+
ax2 = fig.add_subplot(122, projection="3d")
63+
ax2.plot(Y[:, 0], Y[:, 1], Y[:, 2], c='red', linewidth=6)
64+
ax2.tick_params(left=False, right=False, labelleft=False,
65+
labelbottom=False, bottom=False)
66+
ax2.view_init(15, -50)
67+
ax2.set_title("3D curve (target)")
68+
69+
pl.tight_layout()
70+
pl.show()
71+
72+
73+
##############################################################################
74+
# Entropic Gromov-Wasserstein
75+
# ------------
76+
77+
#%%
78+
79+
# Compute cost matrices
80+
C1 = ot.dist(X, X, metric="sqeuclidean")
81+
C2 = ot.dist(Y, Y, metric="sqeuclidean")
82+
83+
# Scale cost matrices
84+
r1 = C1.max()
85+
r2 = C2.max()
86+
87+
C1 = C1 / r1
88+
C2 = C2 / r2
89+
90+
91+
# Solve entropic gw
92+
reg = 5 * 1e-3
93+
94+
start = time.time()
95+
gw, log = ot.gromov.entropic_gromov_wasserstein(
96+
C1, C2, tol=1e-3, epsilon=reg,
97+
log=True, verbose=False)
98+
99+
end = time.time()
100+
time_entropic = end - start
101+
102+
entropic_gw_loss = np.round(log['gw_dist'], 3)
103+
104+
# Plot entropic gw
105+
pl.figure(2)
106+
pl.imshow(gw, interpolation="nearest", aspect="auto")
107+
pl.title("Entropic Gromov-Wasserstein (loss={})".format(entropic_gw_loss))
108+
pl.show()
109+
110+
111+
##############################################################################
112+
# Low rank squared euclidean cost matrices
113+
# ------------
114+
# %%
115+
116+
# Compute the low rank sqeuclidean cost decompositions
117+
A1, A2 = ot.lowrank.compute_lr_sqeuclidean_matrix(X, X, rescale_cost=False)
118+
B1, B2 = ot.lowrank.compute_lr_sqeuclidean_matrix(Y, Y, rescale_cost=False)
119+
120+
# Scale the low rank cost matrices
121+
A1, A2 = A1 / np.sqrt(r1), A2 / np.sqrt(r1)
122+
B1, B2 = B1 / np.sqrt(r2), B2 / np.sqrt(r2)
123+
124+
125+
##############################################################################
126+
# Low rank Gromov-Wasserstein
127+
# ------------
128+
# %%
129+
130+
# Solve low rank gromov-wasserstein with different ranks
131+
list_rank = [10, 50]
132+
list_P_GW = []
133+
list_loss_GW = []
134+
list_time_GW = []
135+
136+
for rank in list_rank:
137+
start = time.time()
138+
139+
Q, R, g, log = ot.lowrank_gromov_wasserstein_samples(
140+
X, Y, reg=0, rank=rank, rescale_cost=False, cost_factorized_Xs=(A1, A2),
141+
cost_factorized_Xt=(B1, B2), seed_init=49, numItermax=1000, log=True, stopThr=1e-6,
142+
)
143+
end = time.time()
144+
145+
P = log["lazy_plan"][:]
146+
loss = log["value"]
147+
148+
list_P_GW.append(P)
149+
list_loss_GW.append(np.round(loss, 3))
150+
list_time_GW.append(end - start)
151+
152+
153+
# %%
154+
# Plot low rank GW with different ranks
155+
pl.figure(3, figsize=(10, 4))
156+
157+
pl.subplot(1, 2, 1)
158+
pl.imshow(list_P_GW[0], interpolation="nearest", aspect="auto")
159+
pl.title('Low rank GW (rank=10, loss={})'.format(list_loss_GW[0]))
160+
161+
pl.subplot(1, 2, 2)
162+
pl.imshow(list_P_GW[1], interpolation="nearest", aspect="auto")
163+
pl.title('Low rank GW (rank=50, loss={})'.format(list_loss_GW[1]))
164+
165+
pl.tight_layout()
166+
pl.show()
167+
168+
169+
# %%
170+
# Compare computation time between entropic GW and low rank GW
171+
print("Entropic GW: {:.2f}s".format(time_entropic))
172+
print("Low rank GW (rank=10): {:.2f}s".format(list_time_GW[0]))
173+
print("Low rank GW (rank=50): {:.2f}s".format(list_time_GW[1]))
Binary file not shown.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {},
6+
"source": [
7+
"\n# Low rank Gromov-Wasterstein between samples\n\nComparaison between entropic Gromov-Wasserstein and Low Rank Gromov Wasserstein [67]\non two curves in 2D and 3D, both sampled with 200 points.\n\nThe squared Euclidean distance is considered as the ground cost for both samples.\n\n[67] Scetbon, M., Peyr\u00e9, G. & Cuturi, M. (2022).\n\"Linear-Time GromovWasserstein Distances using Low Rank Couplings and Costs\".\nIn International Conference on Machine Learning (ICML), 2022.\n"
8+
]
9+
},
10+
{
11+
"cell_type": "code",
12+
"execution_count": null,
13+
"metadata": {
14+
"collapsed": false
15+
},
16+
"outputs": [],
17+
"source": [
18+
"# Author: Laur\u00e8ne David <[email protected]>\n#\n# License: MIT License\n#\n# sphinx_gallery_thumbnail_number = 3"
19+
]
20+
},
21+
{
22+
"cell_type": "code",
23+
"execution_count": null,
24+
"metadata": {
25+
"collapsed": false
26+
},
27+
"outputs": [],
28+
"source": [
29+
"import numpy as np\nimport matplotlib.pylab as pl\nimport ot.plot\nimport time"
30+
]
31+
},
32+
{
33+
"cell_type": "markdown",
34+
"metadata": {},
35+
"source": [
36+
"## Generate data\n\n"
37+
]
38+
},
39+
{
40+
"cell_type": "code",
41+
"execution_count": null,
42+
"metadata": {
43+
"collapsed": false
44+
},
45+
"outputs": [],
46+
"source": [
47+
"n_samples = 200\n\n# Generate 2D and 3D curves\ntheta = np.linspace(-4 * np.pi, 4 * np.pi, n_samples)\nz = np.linspace(1, 2, n_samples)\nr = z**2 + 1\nx = r * np.sin(theta)\ny = r * np.cos(theta)\n\n# Source and target distribution\nX = np.concatenate([x.reshape(-1, 1), z.reshape(-1, 1)], axis=1)\nY = np.concatenate([x.reshape(-1, 1), y.reshape(-1, 1), z.reshape(-1, 1)], axis=1)"
48+
]
49+
},
50+
{
51+
"cell_type": "markdown",
52+
"metadata": {},
53+
"source": [
54+
"## Plot data\n\n"
55+
]
56+
},
57+
{
58+
"cell_type": "markdown",
59+
"metadata": {},
60+
"source": [
61+
"Plot the source and target samples\n\n"
62+
]
63+
},
64+
{
65+
"cell_type": "code",
66+
"execution_count": null,
67+
"metadata": {
68+
"collapsed": false
69+
},
70+
"outputs": [],
71+
"source": [
72+
"fig = pl.figure(1, figsize=(10, 4))\n\nax = fig.add_subplot(121)\nax.plot(X[:, 0], X[:, 1], color=\"blue\", linewidth=6)\nax.tick_params(left=False, right=False, labelleft=False,\n labelbottom=False, bottom=False)\nax.set_title(\"2D curve (source)\")\n\nax2 = fig.add_subplot(122, projection=\"3d\")\nax2.plot(Y[:, 0], Y[:, 1], Y[:, 2], c='red', linewidth=6)\nax2.tick_params(left=False, right=False, labelleft=False,\n labelbottom=False, bottom=False)\nax2.view_init(15, -50)\nax2.set_title(\"3D curve (target)\")\n\npl.tight_layout()\npl.show()"
73+
]
74+
},
75+
{
76+
"cell_type": "markdown",
77+
"metadata": {},
78+
"source": [
79+
"## Entropic Gromov-Wasserstein\n\n"
80+
]
81+
},
82+
{
83+
"cell_type": "code",
84+
"execution_count": null,
85+
"metadata": {
86+
"collapsed": false
87+
},
88+
"outputs": [],
89+
"source": [
90+
"# Compute cost matrices\nC1 = ot.dist(X, X, metric=\"sqeuclidean\")\nC2 = ot.dist(Y, Y, metric=\"sqeuclidean\")\n\n# Scale cost matrices\nr1 = C1.max()\nr2 = C2.max()\n\nC1 = C1 / r1\nC2 = C2 / r2\n\n\n# Solve entropic gw\nreg = 5 * 1e-3\n\nstart = time.time()\ngw, log = ot.gromov.entropic_gromov_wasserstein(\n C1, C2, tol=1e-3, epsilon=reg,\n log=True, verbose=False)\n\nend = time.time()\ntime_entropic = end - start\n\nentropic_gw_loss = np.round(log['gw_dist'], 3)\n\n# Plot entropic gw\npl.figure(2)\npl.imshow(gw, interpolation=\"nearest\", aspect=\"auto\")\npl.title(\"Entropic Gromov-Wasserstein (loss={})\".format(entropic_gw_loss))\npl.show()"
91+
]
92+
},
93+
{
94+
"cell_type": "markdown",
95+
"metadata": {},
96+
"source": [
97+
"## Low rank squared euclidean cost matrices\n%%\n\n"
98+
]
99+
},
100+
{
101+
"cell_type": "code",
102+
"execution_count": null,
103+
"metadata": {
104+
"collapsed": false
105+
},
106+
"outputs": [],
107+
"source": [
108+
"# Compute the low rank sqeuclidean cost decompositions\nA1, A2 = ot.lowrank.compute_lr_sqeuclidean_matrix(X, X, rescale_cost=False)\nB1, B2 = ot.lowrank.compute_lr_sqeuclidean_matrix(Y, Y, rescale_cost=False)\n\n# Scale the low rank cost matrices\nA1, A2 = A1 / np.sqrt(r1), A2 / np.sqrt(r1)\nB1, B2 = B1 / np.sqrt(r2), B2 / np.sqrt(r2)"
109+
]
110+
},
111+
{
112+
"cell_type": "markdown",
113+
"metadata": {},
114+
"source": [
115+
"## Low rank Gromov-Wasserstein\n%%\n\n"
116+
]
117+
},
118+
{
119+
"cell_type": "code",
120+
"execution_count": null,
121+
"metadata": {
122+
"collapsed": false
123+
},
124+
"outputs": [],
125+
"source": [
126+
"# Solve low rank gromov-wasserstein with different ranks\nlist_rank = [10, 50]\nlist_P_GW = []\nlist_loss_GW = []\nlist_time_GW = []\n\nfor rank in list_rank:\n start = time.time()\n\n Q, R, g, log = ot.lowrank_gromov_wasserstein_samples(\n X, Y, reg=0, rank=rank, rescale_cost=False, cost_factorized_Xs=(A1, A2),\n cost_factorized_Xt=(B1, B2), seed_init=49, numItermax=1000, log=True, stopThr=1e-6,\n )\n end = time.time()\n\n P = log[\"lazy_plan\"][:]\n loss = log[\"value\"]\n\n list_P_GW.append(P)\n list_loss_GW.append(np.round(loss, 3))\n list_time_GW.append(end - start)"
127+
]
128+
},
129+
{
130+
"cell_type": "markdown",
131+
"metadata": {},
132+
"source": [
133+
"Plot low rank GW with different ranks\n\n"
134+
]
135+
},
136+
{
137+
"cell_type": "code",
138+
"execution_count": null,
139+
"metadata": {
140+
"collapsed": false
141+
},
142+
"outputs": [],
143+
"source": [
144+
"pl.figure(3, figsize=(10, 4))\n\npl.subplot(1, 2, 1)\npl.imshow(list_P_GW[0], interpolation=\"nearest\", aspect=\"auto\")\npl.title('Low rank GW (rank=10, loss={})'.format(list_loss_GW[0]))\n\npl.subplot(1, 2, 2)\npl.imshow(list_P_GW[1], interpolation=\"nearest\", aspect=\"auto\")\npl.title('Low rank GW (rank=50, loss={})'.format(list_loss_GW[1]))\n\npl.tight_layout()\npl.show()"
145+
]
146+
},
147+
{
148+
"cell_type": "markdown",
149+
"metadata": {},
150+
"source": [
151+
"Compare computation time between entropic GW and low rank GW\n\n"
152+
]
153+
},
154+
{
155+
"cell_type": "code",
156+
"execution_count": null,
157+
"metadata": {
158+
"collapsed": false
159+
},
160+
"outputs": [],
161+
"source": [
162+
"print(\"Entropic GW: {:.2f}s\".format(time_entropic))\nprint(\"Low rank GW (rank=10): {:.2f}s\".format(list_time_GW[0]))\nprint(\"Low rank GW (rank=50): {:.2f}s\".format(list_time_GW[1]))"
163+
]
164+
}
165+
],
166+
"metadata": {
167+
"kernelspec": {
168+
"display_name": "Python 3",
169+
"language": "python",
170+
"name": "python3"
171+
},
172+
"language_info": {
173+
"codemirror_mode": {
174+
"name": "ipython",
175+
"version": 3
176+
},
177+
"file_extension": ".py",
178+
"mimetype": "text/x-python",
179+
"name": "python",
180+
"nbconvert_exporter": "python",
181+
"pygments_lexer": "ipython3",
182+
"version": "3.10.14"
183+
}
184+
},
185+
"nbformat": 4,
186+
"nbformat_minor": 0
187+
}
Loading
-330 Bytes
Loading
Loading
Loading
192 Bytes
Loading
Loading
198 Bytes
Loading
Loading
Loading
Loading
Loading
Loading
Loading
Loading
Loading
Loading
-213 Bytes
Loading
Loading
Loading
Loading
Loading
-339 Bytes
Loading
Loading
Loading
Loading
Loading
Loading
58.3 KB
Loading
75.9 KB
Loading
Loading
Loading
Loading
-182 Bytes
Loading
202 Bytes
Loading
Loading

master/_modules/index.html

+1
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ <h1>All modules for which code is available</h1>
104104
<li><a href="ot/gromov/_dictionary.html">ot.gromov._dictionary</a></li>
105105
<li><a href="ot/gromov/_estimators.html">ot.gromov._estimators</a></li>
106106
<li><a href="ot/gromov/_gw.html">ot.gromov._gw</a></li>
107+
<li><a href="ot/gromov/_lowrank.html">ot.gromov._lowrank</a></li>
107108
<li><a href="ot/gromov/_semirelaxed.html">ot.gromov._semirelaxed</a></li>
108109
<li><a href="ot/gromov/_utils.html">ot.gromov._utils</a></li>
109110
<li><a href="ot/lowrank.html">ot.lowrank</a></li>

0 commit comments

Comments
 (0)