Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ Accessors = "0.1"
AdvancedHMC = "0.3.0, 0.4.0, 0.5.2, 0.6, 0.7, 0.8"
AdvancedMH = "0.8"
AdvancedPS = "0.7"
AdvancedVI = "0.4"
AdvancedVI = "0.5"
BangBang = "0.4.2"
Bijectors = "0.14, 0.15"
Compat = "4.15.0"
Expand Down
3 changes: 3 additions & 0 deletions src/Turing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,9 @@ export
q_locationscale,
q_meanfield_gaussian,
q_fullrank_gaussian,
KLMinRepGradProxDescent,
KLMinRepGradDescent,
KLMinScoreGradDescent,
# ADTypes
AutoForwardDiff,
AutoReverseDiff,
Expand Down
80 changes: 34 additions & 46 deletions src/variational/VariationalInference.jl
Original file line number Diff line number Diff line change
@@ -1,21 +1,24 @@

module Variational

using DynamicPPL
using AdvancedVI:
AdvancedVI, KLMinRepGradDescent, KLMinRepGradProxDescent, KLMinScoreGradDescent
using ADTypes
using Bijectors: Bijectors
using Distributions
using DynamicPPL
using LinearAlgebra
using LogDensityProblems
using Random
using ..Turing: DEFAULT_ADTYPE, PROGRESS

import ..Turing: DEFAULT_ADTYPE, PROGRESS

import AdvancedVI
import Bijectors

export vi, q_locationscale, q_meanfield_gaussian, q_fullrank_gaussian

include("deprecated.jl")
export vi,
q_locationscale,
q_meanfield_gaussian,
q_fullrank_gaussian,
KLMinRepGradProxDescent,
KLMinRepGradDescent,
KLMinScoreGradDescent

"""
q_initialize_scale(
Expand Down Expand Up @@ -248,76 +251,61 @@ end
"""
vi(
[rng::Random.AbstractRNG,]
model::DynamicPPL.Model;
model::DynamicPPL.Model,
q,
n_iterations::Int;
objective::AdvancedVI.AbstractVariationalObjective = AdvancedVI.RepGradELBO(
10; entropy = AdvancedVI.ClosedFormEntropyZeroGradient()
),
max_iter::Int;
algorithm::AdvancedVI.AbstractVariationalAlgorithm = KLMinRepGradProxDescent(DEFAULT_ADTYPE; n_samples=10),
show_progress::Bool = Turing.PROGRESS[],
optimizer::Optimisers.AbstractRule = AdvancedVI.DoWG(),
averager::AdvancedVI.AbstractAverager = AdvancedVI.PolynomialAveraging(),
operator::AdvancedVI.AbstractOperator = AdvancedVI.ProximalLocationScaleEntropy(),
adtype::ADTypes.AbstractADType = Turing.DEFAULT_ADTYPE,
kwargs...
)

Approximating the target `model` via variational inference by optimizing `objective` with the initialization `q`.
Approximate the target `model` via the variational inference algorithm `algorithm` by starting from the initial variational approximation `q`.
This is a thin wrapper around `AdvancedVI.optimize`.
The default `algorithm` assumes `q` uses `AdvancedVI.MvLocationScale`, which can be constructed by invoking `q_fullrank_gaussian` or `q_meanfield_gaussian`.
For other variational families, refer to `AdvancedVI` to determine the best algorithm and options.

# Arguments
- `model`: The target `DynamicPPL.Model`.
- `q`: The initial variational approximation.
- `n_iterations`: Number of optimization steps.
- `max_iter`: Maximum number of steps.

# Keyword Arguments
- `objective`: Variational objective to be optimized.
- `algorithm`: Variational inference algorithm.
- `show_progress`: Whether to show the progress bar.
- `optimizer`: Optimization algorithm.
- `averager`: Parameter averaging strategy.
- `operator`: Operator applied after each optimization step.
- `adtype`: Automatic differentiation backend.
- `adtype`: Automatic differentiation backend to be applied to the log-density. The default value for `algorithm` also uses this backend for differentiation the variational objective.

