@@ -35,16 +35,10 @@ function KLDiv_flow_loss(flow::F, x::AbstractMatrix{<:Real}, logd_orig::Abstract
35
35
36
36
KLDiv = sum (exp .(logd_orig - vec (ladj)) .* (logd_orig - vec (ladj) - logpdf_y (y))) / nsamples # (1)
37
37
38
- # KLDiv = sum(exp.(logpdf_y(y) + vec(ladj)) .* (logpdf_y(y) + vec(ladj) - logd_orig)) / nsamples #(MALA PAPER)
39
-
40
-
41
-
38
+ # KLDiv = sum(exp.(logpdf_y(y) + vec(ladj)) .* (logpdf_y(y) + vec(ladj) - logd_orig)) / nsamples #(MALA PAPER)
42
39
# KLDiv = sum(exp.(logpdf_y(y) + vec(ladj)) .* (logpdf_y(y) + vec(ladj) - logd_orig)) / nsamples
43
-
44
40
# KLDiv = sum(exp.(logd_orig) .* (logd_orig - vec(ladj) - logpdf_y(y))) / nsamples #(to tight)
45
-
46
-
47
- # KLDiv = sum(exp.(logd_orig - vec(ladj)) .* (logd_orig - vec(ladj) - logpdf_y(y))) / nsamples #(1)
41
+ # KLDiv = sum(exp.(logd_orig - vec(ladj)) .* (logd_orig - vec(ladj) - logpdf_y(y))) / nsamples #(1)
48
42
# KLDiv = sum(exp.(logpdf_y(y) + vec(ladj)) .* (logpdf_y(y) + vec(ladj) - logd_orig)) / nsamples #(2)/ (3) with logpdfs[2] = target
49
43
# KLDiv = sum(exp.(logpdf_y(y) + vec(ladj)) .* (vec(ladj) - logd_orig)) / nsamples
50
44
# KLDiv = sum(exp.(logpdf_y(y) + vec(ladj)) .* (logpdf_y(y) + vec(ladj) - logd_orig - logpdf_y(y))) / nsamples
@@ -61,14 +55,15 @@ export KLDiv_flow
61
55
62
56
function optimize_flow (samples:: Union{Matrix, Tuple{Matrix, Matrix}} ,
63
57
initial_flow:: F where F<: AbstractFlow ,
64
- optimizer;
58
+ optimizer = Adam ( 5f-3 ) ;
65
59
sequential:: Bool = true ,
66
60
loss:: Function = negll_flow,
67
61
logpdf:: Union{Function, Tuple{Function, Function}} = std_normal_logpdf,
68
62
nbatches:: Integer = 10 ,
69
63
nepochs:: Integer = 100 ,
70
64
loss_history = Vector {Float64} (),
71
- shuffle_samples:: Bool = false
65
+ shuffle_samples:: Bool = false ,
66
+ lr_safety:: Bool = true
72
67
)
73
68
optimize_flow (nestedview (samples),
74
69
initial_flow,
@@ -79,20 +74,22 @@ function optimize_flow(samples::Union{Matrix, Tuple{Matrix, Matrix}},
79
74
nbatches = nbatches,
80
75
nepochs = nepochs,
81
76
loss_history = loss_history,
82
- shuffle_samples = shuffle_samples
77
+ shuffle_samples = shuffle_samples,
78
+ lr_safety = lr_safety
83
79
)
84
80
end
85
81
86
82
function optimize_flow (samples:: Union{AbstractArray, Tuple{AbstractArray, AbstractArray}} ,
87
83
initial_flow:: F where F<: AbstractFlow ,
88
- optimizer;
84
+ optimizer = Adam ( 5f-3 ) ;
89
85
sequential:: Bool = true ,
90
86
loss:: Function = negll_flow_grad,
91
87
logpdf:: Union{Function, Tuple{Function, Function}} ,
92
88
nbatches:: Integer = 10 ,
93
89
nepochs:: Integer = 100 ,
94
90
loss_history = Vector {Float64} (),
95
- shuffle_samples:: Bool = false
91
+ shuffle_samples:: Bool = false ,
92
+ lr_safety:: Bool = true
96
93
)
97
94
if ! _is_trainable (initial_flow)
98
95
return (result = initial_flow, optimizer_state = nothing , loss_history = nothing )
@@ -108,9 +105,9 @@ function optimize_flow(samples::Union{AbstractArray, Tuple{AbstractArray, Abstra
108
105
end
109
106
110
107
if sequential
111
- flow, state, loss_hist = _train_flow_sequentially (samples, initial_flow, optimizer, nepochs, nbatches, loss, pushfwd_logpdf, logd_orig, shuffle_samples)
108
+ flow, state, loss_hist = _train_flow_sequentially (samples, initial_flow, optimizer, nepochs, nbatches, loss, pushfwd_logpdf, logd_orig, shuffle_samples, lr_safety )
112
109
else
113
- flow, state, loss_hist = _train_flow (samples, initial_flow, optimizer, nepochs, nbatches, loss, pushfwd_logpdf, logd_orig, shuffle_samples)
110
+ flow, state, loss_hist = _train_flow (samples, initial_flow, optimizer, nepochs, nbatches, loss, pushfwd_logpdf, logd_orig, shuffle_samples, lr_safety )
114
111
end
115
112
116
113
return (result = flow, optimizer_state = state, loss_hist = vcat (loss_history, loss_hist), training_metadata = Dict (:nepochs => nepochs, :nbatches => nbatches, :shuffle_samples => shuffle_samples, :sequential => sequential, :optimizer => optimizer, :loss => loss))
@@ -126,8 +123,9 @@ function _train_flow_sequentially(samples::Union{AbstractArray, Tuple{AbstractAr
126
123
pushfwd_logpdf:: Union {Function,
127
124
Tuple{Function, Function}},
128
125
logd_orig:: AbstractVector ,
129
- shuffle_samples:: Bool ;
130
- cum_ladj:: AbstractVector = zeros (length (logd_orig))
126
+ shuffle_samples:: Bool ,
127
+ lr_safety:: Bool ;
128
+ cum_ladj:: AbstractVector = zeros (length (logd_orig)),
131
129
)
132
130
133
131
if ! _is_trainable (initial_flow)
@@ -149,7 +147,8 @@ function _train_flow_sequentially(samples::Union{AbstractArray, Tuple{AbstractAr
149
147
loss,
150
148
pushfwd_logpdf,
151
149
logd_orig,
152
- shuffle_samples;
150
+ shuffle_samples,
151
+ lr_safety;
153
152
cum_ladj
154
153
)
155
154
push! (trained_components, trained_flow_component)
@@ -159,21 +158,19 @@ function _train_flow_sequentially(samples::Union{AbstractArray, Tuple{AbstractAr
159
158
if samples isa Tuple
160
159
x_int, ladj = with_logabsdet_jacobian (trained_flow_component, intermediate_samples[1 ])
161
160
intermediate_samples = (x_int, trained_flow_component (intermediate_samples[2 ]))
162
- # fix AffineMaps to return row matrix ladj
163
161
ladj = ladj isa Real ? fill (ladj, length (logd_orig_intermediate)) : vec (ladj)
164
162
cum_ladj += ladj
165
163
else
166
164
intermediate_samples, ladj = with_logabsdet_jacobian (trained_flow_component, intermediate_samples)
167
165
ladj = ladj isa Real ? fill (ladj, length (logd_orig_intermediate)) : vec (ladj)
168
166
cum_ladj += ladj
169
- end
167
+ end
170
168
end
171
169
return typeof (initial_flow)(trained_components), component_optstates, component_loss_hists
172
170
end
173
- _train_flow (samples, initial_flow, optimizer, nepochs, nbatches, loss, pushfwd_logpdf, logd_orig, shuffle_samples; cum_ladj)
171
+ _train_flow (samples, initial_flow, optimizer, nepochs, nbatches, loss, pushfwd_logpdf, logd_orig, shuffle_samples, lr_safety ; cum_ladj)
174
172
end
175
173
176
-
177
174
function _train_flow (samples:: Union{AbstractArray, Tuple{AbstractArray, AbstractArray}} ,
178
175
initial_flow:: AbstractFlow ,
179
176
optimizer,
@@ -182,8 +179,9 @@ function _train_flow(samples::Union{AbstractArray, Tuple{AbstractArray, Abstract
182
179
loss:: Function ,
183
180
pushfwd_logpdf:: Union{Function, Tuple{Function, Function}} ,
184
181
logd_orig:: AbstractVector ,
185
- shuffle_samples:: Bool ;
186
- cum_ladj:: AbstractVector = zeros (length (logd_orig))
182
+ shuffle_samples:: Bool ,
183
+ lr_safety:: Bool ;
184
+ cum_ladj:: AbstractVector = zeros (length (logd_orig)),
187
185
)
188
186
189
187
if ! _is_trainable (initial_flow)
@@ -195,23 +193,32 @@ function _train_flow(samples::Union{AbstractArray, Tuple{AbstractArray, Abstract
195
193
logd_orig_batches = collect (Iterators. partition (logd_orig, batchsize))
196
194
cum_ladj_batches = collect (Iterators. partition (cum_ladj, batchsize))
197
195
flow = deepcopy (initial_flow)
196
+ flow_tmp = deepcopy (flow)
198
197
state = Optimisers. setup (optimizer, deepcopy (initial_flow))
198
+ state_tmp = deepcopy (state)
199
199
loss_hist = Vector {Float64} ()
200
+
201
+ eta = optimizer. eta
202
+
200
203
for i in 1 : nepochs
201
204
for j in 1 : nbatches
202
205
training_samples = batches isa Tuple ? (Matrix (flatview (batches[1 ][j])), Matrix (flatview (batches[2 ][j]))) : Matrix (flatview (batches[j]))
203
- loss_val, d_flow = loss (flow, training_samples, logd_orig_batches[j], cum_ladj_batches[j], pushfwd_logpdf)
204
- if i == 1 && j == 1 && flow. mask[1 ]
205
- global g_state_gradient_1 = (loss_val, d_flow)
206
+ loss_cache, d_flow = loss (flow, training_samples, logd_orig_batches[j], cum_ladj_batches[j], pushfwd_logpdf)
207
+
208
+ state_tmp, flow_tmp = Optimisers. update (state, flow, d_flow)
209
+ loss_val, d_flow_tmp = loss (flow_tmp, training_samples, logd_orig_batches[j], cum_ladj_batches[j], pushfwd_logpdf)
210
+
211
+ while (lr_safety && i+ j> 2 ) && ((loss_val - loss_cache) / loss_cache > 0.3 )
212
+ @info " Learning Rate too large, automatically reduced by 20%. Was: $(eta) , Epoch: $(i) , Batch: $(j) "
213
+ Optimisers. adjust! (state, 0.8 * eta)
214
+ eta *= 0.8
215
+ state_tmp, flow_tmp = Optimisers. update (state, flow, d_flow)
216
+ loss_val, d_flow_tmp = loss (flow_tmp, training_samples, logd_orig_batches[j], cum_ladj_batches[j], pushfwd_logpdf)
206
217
end
207
218
208
- if i == 1 && j == 2 && flow. mask[1 ]
209
- global g_state_gradient_2 = (loss_val, d_flow)
210
- end
211
-
212
-
213
219
state, flow = Optimisers. update (state, flow, d_flow)
214
- push! (loss_hist, loss_val)
220
+
221
+ push! (loss_hist, loss_cache)
215
222
end
216
223
if shuffle_samples
217
224
batches = collect (Iterators. partition (shuffle (samples), batchsize))
0 commit comments