Skip to content

Commit b314e91

Browse files
committed
Implement LogDensityFunctionWithGrad
1 parent 99d40c0 commit b314e91

10 files changed

+115
-101
lines changed

HISTORY.md

+9
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,15 @@ This release removes the feature of `VarInfo` where it kept track of which varia
4949
5050
This change also affects sampling in Turing.jl.
5151
52+
**Other changes**
53+
54+
LogDensityProblemsAD is now removed as a dependency.
55+
Instead of constructing a `LogDensityProblemAD.ADgradient` object, we now directly use `DifferentiationInterface` to calculate the gradient of the log density with respect to model parameters.
56+
57+
In practice, this means that if you want to calculate the gradient for a model, you can do:
58+
59+
TODO(penelopeysm): Finish this
60+
5261
## 0.34.2
5362
5463
- Fixed bugs in ValuesAsInModelContext as well as DebugContext where underlying PrefixContexts were not being applied.

Project.toml

-3
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
2929
[weakdeps]
3030
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
3131
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
32-
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
3332
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
3433
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
3534
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
@@ -38,7 +37,6 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
3837
[extensions]
3938
DynamicPPLChainRulesCoreExt = ["ChainRulesCore"]
4039
DynamicPPLEnzymeCoreExt = ["EnzymeCore"]
41-
DynamicPPLForwardDiffExt = ["ForwardDiff"]
4240
DynamicPPLJETExt = ["JET"]
4341
DynamicPPLMCMCChainsExt = ["MCMCChains"]
4442
DynamicPPLMooncakeExt = ["Mooncake"]
@@ -58,7 +56,6 @@ DifferentiationInterface = "0.6.39"
5856
Distributions = "0.25"
5957
DocStringExtensions = "0.9"
6058
EnzymeCore = "0.6 - 0.8"
61-
ForwardDiff = "0.10"
6259
JET = "0.9"
6360
KernelAbstractions = "0.9.33"
6461
LinearAlgebra = "1.6"

docs/src/api.md

+2-1
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,11 @@ logjoint
5454

5555
### LogDensityProblems.jl interface
5656

