Skip to content

Commit de0d83e

Browse files
committedApr 9, 2024·
Add lr safety mechanism to avoid exploding loss function during optimization
1 parent 1175a80 commit de0d83e

File tree

2 files changed

+44
-33
lines changed

2 files changed

+44
-33
lines changed
 

‎docs/src/plotting.md

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
# Plotting
2+
3+
AdaptiveFlows.jl offfers plotting recipes to visualize the input samples and the transformed output of a normalizing flow.
4+

‎src/optimize_flow.jl

+40-33
Original file line numberDiff line numberDiff line change
@@ -35,16 +35,10 @@ function KLDiv_flow_loss(flow::F, x::AbstractMatrix{<:Real}, logd_orig::Abstract
3535

3636
KLDiv = sum(exp.(logd_orig - vec(ladj)) .* (logd_orig - vec(ladj) - logpdf_y(y))) / nsamples #(1)
3737

38-
#KLDiv = sum(exp.(logpdf_y(y) + vec(ladj)) .* (logpdf_y(y) + vec(ladj) - logd_orig)) / nsamples #(MALA PAPER)
39-
40-
41-
38+
# KLDiv = sum(exp.(logpdf_y(y) + vec(ladj)) .* (logpdf_y(y) + vec(ladj) - logd_orig)) / nsamples #(MALA PAPER)
4239
# KLDiv = sum(exp.(logpdf_y(y) + vec(ladj)) .* (logpdf_y(y) + vec(ladj) - logd_orig)) / nsamples
43-
4440
# KLDiv = sum(exp.(logd_orig) .* (logd_orig - vec(ladj) - logpdf_y(y))) / nsamples #(to tight)
45-
46-
47-
#KLDiv = sum(exp.(logd_orig - vec(ladj)) .* (logd_orig - vec(ladj) - logpdf_y(y))) / nsamples #(1)
41+
# KLDiv = sum(exp.(logd_orig - vec(ladj)) .* (logd_orig - vec(ladj) - logpdf_y(y))) / nsamples #(1)
4842
# KLDiv = sum(exp.(logpdf_y(y) + vec(ladj)) .* (logpdf_y(y) + vec(ladj) - logd_orig)) / nsamples #(2)/ (3) with logpdfs[2] = target
4943
# KLDiv = sum(exp.(logpdf_y(y) + vec(ladj)) .* (vec(ladj) - logd_orig)) / nsamples
5044
# KLDiv = sum(exp.(logpdf_y(y) + vec(ladj)) .* (logpdf_y(y) + vec(ladj) - logd_orig - logpdf_y(y))) / nsamples
@@ -61,14 +55,15 @@ export KLDiv_flow
6155

6256
function optimize_flow(samples::Union{Matrix, Tuple{Matrix, Matrix}},
6357
initial_flow::F where F<:AbstractFlow,
64-
optimizer;
58+
optimizer = Adam(5f-3);
6559
sequential::Bool = true,
6660
loss::Function = negll_flow,
6761
logpdf::Union{Function, Tuple{Function, Function}} = std_normal_logpdf,
6862
nbatches::Integer = 10,
6963
nepochs::Integer = 100,
7064
loss_history = Vector{Float64}(),
71-
shuffle_samples::Bool = false
65+
shuffle_samples::Bool = false,
66+
lr_safety::Bool = true
7267
)
7368
optimize_flow(nestedview(samples),
7469
initial_flow,
@@ -79,20 +74,22 @@ function optimize_flow(samples::Union{Matrix, Tuple{Matrix, Matrix}},
7974
nbatches = nbatches,
8075
nepochs = nepochs,
8176
loss_history = loss_history,
82-
shuffle_samples = shuffle_samples
77+
shuffle_samples = shuffle_samples,
78+
lr_safety = lr_safety
8379
)
8480
end
8581

8682
function optimize_flow(samples::Union{AbstractArray, Tuple{AbstractArray, AbstractArray}},
8783
initial_flow::F where F<:AbstractFlow,
88-
optimizer;
84+
optimizer = Adam(5f-3);
8985
sequential::Bool = true,
9086
loss::Function = negll_flow_grad,
9187
logpdf::Union{Function, Tuple{Function, Function}},
9288
nbatches::Integer = 10,
9389
nepochs::Integer = 100,
9490
loss_history = Vector{Float64}(),
95-
shuffle_samples::Bool = false
91+
shuffle_samples::Bool = false,
92+
lr_safety::Bool = true
9693
)
9794
if !_is_trainable(initial_flow)
9895
return (result = initial_flow, optimizer_state = nothing, loss_history = nothing)
@@ -108,9 +105,9 @@ function optimize_flow(samples::Union{AbstractArray, Tuple{AbstractArray, Abstra
108105
end
109106

110107
if sequential
111-
flow, state, loss_hist = _train_flow_sequentially(samples, initial_flow, optimizer, nepochs, nbatches, loss, pushfwd_logpdf, logd_orig, shuffle_samples)
108+
flow, state, loss_hist = _train_flow_sequentially(samples, initial_flow, optimizer, nepochs, nbatches, loss, pushfwd_logpdf, logd_orig, shuffle_samples, lr_safety)
112109
else
113-
flow, state, loss_hist = _train_flow(samples, initial_flow, optimizer, nepochs, nbatches, loss, pushfwd_logpdf, logd_orig, shuffle_samples)
110+
flow, state, loss_hist = _train_flow(samples, initial_flow, optimizer, nepochs, nbatches, loss, pushfwd_logpdf, logd_orig, shuffle_samples, lr_safety)
114111
end
115112

116113
return (result = flow, optimizer_state = state, loss_hist = vcat(loss_history, loss_hist), training_metadata = Dict(:nepochs => nepochs, :nbatches => nbatches, :shuffle_samples => shuffle_samples, :sequential => sequential, :optimizer => optimizer, :loss => loss))
@@ -126,8 +123,9 @@ function _train_flow_sequentially(samples::Union{AbstractArray, Tuple{AbstractAr
126123
pushfwd_logpdf::Union{Function,
127124
Tuple{Function, Function}},
128125
logd_orig::AbstractVector,
129-
shuffle_samples::Bool;
130-
cum_ladj::AbstractVector = zeros(length(logd_orig))
126+
shuffle_samples::Bool,
127+
lr_safety::Bool;
128+
cum_ladj::AbstractVector = zeros(length(logd_orig)),
131129
)
132130

133131
if !_is_trainable(initial_flow)
@@ -149,7 +147,8 @@ function _train_flow_sequentially(samples::Union{AbstractArray, Tuple{AbstractAr
149147
loss,
150148
pushfwd_logpdf,
151149
logd_orig,
152-
shuffle_samples;
150+
shuffle_samples,
151+
lr_safety;
153152
cum_ladj
154153
)
155154
push!(trained_components, trained_flow_component)
@@ -159,21 +158,19 @@ function _train_flow_sequentially(samples::Union{AbstractArray, Tuple{AbstractAr
159158
if samples isa Tuple
160159
x_int, ladj = with_logabsdet_jacobian(trained_flow_component, intermediate_samples[1])
161160
intermediate_samples = (x_int, trained_flow_component(intermediate_samples[2]))
162-
# fix AffineMaps to return row matrix ladj
163161
ladj = ladj isa Real ? fill(ladj, length(logd_orig_intermediate)) : vec(ladj)
164162
cum_ladj += ladj
165163
else
166164
intermediate_samples, ladj = with_logabsdet_jacobian(trained_flow_component, intermediate_samples)
167165
ladj = ladj isa Real ? fill(ladj, length(logd_orig_intermediate)) : vec(ladj)
168166
cum_ladj += ladj
169-
end
167+
end
170168
end
171169
return typeof(initial_flow)(trained_components), component_optstates, component_loss_hists
172170
end
173-
_train_flow(samples, initial_flow, optimizer, nepochs, nbatches, loss, pushfwd_logpdf, logd_orig, shuffle_samples; cum_ladj)
171+
_train_flow(samples, initial_flow, optimizer, nepochs, nbatches, loss, pushfwd_logpdf, logd_orig, shuffle_samples, lr_safety; cum_ladj)
174172
end
175173

176-
177174
function _train_flow(samples::Union{AbstractArray, Tuple{AbstractArray, AbstractArray}},
178175
initial_flow::AbstractFlow,
179176
optimizer,
@@ -182,8 +179,9 @@ function _train_flow(samples::Union{AbstractArray, Tuple{AbstractArray, Abstract
182179
loss::Function,
183180
pushfwd_logpdf::Union{Function, Tuple{Function, Function}},
184181
logd_orig::AbstractVector,
185-
shuffle_samples::Bool;
186-
cum_ladj::AbstractVector = zeros(length(logd_orig))
182+
shuffle_samples::Bool,
183+
lr_safety::Bool;
184+
cum_ladj::AbstractVector = zeros(length(logd_orig)),
187185
)
188186

189187
if !_is_trainable(initial_flow)
@@ -195,23 +193,32 @@ function _train_flow(samples::Union{AbstractArray, Tuple{AbstractArray, Abstract
195193
logd_orig_batches = collect(Iterators.partition(logd_orig, batchsize))
196194
cum_ladj_batches = collect(Iterators.partition(cum_ladj, batchsize))
197195
flow = deepcopy(initial_flow)
196+
flow_tmp = deepcopy(flow)
198197
state = Optimisers.setup(optimizer, deepcopy(initial_flow))
198+
state_tmp = deepcopy(state)
199199
loss_hist = Vector{Float64}()
200+
201+
eta = optimizer.eta
202+
200203
for i in 1:nepochs
201204
for j in 1:nbatches
202205
training_samples = batches isa Tuple ? (Matrix(flatview(batches[1][j])), Matrix(flatview(batches[2][j]))) : Matrix(flatview(batches[j]))
203-
loss_val, d_flow = loss(flow, training_samples, logd_orig_batches[j], cum_ladj_batches[j], pushfwd_logpdf)
204-
if i == 1 && j == 1 && flow.mask[1]
205-
global g_state_gradient_1 = (loss_val, d_flow)
206+
loss_cache, d_flow = loss(flow, training_samples, logd_orig_batches[j], cum_ladj_batches[j], pushfwd_logpdf)
207+
208+
state_tmp, flow_tmp = Optimisers.update(state, flow, d_flow)
209+
loss_val, d_flow_tmp = loss(flow_tmp, training_samples, logd_orig_batches[j], cum_ladj_batches[j], pushfwd_logpdf)
210+
211+
while (lr_safety && i+j>2) && ((loss_val - loss_cache) / loss_cache > 0.3)
212+
@info "Learning Rate too large, automatically reduced by 20%. Was: $(eta), Epoch: $(i), Batch: $(j)"
213+
Optimisers.adjust!(state, 0.8 * eta)
214+
eta *= 0.8
215+
state_tmp, flow_tmp = Optimisers.update(state, flow, d_flow)
216+
loss_val, d_flow_tmp = loss(flow_tmp, training_samples, logd_orig_batches[j], cum_ladj_batches[j], pushfwd_logpdf)
206217
end
207218

208-
if i == 1 && j == 2 && flow.mask[1]
209-
global g_state_gradient_2 = (loss_val, d_flow)
210-
end
211-
212-
213219
state, flow = Optimisers.update(state, flow, d_flow)
214-
push!(loss_hist, loss_val)
220+
221+
push!(loss_hist, loss_cache)
215222
end
216223
if shuffle_samples
217224
batches = collect(Iterators.partition(shuffle(samples), batchsize))

0 commit comments

Comments
 (0)
Please sign in to comment.