|
1 | 1 | module DynamicPPLForwardDiffExt
|
2 | 2 |
|
3 |
| -if isdefined(Base, :get_extension) |
4 |
| - using DynamicPPL: ADTypes, DynamicPPL, LogDensityProblems, LogDensityProblemsAD |
5 |
| - using ForwardDiff |
6 |
| -else |
7 |
| - using ..DynamicPPL: ADTypes, DynamicPPL, LogDensityProblems, LogDensityProblemsAD |
8 |
| - using ..ForwardDiff |
9 |
| -end |
10 |
| - |
11 |
| -getchunksize(::ADTypes.AutoForwardDiff{chunk}) where {chunk} = chunk |
12 |
| - |
13 |
| -standardtag(::ADTypes.AutoForwardDiff{<:Any,Nothing}) = true |
14 |
| -standardtag(::ADTypes.AutoForwardDiff) = false |
15 |
| - |
16 |
| -function LogDensityProblemsAD.ADgradient( |
17 |
| - ad::ADTypes.AutoForwardDiff, ℓ::DynamicPPL.LogDensityFunction |
18 |
| -) |
19 |
| - θ = DynamicPPL.getparams(ℓ) |
20 |
| - f = Base.Fix1(LogDensityProblems.logdensity, ℓ) |
21 |
| - |
22 |
| - # Define configuration for ForwardDiff. |
23 |
| - tag = if standardtag(ad) |
24 |
| - ForwardDiff.Tag(DynamicPPL.DynamicPPLTag(), eltype(θ)) |
| 3 | +using DynamicPPL: ADTypes, DynamicPPL, LogDensityProblems |
| 4 | +using ForwardDiff |
| 5 | + |
| 6 | +# check if the AD type already has a tag |
| 7 | +use_dynamicppl_tag(::ADTypes.AutoForwardDiff{<:Any,Nothing}) = true |
| 8 | +use_dynamicppl_tag(::ADTypes.AutoForwardDiff) = false |
| 9 | + |
| 10 | +function DynamicPPL.tweak_adtype( |
| 11 | + ad::ADTypes.AutoForwardDiff{chunk_size}, |
| 12 | + ::DynamicPPL.Model, |
| 13 | + vi::DynamicPPL.AbstractVarInfo, |
| 14 | + ::DynamicPPL.AbstractContext, |
| 15 | +) where {chunk_size} |
| 16 | + params = vi[:] |
| 17 | + |
| 18 | + # Use DynamicPPL tag to improve stack traces |
| 19 | + # https://www.stochasticlifestyle.com/improved-forwarddiff-jl-stacktraces-with-package-tags/ |
| 20 | + # NOTE: DifferentiationInterface disables tag checking if the |
| 21 | + # tag inside the AutoForwardDiff type is not nothing. See |
| 22 | + # https://github.com/JuliaDiff/DifferentiationInterface.jl/blob/1df562180bdcc3e91c885aa5f4162a0be2ced850/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl#L338-L350. |
| 23 | + # So we don't currently need to override ForwardDiff.checktag as well. |
| 24 | + tag = if use_dynamicppl_tag(ad) |
| 25 | + ForwardDiff.Tag(DynamicPPL.DynamicPPLTag(), eltype(params)) |
25 | 26 | else
|
26 |
| - ForwardDiff.Tag(f, eltype(θ)) |
| 27 | + ad.tag |
27 | 28 | end
|
28 |
| - chunk_size = getchunksize(ad) |
| 29 | + |
| 30 | + # Optimise chunk size according to size of model |
29 | 31 | chunk = if chunk_size == 0 || chunk_size === nothing
|
30 |
| - ForwardDiff.Chunk(θ) |
| 32 | + ForwardDiff.Chunk(params) |
31 | 33 | else
|
32 |
| - ForwardDiff.Chunk(length(θ), chunk_size) |
| 34 | + ForwardDiff.Chunk(length(params), chunk_size) |
33 | 35 | end
|
34 | 36 |
|
35 |
| - return LogDensityProblemsAD.ADgradient(Val(:ForwardDiff), ℓ; chunk, tag, x=θ) |
36 |
| -end |
37 |
| - |
38 |
| -# Allow Turing tag in gradient etc. calls of the log density function |
39 |
| -function ForwardDiff.checktag( |
40 |
| - ::Type{ForwardDiff.Tag{DynamicPPL.DynamicPPLTag,V}}, |
41 |
| - ::DynamicPPL.LogDensityFunction, |
42 |
| - ::AbstractArray{W}, |
43 |
| -) where {V,W} |
44 |
| - return true |
45 |
| -end |
46 |
| -function ForwardDiff.checktag( |
47 |
| - ::Type{ForwardDiff.Tag{DynamicPPL.DynamicPPLTag,V}}, |
48 |
| - ::Base.Fix1{typeof(LogDensityProblems.logdensity),<:DynamicPPL.LogDensityFunction}, |
49 |
| - ::AbstractArray{W}, |
50 |
| -) where {V,W} |
51 |
| - return true |
| 37 | + return ADTypes.AutoForwardDiff(; chunksize=ForwardDiff.chunksize(chunk), tag=tag) |
52 | 38 | end
|
53 | 39 |
|
54 | 40 | end # module
|
0 commit comments