Skip to content

Commit 5be363f

Browse files
committed
Remove LogDensityProblemsAD
1 parent 7613dbb commit 5be363f

10 files changed

+82
-142
lines changed

Project.toml

+3-7
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,13 @@ Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
1212
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
1313
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
1414
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
15+
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
1516
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
1617
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
1718
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
18-
# TODO(penelopeysm,mhauru) KernelAbstractions is only a dependency so that we can pin its version, see
19-
# https://github.com/TuringLang/DynamicPPL.jl/pull/781#event-16017866767
2019
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
2120
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
2221
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
23-
LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1"
2422
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
2523
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
2624
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
@@ -56,17 +54,15 @@ Bijectors = "0.13.18, 0.14, 0.15"
5654
ChainRulesCore = "1"
5755
Compat = "4"
5856
ConstructionBase = "1.5.4"
57+
DifferentiationInterface = "0.6.39"
5958
Distributions = "0.25"
6059
DocStringExtensions = "0.9"
61-
# TODO(penelopeysm,mhauru) See https://github.com/TuringLang/DynamicPPL.jl/pull/781#event-16017866767
62-
# for why KernelAbstractions is pinned like this.
63-
KernelAbstractions = "< 0.9.32"
6460
EnzymeCore = "0.6 - 0.8"
6561
ForwardDiff = "0.10"
6662
JET = "0.9"
63+
KernelAbstractions = "< 0.9.32"
6764
LinearAlgebra = "1.6"
6865
LogDensityProblems = "2"
69-
LogDensityProblemsAD = "1.7.0"
7066
MCMCChains = "6"
7167
MacroTools = "0.5.6"
7268
Mooncake = "0.4.59"

ext/DynamicPPLForwardDiffExt.jl

+2-29
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,13 @@
11
module DynamicPPLForwardDiffExt
22

3-
if isdefined(Base, :get_extension)
4-
using DynamicPPL: ADTypes, DynamicPPL, LogDensityProblems, LogDensityProblemsAD
5-
using ForwardDiff
6-
else
7-
using ..DynamicPPL: ADTypes, DynamicPPL, LogDensityProblems, LogDensityProblemsAD
8-
using ..ForwardDiff
9-
end
3+
using DynamicPPL: ADTypes, DynamicPPL, LogDensityProblems
4+
using ForwardDiff
105

116
getchunksize(::ADTypes.AutoForwardDiff{chunk}) where {chunk} = chunk
127

138
standardtag(::ADTypes.AutoForwardDiff{<:Any,Nothing}) = true
149
standardtag(::ADTypes.AutoForwardDiff) = false
1510