57-
The [LogDensityProblems.jl](https://github.com/tpapp/LogDensityProblems.jl) interface is also supported by simply wrapping a [`Model`](@ref) in a `DynamicPPL.LogDensityFunction`:
57+
The [LogDensityProblems.jl](https://github.com/tpapp/LogDensityProblems.jl) interface is also supported by wrapping a [`Model`](@ref) in a `DynamicPPL.LogDensityFunction` or `DynamicPPL.LogDensityFunctionWithGrad`.
5858

5959
```@docs
6060
DynamicPPL.LogDensityFunction
61+
DynamicPPL.LogDensityFunctionWithGrad
6162
```
6263

6364
## Condition and decondition

ext/DynamicPPLForwardDiffExt.jl

-27
This file was deleted.

src/logdensityfunction.jl

+55-33
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@ import DifferentiationInterface as DI
44
LogDensityFunction
55
66
A callable representing a log density function of a `model`.
7+
`DynamicPPL.LogDensityFunction` implements the LogDensityProblems.jl interface,
8+
but only to 0th-order, i.e. it is only possible to calculate the log density,
9+
and not its gradient. If you need to calculate the gradient as well, you have
10+
to construct a [`DynamicPPL.LogDensityFunctionWithGrad`](@ref) object.
711
812
# Fields
913
$(FIELDS)
@@ -55,16 +59,6 @@ struct LogDensityFunction{V,M,C}
5559
context::C
5660
end
5761

58-
# TODO: Deprecate.
59-
function LogDensityFunction(
60-
varinfo::AbstractVarInfo,
61-
model::Model,
62-
sampler::AbstractSampler,
63-
context::AbstractContext,
64-
)
65-
return LogDensityFunction(varinfo, model, SamplingContext(sampler, context))
66-
end
67-
6862
function LogDensityFunction(
6963
model::Model,
7064
varinfo::AbstractVarInfo=VarInfo(model),
@@ -94,48 +88,76 @@ function setmodel(f::DynamicPPL.LogDensityFunction, model::DynamicPPL.Model)
9488
return Accessors.@set f.model = model
9589
end
9690

97-
# HACK: heavy usage of `AbstractSampler` for, well, _everything_, is being phased out. In the mean time
98-
# we need to define these annoying methods to ensure that we stay compatible with everything.
99-
getsampler(f::LogDensityFunction) = getsampler(getcontext(f))
100-
hassampler(f::LogDensityFunction) = hassampler(getcontext(f))
101-
10291
"""
10392
getparams(f::LogDensityFunction)
10493
10594
Return the parameters of the wrapped varinfo as a vector.
10695
"""
10796
getparams(f::LogDensityFunction) = f.varinfo[:]
10897

109-
# LogDensityProblems interface
110-
function LogDensityProblems.logdensity(f::LogDensityFunction, θ::AbstractVector)
98+
# LogDensityProblems interface: logp (0th order)
99+
function LogDensityProblems.logdensity(f::LogDensityFunction, x::AbstractVector)
111100
context = getcontext(f)
112-
vi_new = unflatten(f.varinfo, θ)
101+
vi_new = unflatten(f.varinfo, x)
113102
return getlogp(last(evaluate!!(f.model, vi_new, context)))
114103
end
104+
function _flipped_logdensity(x::AbstractVector, f::LogDensityFunction)
105+
return LogDensityProblems.logdensity(f, x)
106+
end
115107
function LogDensityProblems.capabilities(::Type{<:LogDensityFunction})
116108
return LogDensityProblems.LogDensityOrder{0}()
117109
end
118110
# TODO: should we instead implement and call on `length(f.varinfo)` (at least in the cases where no sampler is involved)?
119111
LogDensityProblems.dimension(f::LogDensityFunction) = length(getparams(f))
120112

121-
_flipped_logdensity(θ, f) = LogDensityProblems.logdensity(f, θ)
113+
# LogDensityProblems interface: gradient (1st order)
114+
"""
115+
LogDensityFunctionWithGrad(ldf::DynamicPPL.LogDensityFunction, adtype::ADTypes.AbstractADType)
116+
117+
A callable representing a log density function of a `model`.
118+
`DynamicPPL.LogDensityFunctionWithGrad` implements the LogDensityProblems.jl
119+
interface to 1st-order, meaning that you can both calculate the log density
120+
using
121+
122+
LogDensityProblems.logdensity(f, x)
123+
124+
and its gradient using
122125
126+
LogDensityProblems.logdensity_and_gradient(f, x)
127+
128+
where `f` is a `LogDensityFunctionWithGrad` object and `x` is a vector of parameters.
129+
130+
# Fields
131+
$(FIELDS)
132+
"""
133+
struct LogDensityFunctionWithGrad{V,M,C,TAD<:ADTypes.AbstractADType}
134+
ldf::LogDensityFunction{V,M,C}
135+
adtype::TAD
136+
prep::DI.GradientPrep
137+
138+
function LogDensityFunctionWithGrad(
139+
ldf::LogDensityFunction{V,M,C}, adtype::TAD
140+
) where {V,M,C,TAD}
141+
# Get a set of dummy params to use for prep and concretise type
142+
x = map(identity, getparams(ldf))
143+
prep = DI.prepare_gradient(_flipped_logdensity, adtype, x, DI.Constant(ldf))
144+
# Store the prep with the struct
145+
return new{V,M,C,TAD}(ldf, adtype, prep)
146+
end
147+
end
148+
function LogDensityProblems.logdensity(f::LogDensityFunctionWithGrad)
149+
return LogDensityProblems.logdensity(f.ldf)
150+
end
151+
function LogDensityProblems.capabilities(::Type{<:LogDensityFunctionWithGrad})
152+
return LogDensityProblems.LogDensityOrder{1}()
153+
end
123154
# By default, the AD backend to use is inferred from the context, which would
124155
# typically be a SamplingContext which contains a sampler.
125156
function LogDensityProblems.logdensity_and_gradient(
126-
f::LogDensityFunction, θ::AbstractVector
127-
)
128-
adtype = getadtype(getsampler(getcontext(f)))
129-
return LogDensityProblems.logdensity_and_gradient(f, θ, adtype)
130-
end
131-
132-
# Extra method allowing one to manually specify the AD backend to use, thus
133-
# overriding the default AD backend inferred from the sampler.
134-
function LogDensityProblems.logdensity_and_gradient(
135-
f::LogDensityFunction, θ::AbstractVector, adtype::ADTypes.AbstractADType
157+
f::LogDensityFunctionWithGrad, x::AbstractVector
136158
)
137-
# Ensure we concretise the elements of the params.
138-
θ = map(identity, θ) # TODO: Is this needed?
139-
prep = DI.prepare_gradient(_flipped_logdensity, adtype, θ, DI.Constant(f))
140-
return DI.value_and_gradient(_flipped_logdensity, prep, adtype, θ, DI.Constant(f))
159+
x = map(identity, x) # Concretise type
160+
return DI.value_and_gradient(
161+
_flipped_logdensity, f.prep, f.adtype, x, DI.Constant(f.ldf)
162+
)
141163
end

src/sampler.jl

-3
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,6 @@ Sampler(alg) = Sampler(alg, Selector())
5454
Sampler(alg, model::Model) = Sampler(alg, model, Selector())
5555
Sampler(alg, model::Model, s::Selector) = Sampler(alg, s)
5656

57-
# Extract the AD type from the underlying algorithm
58-
getadtype(s::Sampler) = getadtype(s.alg)
59-
6057
# AbstractMCMC interface for SampleFromUniform and SampleFromPrior
6158
function AbstractMCMC.step(
6259
rng::Random.AbstractRNG,

test/ad.jl

+40-19
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,52 @@
1+
using DynamicPPL: LogDensityFunction, LogDensityFunctionWithGrad
2+
13
@testset "AD: ForwardDiff, ReverseDiff, and Mooncake" begin
24
@testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS
35
rand_param_values = DynamicPPL.TestUtils.rand_prior_true(m)
46
vns = DynamicPPL.TestUtils.varnames(m)
57
varinfos = DynamicPPL.TestUtils.setup_varinfos(m, rand_param_values, vns)
68

79
@testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos
8-
f = DynamicPPL.LogDensityFunction(m, varinfo)
9-
# convert to `Vector{Float64}` to avoid `ReverseDiff` initializing the gradients to Integer 0
10-
# reference: https://github.com/TuringLang/DynamicPPL.jl/pull/571#issuecomment-1924304489
11-
θ = convert(Vector{Float64}, varinfo[:])
10+
f = LogDensityFunction(m, varinfo)
11+
x = DynamicPPL.getparams(f)
1212
# Calculate reference logp + gradient of logp using ForwardDiff
1313
default_adtype = ADTypes.AutoForwardDiff()
14+
ldf_with_grad = LogDensityFunctionWithGrad(f, default_adtype)
1415
ref_logp, ref_grad = LogDensityProblems.logdensity_and_gradient(
15-
f, θ, default_adtype
16+
ldf_with_grad, x
1617
)
1718

1819
@testset "$adtype" for adtype in [
19-
ADTypes.AutoReverseDiff(; compile=false),
20-
ADTypes.AutoReverseDiff(; compile=true),
21-
ADTypes.AutoMooncake(; config=nothing),
20+
AutoReverseDiff(; compile=false),
21+
AutoReverseDiff(; compile=true),
22+
AutoMooncake(; config=nothing),
2223
]
23-
# Mooncake can't currently handle something that is going on in
24-
# SimpleVarInfo{<:VarNamedVector}. Disable all SimpleVarInfo tests for now.
25-
if adtype isa ADTypes.AutoMooncake && varinfo isa DynamicPPL.SimpleVarInfo
26-
@test_broken 1 == 0
24+
@info "Testing AD on: $(m.f) - $(short_varinfo_name(varinfo)) - $adtype"
25+
26+
# Put predicates here to avoid long lines
27+
is_mooncake = adtype isa AutoMooncake
28+
is_1_10 = v"1.10" <= VERSION < v"1.11"
29+
is_1_11 = v"1.11" <= VERSION < v"1.12"
30+
is_svi_vnv = varinfo isa SimpleVarInfo{<:DynamicPPL.VarNamedVector}
31+
is_svi_od = varinfo isa SimpleVarInfo{<:OrderedDict}
32+
33+
# Mooncake doesn't work with several combinations of SimpleVarInfo.
34+
if is_mooncake && is_1_11 && is_svi_vnv
35+
# https://github.com/compintell/Mooncake.jl/issues/470
36+
@test_throws ArgumentError LogDensityFunctionWithGrad(f, adtype)
37+
elseif is_mooncake && is_1_10 && is_svi_vnv
38+
# TODO: report upstream
39+
@test_throws UndefRefError LogDensityFunctionWithGrad(f, adtype)
40+
elseif is_mooncake && is_1_10 && is_svi_od
41+
# TODO: report upstream
42+
@test_throws Mooncake.MooncakeRuleCompilationError LogDensityFunctionWithGrad(
43+
f, adtype
44+
)
2745
else
28-
logp, grad = LogDensityProblems.logdensity_and_gradient(f, θ, adtype)
46+
ldf_with_grad = LogDensityFunctionWithGrad(f, adtype)
47+
logp, grad = LogDensityProblems.logdensity_and_gradient(
48+
ldf_with_grad, x
49+
)
2950
@test grad ref_grad
3051
@test logp ref_logp
3152
end
@@ -62,15 +83,15 @@
6283
# of implementation
6384
struct MyEmptyAlg end
6485
DynamicPPL.getspace(::DynamicPPL.Sampler{MyEmptyAlg}) = ()
65-
DynamicPPL.assume(rng, ::DynamicPPL.Sampler{MyEmptyAlg}, dist, vn, vi) =
66-
DynamicPPL.assume(dist, vn, vi)
86+
DynamicPPL.assume(
87+
::Random.AbstractRNG, ::DynamicPPL.Sampler{MyEmptyAlg}, dist, vn, vi
88+
) = DynamicPPL.assume(dist, vn, vi)
6789

6890
# Compiling the ReverseDiff tape used to fail here
6991
spl = Sampler(MyEmptyAlg())
7092
vi = VarInfo(model)
71-
ldf = DynamicPPL.LogDensityFunction(vi, model, SamplingContext(spl))
72-
@test LogDensityProblems.logdensity_and_gradient(
73-
ldf, vi[:], AutoReverseDiff(; compile=true)
74-
) isa Any
93+
ldf = LogDensityFunction(vi, model, SamplingContext(spl))
94+
ldf_grad = LogDensityFunctionWithGrad(ldf, AutoReverseDiff(; compile=true))
95+
@test LogDensityProblems.logdensity_and_gradient(ldf_grad, vi[:]) isa Any
7596
end
7697
end

test/ext/DynamicPPLForwardDiffExt.jl

-14
This file was deleted.

test/runtests.jl

-1
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,6 @@ include("test_util.jl")
7575
include("ext/DynamicPPLJETExt.jl")
7676
end
7777
@testset "ad" begin
78-
include("ext/DynamicPPLForwardDiffExt.jl")
7978
include("ext/DynamicPPLMooncakeExt.jl")
8079
include("ad.jl")
8180
end

test/test_util.jl

+9
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,15 @@ function short_varinfo_name(vi::TypedVarInfo)
5656
end
5757
short_varinfo_name(::UntypedVarInfo) = "UntypedVarInfo"
5858
short_varinfo_name(::DynamicPPL.VectorVarInfo) = "VectorVarInfo"
59+
function short_varinfo_name(::SimpleVarInfo{<:NamedTuple,<:Ref})
60+
return "SimpleVarInfo{<:NamedTuple,<:Ref}"
61+
end
62+
function short_varinfo_name(::SimpleVarInfo{<:OrderedDict,<:Ref})
63+
return "SimpleVarInfo{<:OrderedDict,<:Ref}"
64+
end
65+
function short_varinfo_name(::SimpleVarInfo{<:DynamicPPL.VarNamedVector,<:Ref})
66+
return "SimpleVarInfo{<:VarNamedVector,<:Ref}"
67+
end
5968
short_varinfo_name(::SimpleVarInfo{<:NamedTuple}) = "SimpleVarInfo{<:NamedTuple}"
6069
short_varinfo_name(::SimpleVarInfo{<:OrderedDict}) = "SimpleVarInfo{<:OrderedDict}"
6170
function short_varinfo_name(::SimpleVarInfo{<:DynamicPPL.VarNamedVector})

0 commit comments

Comments
 (0)