Skip to content

Commit eb32ae8

Browse files
Circle CICircle CI
Circle CI
authored and
Circle CI
committed
CircleCI update of dev docs (2804).
1 parent 95373ce commit eb32ae8

File tree

277 files changed

+734614
-732468
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

277 files changed

+734614
-732468
lines changed
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,266 @@
1+
# -*- coding: utf-8 -*-
2+
3+
r"""
4+
=====================================================
5+
Semi-relaxed (Fused) Gromov-Wasserstein Barycenter as Dictionary Learning
6+
=====================================================
7+
8+
In this example, we illustrate how to learn a semi-relaxed Gromov-Wasserstein
9+
(srGW) barycenter using a Block-Coordinate Descent algorithm, on a dataset of
10+
structured data such as graphs, denoted :math:`\{ \mathbf{C_s} \}_{s \in [S]}`
11+
where every nodes have uniform weights :math:`\{ \mathbf{p_s} \}_{s \in [S]}`.
12+
Given a barycenter structure matrix :math:`\mathbf{C}` with N nodes,
13+
each graph :math:`(\mathbf{C_s}, \mathbf{p_s})` is modeled as a reweighed subgraph
14+
with structure :math:`\mathbf{C}` and weights :math:`\mathbf{w_s} \in \Sigma_N`
15+
where each :math:`\mathbf{w_s}` corresponds to the second marginal of the OT
16+
:math:`\mathbf{T_s}` (s.t :math:`\mathbf{w_s} = \mathbf{T_s}^\top \mathbf{1}`)
17+
minimizing the srGW loss between the s^{th} input and the barycenter.
18+
19+
20+
First, we consider a dataset composed of graphs generated by Stochastic Block models
21+
with variable sizes taken in :math:`\{30, ... , 50\}` and number of clusters
22+
varying in :math:`\{ 1, 2, 3\}` with random proportions. We learn a srGW barycenter
23+
with 3 nodes and visualize the learned structure and the embeddings for some inputs.
24+
25+
Second, we illustrate the extension of this framework to graphs endowed
26+
with node features by using the semi-relaxed Fused Gromov-Wasserstein
27+
divergence (srFGW). Starting from the aforementioned dataset of unattributed graphs, we
28+
add discrete labels uniformly depending on the number of clusters. Then conduct
29+
the analog analysis.
30+
31+
32+
[48] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, Nicolas Courty.
33+
"Semi-relaxed Gromov-Wasserstein divergence and applications on graphs".
34+
International Conference on Learning Representations (ICLR), 2022.
35+
36+
"""
37+
# Author: Cédric Vincent-Cuaz <[email protected]>
38+
#
39+
# License: MIT License
40+
41+
# sphinx_gallery_thumbnail_number = 2
42+
43+
import numpy as np
44+
import matplotlib.pylab as pl
45+
from sklearn.manifold import MDS
46+
from ot.gromov import (
47+
semirelaxed_gromov_barycenters, semirelaxed_fgw_barycenters)
48+
import ot
49+
import networkx
50+
from networkx.generators.community import stochastic_block_model as sbm
51+
52+
#############################################################################
53+
#
54+
# Generate a dataset composed of graphs following Stochastic Block models of 1, 2 and 3 clusters.
55+
# -----------------------------------------------------------------------------------------------
56+
57+
np.random.seed(42)
58+
59+
n_samples = 60 # number of graphs in the dataset
60+
# For every number of clusters, we generate SBM with fixed inter/intra-clusters probability,
61+
# and variable cluster proportions.
62+
clusters = [1, 2, 3]
63+
Nc = n_samples // len(clusters) # number of graphs by cluster
64+
nlabels = len(clusters)
65+
dataset = []
66+
node_labels = []
67+
labels = []
68+
69+
p_inter = 0.1
70+
p_intra = 0.9
71+
for n_cluster in clusters:
72+
for i in range(Nc):
73+
n_nodes = int(np.random.uniform(low=30, high=50))
74+
75+
if n_cluster > 1:
76+
P = p_inter * np.ones((n_cluster, n_cluster))
77+
np.fill_diagonal(P, p_intra)
78+
props = np.random.uniform(0.2, 1, size=(n_cluster,))
79+
props /= props.sum()
80+
sizes = np.round(n_nodes * props).astype(np.int32)
81+
else:
82+
P = p_intra * np.eye(1)
83+
sizes = [n_nodes]
84+
85+
G = sbm(sizes, P, seed=i, directed=False)
86+
part = np.array([G.nodes[i]['block'] for i in range(np.sum(sizes))])
87+
C = networkx.to_numpy_array(G)
88+
dataset.append(C)
89+
node_labels.append(part)
90+
labels.append(n_cluster)
91+
92+
93+
# Visualize samples
94+
95+
def plot_graph(x, C, binary=True, color='C0', s=None):
96+
for j in range(C.shape[0]):
97+
for i in range(j):
98+
if binary:
99+
if C[i, j] > 0:
100+
pl.plot([x[i, 0], x[j, 0]], [x[i, 1], x[j, 1]], alpha=0.2, color='k')
101+
else: # connection intensity proportional to C[i,j]
102+
pl.plot([x[i, 0], x[j, 0]], [x[i, 1], x[j, 1]], alpha=C[i, j], color='k')
103+
104+
pl.scatter(x[:, 0], x[:, 1], c=color, s=s, zorder=10, edgecolors='k', cmap='tab10', vmax=9)
105+
106+
107+
pl.figure(1, (12, 8))
108+
pl.clf()
109+
for idx_c, c in enumerate(clusters):
110+
C = dataset[(c - 1) * Nc] # sample with c clusters
111+
# get 2d position for nodes
112+
x = MDS(dissimilarity='precomputed', random_state=0).fit_transform(1 - C)
113+
pl.subplot(2, nlabels, c)
114+
pl.title('(graph) sample from label ' + str(c), fontsize=14)
115+
plot_graph(x, C, binary=True, color='C0', s=50.)
116+
pl.axis("off")
117+
pl.subplot(2, nlabels, nlabels + c)
118+
pl.title('(matrix) sample from label %s \n' % c, fontsize=14)
119+
pl.imshow(C, interpolation='nearest')
120+
pl.axis("off")
121+
pl.tight_layout()
122+
pl.show()
123+
124+
#############################################################################
125+
#
126+
# Estimate the srGW barycenter from the dataset and visualize embeddings
127+
# -----------------------------------------------------------
128+
129+
130+
np.random.seed(0)
131+
ps = [ot.unif(C.shape[0]) for C in dataset] # uniform weights on input nodes
132+
lambdas = [1. / n_samples for _ in range(n_samples)] # uniform barycenter
133+
N = 3 # 3 nodes in the barycenter
134+
135+
# Here we use the Fluid partitioning method to deduce initial transport plans
136+
# for the barycenter problem. An initlal structure is also deduced from these
137+
# initial transport plans. Then a warmstart strategy is used iteratively to
138+
# init each individual srGW problem within the BCD algorithm.
139+
140+
init_plan = 'fluid' # notice that several init options are implemented in `ot.gromov.semirelaxed_init_plan`
141+
warmstartT = True
142+
143+
C, log = semirelaxed_gromov_barycenters(
144+
N=N, Cs=dataset, ps=ps, lambdas=lambdas, loss_fun='square_loss',
145+
tol=1e-6, stop_criterion='loss', warmstartT=warmstartT, log=True,
146+
G0=init_plan, verbose=False)
147+
148+
print('barycenter structure:', C)
149+
150+
unmixings = log['p']
151+
# Compute the 2D representation of the embeddings living in the 2-simplex of probability
152+
unmixings2D = np.zeros(shape=(n_samples, 2))
153+
for i, w in enumerate(unmixings):
154+
unmixings2D[i, 0] = (2. * w[1] + w[2]) / 2.
155+
unmixings2D[i, 1] = (np.sqrt(3.) * w[2]) / 2.
156+
x = [0., 0.]
157+
y = [1., 0.]
158+
z = [0.5, np.sqrt(3) / 2.]
159+
extremities = np.stack([x, y, z])
160+
161+
pl.figure(2, (4, 4))
162+
pl.clf()
163+
pl.title('Embedding space', fontsize=14)
164+
for cluster in range(nlabels):
165+
start, end = Nc * cluster, Nc * (cluster + 1)
166+
if cluster == 0:
167+
pl.scatter(unmixings2D[start:end, 0], unmixings2D[start:end, 1], c='C' + str(cluster), marker='o', s=80., label='1 cluster')
168+
else:
169+
pl.scatter(unmixings2D[start:end, 0], unmixings2D[start:end, 1], c='C' + str(cluster), marker='o', s=80., label='%s clusters' % (cluster + 1))
170+
pl.scatter(extremities[:, 0], extremities[:, 1], c='black', marker='x', s=100., label='bary. nodes')
171+
pl.plot([x[0], y[0]], [x[1], y[1]], color='black', linewidth=2.)
172+
pl.plot([x[0], z[0]], [x[1], z[1]], color='black', linewidth=2.)
173+
pl.plot([y[0], z[0]], [y[1], z[1]], color='black', linewidth=2.)
174+
pl.axis('off')
175+
pl.legend(fontsize=11)
176+
pl.tight_layout()
177+
pl.show()
178+
179+
#############################################################################
180+
#
181+
# Endow the dataset with node features
182+
# ------------------------------------
183+
# node labels, corresponding to the true SBM cluster assignments,
184+
# are set for each graph as one-hot encoded node features.
185+
186+
dataset_features = []
187+
for i in range(len(dataset)):
188+
n = dataset[i].shape[0]
189+
F = np.zeros((n, 3))
190+
F[np.arange(n), node_labels[i]] = 1.
191+
dataset_features.append(F)
192+
193+
pl.figure(3, (12, 8))
194+
pl.clf()
195+
for idx_c, c in enumerate(clusters):
196+
C = dataset[(c - 1) * Nc] # sample with c clusters
197+
F = dataset_features[(c - 1) * Nc]
198+
colors = [f'C{labels[i]}' for i in range(F.shape[0])]
199+
# get 2d position for nodes
200+
x = MDS(dissimilarity='precomputed', random_state=0).fit_transform(1 - C)
201+
pl.subplot(2, nlabels, c)
202+
pl.title('(graph) sample from label ' + str(c), fontsize=14)
203+
plot_graph(x, C, binary=True, color=colors, s=50)
204+
pl.axis("off")
205+
pl.subplot(2, nlabels, nlabels + c)
206+
pl.title('(matrix) sample from label %s \n' % c, fontsize=14)
207+
pl.imshow(C, interpolation='nearest')
208+
pl.axis("off")
209+
pl.tight_layout()
210+
pl.show()
211+
212+
#############################################################################
213+
#
214+
# Estimate the srFGW barycenter from the attributed graphs and visualize embeddings
215+
# -----------------------------------------------------------
216+
# We emphasize the dependence to the trade-off parameter alpha that weights the
217+
# relative importance between structures (alpha=1) and features (alpha=0),
218+
# knowing that embeddings that perfectly cluster graphs w.r.t their features
219+
# should ease the identification of the number of clusters in the graphs.
220+
221+
list_alphas = [0.0001, 0.5, 0.9999]
222+
list_unmixings2D = []
223+
224+
for ialpha, alpha in enumerate(list_alphas):
225+
print('--- alpha:', alpha)
226+
C, F, log = semirelaxed_fgw_barycenters(
227+
N=N, Ys=dataset_features, Cs=dataset, ps=ps, lambdas=lambdas,
228+
alpha=alpha, loss_fun='square_loss', tol=1e-6, stop_criterion='loss',
229+
warmstartT=warmstartT, log=True, G0=init_plan)
230+
231+
print('barycenter structure:', C)
232+
print('barycenter features:', F)
233+
234+
unmixings = log['p']
235+
# Compute the 2D representation of the embeddings living in the 2-simplex of probability
236+
unmixings2D = np.zeros(shape=(n_samples, 2))
237+
for i, w in enumerate(unmixings):
238+
unmixings2D[i, 0] = (2. * w[1] + w[2]) / 2.
239+
unmixings2D[i, 1] = (np.sqrt(3.) * w[2]) / 2.
240+
list_unmixings2D.append(unmixings2D.copy())
241+
242+
x = [0., 0.]
243+
y = [1., 0.]
244+
z = [0.5, np.sqrt(3) / 2.]
245+
extremities = np.stack([x, y, z])
246+
247+
pl.figure(4, (12, 4))
248+
pl.clf()
249+
pl.suptitle('Embedding spaces', fontsize=14)
250+
for ialpha, alpha in enumerate(list_alphas):
251+
pl.subplot(1, len(list_alphas), ialpha + 1)
252+
pl.title(f'alpha = {alpha}', fontsize=14)
253+
for cluster in range(nlabels):
254+
start, end = Nc * cluster, Nc * (cluster + 1)
255+
if cluster == 0:
256+
pl.scatter(list_unmixings2D[ialpha][start:end, 0], list_unmixings2D[ialpha][start:end, 1], c='C' + str(cluster), marker='o', s=80., label='1 cluster')
257+
else:
258+
pl.scatter(list_unmixings2D[ialpha][start:end, 0], list_unmixings2D[ialpha][start:end, 1], c='C' + str(cluster), marker='o', s=80., label='%s clusters' % (cluster + 1))
259+
pl.scatter(extremities[:, 0], extremities[:, 1], c='black', marker='x', s=100., label='bary. nodes')
260+
pl.plot([x[0], y[0]], [x[1], y[1]], color='black', linewidth=2.)
261+
pl.plot([x[0], z[0]], [x[1], z[1]], color='black', linewidth=2.)
262+
pl.plot([y[0], z[0]], [y[1], z[1]], color='black', linewidth=2.)
263+
pl.axis('off')
264+
pl.legend(fontsize=11)
265+
pl.tight_layout()
266+
pl.show()
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.

0 commit comments

Comments
 (0)