|
| 1 | +module AD |
| 2 | + |
| 3 | +using ADTypes: AbstractADType, AutoForwardDiff |
| 4 | +using Chairmarks: @be |
| 5 | +import DifferentiationInterface as DI |
| 6 | +using DocStringExtensions |
| 7 | +using DynamicPPL: Model, LogDensityFunction, VarInfo, AbstractVarInfo |
| 8 | +using LogDensityProblems: logdensity, logdensity_and_gradient |
| 9 | +using Random: Random, Xoshiro |
| 10 | +using Statistics: median |
| 11 | +using Test: @test |
| 12 | + |
| 13 | +export ADResult, run_ad |
| 14 | + |
| 15 | +# This function needed to work around the fact that different backends can |
| 16 | +# return different AbstractArrays for the gradient. See |
| 17 | +# https://github.com/JuliaDiff/DifferentiationInterface.jl/issues/754 for more |
| 18 | +# context. |
| 19 | +_to_vec_f64(x::AbstractArray) = x isa Vector{Float64} ? x : collect(Float64, x) |
| 20 | + |
| 21 | +""" |
| 22 | + REFERENCE_ADTYPE |
| 23 | +
|
| 24 | +Reference AD backend to use for comparison. In this case, ForwardDiff.jl, since |
| 25 | +it's the default AD backend used in Turing.jl. |
| 26 | +""" |
| 27 | +const REFERENCE_ADTYPE = AutoForwardDiff() |
| 28 | + |
| 29 | +""" |
| 30 | + ADResult |
| 31 | +
|
| 32 | +Data structure to store the results of the AD correctness test. |
| 33 | +""" |
| 34 | +struct ADResult |
| 35 | + "The DynamicPPL model that was tested" |
| 36 | + model::Model |
| 37 | + "The VarInfo that was used" |
| 38 | + varinfo::AbstractVarInfo |
| 39 | + "The values at which the model was evaluated" |
| 40 | + params::Vector{<:Real} |
| 41 | + "The AD backend that was tested" |
| 42 | + adtype::AbstractADType |
| 43 | + "The absolute tolerance for the value of logp" |
| 44 | + value_atol::Real |
| 45 | + "The absolute tolerance for the gradient of logp" |
| 46 | + grad_atol::Real |
| 47 | + "The expected value of logp" |
| 48 | + value_expected::Union{Nothing,Float64} |
| 49 | + "The expected gradient of logp" |
| 50 | + grad_expected::Union{Nothing,Vector{Float64}} |
| 51 | + "The value of logp (calculated using `adtype`)" |
| 52 | + value_actual::Union{Nothing,Real} |
| 53 | + "The gradient of logp (calculated using `adtype`)" |
| 54 | + grad_actual::Union{Nothing,Vector{Float64}} |
| 55 | + "If benchmarking was requested, the time taken by the AD backend to calculate the gradient of logp, divided by the time taken to evaluate logp itself" |
| 56 | + time_vs_primal::Union{Nothing,Float64} |
| 57 | +end |
| 58 | + |
| 59 | +""" |
| 60 | + run_ad( |
| 61 | + model::Model, |
| 62 | + adtype::ADTypes.AbstractADType; |
| 63 | + test=true, |
| 64 | + benchmark=false, |
| 65 | + value_atol=1e-6, |
| 66 | + grad_atol=1e-6, |
| 67 | + varinfo::AbstractVarInfo=VarInfo(model), |
| 68 | + params::Vector{<:Real}=varinfo[:], |
| 69 | + reference_adtype::ADTypes.AbstractADType=REFERENCE_ADTYPE, |
| 70 | + expected_value_and_grad::Union{Nothing,Tuple{Real,Vector{<:Real}}}=nothing, |
| 71 | + verbose=true, |
| 72 | + )::ADResult |
| 73 | +
|
| 74 | +Test the correctness and/or benchmark the AD backend `adtype` for the model |
| 75 | +`model`. |
| 76 | +
|
| 77 | +Whether to test and benchmark is controlled by the `test` and `benchmark` |
| 78 | +keyword arguments. By default, `test` is `true` and `benchmark` is `false`. |
| 79 | +
|
| 80 | +Returns an [`ADResult`](@ref) object, which contains the results of the |
| 81 | +test and/or benchmark. |
| 82 | +
|
| 83 | +Note that to run AD successfully you will need to import the AD backend itself. |
| 84 | +For example, to test with `AutoReverseDiff()` you will need to run `import |
| 85 | +ReverseDiff`. |
| 86 | +
|
| 87 | +There are two positional arguments, which absolutely must be provided: |
| 88 | +
|
| 89 | +1. `model` - The model being tested. |
| 90 | +2. `adtype` - The AD backend being tested. |
| 91 | +
|
| 92 | +Everything else is optional, and can be categorised into several groups: |
| 93 | +
|
| 94 | +1. _How to specify the VarInfo._ |
| 95 | +
|
| 96 | + DynamicPPL contains several different types of VarInfo objects which change |
| 97 | + the way model evaluation occurs. If you want to use a specific type of |
| 98 | + VarInfo, pass it as the `varinfo` argument. Otherwise, it will default to |
| 99 | + using a `TypedVarInfo` generated from the model. |
| 100 | +
|
| 101 | +2. _How to specify the parameters._ |
| 102 | +
|
| 103 | + For maximum control over this, generate a vector of parameters yourself and |
| 104 | + pass this as the `params` argument. If you don't specify this, it will be |
| 105 | + taken from the contents of the VarInfo. |
| 106 | +
|
| 107 | + Note that if the VarInfo is not specified (and thus automatically generated) |
| 108 | + the parameters in it will have been sampled from the prior of the model. If |
| 109 | + you want to seed the parameter generation, the easiest way is to pass a |
| 110 | + `rng` argument to the VarInfo constructor (i.e. do `VarInfo(rng, model)`). |
| 111 | +
|
| 112 | + Finally, note that these only reflect the parameters used for _evaluating_ |
| 113 | + the gradient. If you also want to control the parameters used for |
| 114 | + _preparing_ the gradient, then you need to manually set these parameters in |
| 115 | + the VarInfo object, for example using `vi = DynamicPPL.unflatten(vi, |
| 116 | + prep_params)`. You could then evaluate the gradient at a different set of |
| 117 | + parameters using the `params` keyword argument. |
| 118 | +
|
| 119 | +3. _How to specify the results to compare against._ (Only if `test=true`.) |
| 120 | +
|
| 121 | + Once logp and its gradient has been calculated with the specified `adtype`, |
| 122 | + it must be tested for correctness. |
| 123 | +
|
| 124 | + This can be done either by specifying `reference_adtype`, in which case logp |
| 125 | + and its gradient will also be calculated with this reference in order to |
| 126 | + obtain the ground truth; or by using `expected_value_and_grad`, which is a |
| 127 | + tuple of `(logp, gradient)` that the calculated values must match. The |
| 128 | + latter is useful if you are testing multiple AD backends and want to avoid |
| 129 | + recalculating the ground truth multiple times. |
| 130 | +
|
| 131 | + The default reference backend is ForwardDiff. If none of these parameters are |
| 132 | + specified, ForwardDiff will be used to calculate the ground truth. |
| 133 | +
|
| 134 | +4. _How to specify the tolerances._ (Only if `test=true`.) |
| 135 | +
|
| 136 | + The tolerances for the value and gradient can be set using `value_atol` and |
| 137 | + `grad_atol`. These default to 1e-6. |
| 138 | +
|
| 139 | +5. _Whether to output extra logging information._ |
| 140 | +
|
| 141 | + By default, this function prints messages when it runs. To silence it, set |
| 142 | + `verbose=false`. |
| 143 | +""" |
| 144 | +function run_ad( |
| 145 | + model::Model, |
| 146 | + adtype::AbstractADType; |
| 147 | + test=true, |
| 148 | + benchmark=false, |
| 149 | + value_atol=1e-6, |
| 150 | + grad_atol=1e-6, |
| 151 | + varinfo::AbstractVarInfo=VarInfo(model), |
| 152 | + params::Vector{<:Real}=varinfo[:], |
| 153 | + reference_adtype::AbstractADType=REFERENCE_ADTYPE, |
| 154 | + expected_value_and_grad::Union{Nothing,Tuple{Real,Vector{<:Real}}}=nothing, |
| 155 | + verbose=true, |
| 156 | +)::ADResult |
| 157 | + verbose && @info "Running AD on $(model.f) with $(adtype)\n" |
| 158 | + params = map(identity, params) |
| 159 | + verbose && println(" params : $(params)") |
| 160 | + ldf = LogDensityFunction(model, varinfo; adtype=adtype) |
| 161 | + |
| 162 | + value, grad = logdensity_and_gradient(ldf, params) |
| 163 | + grad = _to_vec_f64(grad) |
| 164 | + verbose && println(" actual : $((value, grad))") |
| 165 | + |
| 166 | + if test |
| 167 | + # Calculate ground truth to compare against |
| 168 | + value_true, grad_true = if expected_value_and_grad === nothing |
| 169 | + ldf_reference = LogDensityFunction(model, varinfo; adtype=reference_adtype) |
| 170 | + logdensity_and_gradient(ldf_reference, params) |
| 171 | + else |
| 172 | + expected_value_and_grad |
| 173 | + end |
| 174 | + verbose && println(" expected : $((value_true, grad_true))") |
| 175 | + grad_true = _to_vec_f64(grad_true) |
| 176 | + # Then compare |
| 177 | + @test isapprox(value, value_true; atol=value_atol) |
| 178 | + @test isapprox(grad, grad_true; atol=grad_atol) |
| 179 | + else |
| 180 | + value_true = nothing |
| 181 | + grad_true = nothing |
| 182 | + end |
| 183 | + |
| 184 | + time_vs_primal = if benchmark |
| 185 | + primal_benchmark = @be (ldf, params) logdensity(_[1], _[2]) |
| 186 | + grad_benchmark = @be (ldf, params) logdensity_and_gradient(_[1], _[2]) |
| 187 | + t = median(grad_benchmark).time / median(primal_benchmark).time |
| 188 | + verbose && println("grad / primal : $(t)") |
| 189 | + t |
| 190 | + else |
| 191 | + nothing |
| 192 | + end |
| 193 | + |
| 194 | + return ADResult( |
| 195 | + model, |
| 196 | + varinfo, |
| 197 | + params, |
| 198 | + adtype, |
| 199 | + value_atol, |
| 200 | + grad_atol, |
| 201 | + value_true, |
| 202 | + grad_true, |
| 203 | + value, |
| 204 | + grad, |
| 205 | + time_vs_primal, |
| 206 | + ) |
| 207 | +end |
| 208 | + |
| 209 | +end # module DynamicPPL.TestUtils.AD |
0 commit comments