See the docs of `AdvancedVI.optimize` for additional keyword arguments.

# Returns
- `q`: Variational distribution formed by the last iterate of the optimization run.
- `q_avg`: Variational distribution formed by the averaged iterates according to `averager`.
- `state`: Collection of states used for optimization. This can be used to resume from a past call to `vi`.
- `info`: Information generated during the optimization run.
- `q`: Output variational distribution of `algorithm`.
- `state`: Collection of states used by `algorithm`. This can be used to resume from a past call to `vi`.
- `info`: Information generated while executing `algorithm`.
"""
function vi(
rng::Random.AbstractRNG,
model::DynamicPPL.Model,
q,
n_iterations::Int;
objective=AdvancedVI.RepGradELBO(
10; entropy=AdvancedVI.ClosedFormEntropyZeroGradient()
),
show_progress::Bool=PROGRESS[],
optimizer=AdvancedVI.DoWG(),
averager=AdvancedVI.PolynomialAveraging(),
operator=AdvancedVI.ProximalLocationScaleEntropy(),
max_iter::Int,
args...;
adtype::ADTypes.AbstractADType=DEFAULT_ADTYPE,
algorithm=KLMinRepGradProxDescent(adtype; n_samples=10),
show_progress::Bool=PROGRESS[],
kwargs...,
)
return AdvancedVI.optimize(
rng,
LogDensityFunction(model),
objective,
algorithm,
max_iter,
LogDensityFunction(model; adtype),
q,
n_iterations;
args...;
show_progress=show_progress,
adtype,
optimizer,
averager,
operator,
kwargs...,
)
end

function vi(model::DynamicPPL.Model, q, n_iterations::Int; kwargs...)
return vi(Random.default_rng(), model, q, n_iterations; kwargs...)
function vi(model::DynamicPPL.Model, q, max_iter::Int; kwargs...)
return vi(Random.default_rng(), model, q, max_iter; kwargs...)
end

end
61 changes: 0 additions & 61 deletions src/variational/deprecated.jl

This file was deleted.

2 changes: 1 addition & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ AbstractMCMC = "5"
AbstractPPL = "0.11, 0.12, 0.13"
AdvancedMH = "0.6, 0.7, 0.8"
AdvancedPS = "0.7"
AdvancedVI = "0.4"
AdvancedVI = "0.5"
Aqua = "0.8"
BangBang = "0.4"
Bijectors = "0.14, 0.15"
Expand Down
72 changes: 26 additions & 46 deletions test/variational/advi.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,16 @@ using Distributions: Dirichlet, Normal
using LinearAlgebra
using MCMCChains: Chains
using Random
using ReverseDiff
using StableRNGs: StableRNG
using Test: @test, @testset
using Turing
using Turing.Variational

@testset "ADVI" begin
adtype = AutoReverseDiff()
operator = AdvancedVI.ClipScale()

@testset "q initialization" begin
m = gdemo_default
d = length(Turing.DynamicPPL.VarInfo(m)[:])
Expand All @@ -41,86 +45,62 @@ using Turing.Variational

@testset "default interface" begin
for q0 in [q_meanfield_gaussian(gdemo_default), q_fullrank_gaussian(gdemo_default)]
_, q, _, _ = vi(gdemo_default, q0, 100; show_progress=Turing.PROGRESS[])
q, _, _ = vi(gdemo_default, q0, 100; show_progress=Turing.PROGRESS[], adtype)
c1 = rand(q, 10)
end
end

@testset "custom interface $name" for (name, objective, operator, optimizer) in [
(
"ADVI with closed-form entropy",
AdvancedVI.RepGradELBO(10),
AdvancedVI.ProximalLocationScaleEntropy(),
AdvancedVI.DoG(),
),
@testset "custom algorithm $name" for (name, algorithm) in [
(
"ADVI with proximal entropy",
AdvancedVI.RepGradELBO(10; entropy=AdvancedVI.ClosedFormEntropyZeroGradient()),
AdvancedVI.ClipScale(),
AdvancedVI.DoG(),
"KLMinRepGradProxDescent",
KLMinRepGradProxDescent(AutoReverseDiff(); n_samples=10),
),
(
"ADVI with STL entropy",
AdvancedVI.RepGradELBO(10; entropy=AdvancedVI.StickingTheLandingEntropy()),
AdvancedVI.ClipScale(),
AdvancedVI.DoG(),
"KLMinRepGradDescent",
KLMinRepGradDescent(AutoReverseDiff(); operator, n_samples=10),
),
]
T = 1000
q, q_avg, _, _ = vi(
q, _, _ = vi(
gdemo_default,
q_meanfield_gaussian(gdemo_default),
T;
objective,
optimizer,
operator,
algorithm,
adtype,
show_progress=Turing.PROGRESS[],
)

N = 1000
c1 = rand(q_avg, N)
c2 = rand(q, N)
end

@testset "inference $name" for (name, objective, operator, optimizer) in [
@testset "inference $name" for (name, algorithm) in [
(
"ADVI with closed-form entropy",
AdvancedVI.RepGradELBO(10),
AdvancedVI.ProximalLocationScaleEntropy(),
AdvancedVI.DoG(),
"KLMinRepGradProxDescent",
KLMinRepGradProxDescent(AutoReverseDiff(); n_samples=10),
),
(
"ADVI with proximal entropy",
RepGradELBO(10; entropy=AdvancedVI.ClosedFormEntropyZeroGradient()),
AdvancedVI.ClipScale(),
AdvancedVI.DoG(),
),
(
"ADVI with STL entropy",
AdvancedVI.RepGradELBO(10; entropy=AdvancedVI.StickingTheLandingEntropy()),
AdvancedVI.ClipScale(),
AdvancedVI.DoG(),
"KLMinRepGradDescent",
KLMinRepGradDescent(AutoReverseDiff(); operator, n_samples=10),
),
]
rng = StableRNG(0x517e1d9bf89bf94f)

T = 1000
q, q_avg, _, _ = vi(
q, _, _ = vi(
rng,
gdemo_default,
q_meanfield_gaussian(gdemo_default),
T;
optimizer,
algorithm,
adtype,
show_progress=Turing.PROGRESS[],
)

N = 1000
for q_out in [q_avg, q]
samples = transpose(rand(rng, q_out, N))
chn = Chains(reshape(samples, size(samples)..., 1), ["s", "m"])
samples = transpose(rand(rng, q, N))
chn = Chains(reshape(samples, size(samples)..., 1), ["s", "m"])

check_gdemo(chn; atol=0.5)
end
check_gdemo(chn; atol=0.5)
end

# regression test for:
Expand All @@ -143,7 +123,7 @@ using Turing.Variational
@test all(x0 .≈ x0_inv)

# And regression for https://github.com/TuringLang/Turing.jl/issues/2160.
_, q, _, _ = vi(rng, m, q_meanfield_gaussian(m), 1000)
q, _, _ = vi(rng, m, q_meanfield_gaussian(m), 1000; adtype)
x = rand(rng, q, 1000)
@test mean(eachcol(x)) ≈ [0.5, 0.5] atol = 0.1
end
Expand All @@ -158,7 +138,7 @@ using Turing.Variational
end

model = demo_issue2205() | (y=1.0,)
_, q, _, _ = vi(rng, model, q_meanfield_gaussian(model), 1000)
q, _, _ = vi(rng, model, q_meanfield_gaussian(model), 1000; adtype)
# True mean.
mean_true = 1 / 2
var_true = 1 / 2
Expand Down
Loading