Skip to content

Commit 8e22c05

Browse files
committed
Remove LogDensityProblemsAD
1 parent 7613dbb commit 8e22c05

10 files changed

+41
-98
lines changed

Project.toml

+3-7
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,13 @@ Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
1212
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
1313
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
1414
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
15+
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
1516
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
1617
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
1718
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
18-
# TODO(penelopeysm,mhauru) KernelAbstractions is only a dependency so that we can pin its version, see
19-
# https://github.com/TuringLang/DynamicPPL.jl/pull/781#event-16017866767
2019
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
2120
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
2221
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
23-
LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1"
2422
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
2523
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
2624
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
@@ -56,17 +54,15 @@ Bijectors = "0.13.18, 0.14, 0.15"
5654
ChainRulesCore = "1"
5755
Compat = "4"
5856
ConstructionBase = "1.5.4"
57+
DifferentiationInterface = "0.6.39"
5958
Distributions = "0.25"
6059
DocStringExtensions = "0.9"
61-
# TODO(penelopeysm,mhauru) See https://github.com/TuringLang/DynamicPPL.jl/pull/781#event-16017866767
62-
# for why KernelAbstractions is pinned like this.
63-
KernelAbstractions = "< 0.9.32"
6460
EnzymeCore = "0.6 - 0.8"
6561
ForwardDiff = "0.10"
6662
JET = "0.9"
63+
KernelAbstractions = "< 0.9.32"
6764
LinearAlgebra = "1.6"
6865
LogDensityProblems = "2"
69-
LogDensityProblemsAD = "1.7.0"
7066
MCMCChains = "6"
7167
MacroTools = "0.5.6"
7268
Mooncake = "0.4.59"

ext/DynamicPPLForwardDiffExt.jl

+2-29
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,13 @@
11
module DynamicPPLForwardDiffExt
22

3-
if isdefined(Base, :get_extension)
4-
using DynamicPPL: ADTypes, DynamicPPL, LogDensityProblems, LogDensityProblemsAD
5-
using ForwardDiff
6-
else
7-
using ..DynamicPPL: ADTypes, DynamicPPL, LogDensityProblems, LogDensityProblemsAD
8-
using ..ForwardDiff
9-
end
3+
using DynamicPPL: ADTypes, DynamicPPL, LogDensityProblems
4+
using ForwardDiff
105

116
getchunksize(::ADTypes.AutoForwardDiff{chunk}) where {chunk} = chunk
127

138
standardtag(::ADTypes.AutoForwardDiff{<:Any,Nothing}) = true
149
standardtag(::ADTypes.AutoForwardDiff) = false
1510

