Skip to content

Commit 8de4742

Browse files
committed
Implement LogDensityFunctionWithGrad
1 parent 8e22c05 commit 8de4742

File tree

3 files changed

+53
-27
lines changed

3 files changed

+53
-27
lines changed

HISTORY.md

+9
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,15 @@ This release removes the feature of `VarInfo` where it kept track of which varia
4949
5050
This change also affects sampling in Turing.jl.
5151
52+
**Other changes**
53+
54+
LogDensityProblemsAD is now removed as a dependency.
55+
Instead of constructing a `LogDensityProblemAD.ADgradient` object, we now directly use `DifferentiationInterface` to calculate the gradient of the log density with respect to model parameters.
56+
57+
In practice, this means that if you want to calculate the gradient for a model, you can do:
58+
59+
TODO(penelopeysm): Finish this
60+
5261
## 0.34.2
5362
5463
- Fixed bugs in ValuesAsInModelContext as well as DebugContext where underlying PrefixContexts were not being applied.

src/logdensityfunction.jl

+32-19
Original file line numberDiff line numberDiff line change
@@ -106,36 +106,49 @@ Return the parameters of the wrapped varinfo as a vector.
106106
"""
107107
getparams(f::LogDensityFunction) = f.varinfo[:]
108108

109-
# LogDensityProblems interface
110-
function LogDensityProblems.logdensity(f::LogDensityFunction, θ::AbstractVector)
109+
# LogDensityProblems interface: logp (0th order)
110+
function LogDensityProblems.logdensity(f::LogDensityFunction, x::AbstractVector)
111111
context = getcontext(f)
112-
vi_new = unflatten(f.varinfo, θ)
112+
vi_new = unflatten(f.varinfo, x)
113113
return getlogp(last(evaluate!!(f.model, vi_new, context)))
114114
end
115+
function _flipped_logdensity(x::AbstractVector, f::LogDensityFunction)
116+
return LogDensityProblems.logdensity(f, x)
117+
end
115118
function LogDensityProblems.capabilities(::Type{<:LogDensityFunction})
116119
return LogDensityProblems.LogDensityOrder{0}()
117120
end
118121
# TODO: should we instead implement and call on `length(f.varinfo)` (at least in the cases where no sampler is involved)?
119122
LogDensityProblems.dimension(f::LogDensityFunction) = length(getparams(f))
120123

121-
_flipped_logdensity(θ, f) = LogDensityProblems.logdensity(f, θ)
122-
124+
# LogDensityProblems interface: gradient (1st order)
125+
struct LogDensityFunctionWithGrad{V,M,C,TAD}
126+
ldf::LogDensityFunction{V,M,C}
127+
adtype::TAD
128+
prep::DI.GradientPrep
129+
130+
function LogDensityFunctionWithGrad(
131+
ldf::LogDensityFunction{V,M,C}, adtype::TAD
132+
) where {V,M,C,TAD}
133+
# Get a set of dummy params to use for prep
134+
x = ldf.varinfo[:]
135+
prep = DI.prepare_gradient(_flipped_logdensity, adtype, x, DI.Constant(ldf))
136+
# Store the prep with the struct
137+
return new{V,M,C,TAD}(ldf, adtype, prep)
138+
end
139+
end
140+
function LogDensityProblems.logdensity(f::LogDensityFunctionWithGrad)
141+
return LogDensityProblems.logdensity(f.ldf)
142+
end
143+
function LogDensityProblems.capabilities(::Type{<:LogDensityFunctionWithGrad})
144+
return LogDensityProblems.LogDensityOrder{1}()
145+
end
123146
# By default, the AD backend to use is inferred from the context, which would
124147
# typically be a SamplingContext which contains a sampler.
125148
function LogDensityProblems.logdensity_and_gradient(
126-
f::LogDensityFunction, θ::AbstractVector
127-
)
128-
adtype = getadtype(getsampler(getcontext(f)))
129-
return LogDensityProblems.logdensity_and_gradient(f, θ, adtype)
130-
end
131-
132-
# Extra method allowing one to manually specify the AD backend to use, thus
133-
# overriding the default AD backend inferred from the sampler.
134-
function LogDensityProblems.logdensity_and_gradient(
135-
f::LogDensityFunction, θ::AbstractVector, adtype::ADTypes.AbstractADType
149+
f::LogDensityFunctionWithGrad, x::AbstractVector
136150
)
137-
# Ensure we concretise the elements of the params.
138-
θ = map(identity, θ) # TODO: Is this needed?
139-
prep = DI.prepare_gradient(_flipped_logdensity, adtype, θ, DI.Constant(f))
140-
return DI.value_and_gradient(_flipped_logdensity, prep, adtype, θ, DI.Constant(f))
151+
return DI.value_and_gradient(
152+
_flipped_logdensity, f.prep, f.adtype, x, DI.Constant(f.ldf)
153+
)
141154
end

test/ad.jl

+12-8
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,12 @@
88
f = DynamicPPL.LogDensityFunction(m, varinfo)
99
# convert to `Vector{Float64}` to avoid `ReverseDiff` initializing the gradients to Integer 0
1010
# reference: https://github.com/TuringLang/DynamicPPL.jl/pull/571#issuecomment-1924304489
11-
θ = convert(Vector{Float64}, varinfo[:])
11+
x = convert(Vector{Float64}, varinfo[:])
1212
# Calculate reference logp + gradient of logp using ForwardDiff
1313
default_adtype = ADTypes.AutoForwardDiff()
14+
ldf_with_grad = DynamicPPL.LogDensityFunctionWithGrad(f, default_adtype)
1415
ref_logp, ref_grad = LogDensityProblems.logdensity_and_gradient(
15-
f, θ, default_adtype
16+
ldf_with_grad, x
1617
)
1718

1819
@testset "$adtype" for adtype in [
@@ -25,7 +26,8 @@
2526
if adtype isa ADTypes.AutoMooncake && varinfo isa DynamicPPL.SimpleVarInfo
2627
@test_broken 1 == 0
2728
else
28-
logp, grad = LogDensityProblems.logdensity_and_gradient(f, θ, adtype)
29+
ldf_with_grad = DynamicPPL.LogDensityFunctionWithGrad(f, adtype)
30+
logp, grad = LogDensityProblems.logdensity_and_gradient(f, x)
2931
@test grad ref_grad
3032
@test logp ref_logp
3133
end
@@ -62,15 +64,17 @@
6264
# of implementation
6365
struct MyEmptyAlg end
6466
DynamicPPL.getspace(::DynamicPPL.Sampler{MyEmptyAlg}) = ()
65-
DynamicPPL.assume(rng, ::DynamicPPL.Sampler{MyEmptyAlg}, dist, vn, vi) =
66-
DynamicPPL.assume(dist, vn, vi)
67+
DynamicPPL.assume(
68+
::Random.AbstractRNG, ::DynamicPPL.Sampler{MyEmptyAlg}, dist, vn, vi
69+
) = DynamicPPL.assume(dist, vn, vi)
6770

6871
# Compiling the ReverseDiff tape used to fail here
6972
spl = Sampler(MyEmptyAlg())
7073
vi = VarInfo(model)
7174
ldf = DynamicPPL.LogDensityFunction(vi, model, SamplingContext(spl))
72-
@test LogDensityProblems.logdensity_and_gradient(
73-
ldf, vi[:], AutoReverseDiff(; compile=true)
74-
) isa Any
75+
ldf_grad = DynamicPPL.LogDensityFunctionWithGrad(
76+
ldf, AutoReverseDiff(; compile=true)
77+
)
78+
@test LogDensityProblems.logdensity_and_gradient(ldf_grad, vi[:]) isa Any
7579
end
7680
end

0 commit comments

Comments
 (0)