Skip to content

Commit 5b05ad3

Browse files
committed
Dynamically decide whether to use closure vs constant
1 parent b314e91 commit 5b05ad3

File tree

1 file changed

+62
-12
lines changed

1 file changed

+62
-12
lines changed

src/logdensityfunction.jl

+62-12
Original file line numberDiff line numberDiff line change
@@ -101,16 +101,54 @@ function LogDensityProblems.logdensity(f::LogDensityFunction, x::AbstractVector)
101101
vi_new = unflatten(f.varinfo, x)
102102
return getlogp(last(evaluate!!(f.model, vi_new, context)))
103103
end
104-
function _flipped_logdensity(x::AbstractVector, f::LogDensityFunction)
105-
return LogDensityProblems.logdensity(f, x)
106-
end
107104
function LogDensityProblems.capabilities(::Type{<:LogDensityFunction})
108105
return LogDensityProblems.LogDensityOrder{0}()
109106
end
110107
# TODO: should we instead implement and call on `length(f.varinfo)` (at least in the cases where no sampler is involved)?
111108
LogDensityProblems.dimension(f::LogDensityFunction) = length(getparams(f))
112109

113110
# LogDensityProblems interface: gradient (1st order)
111+
"""
112+
use_closure(adtype::ADTypes.AbstractADType)
113+
114+
In LogDensityProblems, we want to calculate the derivative of logdensity(f, x)
115+
with respect to x, where f is the model (in our case LogDensityFunction) and is
116+
a constant. However, DifferentiationInterface generally expects a
117+
single-argument function g(x) to differentiate.
118+
119+
There are two ways of dealing with this:
120+
121+
1. Construct a closure over the model, i.e. let g = Base.Fix1(logdensity, f)
122+
123+
2. Use a constant context. This lets us pass a two-argument function to
124+
DifferentiationInterface, as long as we also give it the 'inactive argument'
125+
(i.e. the model) wrapped in `DI.Constant`.
126+
127+
The relative performance of the two approaches, however, depends on the AD
128+
backend used. Some benchmarks are provided here:
129+
https://github.com/TuringLang/DynamicPPL.jl/pull/806#issuecomment-2658061480
130+
131+
This function is used to determine whether a given AD backend should use a
132+
closure or a constant. If `use_closure(adtype)` returns `true`, then the
133+
closure approach will be used. By default, this function returns `false`, i.e.
134+
the constant approach will be used.
135+
"""
136+
use_closure(::ADTypes.AbstractADType) = false
137+
use_closure(::ADTypes.AutoForwardDiff) = false
138+
use_closure(::ADTypes.AutoMooncake) = false
139+
use_closure(::ADTypes.AutoReverseDiff) = true
140+
141+
"""
142+
_flipped_logdensity(f::LogDensityFunction, x::AbstractVector)
143+
144+
This function is the same as `LogDensityProblems.logdensity(f, x)` but with the
145+
arguments flipped. It is used in the 'constant' approach to DifferentiationInterface
146+
(see `use_closure` for more information).
147+
"""
148+
function _flipped_logdensity(x::AbstractVector, f::LogDensityFunction)
149+
return LogDensityProblems.logdensity(f, x)
150+
end
151+
114152
"""
115153
LogDensityFunctionWithGrad(ldf::DynamicPPL.LogDensityFunction, adtype::ADTypes.AbstractADType)
116154
@@ -134,15 +172,25 @@ struct LogDensityFunctionWithGrad{V,M,C,TAD<:ADTypes.AbstractADType}
134172
ldf::LogDensityFunction{V,M,C}
135173
adtype::TAD
136174
prep::DI.GradientPrep
175+
with_closure::Bool
137176

138177
function LogDensityFunctionWithGrad(
139178
ldf::LogDensityFunction{V,M,C}, adtype::TAD
140179
) where {V,M,C,TAD}
141-
# Get a set of dummy params to use for prep and concretise type
180+
# Get a set of dummy params to use for prep
142181
x = map(identity, getparams(ldf))
143-
prep = DI.prepare_gradient(_flipped_logdensity, adtype, x, DI.Constant(ldf))
144-
# Store the prep with the struct
145-
return new{V,M,C,TAD}(ldf, adtype, prep)
182+
with_closure = use_closure(adtype)
183+
if with_closure
184+
prep = DI.prepare_gradient(
185+
Base.Fix1(LogDensityProblems.logdensity, ldf), adtype, x
186+
)
187+
else
188+
prep = DI.prepare_gradient(_flipped_logdensity, adtype, x, DI.Constant(ldf))
189+
end
190+
# Store the prep with the struct. We also store whether a closure was used because
191+
# we need to know this when calling `DI.value_and_gradient`. In practice we could
192+
# recalculate it, but this runs the risk of introducing inconsistencies.
193+
return new{V,M,C,TAD}(ldf, adtype, prep, with_closure)
146194
end
147195
end
148196
function LogDensityProblems.logdensity(f::LogDensityFunctionWithGrad)
@@ -151,13 +199,15 @@ end
151199
function LogDensityProblems.capabilities(::Type{<:LogDensityFunctionWithGrad})
152200
return LogDensityProblems.LogDensityOrder{1}()
153201
end
154-
# By default, the AD backend to use is inferred from the context, which would
155-
# typically be a SamplingContext which contains a sampler.
156202
function LogDensityProblems.logdensity_and_gradient(
157203
f::LogDensityFunctionWithGrad, x::AbstractVector
158204
)
159205
x = map(identity, x) # Concretise type
160-
return DI.value_and_gradient(
161-
_flipped_logdensity, f.prep, f.adtype, x, DI.Constant(f.ldf)
162-
)
206+
return if f.with_closure
207+
DI.value_and_gradient(
208+
Base.Fix1(LogDensityProblems.logdensity, f.ldf), f.prep, f.adtype, x
209+
)
210+
else
211+
DI.value_and_gradient(_flipped_logdensity, f.prep, f.adtype, x, DI.Constant(f.ldf))
212+
end
163213
end

0 commit comments

Comments
 (0)