@@ -101,16 +101,54 @@ function LogDensityProblems.logdensity(f::LogDensityFunction, x::AbstractVector)
101
101
vi_new = unflatten (f. varinfo, x)
102
102
return getlogp (last (evaluate!! (f. model, vi_new, context)))
103
103
end
104
- function _flipped_logdensity (x:: AbstractVector , f:: LogDensityFunction )
105
- return LogDensityProblems. logdensity (f, x)
106
- end
107
104
function LogDensityProblems. capabilities (:: Type{<:LogDensityFunction} )
108
105
return LogDensityProblems. LogDensityOrder {0} ()
109
106
end
110
107
# TODO : should we instead implement and call on `length(f.varinfo)` (at least in the cases where no sampler is involved)?
111
108
LogDensityProblems. dimension (f:: LogDensityFunction ) = length (getparams (f))
112
109
113
110
# LogDensityProblems interface: gradient (1st order)
111
+ """
112
+ use_closure(adtype::ADTypes.AbstractADType)
113
+
114
+ In LogDensityProblems, we want to calculate the derivative of logdensity(f, x)
115
+ with respect to x, where f is the model (in our case LogDensityFunction) and is
116
+ a constant. However, DifferentiationInterface generally expects a
117
+ single-argument function g(x) to differentiate.
118
+
119
+ There are two ways of dealing with this:
120
+
121
+ 1. Construct a closure over the model, i.e. let g = Base.Fix1(logdensity, f)
122
+
123
+ 2. Use a constant context. This lets us pass a two-argument function to
124
+ DifferentiationInterface, as long as we also give it the 'inactive argument'
125
+ (i.e. the model) wrapped in `DI.Constant`.
126
+
127
+ The relative performance of the two approaches, however, depends on the AD
128
+ backend used. Some benchmarks are provided here:
129
+ https://github.com/TuringLang/DynamicPPL.jl/pull/806#issuecomment-2658061480
130
+
131
+ This function is used to determine whether a given AD backend should use a
132
+ closure or a constant. If `use_closure(adtype)` returns `true`, then the
133
+ closure approach will be used. By default, this function returns `false`, i.e.
134
+ the constant approach will be used.
135
+ """
136
+ use_closure (:: ADTypes.AbstractADType ) = false
137
+ use_closure (:: ADTypes.AutoForwardDiff ) = false
138
+ use_closure (:: ADTypes.AutoMooncake ) = false
139
+ use_closure (:: ADTypes.AutoReverseDiff ) = true
140
+
141
+ """
142
+ _flipped_logdensity(f::LogDensityFunction, x::AbstractVector)
143
+
144
+ This function is the same as `LogDensityProblems.logdensity(f, x)` but with the
145
+ arguments flipped. It is used in the 'constant' approach to DifferentiationInterface
146
+ (see `use_closure` for more information).
147
+ """
148
+ function _flipped_logdensity (x:: AbstractVector , f:: LogDensityFunction )
149
+ return LogDensityProblems. logdensity (f, x)
150
+ end
151
+
114
152
"""
115
153
LogDensityFunctionWithGrad(ldf::DynamicPPL.LogDensityFunction, adtype::ADTypes.AbstractADType)
116
154
@@ -134,15 +172,25 @@ struct LogDensityFunctionWithGrad{V,M,C,TAD<:ADTypes.AbstractADType}
134
172
ldf:: LogDensityFunction{V,M,C}
135
173
adtype:: TAD
136
174
prep:: DI.GradientPrep
175
+ with_closure:: Bool
137
176
138
177
function LogDensityFunctionWithGrad (
139
178
ldf:: LogDensityFunction{V,M,C} , adtype:: TAD
140
179
) where {V,M,C,TAD}
141
- # Get a set of dummy params to use for prep and concretise type
180
+ # Get a set of dummy params to use for prep
142
181
x = map (identity, getparams (ldf))
143
- prep = DI. prepare_gradient (_flipped_logdensity, adtype, x, DI. Constant (ldf))
144
- # Store the prep with the struct
145
- return new {V,M,C,TAD} (ldf, adtype, prep)
182
+ with_closure = use_closure (adtype)
183
+ if with_closure
184
+ prep = DI. prepare_gradient (
185
+ Base. Fix1 (LogDensityProblems. logdensity, ldf), adtype, x
186
+ )
187
+ else
188
+ prep = DI. prepare_gradient (_flipped_logdensity, adtype, x, DI. Constant (ldf))
189
+ end
190
+ # Store the prep with the struct. We also store whether a closure was used because
191
+ # we need to know this when calling `DI.value_and_gradient`. In practice we could
192
+ # recalculate it, but this runs the risk of introducing inconsistencies.
193
+ return new {V,M,C,TAD} (ldf, adtype, prep, with_closure)
146
194
end
147
195
end
148
196
function LogDensityProblems. logdensity (f:: LogDensityFunctionWithGrad )
@@ -151,13 +199,15 @@ end
151
199
function LogDensityProblems. capabilities (:: Type{<:LogDensityFunctionWithGrad} )
152
200
return LogDensityProblems. LogDensityOrder {1} ()
153
201
end
154
- # By default, the AD backend to use is inferred from the context, which would
155
- # typically be a SamplingContext which contains a sampler.
156
202
function LogDensityProblems. logdensity_and_gradient (
157
203
f:: LogDensityFunctionWithGrad , x:: AbstractVector
158
204
)
159
205
x = map (identity, x) # Concretise type
160
- return DI. value_and_gradient (
161
- _flipped_logdensity, f. prep, f. adtype, x, DI. Constant (f. ldf)
162
- )
206
+ return if f. with_closure
207
+ DI. value_and_gradient (
208
+ Base. Fix1 (LogDensityProblems. logdensity, f. ldf), f. prep, f. adtype, x
209
+ )
210
+ else
211
+ DI. value_and_gradient (_flipped_logdensity, f. prep, f. adtype, x, DI. Constant (f. ldf))
212
+ end
163
213
end
0 commit comments