16-
function LogDensityProblemsAD.ADgradient(
17-
ad::ADTypes.AutoForwardDiff, ℓ::DynamicPPL.LogDensityFunction
18-
)
19-
θ = DynamicPPL.getparams(ℓ)
20-
f = Base.Fix1(LogDensityProblems.logdensity, ℓ)
21-
22-
# Define configuration for ForwardDiff.
23-
tag = if standardtag(ad)
24-
ForwardDiff.Tag(DynamicPPL.DynamicPPLTag(), eltype(θ))
25-
else
26-
ForwardDiff.Tag(f, eltype(θ))
27-
end
28-
chunk_size = getchunksize(ad)
29-
chunk = if chunk_size == 0 || chunk_size === nothing
30-
ForwardDiff.Chunk(θ)
31-
else
32-
ForwardDiff.Chunk(length(θ), chunk_size)
33-
end
34-
35-
return LogDensityProblemsAD.ADgradient(Val(:ForwardDiff), ℓ; chunk, tag, x=θ)
36-
end
37-
3811
# Allow Turing tag in gradient etc. calls of the log density function
3912
function ForwardDiff.checktag(
4013
::Type{ForwardDiff.Tag{DynamicPPL.DynamicPPLTag,V}},

src/DynamicPPL.jl

-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ using MacroTools: MacroTools
1414
using ConstructionBase: ConstructionBase
1515
using Accessors: Accessors
1616
using LogDensityProblems: LogDensityProblems
17-
using LogDensityProblemsAD: LogDensityProblemsAD
1817

1918
using LinearAlgebra: LinearAlgebra, Cholesky
2019

src/contexts.jl

+1
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,7 @@ at which point it will return the sampler of that context.
184184
getsampler(context::SamplingContext) = context.sampler
185185
getsampler(context::AbstractContext) = getsampler(NodeTrait(context), context)
186186
getsampler(::IsParent, context::AbstractContext) = getsampler(childcontext(context))
187+
getsampler(::IsLeaf, ::AbstractContext) = error("No sampler found in context")
187188

188189
"""
189190
struct DefaultContext <: AbstractContext end

src/logdensityfunction.jl

+20-36
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import DifferentiationInterface as DI
2+
13
"""
24
LogDensityFunction
35
@@ -81,37 +83,13 @@ end
8183
8284
Return the `DynamicPPL.Model` wrapped in the given log-density function `f`.
8385
"""
84-
getmodel(f::LogDensityProblemsAD.ADGradientWrapper) =
85-
getmodel(LogDensityProblemsAD.parent(f))
8686
getmodel(f::DynamicPPL.LogDensityFunction) = f.model
8787

8888
"""
8989
setmodel(f, model[, adtype])
9090
9191
Set the `DynamicPPL.Model` in the given log-density function `f` to `model`.
92-
93-
!!! warning
94-
Note that if `f` is a `LogDensityProblemsAD.ADGradientWrapper` wrapping a
95-
`DynamicPPL.LogDensityFunction`, performing an update of the `model` in `f`
96-
might require recompilation of the gradient tape, depending on the AD backend.
9792
"""
98-
function setmodel(
99-
f::LogDensityProblemsAD.ADGradientWrapper,
100-
model::DynamicPPL.Model,
101-
adtype::ADTypes.AbstractADType,
102-
)
103-
# TODO: Should we handle `SciMLBase.NoAD`?
104-
# For an `ADGradientWrapper` we do the following:
105-
# 1. Update the `Model` in the underlying `LogDensityFunction`.
106-
# 2. Re-construct the `ADGradientWrapper` using `ADgradient` using the provided `adtype`
107-
# to ensure that the recompilation of gradient tapes, etc. also occur. For example,
108-
# ReverseDiff.jl in compiled mode will cache the compiled tape, which means that just
109-
# replacing the corresponding field with the new model won't be sufficient to obtain
110-
# the correct gradients.
111-
return LogDensityProblemsAD.ADgradient(
112-
adtype, setmodel(LogDensityProblemsAD.parent(f), model)
113-
)
114-
end
11593
function setmodel(f::DynamicPPL.LogDensityFunction, model::DynamicPPL.Model)
11694
return Accessors.@set f.model = model
11795
end
@@ -140,18 +118,24 @@ end
140118
# TODO: should we instead implement and call on `length(f.varinfo)` (at least in the cases where no sampler is involved)?
141119
LogDensityProblems.dimension(f::LogDensityFunction) = length(getparams(f))
142120

143-
# This is important for performance -- one needs to provide `ADGradient` with a vector of
144-
# parameters, or DifferentiationInterface will not have sufficient information to e.g.
145-
# compile a rule for Mooncake (because it won't know the type of the input), or pre-allocate
146-
# a tape when using ReverseDiff.jl.
147-
function _make_ad_gradient(ad::ADTypes.AbstractADType, ℓ::LogDensityFunction)
148-
x = map(identity, getparams(ℓ)) # ensure we concretise the elements of the params
149-
return LogDensityProblemsAD.ADgradient(ad, ℓ; x)
150-
end
121+
_flipped_logdensity(θ, f) = LogDensityProblems.logdensity(f, θ)
151122

152-
function LogDensityProblemsAD.ADgradient(ad::ADTypes.AutoMooncake, f::LogDensityFunction)
153-
return _make_ad_gradient(ad, f)
123+
# By default, the AD backend to use is inferred from the context, which would
124+
# typically be a SamplingContext which contains a sampler.
125+
function LogDensityProblems.logdensity_and_gradient(
126+
f::LogDensityFunction, θ::AbstractVector
127+
)
128+
adtype = getadtype(getsampler(getcontext(f)))
129+
return LogDensityProblems.logdensity_and_gradient(f, θ, adtype)
154130
end
155-
function LogDensityProblemsAD.ADgradient(ad::ADTypes.AutoReverseDiff, f::LogDensityFunction)
156-
return _make_ad_gradient(ad, f)
131+
132+
# Extra method allowing one to manually specify the AD backend to use, thus
133+
# overriding the default AD backend inferred from the sampler.
134+
function LogDensityProblems.logdensity_and_gradient(
135+
f::LogDensityFunction, θ::AbstractVector, adtype::ADTypes.AbstractADType
136+
)
137+
# Ensure we concretise the elements of the params.
138+
θ = map(identity, getparams(f))
139+
prep = DI.prepare_gradient(_flipped_logdensity, adtype, params, DI.Constant(f))
140+
return DI.value_and_gradient(_flipped_logdensity, prep, adtype, params, DI.Constant(f))
157141
end

src/sampler.jl

+3
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,9 @@ Sampler(alg) = Sampler(alg, Selector())
5454
Sampler(alg, model::Model) = Sampler(alg, model, Selector())
5555
Sampler(alg, model::Model, s::Selector) = Sampler(alg, s)
5656

57+
# Extract the AD type from the underlying algorithm
58+
getadtype(s::Sampler) = getadtype(s.alg)
59+
5760
# AbstractMCMC interface for SampleFromUniform and SampleFromPrior
5861
function AbstractMCMC.step(
5962
rng::Random.AbstractRNG,

test/Project.toml

-2
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
1616
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
1717
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1818
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
19-
LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1"
2019
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
2120
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
2221
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
@@ -46,7 +45,6 @@ EnzymeCore = "0.6 - 0.8"
4645
ForwardDiff = "0.10.12"
4746
JET = "0.9"
4847
LogDensityProblems = "2"
49-
LogDensityProblemsAD = "1.7.0"
5048
MCMCChains = "6.0.4"
5149
MacroTools = "0.5.6"
5250
Mooncake = "0.4.59"

test/ad.jl

+7-10
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,12 @@
77

88
@testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos
99
f = DynamicPPL.LogDensityFunction(m, varinfo)
10-
11-
# use ForwardDiff result as reference
12-
ad_forwarddiff_f = LogDensityProblemsAD.ADgradient(
13-
ADTypes.AutoForwardDiff(; chunksize=0), f
14-
)
1510
# convert to `Vector{Float64}` to avoid `ReverseDiff` initializing the gradients to Integer 0
1611
# reference: https://github.com/TuringLang/DynamicPPL.jl/pull/571#issuecomment-1924304489
1712
θ = convert(Vector{Float64}, varinfo[:])
18-
logp, ref_grad = LogDensityProblems.logdensity_and_gradient(ad_forwarddiff_f, θ)
13+
# Calculate reference logp + gradient of logp using ForwardDiff
14+
default_adtype = ADTypes.AutoForwardDiff(; chunksize=0)
15+
ref_logp, ref_grad = LogDensityProblems.logdensity_and_gradient(f, θ, default_adtype)
1916

2017
@testset "$adtype" for adtype in [
2118
ADTypes.AutoReverseDiff(; compile=false),
@@ -27,9 +24,9 @@
2724
if adtype isa ADTypes.AutoMooncake && varinfo isa DynamicPPL.SimpleVarInfo
2825
@test_broken 1 == 0
2926
else
30-
ad_f = LogDensityProblemsAD.ADgradient(adtype, f)
31-
_, grad = LogDensityProblems.logdensity_and_gradient(ad_f, θ)
27+
logp, grad = LogDensityProblems.logdensity_and_gradient(f, θ, adtype)
3228
@test grad ref_grad
29+
@test logp ref_logp
3330
end
3431
end
3532
end
@@ -50,7 +47,7 @@
5047
x = Vector{T}(undef, TT)
5148
x[1] = α
5249
for t in 2:TT
53-
x[t] = x[t - 1] + η[t - 1] * τ
50+
x[t] = x[t-1] + η[t-1] * τ
5451
end
5552
# measurement model
5653
y ~ MvNormal(x, σ^2 * I)
@@ -71,6 +68,6 @@
7168
spl = Sampler(MyEmptyAlg())
7269
vi = VarInfo(model)
7370
ldf = DynamicPPL.LogDensityFunction(vi, model, SamplingContext(spl))
74-
@test LogDensityProblemsAD.ADgradient(AutoReverseDiff(; compile=true), ldf) isa Any
71+
@test LogDensityProblems.logdensity_and_gradient(ldf, vi[:], AutoReverseDiff(; compile=true)) isa Any
7572
end
7673
end

test/logdensityfunction.jl

+1-12
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,11 @@
1-
using Test, DynamicPPL, ADTypes, LogDensityProblems, LogDensityProblemsAD, ReverseDiff
1+
using Test, DynamicPPL, ADTypes, LogDensityProblems, ReverseDiff
22

33
@testset "`getmodel` and `setmodel`" begin
44
@testset "$(nameof(model))" for model in DynamicPPL.TestUtils.DEMO_MODELS
55
model = DynamicPPL.TestUtils.DEMO_MODELS[1]
66
= DynamicPPL.LogDensityFunction(model)
77
@test DynamicPPL.getmodel(ℓ) == model
88
@test DynamicPPL.setmodel(ℓ, model).model == model
9-
10-
# ReverseDiff related
11-
∇ℓ = LogDensityProblemsAD.ADgradient(:ReverseDiff, ℓ; compile=Val(false))
12-
@test DynamicPPL.getmodel(∇ℓ) == model
13-
@test DynamicPPL.getmodel(DynamicPPL.setmodel(∇ℓ, model, AutoReverseDiff())) ==
14-
model
15-
∇ℓ = LogDensityProblemsAD.ADgradient(:ReverseDiff, ℓ; compile=Val(true))
16-
new_∇ℓ = DynamicPPL.setmodel(∇ℓ, model, AutoReverseDiff())
17-
@test DynamicPPL.getmodel(new_∇ℓ) == model
18-
# HACK(sunxd): rely on internal implementation detail, i.e., naming of `compiledtape`
19-
@test new_∇ℓ.compiledtape != ∇ℓ.compiledtape
209
end
2110
end
2211

test/runtests.jl

+45-45
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ using Distributions
99
using DistributionsAD
1010
using Documenter
1111
using ForwardDiff
12-
using LogDensityProblems, LogDensityProblemsAD
12+
using LogDensityProblems
1313
using MacroTools
1414
using MCMCChains
1515
using Mooncake: Mooncake
@@ -45,57 +45,57 @@ include("test_util.jl")
4545
# groups are chosen to make both groups take roughly the same amount of
4646
# time, but beyond that there is no particular reason for the split.
4747
if GROUP == "All" || GROUP == "Group1"
48-
include("utils.jl")
49-
include("compiler.jl")
50-
include("varnamedvector.jl")
51-
include("varinfo.jl")
52-
include("simple_varinfo.jl")
53-
include("model.jl")
54-
include("sampler.jl")
55-
include("independence.jl")
56-
include("distribution_wrappers.jl")
57-
include("logdensityfunction.jl")
58-
include("linking.jl")
59-
include("serialization.jl")
60-
include("pointwise_logdensities.jl")
61-
include("lkj.jl")
62-
include("deprecated.jl")
48+
# include("utils.jl")
49+
# include("compiler.jl")
50+
# include("varnamedvector.jl")
51+
# include("varinfo.jl")
52+
# include("simple_varinfo.jl")
53+
# include("model.jl")
54+
# include("sampler.jl")
55+
# include("independence.jl")
56+
# include("distribution_wrappers.jl")
57+
# include("logdensityfunction.jl")
58+
# include("linking.jl")
59+
# include("serialization.jl")
60+
# include("pointwise_logdensities.jl")
61+
# include("lkj.jl")
62+
# include("deprecated.jl")
6363
end
6464

6565
if GROUP == "All" || GROUP == "Group2"
66-
include("contexts.jl")
67-
include("context_implementations.jl")
68-
include("threadsafe.jl")
69-
include("debug_utils.jl")
70-
@testset "compat" begin
71-
include(joinpath("compat", "ad.jl"))
72-
end
73-
@testset "extensions" begin
74-
include("ext/DynamicPPLMCMCChainsExt.jl")
75-
include("ext/DynamicPPLJETExt.jl")
76-
end
66+
# include("contexts.jl")
67+
# include("context_implementations.jl")
68+
# include("threadsafe.jl")
69+
# include("debug_utils.jl")
70+
# @testset "compat" begin
71+
# include(joinpath("compat", "ad.jl"))
72+
# end
73+
# @testset "extensions" begin
74+
# include("ext/DynamicPPLMCMCChainsExt.jl")
75+
# include("ext/DynamicPPLJETExt.jl")
76+
# end
7777
@testset "ad" begin
7878
include("ext/DynamicPPLForwardDiffExt.jl")
7979
include("ext/DynamicPPLMooncakeExt.jl")
8080
include("ad.jl")
8181
end
82-
@testset "prob and logprob macro" begin
83-
@test_throws ErrorException prob"..."
84-
@test_throws ErrorException logprob"..."
85-
end
86-
@testset "doctests" begin
87-
DocMeta.setdocmeta!(
88-
DynamicPPL,
89-
:DocTestSetup,
90-
:(using DynamicPPL, Distributions);
91-
recursive=true,
92-
)
93-
doctestfilters = [
94-
# Ignore the source of a warning in the doctest output, since this is dependent on host.
95-
# This is a line that starts with "└ @ " and ends with the line number.
96-
r"└ @ .+:[0-9]+",
97-
]
98-
doctest(DynamicPPL; manual=false, doctestfilters=doctestfilters)
99-
end
82+
# @testset "prob and logprob macro" begin
83+
# @test_throws ErrorException prob"..."
84+
# @test_throws ErrorException logprob"..."
85+
# end
86+
# @testset "doctests" begin
87+
# DocMeta.setdocmeta!(
88+
# DynamicPPL,
89+
# :DocTestSetup,
90+
# :(using DynamicPPL, Distributions);
91+
# recursive=true,
92+
# )
93+
# doctestfilters = [
94+
# # Ignore the source of a warning in the doctest output, since this is dependent on host.
95+
# # This is a line that starts with "└ @ " and ends with the line number.
96+
# r"└ @ .+:[0-9]+",
97+
# ]
98+
# doctest(DynamicPPL; manual=false, doctestfilters=doctestfilters)
99+
# end
100100
end
101101
end

0 commit comments

Comments
 (0)