Skip to content

Commit 566257e

Browse files
committed
Re-add ForwardDiffExt (including tests)
1 parent 05f1bce commit 566257e

File tree

5 files changed

+87
-44
lines changed

5 files changed

+87
-44
lines changed

Project.toml

+2
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,15 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2828
[weakdeps]
2929
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
3030
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
31+
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
3132
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
3233
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
3334
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
3435

3536
[extensions]
3637
DynamicPPLChainRulesCoreExt = ["ChainRulesCore"]
3738
DynamicPPLEnzymeCoreExt = ["EnzymeCore"]
39+
DynamicPPLForwardDiffExt = ["ForwardDiff"]
3840
DynamicPPLJETExt = ["JET"]
3941
DynamicPPLMCMCChainsExt = ["MCMCChains"]
4042
DynamicPPLMooncakeExt = ["Mooncake"]

ext/DynamicPPLForwardDiffExt.jl

+29-43
Original file line numberDiff line numberDiff line change
@@ -1,54 +1,40 @@
11
module DynamicPPLForwardDiffExt
22

3-
if isdefined(Base, :get_extension)
4-
using DynamicPPL: ADTypes, DynamicPPL, LogDensityProblems, LogDensityProblemsAD
5-
using ForwardDiff
6-
else
7-
using ..DynamicPPL: ADTypes, DynamicPPL, LogDensityProblems, LogDensityProblemsAD
8-
using ..ForwardDiff
9-
end
10-
11-
getchunksize(::ADTypes.AutoForwardDiff{chunk}) where {chunk} = chunk
12-
13-
standardtag(::ADTypes.AutoForwardDiff{<:Any,Nothing}) = true
14-
standardtag(::ADTypes.AutoForwardDiff) = false
15-
16-
function LogDensityProblemsAD.ADgradient(
17-
ad::ADTypes.AutoForwardDiff, ℓ::DynamicPPL.LogDensityFunction
18-
)
19-
θ = DynamicPPL.getparams(ℓ)
20-
f = Base.Fix1(LogDensityProblems.logdensity, ℓ)
21-
22-
# Define configuration for ForwardDiff.
23-
tag = if standardtag(ad)
24-
ForwardDiff.Tag(DynamicPPL.DynamicPPLTag(), eltype(θ))
3+
using DynamicPPL: ADTypes, DynamicPPL, LogDensityProblems
4+
using ForwardDiff
5+
6+
# check if the AD type already has a tag
7+
use_dynamicppl_tag(::ADTypes.AutoForwardDiff{<:Any,Nothing}) = true
8+
use_dynamicppl_tag(::ADTypes.AutoForwardDiff) = false
9+
10+
function DynamicPPL.tweak_adtype(
11+
ad::ADTypes.AutoForwardDiff{chunk_size},
12+
::DynamicPPL.Model,
13+
vi::DynamicPPL.AbstractVarInfo,
14+
::DynamicPPL.AbstractContext,
15+
) where {chunk_size}
16+
params = vi[:]
17+
18+
# Use DynamicPPL tag to improve stack traces
19+
# https://www.stochasticlifestyle.com/improved-forwarddiff-jl-stacktraces-with-package-tags/
20+
# NOTE: DifferentiationInterface disables tag checking if the
21+
# tag inside the AutoForwardDiff type is not nothing. See
22+
# https://github.com/JuliaDiff/DifferentiationInterface.jl/blob/1df562180bdcc3e91c885aa5f4162a0be2ced850/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl#L338-L350.
23+
# So we don't currently need to override ForwardDiff.checktag as well.
24+
tag = if use_dynamicppl_tag(ad)
25+
ForwardDiff.Tag(DynamicPPL.DynamicPPLTag(), eltype(params))
2526
else
26-
ForwardDiff.Tag(f, eltype(θ))
27+
ad.tag
2728
end
28-
chunk_size = getchunksize(ad)
29+
30+
# Optimise chunk size according to size of model
2931
chunk = if chunk_size == 0 || chunk_size === nothing
30-
ForwardDiff.Chunk(θ)
32+
ForwardDiff.Chunk(params)
3133
else
32-
ForwardDiff.Chunk(length(θ), chunk_size)
34+
ForwardDiff.Chunk(length(params), chunk_size)
3335
end
3436

