3
3
std_normal_logpdf (x:: Real ) = - (abs2 (x) + log2π)/ 2
4
4
std_normal_logpdf (x:: AbstractArray ) = vec (sum (std_normal_logpdf .(flatview (x)), dims = 1 ))
5
5
6
- function negll_flow_loss (flow:: F , x:: AbstractMatrix{<:Real} , logd_orig :: AbstractVector , logpdf:: Function ) where F<: AbstractFlow
6
+ function negll_flow_loss (flow:: F , x:: AbstractMatrix{<:Real} , logpdf:: Function ) where F<: AbstractFlow
7
7
nsamples = size (x, 2 )
8
8
flow_corr = fchain (flow,logpdf. f)
9
9
y, ladj = with_logabsdet_jacobian (flow_corr, x)
@@ -12,15 +12,15 @@ function negll_flow_loss(flow::F, x::AbstractMatrix{<:Real}, logd_orig::Abstract
12
12
end
13
13
14
14
function negll_flow (flow:: F , x:: AbstractMatrix{<:Real} , logd_orig:: AbstractVector , logpdf:: Tuple{Function, Function} ) where F<: AbstractFlow
15
- negll, back = Zygote. pullback (negll_flow , flow, x, logd_orig , logpdf[2 ])
15
+ negll, back = Zygote. pullback (negll_flow_loss , flow, x, logpdf[2 ])
16
16
d_flow = back (one (eltype (x)))[1 ]
17
17
return negll, d_flow
18
18
end
19
19
export negll_flow
20
20
21
21
function KLDiv_flow_loss (flow:: F , x:: AbstractMatrix{<:Real} , logd_orig:: AbstractVector , logpdfs:: Tuple{Function, Function} ) where F<: AbstractFlow
22
22
nsamples = size (x, 2 )
23
- flow_corr = fchain (flow,logpdfs[2 ]. f)
23
+ flow_corr = fchain (flow, logpdfs[2 ]. f)
24
24
logpdf_y = logpdfs[2 ]. logdensity
25
25
y, ladj = with_logabsdet_jacobian (flow_corr, x)
26
26
KLDiv = sum (exp .(logd_orig - vec (ladj)) .* (logd_orig - vec (ladj) - logpdf_y (y))) / nsamples
@@ -38,7 +38,7 @@ function optimize_flow(samples::Union{Matrix, Tuple{Matrix, Matrix}},
38
38
initial_flow:: F where F<: AbstractFlow ,
39
39
optimizer;
40
40
sequential:: Bool = true ,
41
- loss:: Function = negll_flow_grad ,
41
+ loss:: Function = negll_flow ,
42
42
logpdf:: Union{Function, Tuple{Function, Function}} = std_normal_logpdf,
43
43
nbatches:: Integer = 10 ,
44
44
nepochs:: Integer = 100 ,
@@ -75,12 +75,17 @@ function optimize_flow(samples::Union{AbstractArray, Tuple{AbstractArray, Abstra
75
75
76
76
n_dims = _get_n_dims (samples)
77
77
logd_orig = samples isa Tuple ? logpdf[1 ](samples[1 ]) : logpdf[1 ](samples)
78
- pushfwd_logpdf = logpdf[2 ] == std_normal_logpdf ? (PushForwardLogDensity (first (initial_flow. flow. fs), logpdf[1 ]), PushForwardLogDensity (FlowModule (InvMulAdd (I (n_dims), zeros (n_dims)), false ), logpdf[2 ])) : (PushForwardLogDensity (first (initial_flow. flow. fs), logpdf[1 ]), PushForwardLogDensity (last (initial_flow. flow. fs), logpdf[2 ]))
78
+
79
+ if ! (initial_flow isa AbstractFlowBlock)
80
+ pushfwd_logpdf = logpdf[2 ] == std_normal_logpdf ? (PushForwardLogDensity (first (initial_flow. flow. fs), logpdf[1 ]), PushForwardLogDensity (FlowModule (InvMulAdd (I (n_dims), zeros (n_dims)), false ), logpdf[2 ])) : (PushForwardLogDensity (first (initial_flow. flow. fs), logpdf[1 ]), PushForwardLogDensity (last (initial_flow. flow. fs), logpdf[2 ]))
81
+ else
82
+ pushfwd_logpdf = (PushForwardLogDensity (InvMulAdd (I (n_dims), zeros (n_dims)), logpdf[1 ]), PushForwardLogDensity (InvMulAdd (I (n_dims), zeros (n_dims)), logpdf[2 ]))
83
+ end
79
84
80
85
if sequential
81
86
flow, state, loss_hist = _train_flow_sequentially (samples, initial_flow, optimizer, nepochs, nbatches, loss, pushfwd_logpdf, logd_orig, shuffle_samples)
82
87
else
83
- flow, state, loss_hist = _train_flow (samples, initial_flow, optimizer, nepochs, nbatches, loss, pushfwd_logpd , logd_orig, shuffle_samples)
88
+ flow, state, loss_hist = _train_flow (samples, initial_flow, optimizer, nepochs, nbatches, loss, pushfwd_logpdf , logd_orig, shuffle_samples)
84
89
end
85
90
86
91
return (result = flow, optimizer_state = state, loss_hist = vcat (loss_history, loss_hist))
0 commit comments