|
| 1 | +import DifferentiationInterface as DI |
| 2 | + |
1 | 3 | """
|
2 | 4 | LogDensityFunction
|
3 | 5 |
|
|
81 | 83 |
|
82 | 84 | Return the `DynamicPPL.Model` wrapped in the given log-density function `f`.
|
83 | 85 | """
|
84 |
| -getmodel(f::LogDensityProblemsAD.ADGradientWrapper) = |
85 |
| - getmodel(LogDensityProblemsAD.parent(f)) |
86 | 86 | getmodel(f::DynamicPPL.LogDensityFunction) = f.model
|
87 | 87 |
|
88 | 88 | """
|
89 | 89 | setmodel(f, model[, adtype])
|
90 | 90 |
|
91 | 91 | Set the `DynamicPPL.Model` in the given log-density function `f` to `model`.
|
92 |
| -
|
93 |
| -!!! warning |
94 |
| - Note that if `f` is a `LogDensityProblemsAD.ADGradientWrapper` wrapping a |
95 |
| - `DynamicPPL.LogDensityFunction`, performing an update of the `model` in `f` |
96 |
| - might require recompilation of the gradient tape, depending on the AD backend. |
97 | 92 | """
|
98 |
| -function setmodel( |
99 |
| - f::LogDensityProblemsAD.ADGradientWrapper, |
100 |
| - model::DynamicPPL.Model, |
101 |
| - adtype::ADTypes.AbstractADType, |
102 |
| -) |
103 |
| - # TODO: Should we handle `SciMLBase.NoAD`? |
104 |
| - # For an `ADGradientWrapper` we do the following: |
105 |
| - # 1. Update the `Model` in the underlying `LogDensityFunction`. |
106 |
| - # 2. Re-construct the `ADGradientWrapper` using `ADgradient` using the provided `adtype` |
107 |
| - # to ensure that the recompilation of gradient tapes, etc. also occur. For example, |
108 |
| - # ReverseDiff.jl in compiled mode will cache the compiled tape, which means that just |
109 |
| - # replacing the corresponding field with the new model won't be sufficient to obtain |
110 |
| - # the correct gradients. |
111 |
| - return LogDensityProblemsAD.ADgradient( |
112 |
| - adtype, setmodel(LogDensityProblemsAD.parent(f), model) |
113 |
| - ) |
114 |
| -end |
115 | 93 | function setmodel(f::DynamicPPL.LogDensityFunction, model::DynamicPPL.Model)
|
116 | 94 | return Accessors.@set f.model = model
|
117 | 95 | end
|
@@ -140,18 +118,24 @@ end
|
140 | 118 | # TODO: should we instead implement and call on `length(f.varinfo)` (at least in the cases where no sampler is involved)?
|
141 | 119 | LogDensityProblems.dimension(f::LogDensityFunction) = length(getparams(f))
|
142 | 120 |
|
143 |
| -# This is important for performance -- one needs to provide `ADGradient` with a vector of |
144 |
| -# parameters, or DifferentiationInterface will not have sufficient information to e.g. |
145 |
| -# compile a rule for Mooncake (because it won't know the type of the input), or pre-allocate |
146 |
| -# a tape when using ReverseDiff.jl. |
147 |
| -function _make_ad_gradient(ad::ADTypes.AbstractADType, ℓ::LogDensityFunction) |
148 |
| - x = map(identity, getparams(ℓ)) # ensure we concretise the elements of the params |
149 |
| - return LogDensityProblemsAD.ADgradient(ad, ℓ; x) |
150 |
| -end |
| 121 | +_flipped_logdensity(θ, f) = LogDensityProblems.logdensity(f, θ) |
151 | 122 |
|
152 |
| -function LogDensityProblemsAD.ADgradient(ad::ADTypes.AutoMooncake, f::LogDensityFunction) |
153 |
| - return _make_ad_gradient(ad, f) |
| 123 | +# By default, the AD backend to use is inferred from the context, which would |
| 124 | +# typically be a SamplingContext which contains a sampler. |
| 125 | +function LogDensityProblems.logdensity_and_gradient( |
| 126 | + f::LogDensityFunction, θ::AbstractVector |
| 127 | +) |
| 128 | + adtype = getadtype(getsampler(getcontext(f))) |
| 129 | + return LogDensityProblems.logdensity_and_gradient(f, θ, adtype) |
154 | 130 | end
|
155 |
| -function LogDensityProblemsAD.ADgradient(ad::ADTypes.AutoReverseDiff, f::LogDensityFunction) |
156 |
| - return _make_ad_gradient(ad, f) |
| 131 | + |
| 132 | +# Extra method allowing one to manually specify the AD backend to use, thus |
| 133 | +# overriding the default AD backend inferred from the sampler. |
| 134 | +function LogDensityProblems.logdensity_and_gradient( |
| 135 | + f::LogDensityFunction, θ::AbstractVector, adtype::ADTypes.AbstractADType |
| 136 | +) |
| 137 | + # Ensure we concretise the elements of the params. |
| 138 | + θ = map(identity, θ) # TODO: Is this needed? |
| 139 | + prep = DI.prepare_gradient(_flipped_logdensity, adtype, θ, DI.Constant(f)) |
| 140 | + return DI.value_and_gradient(_flipped_logdensity, prep, adtype, θ, DI.Constant(f)) |
157 | 141 | end
|
0 commit comments