35-
return LogDensityProblemsAD.ADgradient(Val(:ForwardDiff), ℓ; chunk, tag, x=θ)
36-
end
37-
38-
# Allow Turing tag in gradient etc. calls of the log density function
39-
function ForwardDiff.checktag(
40-
::Type{ForwardDiff.Tag{DynamicPPL.DynamicPPLTag,V}},
41-
::DynamicPPL.LogDensityFunction,
42-
::AbstractArray{W},
43-
) where {V,W}
44-
return true
45-
end
46-
function ForwardDiff.checktag(
47-
::Type{ForwardDiff.Tag{DynamicPPL.DynamicPPLTag,V}},
48-
::Base.Fix1{typeof(LogDensityProblems.logdensity),<:DynamicPPL.LogDensityFunction},
49-
::AbstractArray{W},
50-
) where {V,W}
51-
return true
37+
return ADTypes.AutoForwardDiff(; chunksize=ForwardDiff.chunksize(chunk), tag=tag)
5238
end
5339

5440
end # module

src/logdensityfunction.jl

+23-1
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,9 @@ struct LogDensityFunction{
116116
if adtype === nothing
117117
prep = nothing
118118
else
119-
# Check support
119+
# Make backend-specific tweaks to the adtype
120+
adtype = tweak_adtype(adtype, model, varinfo, context)
121+
# Check whether it is supported
120122
is_supported(adtype) ||
121123
@warn "The AD backend $adtype is not officially supported by DynamicPPL. Gradient calculations may still work, but compatibility is not guaranteed."
122124
# Get a set of dummy params to use for prep
@@ -227,6 +229,26 @@ LogDensityProblems.dimension(f::LogDensityFunction) = length(getparams(f))
227229

228230
### Utils
229231

232+
"""
233+
tweak_adtype(
234+
adtype::ADTypes.AbstractADType,
235+
model::Model,
236+
varinfo::AbstractVarInfo,
237+
context::AbstractContext
238+
)
239+
240+
Return an 'optimised' form of the adtype. This is useful for doing
241+
backend-specific optimisation of the adtype (e.g., for ForwardDiff, calculating
242+
the chunk size: see the method override in `ext/DynamicPPLForwardDiffExt.jl`).
243+
The model is passed as a parameter in case the optimisation depends on the
244+
model.
245+
246+
By default, this just returns the input unchanged.
247+
"""
248+
tweak_adtype(
249+
adtype::ADTypes.AbstractADType, ::Model, ::AbstractVarInfo, ::AbstractContext
250+
) = adtype
251+
230252
"""
231253
use_closure(adtype::ADTypes.AbstractADType)
232254

test/ext/DynamicPPLForwardDiffExt.jl

+32
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
module DynamicPPLForwardDiffExtTests
2+
3+
using DynamicPPL
4+
using ADTypes: AutoForwardDiff
5+
using ForwardDiff: ForwardDiff
6+
using Distributions: MvNormal
7+
using LinearAlgebra: I
8+
using Test: @test, @testset
9+
10+
# get_chunksize(ad::AutoForwardDiff{chunk}) where {chunk} = chunk
11+
12+
@testset "ForwardDiff tweak_adtype" begin
13+
MODEL_SIZE = 10
14+
@model f() = x ~ MvNormal(zeros(MODEL_SIZE), I)
15+
model = f()
16+
varinfo = VarInfo(model)
17+
context = DefaultContext()
18+
19+
@testset "Chunk size setting" for chunksize in (nothing, 0)
20+
base_adtype = AutoForwardDiff(; chunksize=chunksize)
21+
new_adtype = DynamicPPL.tweak_adtype(base_adtype, model, varinfo, context)
22+
@test new_adtype isa AutoForwardDiff{MODEL_SIZE}
23+
end
24+
25+
@testset "Tag setting" begin
26+
base_adtype = AutoForwardDiff()
27+
new_adtype = DynamicPPL.tweak_adtype(base_adtype, model, varinfo, context)
28+
@test new_adtype.tag isa ForwardDiff.Tag{DynamicPPL.DynamicPPLTag}
29+
end
30+
end
31+
32+
end

test/runtests.jl

+1
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ include("test_util.jl")
7575
include("ext/DynamicPPLJETExt.jl")
7676
end
7777
@testset "ad" begin
78+
include("ext/DynamicPPLForwardDiffExt.jl")
7879
include("ext/DynamicPPLMooncakeExt.jl")
7980
include("ad.jl")
8081
end

0 commit comments

Comments
 (0)