28
28
from scipy .stats import special_ortho_group
29
29
from tqdm import tqdm
30
30
31
- flags .DEFINE_list ('specification' , ['S^2' , 'S^2' ], 'List of submanifolds' )
32
- flags .DEFINE_integer ('npts' , 1000 , 'Number of data points' )
33
- flags .DEFINE_boolean ('rotate' , False , 'Apply random rotation to the data' )
34
-
35
- FLAGS = flags .FLAGS
31
+ SPECIFICATION = flags .DEFINE_list (
32
+ name = 'specification' , default = ['S^2' , 'S^2' ], help = 'List of submanifolds' )
33
+ NPTS = flags .DEFINE_integer (
34
+ name = 'npts' , default = 1000 , help = 'Number of data points' )
35
+ ROTATE = flags .DEFINE_boolean (
36
+ name = 'rotate' , default = False , help = 'Apply random rotation to the data' )
37
+ PLOT = flags .DEFINE_boolean (
38
+ name = 'plot' , default = True , help = 'Whether to enable plotting' )
36
39
37
40
38
41
def make_so_tangent (q ):
@@ -139,8 +142,8 @@ def make_product_manifold(specification, npts):
139
142
140
143
def main (_ ):
141
144
# Generate data and run GEOMANCER
142
- data , dim , tangents = make_product_manifold (FLAGS . specification , FLAGS . npts )
143
- if FLAGS . rotate :
145
+ data , dim , tangents = make_product_manifold (SPECIFICATION . value , NPTS . value )
146
+ if ROTATE . value :
144
147
rot , _ = np .linalg .qr (np .random .randn (data .shape [1 ], data .shape [1 ]))
145
148
data_rot = data @ rot .T
146
149
components , spectrum = geomancer .fit (data_rot , dim )
@@ -149,46 +152,49 @@ def main(_):
149
152
components , spectrum = geomancer .fit (data , dim )
150
153
errors = geomancer .eval_aligned (components , tangents )
151
154
152
- # Plot spectrum
153
- plt .figure (figsize = (8 , 6 ))
154
- plt .scatter (np .arange (len (spectrum )), spectrum , s = 100 )
155
- largest_gap = np .argmax (spectrum [1 :]- spectrum [:- 1 ]) + 1
156
- plt .axvline (largest_gap , linewidth = 2 , c = 'r' )
157
- plt .xticks ([])
158
- plt .yticks (fontsize = 18 )
159
- plt .xlabel ('Index' , fontsize = 24 )
160
- plt .ylabel ('Eigenvalue' , fontsize = 24 )
161
- plt .title ('GeoManCEr Eigenvalue Spectrum' , fontsize = 24 )
162
-
163
- # Plot subspace bases
164
- fig = plt .figure (figsize = (8 , 6 ))
165
- bases = components [0 ]
166
- gs = gridspec .GridSpec (1 , len (bases ),
167
- width_ratios = [b .shape [1 ] for b in bases ])
168
- for i in range (len (bases )):
169
- ax = plt .subplot (gs [i ])
170
- ax .imshow (bases [i ])
171
- ax .set_xticks ([])
172
- ax .set_yticks ([])
173
- ax .set_title (r'$T_{\mathbf{x}_1}\mathcal{M}_%d$' % (i + 1 ), fontsize = 18 )
174
- fig .canvas .set_window_title ('GeoManCEr Results' )
175
-
176
- # Plot ground truth
177
- fig = plt .figure (figsize = (8 , 6 ))
178
- gs = gridspec .GridSpec (1 , len (tangents ),
179
- width_ratios = [b .shape [2 ] for b in tangents ])
180
- for i , spec in enumerate (FLAGS .specification ):
181
- ax = plt .subplot (gs [i ])
182
- ax .imshow (tangents [i ][0 ])
183
- ax .set_xticks ([])
184
- ax .set_yticks ([])
185
- ax .set_title (r'$T_{\mathbf{x}_1}%s$' % spec , fontsize = 18 )
186
- fig .canvas .set_window_title ('Ground Truth' )
187
-
188
155
logging .info ('Error between subspaces: %.2f +/- %.2f radians' ,
189
156
np .mean (errors ),
190
157
np .std (errors ))
191
- plt .show ()
158
+
159
+ if PLOT .value :
160
+
161
+ # Plot spectrum
162
+ plt .figure (figsize = (8 , 6 ))
163
+ plt .scatter (np .arange (len (spectrum )), spectrum , s = 100 )
164
+ largest_gap = np .argmax (spectrum [1 :]- spectrum [:- 1 ]) + 1
165
+ plt .axvline (largest_gap , linewidth = 2 , c = 'r' )
166
+ plt .xticks ([])
167
+ plt .yticks (fontsize = 18 )
168
+ plt .xlabel ('Index' , fontsize = 24 )
169
+ plt .ylabel ('Eigenvalue' , fontsize = 24 )
170
+ plt .title ('GeoManCEr Eigenvalue Spectrum' , fontsize = 24 )
171
+
172
+ # Plot subspace bases
173
+ fig = plt .figure (figsize = (8 , 6 ))
174
+ bases = components [0 ]
175
+ gs = gridspec .GridSpec (1 , len (bases ),
176
+ width_ratios = [b .shape [1 ] for b in bases ])
177
+ for i in range (len (bases )):
178
+ ax = plt .subplot (gs [i ])
179
+ ax .imshow (bases [i ])
180
+ ax .set_xticks ([])
181
+ ax .set_yticks ([])
182
+ ax .set_title (r'$T_{\mathbf{x}_1}\mathcal{M}_%d$' % (i + 1 ), fontsize = 18 )
183
+ fig .canvas .set_window_title ('GeoManCEr Results' )
184
+
185
+ # Plot ground truth
186
+ fig = plt .figure (figsize = (8 , 6 ))
187
+ gs = gridspec .GridSpec (1 , len (tangents ),
188
+ width_ratios = [b .shape [2 ] for b in tangents ])
189
+ for i , spec in enumerate (SPECIFICATION .value ):
190
+ ax = plt .subplot (gs [i ])
191
+ ax .imshow (tangents [i ][0 ])
192
+ ax .set_xticks ([])
193
+ ax .set_yticks ([])
194
+ ax .set_title (r'$T_{\mathbf{x}_1}%s$' % spec , fontsize = 18 )
195
+ fig .canvas .set_window_title ('Ground Truth' )
196
+
197
+ plt .show ()
192
198
193
199
194
200
if __name__ == '__main__' :
0 commit comments