Skip to content

Commit a4220c5

Browse files
committed
setadtype --> LogDensityFunction
1 parent f76bb3d commit a4220c5

File tree

3 files changed

+17
-16
lines changed

3 files changed

+17
-16
lines changed

HISTORY.md

+3-2
Original file line numberDiff line numberDiff line change
@@ -155,13 +155,14 @@ LogDensityProblems.logdensity_and_gradient(ldf, params)
155155
156156
without having to construct a separate `ADgradient` object.
157157
158-
If you prefer, you can also use `setadtype` to tack on the AD type afterwards:
158+
If you prefer, you can also construct a new `LogDensityFunction` with a new AD type afterwards.
159+
The model, varinfo, and context will be taken from the original `LogDensityFunction`:
159160
160161
```julia
161162
@model f() = ...
162163
163164
ldf = LogDensityFunction(f()) # by default, no adtype set
164-
ldf_with_ad = setadtype(ldf, AutoForwardDiff())
165+
ldf_with_ad = LogDensityFunction(ldf, AutoForwardDiff())
165166
```
166167
167168
## 0.34.2

src/logdensityfunction.jl

+10-10
Original file line numberDiff line numberDiff line change
@@ -143,18 +143,18 @@ struct LogDensityFunction{
143143
end
144144

145145
"""
146-
setadtype(f::LogDensityFunction, adtype::Union{Nothing,ADTypes.AbstractADType})
147-
148-
Set the AD type used for evaluation of log density gradient in the given
149-
LogDensityFunction. This function also performs preparation of the gradient,
150-
and sets the `prep` field of the LogDensityFunction.
151-
152-
If `adtype` is `nothing`, the `prep` field will be set to `nothing` as well.
146+
LogDensityFunction(
147+
ldf::LogDensityFunction,
148+
adtype::Union{Nothing,ADTypes.AbstractADType}
149+
)
153150
154-
This function returns a new LogDensityFunction with the updated AD type, i.e. it does
155-
not mutate the input LogDensityFunction.
151+
Create a new LogDensityFunction using the model, varinfo, and context from the given
152+
`ldf` argument, but with the AD type set to `adtype`. To remove the AD type, pass
153+
`nothing` as the second argument.
156154
"""
157-
function setadtype(f::LogDensityFunction, adtype::Union{Nothing,ADTypes.AbstractADType})
155+
function LogDensityFunction(
156+
f::LogDensityFunction, adtype::Union{Nothing,ADTypes.AbstractADType}
157+
)
158158
return if adtype === f.adtype
159159
f # Avoid recomputing prep if not needed
160160
else

test/ad.jl

+4-4
Original file line numberDiff line numberDiff line change
@@ -39,17 +39,17 @@ using DynamicPPL: LogDensityFunction
3939
# Mooncake doesn't work with several combinations of SimpleVarInfo.
4040
if is_mooncake && is_1_11 && is_svi_vnv
4141
# https://github.com/compintell/Mooncake.jl/issues/470
42-
@test_throws ArgumentError DynamicPPL.setadtype(ref_ldf, adtype)
42+
@test_throws ArgumentError DynamicPPL.LogDensityFunction(ref_ldf, adtype)
4343
elseif is_mooncake && is_1_10 && is_svi_vnv
4444
# TODO: report upstream
45-
@test_throws UndefRefError DynamicPPL.setadtype(ref_ldf, adtype)
45+
@test_throws UndefRefError DynamicPPL.LogDensityFunction(ref_ldf, adtype)
4646
elseif is_mooncake && is_1_10 && is_svi_od
4747
# TODO: report upstream
48-
@test_throws Mooncake.MooncakeRuleCompilationError DynamicPPL.setadtype(
48+
@test_throws Mooncake.MooncakeRuleCompilationError DynamicPPL.LogDensityFunction(
4949
ref_ldf, adtype
5050
)
5151
else
52-
ldf = DynamicPPL.setadtype(ref_ldf, adtype)
52+
ldf = DynamicPPL.LogDensityFunction(ref_ldf, adtype)
5353
logp, grad = LogDensityProblems.logdensity_and_gradient(ldf, x)
5454
@test grad ref_grad
5555
@test logp ref_logp

0 commit comments

Comments
 (0)