Skip to content

Implement AD testing and benchmarking (hand rolled) #882

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Apr 14, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
# DynamicPPL Changelog

## 0.35.8

Added the `DynamicPPL.TestUtils.AD.run_ad` function to test the correctness and/or benchmark the performance of an automatic differentiation backend on DynamicPPL models.
Please see [the docstring](https://turinglang.org/DynamicPPL.jl/api/#DynamicPPL.TestUtils.AD.run_ad) for more information.

## 0.35.7

`check_model_and_trace` now errors if any NaN's are encountered when evaluating the model.
6 changes: 5 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.35.7"
version = "0.35.8"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
@@ -10,6 +10,7 @@ Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Chairmarks = "0ca39b1e-fe0b-4e98-acfc-b1656634c4de"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
@@ -22,6 +23,7 @@ MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[weakdeps]
@@ -49,6 +51,7 @@ Accessors = "0.1"
BangBang = "0.4.1"
Bijectors = "0.13.18, 0.14, 0.15"
ChainRulesCore = "1"
Chairmarks = "1.3.1"
Compat = "4"
ConstructionBase = "1.5.4"
DifferentiationInterface = "0.6.41"
@@ -67,5 +70,6 @@ Mooncake = "0.4.95"
OrderedCollections = "1"
Random = "1.6"
Requires = "1"
Statistics = "1"
Test = "1.6"
julia = "1.10"
11 changes: 10 additions & 1 deletion docs/src/api.md
Original file line number Diff line number Diff line change
@@ -205,7 +205,16 @@ values_as_in_model
NamedDist
```

## Testing Utilities
## AD testing and benchmarking utilities

To test and/or benchmark the performance of an AD backend on a model, DynamicPPL provides the following utilities:

```@docs
DynamicPPL.TestUtils.AD.run_ad
DynamicPPL.TestUtils.AD.ADResult
```

## Demo models

DynamicPPL provides several demo models and helpers for testing samplers in the `DynamicPPL.TestUtils` submodule.

2 changes: 1 addition & 1 deletion src/DynamicPPL.jl
Original file line number Diff line number Diff line change
@@ -175,7 +175,6 @@ include("context_implementations.jl")
include("compiler.jl")
include("pointwise_logdensities.jl")
include("submodel_macro.jl")
include("test_utils.jl")
include("transforming.jl")
include("logdensityfunction.jl")
include("model_utils.jl")
@@ -184,6 +183,7 @@ include("values_as_in_model.jl")

include("debug_utils.jl")
using .DebugUtils
include("test_utils.jl")

include("experimental.jl")
include("deprecated.jl")
1 change: 1 addition & 0 deletions src/test_utils.jl
Original file line number Diff line number Diff line change
@@ -18,5 +18,6 @@ include("test_utils/models.jl")
include("test_utils/contexts.jl")
include("test_utils/varinfo.jl")
include("test_utils/sampler.jl")
include("test_utils/ad.jl")

end
209 changes: 209 additions & 0 deletions src/test_utils/ad.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
module AD

using ADTypes: AbstractADType, AutoForwardDiff
using Chairmarks: @be
import DifferentiationInterface as DI
using DocStringExtensions
using DynamicPPL: Model, LogDensityFunction, VarInfo, AbstractVarInfo
using LogDensityProblems: logdensity, logdensity_and_gradient
using Random: Random, Xoshiro
using Statistics: median
using Test: @test

export ADResult, run_ad

# This function needed to work around the fact that different backends can
# return different AbstractArrays for the gradient. See
# https://github.com/JuliaDiff/DifferentiationInterface.jl/issues/754 for more
# context.
_to_vec_f64(x::AbstractArray) = x isa Vector{Float64} ? x : collect(Float64, x)

"""
REFERENCE_ADTYPE

Reference AD backend to use for comparison. In this case, ForwardDiff.jl, since
it's the default AD backend used in Turing.jl.
"""
const REFERENCE_ADTYPE = AutoForwardDiff()

"""
ADResult

Data structure to store the results of the AD correctness test.
"""
struct ADResult
"The DynamicPPL model that was tested"
model::Model
"The VarInfo that was used"
varinfo::AbstractVarInfo
"The values at which the model was evaluated"
params::Vector{<:Real}
"The AD backend that was tested"
adtype::AbstractADType
"The absolute tolerance for the value of logp"
value_atol::Real
"The absolute tolerance for the gradient of logp"
grad_atol::Real
"The expected value of logp"
value_expected::Union{Nothing,Float64}
"The expected gradient of logp"
grad_expected::Union{Nothing,Vector{Float64}}
"The value of logp (calculated using `adtype`)"
value_actual::Union{Nothing,Real}
"The gradient of logp (calculated using `adtype`)"
grad_actual::Union{Nothing,Vector{Float64}}
"If benchmarking was requested, the time taken by the AD backend to calculate the gradient of logp, divided by the time taken to evaluate logp itself"
time_vs_primal::Union{Nothing,Float64}
end

"""
run_ad(
model::Model,
adtype::ADTypes.AbstractADType;
test=true,
benchmark=false,
value_atol=1e-6,
grad_atol=1e-6,
varinfo::AbstractVarInfo=VarInfo(model),
params::Vector{<:Real}=varinfo[:],
reference_adtype::ADTypes.AbstractADType=REFERENCE_ADTYPE,
expected_value_and_grad::Union{Nothing,Tuple{Real,Vector{<:Real}}}=nothing,
verbose=true,
)::ADResult

Test the correctness and/or benchmark the AD backend `adtype` for the model
`model`.

Whether to test and benchmark is controlled by the `test` and `benchmark`
keyword arguments. By default, `test` is `true` and `benchmark` is `false`.

Returns an [`ADResult`](@ref) object, which contains the results of the
test and/or benchmark.

Note that to run AD successfully you will need to import the AD backend itself.
For example, to test with `AutoReverseDiff()` you will need to run `import
ReverseDiff`.

There are two positional arguments, which absolutely must be provided:

1. `model` - The model being tested.
2. `adtype` - The AD backend being tested.

Everything else is optional, and can be categorised into several groups:

1. _How to specify the VarInfo._

DynamicPPL contains several different types of VarInfo objects which change
the way model evaluation occurs. If you want to use a specific type of
VarInfo, pass it as the `varinfo` argument. Otherwise, it will default to
using a `TypedVarInfo` generated from the model.

2. _How to specify the parameters._

For maximum control over this, generate a vector of parameters yourself and
pass this as the `params` argument. If you don't specify this, it will be
taken from the contents of the VarInfo.

Note that if the VarInfo is not specified (and thus automatically generated)
the parameters in it will have been sampled from the prior of the model. If
you want to seed the parameter generation, the easiest way is to pass a
`rng` argument to the VarInfo constructor (i.e. do `VarInfo(rng, model)`).

Finally, note that these only reflect the parameters used for _evaluating_
the gradient. If you also want to control the parameters used for
_preparing_ the gradient, then you need to manually set these parameters in
the VarInfo object, for example using `vi = DynamicPPL.unflatten(vi,
prep_params)`. You could then evaluate the gradient at a different set of
parameters using the `params` keyword argument.

3. _How to specify the results to compare against._ (Only if `test=true`.)

Once logp and its gradient has been calculated with the specified `adtype`,
it must be tested for correctness.

This can be done either by specifying `reference_adtype`, in which case logp
and its gradient will also be calculated with this reference in order to
obtain the ground truth; or by using `expected_value_and_grad`, which is a
tuple of `(logp, gradient)` that the calculated values must match. The
latter is useful if you are testing multiple AD backends and want to avoid
recalculating the ground truth multiple times.

The default reference backend is ForwardDiff. If none of these parameters are
specified, ForwardDiff will be used to calculate the ground truth.

4. _How to specify the tolerances._ (Only if `test=true`.)

The tolerances for the value and gradient can be set using `value_atol` and
`grad_atol`. These default to 1e-6.

5. _Whether to output extra logging information._

By default, this function prints messages when it runs. To silence it, set
`verbose=false`.
"""
function run_ad(
model::Model,
adtype::AbstractADType;
test=true,
benchmark=false,
value_atol=1e-6,
grad_atol=1e-6,
varinfo::AbstractVarInfo=VarInfo(model),
params::Vector{<:Real}=varinfo[:],
reference_adtype::AbstractADType=REFERENCE_ADTYPE,
expected_value_and_grad::Union{Nothing,Tuple{Real,Vector{<:Real}}}=nothing,
verbose=true,
)::ADResult
verbose && @info "Running AD on $(model.f) with $(adtype)\n"
params = map(identity, params)
verbose && println(" params : $(params)")
ldf = LogDensityFunction(model, varinfo; adtype=adtype)

value, grad = logdensity_and_gradient(ldf, params)
grad = _to_vec_f64(grad)
verbose && println(" actual : $((value, grad))")

if test
# Calculate ground truth to compare against
value_true, grad_true = if expected_value_and_grad === nothing
ldf_reference = LogDensityFunction(model, varinfo; adtype=reference_adtype)
logdensity_and_gradient(ldf_reference, params)

Check warning on line 170 in src/test_utils/ad.jl

Codecov / codecov/patch

src/test_utils/ad.jl#L169-L170

Added lines #L169 - L170 were not covered by tests
else
expected_value_and_grad
end
verbose && println(" expected : $((value_true, grad_true))")
grad_true = _to_vec_f64(grad_true)
# Then compare
@test isapprox(value, value_true; atol=value_atol)
@test isapprox(grad, grad_true; atol=grad_atol)
else
value_true = nothing
grad_true = nothing

Check warning on line 181 in src/test_utils/ad.jl

Codecov / codecov/patch

src/test_utils/ad.jl#L180-L181

Added lines #L180 - L181 were not covered by tests
end

time_vs_primal = if benchmark
primal_benchmark = @be (ldf, params) logdensity(_[1], _[2])
grad_benchmark = @be (ldf, params) logdensity_and_gradient(_[1], _[2])
t = median(grad_benchmark).time / median(primal_benchmark).time
verbose && println("grad / primal : $(t)")
t

Check warning on line 189 in src/test_utils/ad.jl

Codecov / codecov/patch

src/test_utils/ad.jl#L185-L189

Added lines #L185 - L189 were not covered by tests
else
nothing
end

return ADResult(
model,
varinfo,
params,
adtype,
value_atol,
grad_atol,
value_true,
grad_true,
value,
grad,
time_vs_primal,
)
end

end # module DynamicPPL.TestUtils.AD
10 changes: 6 additions & 4 deletions test/ad.jl
Original file line number Diff line number Diff line change
@@ -56,10 +56,12 @@ using DynamicPPL: LogDensityFunction
ref_ldf, adtype
)
else
ldf = DynamicPPL.LogDensityFunction(ref_ldf, adtype)
logp, grad = LogDensityProblems.logdensity_and_gradient(ldf, x)
@test grad ≈ ref_grad
@test logp ≈ ref_logp
DynamicPPL.TestUtils.AD.run_ad(
m,
adtype;
varinfo=varinfo,
expected_value_and_grad=(ref_logp, ref_grad),
)
end
end
end