16-
function LogDensityProblemsAD.ADgradient(
17-
ad::ADTypes.AutoForwardDiff, ℓ::DynamicPPL.LogDensityFunction
18-
)
19-
θ = DynamicPPL.getparams(ℓ)
20-
f = Base.Fix1(LogDensityProblems.logdensity, ℓ)
21-
22-
# Define configuration for ForwardDiff.
23-
tag = if standardtag(ad)
24-
ForwardDiff.Tag(DynamicPPL.DynamicPPLTag(), eltype(θ))
25-
else
26-
ForwardDiff.Tag(f, eltype(θ))
27-
end
28-
chunk_size = getchunksize(ad)
29-
chunk = if chunk_size == 0 || chunk_size === nothing
30-
ForwardDiff.Chunk(θ)
31-
else
32-
ForwardDiff.Chunk(length(θ), chunk_size)
33-
end
34-
35-
return LogDensityProblemsAD.ADgradient(Val(:ForwardDiff), ℓ; chunk, tag, x=θ)
36-
end
37-
3811
# Allow Turing tag in gradient etc. calls of the log density function
3912
function ForwardDiff.checktag(
4013
::Type{ForwardDiff.Tag{DynamicPPL.DynamicPPLTag,V}},

src/DynamicPPL.jl

-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ using MacroTools: MacroTools
1414
using ConstructionBase: ConstructionBase
1515
using Accessors: Accessors
1616
using LogDensityProblems: LogDensityProblems
17-
using LogDensityProblemsAD: LogDensityProblemsAD
1817

1918
using LinearAlgebra: LinearAlgebra, Cholesky
2019

src/contexts.jl

+1
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,7 @@ at which point it will return the sampler of that context.
184184
getsampler(context::SamplingContext) = context.sampler
185185
getsampler(context::AbstractContext) = getsampler(NodeTrait(context), context)
186186
getsampler(::IsParent, context::AbstractContext) = getsampler(childcontext(context))
187+
getsampler(::IsLeaf, ::AbstractContext) = error("No sampler found in context")
187188

188189
"""
189190
struct DefaultContext <: AbstractContext end

src/logdensityfunction.jl

+20-36
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import DifferentiationInterface as DI
2+
13
"""
24
LogDensityFunction
35
@@ -81,37 +83,13 @@ end
8183
8284
Return the `DynamicPPL.Model` wrapped in the given log-density function `f`.
8385
"""
84-
getmodel(f::LogDensityProblemsAD.ADGradientWrapper) =
85-
getmodel(LogDensityProblemsAD.parent(f))
8686
getmodel(f::DynamicPPL.LogDensityFunction) = f.model
8787

8888
"""
8989
setmodel(f, model[, adtype])
9090
9191
Set the `DynamicPPL.Model` in the given log-density function `f` to `model`.
92-
93-
!!! warning
94-
Note that if `f` is a `LogDensityProblemsAD.ADGradientWrapper` wrapping a
95-
`DynamicPPL.LogDensityFunction`, performing an update of the `model` in `f`
96-
might require recompilation of the gradient tape, depending on the AD backend.
9792
"""
98-
function setmodel(
99-
f::LogDensityProblemsAD.ADGradientWrapper,
100-
model::DynamicPPL.Model,
101-
adtype::ADTypes.AbstractADType,
102-
)
103-
# TODO: Should we handle `SciMLBase.NoAD`?
104-
# For an `ADGradientWrapper` we do the following:
105-
# 1. Update the `Model` in the underlying `LogDensityFunction`.
106-
# 2. Re-construct the `ADGradientWrapper` using `ADgradient` using the provided `adtype`
107-
# to ensure that the recompilation of gradient tapes, etc. also occur. For example,
108-
# ReverseDiff.jl in compiled mode will cache the compiled tape, which means that just
109-
# replacing the corresponding field with the new model won't be sufficient to obtain
110-
# the correct gradients.
111-
return LogDensityProblemsAD.ADgradient(
112-
adtype, setmodel(LogDensityProblemsAD.parent(f), model)
113-
)
114-
end
11593
function setmodel(f::DynamicPPL.LogDensityFunction, model::DynamicPPL.Model)
11694
return Accessors.@set f.model = model
11795
end
@@ -140,18 +118,24 @@ end
140118
# TODO: should we instead implement and call on `length(f.varinfo)` (at least in the cases where no sampler is involved)?
141119
LogDensityProblems.dimension(f::LogDensityFunction) = length(getparams(f))
142120

143-
# This is important for performance -- one needs to provide `ADGradient` with a vector of
144-
# parameters, or DifferentiationInterface will not have sufficient information to e.g.
145-
# compile a rule for Mooncake (because it won't know the type of the input), or pre-allocate
146-
# a tape when using ReverseDiff.jl.
147-
function _make_ad_gradient(ad::ADTypes.AbstractADType, ℓ::LogDensityFunction)
148-
x = map(identity, getparams(ℓ)) # ensure we concretise the elements of the params
149-
return LogDensityProblemsAD.ADgradient(ad, ℓ; x)
150-
end
121+
_flipped_logdensity(θ, f) = LogDensityProblems.logdensity(f, θ)
151122

152-
function LogDensityProblemsAD.ADgradient(ad::ADTypes.AutoMooncake, f::LogDensityFunction)
153-
return _make_ad_gradient(ad, f)
123+
# By default, the AD backend to use is inferred from the context, which would
124+
# typically be a SamplingContext which contains a sampler.
125+
function LogDensityProblems.logdensity_and_gradient(
126+
f::LogDensityFunction, θ::AbstractVector
127+
)
128+
adtype = getadtype(getsampler(getcontext(f)))
129+
return LogDensityProblems.logdensity_and_gradient(f, θ, adtype)
154130
end
155-
function LogDensityProblemsAD.ADgradient(ad::ADTypes.AutoReverseDiff, f::LogDensityFunction)
156-
return _make_ad_gradient(ad, f)
131+
132+
# Extra method allowing one to manually specify the AD backend to use, thus
133+
# overriding the default AD backend inferred from the sampler.
134+
function LogDensityProblems.logdensity_and_gradient(
135+
f::LogDensityFunction, θ::AbstractVector, adtype::ADTypes.AbstractADType
136+
)
137+
# Ensure we concretise the elements of the params.
138+
θ = map(identity, θ) # TODO: Is this needed?
139+
prep = DI.prepare_gradient(_flipped_logdensity, adtype, θ, DI.Constant(f))
140+
return DI.value_and_gradient(_flipped_logdensity, prep, adtype, θ, DI.Constant(f))
157141
end

src/sampler.jl

+3
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,9 @@ Sampler(alg) = Sampler(alg, Selector())
5454
Sampler(alg, model::Model) = Sampler(alg, model, Selector())
5555
Sampler(alg, model::Model, s::Selector) = Sampler(alg, s)
5656

57+
# Extract the AD type from the underlying algorithm
58+
getadtype(s::Sampler) = getadtype(s.alg)
59+
5760
# AbstractMCMC interface for SampleFromUniform and SampleFromPrior
5861
function AbstractMCMC.step(
5962
rng::Random.AbstractRNG,

test/Project.toml

-2
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
1616
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
1717
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1818
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
19-
LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1"
2019
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
2120
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
2221
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
@@ -46,7 +45,6 @@ EnzymeCore = "0.6 - 0.8"
4645
ForwardDiff = "0.10.12"
4746
JET = "0.9"
4847
LogDensityProblems = "2"
49-
LogDensityProblemsAD = "1.7.0"
5048
MCMCChains = "6.0.4"
5149
MacroTools = "0.5.6"
5250
Mooncake = "0.4.59"

test/ad.jl

+10-10
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,19 @@
11
@testset "AD: ForwardDiff, ReverseDiff, and Mooncake" begin
22
@testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS
3-
f = DynamicPPL.LogDensityFunction(m)
43
rand_param_values = DynamicPPL.TestUtils.rand_prior_true(m)
54
vns = DynamicPPL.TestUtils.varnames(m)
65
varinfos = DynamicPPL.TestUtils.setup_varinfos(m, rand_param_values, vns)
76

87
@testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos
98
f = DynamicPPL.LogDensityFunction(m, varinfo)
10-
11-
# use ForwardDiff result as reference
12-
ad_forwarddiff_f = LogDensityProblemsAD.ADgradient(
13-
ADTypes.AutoForwardDiff(; chunksize=0), f
14-
)
159
# convert to `Vector{Float64}` to avoid `ReverseDiff` initializing the gradients to Integer 0
1610
# reference: https://github.com/TuringLang/DynamicPPL.jl/pull/571#issuecomment-1924304489
1711
θ = convert(Vector{Float64}, varinfo[:])
18-
logp, ref_grad = LogDensityProblems.logdensity_and_gradient(ad_forwarddiff_f, θ)
12+
# Calculate reference logp + gradient of logp using ForwardDiff
13+
default_adtype = ADTypes.AutoForwardDiff()
14+
ref_logp, ref_grad = LogDensityProblems.logdensity_and_gradient(
15+
f, θ, default_adtype
16+
)
1917

2018
@testset "$adtype" for adtype in [
2119
ADTypes.AutoReverseDiff(; compile=false),
@@ -27,9 +25,9 @@
2725
if adtype isa ADTypes.AutoMooncake && varinfo isa DynamicPPL.SimpleVarInfo
2826
@test_broken 1 == 0
2927
else
30-
ad_f = LogDensityProblemsAD.ADgradient(adtype, f)
31-
_, grad = LogDensityProblems.logdensity_and_gradient(ad_f, θ)
28+
logp, grad = LogDensityProblems.logdensity_and_gradient(f, θ, adtype)
3229
@test grad ref_grad
30+
@test logp ref_logp
3331
end
3432
end
3533
end
@@ -71,6 +69,8 @@
7169
spl = Sampler(MyEmptyAlg())
7270
vi = VarInfo(model)
7371
ldf = DynamicPPL.LogDensityFunction(vi, model, SamplingContext(spl))
74-
@test LogDensityProblemsAD.ADgradient(AutoReverseDiff(; compile=true), ldf) isa Any
72+
@test LogDensityProblems.logdensity_and_gradient(
73+
ldf, vi[:], AutoReverseDiff(; compile=true)
74+
) isa Any
7575
end
7676
end

test/logdensityfunction.jl

+1-12
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,11 @@
1-
using Test, DynamicPPL, ADTypes, LogDensityProblems, LogDensityProblemsAD, ReverseDiff
1+
using Test, DynamicPPL, ADTypes, LogDensityProblems, ReverseDiff
22

33
@testset "`getmodel` and `setmodel`" begin
44
@testset "$(nameof(model))" for model in DynamicPPL.TestUtils.DEMO_MODELS
55
model = DynamicPPL.TestUtils.DEMO_MODELS[1]
66
= DynamicPPL.LogDensityFunction(model)
77
@test DynamicPPL.getmodel(ℓ) == model
88
@test DynamicPPL.setmodel(ℓ, model).model == model
9-
10-
# ReverseDiff related
11-
∇ℓ = LogDensityProblemsAD.ADgradient(:ReverseDiff, ℓ; compile=Val(false))
12-
@test DynamicPPL.getmodel(∇ℓ) == model
13-
@test DynamicPPL.getmodel(DynamicPPL.setmodel(∇ℓ, model, AutoReverseDiff())) ==
14-
model
15-
∇ℓ = LogDensityProblemsAD.ADgradient(:ReverseDiff, ℓ; compile=Val(true))
16-
new_∇ℓ = DynamicPPL.setmodel(∇ℓ, model, AutoReverseDiff())
17-
@test DynamicPPL.getmodel(new_∇ℓ) == model
18-
# HACK(sunxd): rely on internal implementation detail, i.e., naming of `compiledtape`
19-
@test new_∇ℓ.compiledtape != ∇ℓ.compiledtape
209
end
2110
end
2211

test/runtests.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ using Distributions
99
using DistributionsAD
1010
using Documenter
1111
using ForwardDiff
12-
using LogDensityProblems, LogDensityProblemsAD
12+
using LogDensityProblems
1313
using MacroTools
1414
using MCMCChains
1515
using Mooncake: Mooncake

0 commit comments

Comments
 (0)