-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgenerative_results.py
101 lines (90 loc) · 2.8 KB
/
generative_results.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import pickle
import os
import numpy as np
from scipy.stats import sem
import matplotlib.pyplot as plt
names = {
'maf': 'MAF',
'maf3': 'MAF-L',
'np_maf': 'GEMF-T',
'c100_np_maf': 'GEMF-T',
'sandwich': 'GEMF-M',
'c100_sandwich': 'GEMF-M',
'np_maf_smoothness': 'GEMF-T(s)',
'np_maf_continuity': 'GEMF-T(c)',
'bottom': 'B-MAF',
'splines': 'NSF',
'c100_np_splines': 'NSF-EMF-T',
'c100_sandwich_splines': 'NSF-EMF-M',
'np_splines_continuity': 'NSF-EMF-T(c)',
}
d_names = {
'brownian': 'Brownian motion',
'ornstein': 'Ornstein-Uhlenbeck process',
'lorenz': 'Lorenz system'
}
base_dir = 'time_series_results'
if base_dir == 'mnist':
datasets = ['mnist']
models = [
'c100_np_maf',
'c100_sandwich_bn',
'c100_np_splines',
'c100_sandwich_splines_bn',
'maf',
'maf3',
'splines'
]#
# 'c100_sandwich', 'maf_bn',
#'splines_bn', 'c100_np_maf_bn', 'c10_np_maf']
if base_dir == '2d_toy_results':
datasets = ['checkerboard', '8gaussians']
models = ['c100_np_splines','c100_sandwich_splines','splines']
elif base_dir == 'time_series_results':
datasets = ['brownian', 'ornstein', 'lorenz', 'van_der_pol'] # , 'ornstein']
models = ['np_maf_continuity', 'np_splines_continuity','maf', 'maf3','splines'
,'bottom']
elif base_dir == 'hierarchical_results':
datasets = ['digits']
models = ['np_maf', 'sandwich', 'maf','maf3']
results = {}
for dataset in datasets:
results[dataset] = {}
for model in models:
results[dataset][model] = []
for run in os.listdir(base_dir):
if 'run' in run:
if datasets:
for dataset in datasets:
for model in models:
try:
#with open(f'{base_dir}/{run}/{dataset}/{model}.pickle', 'rb') as
# handle:
if base_dir =='mnist':
full_path = f'{base_dir}/{run}/{model}.pickle'
else:
full_path = f'{base_dir}/{run}/{dataset}/{model}.pickle'
with open(full_path, 'rb') as \
handle:
res = pickle.load(handle)
results[dataset][model].append(res['loss_eval'])
except:
a = 0
'''plt.plot(res['loss'], label=names[model], alpha=0.9)
# if dataset == 'lorenz':
plt.ylim(bottom=-250, top=100)
plt.title(f'{d_names[dataset]}')
plt.legend()
plt.savefig(f'{base_dir}/loss_{dataset}.png')
plt.close()'''
for d, dataset in results.items():
print(f'{d}')
bold_idx = np.argmin([np.mean(dataset[m]) for m in models])
print("& -LOGP")
for i, model in enumerate(models):
f = ''
if i == bold_idx:
f += f' & $\\boldsymbol{{{np.mean(dataset[model]):.3f} \\pm {sem(dataset[model]):.4f}}}$'
else:
f += f' & ${np.mean(dataset[model]):.3f} \\pm {sem(dataset[model]):.4f}$'
print(f)