Skip to content

Commit 60ee68e

Browse files
penelopeysmsunxd3
andauthored
Implement AD testing and benchmarking (hand rolled) (#882)
* Implement AD testing and benchmarking (hand rolled) * Also pass varinfo to LogDensityFunction * Improve docstring Co-authored-by: Xianda Sun <[email protected]> * Fix docstring again * Fix out of sync docstring * Bump version, add changelog entry --------- Co-authored-by: Xianda Sun <[email protected]>
1 parent c7bdc3f commit 60ee68e

File tree

7 files changed

+237
-7
lines changed

7 files changed

+237
-7
lines changed

Diff for: HISTORY.md

+5
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
# DynamicPPL Changelog
22

3+
## 0.35.8
4+
5+
Added the `DynamicPPL.TestUtils.AD.run_ad` function to test the correctness and/or benchmark the performance of an automatic differentiation backend on DynamicPPL models.
6+
Please see [the docstring](https://turinglang.org/DynamicPPL.jl/api/#DynamicPPL.TestUtils.AD.run_ad) for more information.
7+
38
## 0.35.7
49

510
`check_model_and_trace` now errors if any NaN's are encountered when evaluating the model.

Diff for: Project.toml

+5-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.35.7"
3+
version = "0.35.8"
44

55
[deps]
66
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
@@ -10,6 +10,7 @@ Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
1010
BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"
1111
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
1212
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
13+
Chairmarks = "0ca39b1e-fe0b-4e98-acfc-b1656634c4de"
1314
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
1415
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
1516
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
@@ -22,6 +23,7 @@ MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
2223
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
2324
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
2425
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
26+
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
2527
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2628

2729
[weakdeps]
@@ -49,6 +51,7 @@ Accessors = "0.1"
4951
BangBang = "0.4.1"
5052
Bijectors = "0.13.18, 0.14, 0.15"
5153
ChainRulesCore = "1"
54+
Chairmarks = "1.3.1"
5255
Compat = "4"
5356
ConstructionBase = "1.5.4"
5457
DifferentiationInterface = "0.6.41"
@@ -67,5 +70,6 @@ Mooncake = "0.4.95"
6770
OrderedCollections = "1"
6871
Random = "1.6"
6972
Requires = "1"
73+
Statistics = "1"
7074
Test = "1.6"
7175
julia = "1.10"

Diff for: docs/src/api.md

+10-1
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,16 @@ values_as_in_model
205205
NamedDist
206206
```
207207

208-
## Testing Utilities
208+
## AD testing and benchmarking utilities
209+
210+
To test and/or benchmark the performance of an AD backend on a model, DynamicPPL provides the following utilities:
211+
212+
```@docs
213+
DynamicPPL.TestUtils.AD.run_ad
214+
DynamicPPL.TestUtils.AD.ADResult
215+
```
216+
217+
## Demo models
209218

210219
DynamicPPL provides several demo models and helpers for testing samplers in the `DynamicPPL.TestUtils` submodule.
211220

Diff for: src/DynamicPPL.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,6 @@ include("context_implementations.jl")
175175
include("compiler.jl")
176176
include("pointwise_logdensities.jl")
177177
include("submodel_macro.jl")
178-
include("test_utils.jl")
179178
include("transforming.jl")
180179
include("logdensityfunction.jl")
181180
include("model_utils.jl")
@@ -184,6 +183,7 @@ include("values_as_in_model.jl")
184183

185184
include("debug_utils.jl")
186185
using .DebugUtils
186+
include("test_utils.jl")
187187

188188
include("experimental.jl")
189189
include("deprecated.jl")

Diff for: src/test_utils.jl

+1
Original file line numberDiff line numberDiff line change
@@ -18,5 +18,6 @@ include("test_utils/models.jl")
1818
include("test_utils/contexts.jl")
1919
include("test_utils/varinfo.jl")
2020
include("test_utils/sampler.jl")
21+
include("test_utils/ad.jl")
2122

2223
end

Diff for: src/test_utils/ad.jl

+209
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,209 @@
1+
module AD
2+
3+
using ADTypes: AbstractADType, AutoForwardDiff
4+
using Chairmarks: @be
5+
import DifferentiationInterface as DI
6+
using DocStringExtensions
7+
using DynamicPPL: Model, LogDensityFunction, VarInfo, AbstractVarInfo
8+
using LogDensityProblems: logdensity, logdensity_and_gradient
9+
using Random: Random, Xoshiro
10+
using Statistics: median
11+
using Test: @test
12+
13+
export ADResult, run_ad
14+
15+
# This function needed to work around the fact that different backends can
16+
# return different AbstractArrays for the gradient. See
17+
# https://github.com/JuliaDiff/DifferentiationInterface.jl/issues/754 for more
18+
# context.
19+
_to_vec_f64(x::AbstractArray) = x isa Vector{Float64} ? x : collect(Float64, x)
20+
21+
"""
22+
REFERENCE_ADTYPE
23+
24+
Reference AD backend to use for comparison. In this case, ForwardDiff.jl, since
25+
it's the default AD backend used in Turing.jl.
26+
"""
27+
const REFERENCE_ADTYPE = AutoForwardDiff()
28+
29+
"""
30+
ADResult
31+
32+
Data structure to store the results of the AD correctness test.
33+
"""
34+
struct ADResult
35+
"The DynamicPPL model that was tested"
36+
model::Model
37+
"The VarInfo that was used"
38+
varinfo::AbstractVarInfo
39+
"The values at which the model was evaluated"
40+
params::Vector{<:Real}
41+
"The AD backend that was tested"
42+
adtype::AbstractADType
43+
"The absolute tolerance for the value of logp"
44+
value_atol::Real
45+
"The absolute tolerance for the gradient of logp"
46+
grad_atol::Real
47+
"The expected value of logp"
48+
value_expected::Union{Nothing,Float64}
49+
"The expected gradient of logp"
50+
grad_expected::Union{Nothing,Vector{Float64}}
51+
"The value of logp (calculated using `adtype`)"
52+
value_actual::Union{Nothing,Real}
53+
"The gradient of logp (calculated using `adtype`)"
54+
grad_actual::Union{Nothing,Vector{Float64}}
55+
"If benchmarking was requested, the time taken by the AD backend to calculate the gradient of logp, divided by the time taken to evaluate logp itself"
56+
time_vs_primal::Union{Nothing,Float64}
57+
end
58+
59+
"""
60+
run_ad(
61+
model::Model,
62+
adtype::ADTypes.AbstractADType;
63+
test=true,
64+
benchmark=false,
65+
value_atol=1e-6,
66+
grad_atol=1e-6,
67+
varinfo::AbstractVarInfo=VarInfo(model),
68+
params::Vector{<:Real}=varinfo[:],
69+
reference_adtype::ADTypes.AbstractADType=REFERENCE_ADTYPE,
70+
expected_value_and_grad::Union{Nothing,Tuple{Real,Vector{<:Real}}}=nothing,
71+
verbose=true,
72+
)::ADResult
73+
74+
Test the correctness and/or benchmark the AD backend `adtype` for the model
75+
`model`.
76+
77+
Whether to test and benchmark is controlled by the `test` and `benchmark`
78+
keyword arguments. By default, `test` is `true` and `benchmark` is `false`.
79+
80+
Returns an [`ADResult`](@ref) object, which contains the results of the
81+
test and/or benchmark.
82+
83+
Note that to run AD successfully you will need to import the AD backend itself.
84+
For example, to test with `AutoReverseDiff()` you will need to run `import
85+
ReverseDiff`.
86+
87+
There are two positional arguments, which absolutely must be provided:
88+
89+
1. `model` - The model being tested.
90+
2. `adtype` - The AD backend being tested.
91+
92+
Everything else is optional, and can be categorised into several groups:
93+
94+
1. _How to specify the VarInfo._
95+
96+
DynamicPPL contains several different types of VarInfo objects which change
97+
the way model evaluation occurs. If you want to use a specific type of
98+
VarInfo, pass it as the `varinfo` argument. Otherwise, it will default to
99+
using a `TypedVarInfo` generated from the model.
100+
101+
2. _How to specify the parameters._
102+
103+
For maximum control over this, generate a vector of parameters yourself and
104+
pass this as the `params` argument. If you don't specify this, it will be
105+
taken from the contents of the VarInfo.
106+
107+
Note that if the VarInfo is not specified (and thus automatically generated)
108+
the parameters in it will have been sampled from the prior of the model. If
109+
you want to seed the parameter generation, the easiest way is to pass a
110+
`rng` argument to the VarInfo constructor (i.e. do `VarInfo(rng, model)`).
111+
112+
Finally, note that these only reflect the parameters used for _evaluating_
113+
the gradient. If you also want to control the parameters used for
114+
_preparing_ the gradient, then you need to manually set these parameters in
115+
the VarInfo object, for example using `vi = DynamicPPL.unflatten(vi,
116+
prep_params)`. You could then evaluate the gradient at a different set of
117+
parameters using the `params` keyword argument.
118+
119+
3. _How to specify the results to compare against._ (Only if `test=true`.)
120+
121+
Once logp and its gradient has been calculated with the specified `adtype`,
122+
it must be tested for correctness.
123+
124+
This can be done either by specifying `reference_adtype`, in which case logp
125+
and its gradient will also be calculated with this reference in order to
126+
obtain the ground truth; or by using `expected_value_and_grad`, which is a
127+
tuple of `(logp, gradient)` that the calculated values must match. The
128+
latter is useful if you are testing multiple AD backends and want to avoid
129+
recalculating the ground truth multiple times.
130+
131+
The default reference backend is ForwardDiff. If none of these parameters are
132+
specified, ForwardDiff will be used to calculate the ground truth.
133+
134+
4. _How to specify the tolerances._ (Only if `test=true`.)
135+
136+
The tolerances for the value and gradient can be set using `value_atol` and
137+
`grad_atol`. These default to 1e-6.
138+
139+
5. _Whether to output extra logging information._
140+
141+
By default, this function prints messages when it runs. To silence it, set
142+
`verbose=false`.
143+
"""
144+
function run_ad(
145+
model::Model,
146+
adtype::AbstractADType;
147+
test=true,
148+
benchmark=false,
149+
value_atol=1e-6,
150+
grad_atol=1e-6,
151+
varinfo::AbstractVarInfo=VarInfo(model),
152+
params::Vector{<:Real}=varinfo[:],
153+
reference_adtype::AbstractADType=REFERENCE_ADTYPE,
154+
expected_value_and_grad::Union{Nothing,Tuple{Real,Vector{<:Real}}}=nothing,
155+
verbose=true,
156+
)::ADResult
157+
verbose && @info "Running AD on $(model.f) with $(adtype)\n"
158+
params = map(identity, params)
159+
verbose && println(" params : $(params)")
160+
ldf = LogDensityFunction(model, varinfo; adtype=adtype)
161+
162+
value, grad = logdensity_and_gradient(ldf, params)
163+
grad = _to_vec_f64(grad)
164+
verbose && println(" actual : $((value, grad))")
165+
166+
if test
167+
# Calculate ground truth to compare against
168+
value_true, grad_true = if expected_value_and_grad === nothing
169+
ldf_reference = LogDensityFunction(model, varinfo; adtype=reference_adtype)
170+
logdensity_and_gradient(ldf_reference, params)
171+
else
172+
expected_value_and_grad
173+
end
174+
verbose && println(" expected : $((value_true, grad_true))")
175+
grad_true = _to_vec_f64(grad_true)
176+
# Then compare
177+
@test isapprox(value, value_true; atol=value_atol)
178+
@test isapprox(grad, grad_true; atol=grad_atol)
179+
else
180+
value_true = nothing
181+
grad_true = nothing
182+
end
183+
184+
time_vs_primal = if benchmark
185+
primal_benchmark = @be (ldf, params) logdensity(_[1], _[2])
186+
grad_benchmark = @be (ldf, params) logdensity_and_gradient(_[1], _[2])
187+
t = median(grad_benchmark).time / median(primal_benchmark).time
188+
verbose && println("grad / primal : $(t)")
189+
t
190+
else
191+
nothing
192+
end
193+
194+
return ADResult(
195+
model,
196+
varinfo,
197+
params,
198+
adtype,
199+
value_atol,
200+
grad_atol,
201+
value_true,
202+
grad_true,
203+
value,
204+
grad,
205+
time_vs_primal,
206+
)
207+
end
208+
209+
end # module DynamicPPL.TestUtils.AD

Diff for: test/ad.jl

+6-4
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,12 @@ using DynamicPPL: LogDensityFunction
5656
ref_ldf, adtype
5757
)
5858
else
59-
ldf = DynamicPPL.LogDensityFunction(ref_ldf, adtype)
60-
logp, grad = LogDensityProblems.logdensity_and_gradient(ldf, x)
61-
@test grad ref_grad
62-
@test logp ref_logp
59+
DynamicPPL.TestUtils.AD.run_ad(
60+
m,
61+
adtype;
62+
varinfo=varinfo,
63+
expected_value_and_grad=(ref_logp, ref_grad),
64+
)
6365
end
6466
end
6567
end

0 commit comments

Comments
 (0)