Skip to content

Commit 20e06f1

Browse files
alimuldaldiegolascasas
authored andcommitted
Enable Travis tests for geomancer
* Added a flag to disable plotting when running headless on Travis PiperOrigin-RevId: 368229292
1 parent d4a9a68 commit 20e06f1

File tree

3 files changed

+57
-50
lines changed

3 files changed

+57
-50
lines changed

Diff for: .travis.yml

+1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ env:
1212
- PROJECT="avae"
1313
# - PROJECT="cs_gan" # TODO(b/184845450): Fix and re-enable
1414
- PROJECT="gated_linear_networks"
15+
- PROJECT="geomancer"
1516
- PROJECT="iodine"
1617
- PROJECT="kfac_ferminet_alpha"
1718
- PROJECT="learning_to_simulate"

Diff for: geomancer/run.sh

100644100755
+6-6
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
python3 -m venv geomancer-venv
17-
source geomancer-venv/bin/activate
18-
pip3 install .
19-
python3 geomancer_test.py
20-
python3 train.py
21-
deactivate
16+
python3 -m venv /tmp/geomancer-venv
17+
source /tmp/geomancer-venv/bin/activate
18+
pip3 install -U pip
19+
pip3 install geomancer/
20+
python3 -m geomancer.geomancer_test
21+
python3 geomancer/train.py --plot=False

Diff for: geomancer/train.py

+50-44
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,14 @@
2828
from scipy.stats import special_ortho_group
2929
from tqdm import tqdm
3030

31-
flags.DEFINE_list('specification', ['S^2', 'S^2'], 'List of submanifolds')
32-
flags.DEFINE_integer('npts', 1000, 'Number of data points')
33-
flags.DEFINE_boolean('rotate', False, 'Apply random rotation to the data')
34-
35-
FLAGS = flags.FLAGS
31+
SPECIFICATION = flags.DEFINE_list(
32+
name='specification', default=['S^2', 'S^2'], help='List of submanifolds')
33+
NPTS = flags.DEFINE_integer(
34+
name='npts', default=1000, help='Number of data points')
35+
ROTATE = flags.DEFINE_boolean(
36+
name='rotate', default=False, help='Apply random rotation to the data')
37+
PLOT = flags.DEFINE_boolean(
38+
name='plot', default=True, help='Whether to enable plotting')
3639

3740

3841
def make_so_tangent(q):
@@ -139,8 +142,8 @@ def make_product_manifold(specification, npts):
139142

140143
def main(_):
141144
# Generate data and run GEOMANCER
142-
data, dim, tangents = make_product_manifold(FLAGS.specification, FLAGS.npts)
143-
if FLAGS.rotate:
145+
data, dim, tangents = make_product_manifold(SPECIFICATION.value, NPTS.value)
146+
if ROTATE.value:
144147
rot, _ = np.linalg.qr(np.random.randn(data.shape[1], data.shape[1]))
145148
data_rot = data @ rot.T
146149
components, spectrum = geomancer.fit(data_rot, dim)
@@ -149,46 +152,49 @@ def main(_):
149152
components, spectrum = geomancer.fit(data, dim)
150153
errors = geomancer.eval_aligned(components, tangents)
151154

152-
# Plot spectrum
153-
plt.figure(figsize=(8, 6))
154-
plt.scatter(np.arange(len(spectrum)), spectrum, s=100)
155-
largest_gap = np.argmax(spectrum[1:]-spectrum[:-1]) + 1
156-
plt.axvline(largest_gap, linewidth=2, c='r')
157-
plt.xticks([])
158-
plt.yticks(fontsize=18)
159-
plt.xlabel('Index', fontsize=24)
160-
plt.ylabel('Eigenvalue', fontsize=24)
161-
plt.title('GeoManCEr Eigenvalue Spectrum', fontsize=24)
162-
163-
# Plot subspace bases
164-
fig = plt.figure(figsize=(8, 6))
165-
bases = components[0]
166-
gs = gridspec.GridSpec(1, len(bases),
167-
width_ratios=[b.shape[1] for b in bases])
168-
for i in range(len(bases)):
169-
ax = plt.subplot(gs[i])
170-
ax.imshow(bases[i])
171-
ax.set_xticks([])
172-
ax.set_yticks([])
173-
ax.set_title(r'$T_{\mathbf{x}_1}\mathcal{M}_%d$' % (i+1), fontsize=18)
174-
fig.canvas.set_window_title('GeoManCEr Results')
175-
176-
# Plot ground truth
177-
fig = plt.figure(figsize=(8, 6))
178-
gs = gridspec.GridSpec(1, len(tangents),
179-
width_ratios=[b.shape[2] for b in tangents])
180-
for i, spec in enumerate(FLAGS.specification):
181-
ax = plt.subplot(gs[i])
182-
ax.imshow(tangents[i][0])
183-
ax.set_xticks([])
184-
ax.set_yticks([])
185-
ax.set_title(r'$T_{\mathbf{x}_1}%s$' % spec, fontsize=18)
186-
fig.canvas.set_window_title('Ground Truth')
187-
188155
logging.info('Error between subspaces: %.2f +/- %.2f radians',
189156
np.mean(errors),
190157
np.std(errors))
191-
plt.show()
158+
159+
if PLOT.value:
160+
161+
# Plot spectrum
162+
plt.figure(figsize=(8, 6))
163+
plt.scatter(np.arange(len(spectrum)), spectrum, s=100)
164+
largest_gap = np.argmax(spectrum[1:]-spectrum[:-1]) + 1
165+
plt.axvline(largest_gap, linewidth=2, c='r')
166+
plt.xticks([])
167+
plt.yticks(fontsize=18)
168+
plt.xlabel('Index', fontsize=24)
169+
plt.ylabel('Eigenvalue', fontsize=24)
170+
plt.title('GeoManCEr Eigenvalue Spectrum', fontsize=24)
171+
172+
# Plot subspace bases
173+
fig = plt.figure(figsize=(8, 6))
174+
bases = components[0]
175+
gs = gridspec.GridSpec(1, len(bases),
176+
width_ratios=[b.shape[1] for b in bases])
177+
for i in range(len(bases)):
178+
ax = plt.subplot(gs[i])
179+
ax.imshow(bases[i])
180+
ax.set_xticks([])
181+
ax.set_yticks([])
182+
ax.set_title(r'$T_{\mathbf{x}_1}\mathcal{M}_%d$' % (i+1), fontsize=18)
183+
fig.canvas.set_window_title('GeoManCEr Results')
184+
185+
# Plot ground truth
186+
fig = plt.figure(figsize=(8, 6))
187+
gs = gridspec.GridSpec(1, len(tangents),
188+
width_ratios=[b.shape[2] for b in tangents])
189+
for i, spec in enumerate(SPECIFICATION.value):
190+
ax = plt.subplot(gs[i])
191+
ax.imshow(tangents[i][0])
192+
ax.set_xticks([])
193+
ax.set_yticks([])
194+
ax.set_title(r'$T_{\mathbf{x}_1}%s$' % spec, fontsize=18)
195+
fig.canvas.set_window_title('Ground Truth')
196+
197+
plt.show()
192198

193199

194200
if __name__ == '__main__':

0 commit comments

Comments
 (0)