Skip to content

Commit 7cb38f3

Browse files
committed
Remove DynamicPPLForwardDiffExt
1 parent 8de4742 commit 7cb38f3

File tree

5 files changed

+27
-61
lines changed

5 files changed

+27
-61
lines changed

Project.toml

-3
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
2929
[weakdeps]
3030
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
3131
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
32-
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
3332
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
3433
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
3534
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
@@ -38,7 +37,6 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
3837
[extensions]
3938
DynamicPPLChainRulesCoreExt = ["ChainRulesCore"]
4039
DynamicPPLEnzymeCoreExt = ["EnzymeCore"]
41-
DynamicPPLForwardDiffExt = ["ForwardDiff"]
4240
DynamicPPLJETExt = ["JET"]
4341
DynamicPPLMCMCChainsExt = ["MCMCChains"]
4442
DynamicPPLMooncakeExt = ["Mooncake"]
@@ -58,7 +56,6 @@ DifferentiationInterface = "0.6.39"
5856
Distributions = "0.25"
5957
DocStringExtensions = "0.9"
6058
EnzymeCore = "0.6 - 0.8"
61-
ForwardDiff = "0.10"
6259
JET = "0.9"
6360
KernelAbstractions = "< 0.9.32"
6461
LinearAlgebra = "1.6"

ext/DynamicPPLForwardDiffExt.jl

-27
This file was deleted.

src/logdensityfunction.jl

+24-16
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@ import DifferentiationInterface as DI
44
LogDensityFunction
55
66
A callable representing a log density function of a `model`.
7+
`DynamicPPL.LogDensityFunction` implements the LogDensityProblems.jl interface,
8+
but only to 0th-order, i.e. it is only possible to calculate the log density,
9+
and not its gradient. If you need to calculate the gradient as well, you have
10+
to construct a [`DynamicPPL.LogDensityFunctionWithGrad`](@ref) object.
711
812
# Fields
913
$(FIELDS)
@@ -55,16 +59,6 @@ struct LogDensityFunction{V,M,C}
5559
context::C
5660
end
5761

58-
# TODO: Deprecate.
59-
function LogDensityFunction(
60-
varinfo::AbstractVarInfo,
61-
model::Model,
62-
sampler::AbstractSampler,
63-
context::AbstractContext,
64-
)
65-
return LogDensityFunction(varinfo, model, SamplingContext(sampler, context))
66-
end
67-
6862
function LogDensityFunction(
6963
model::Model,
7064
varinfo::AbstractVarInfo=VarInfo(model),
@@ -94,11 +88,6 @@ function setmodel(f::DynamicPPL.LogDensityFunction, model::DynamicPPL.Model)
9488
return Accessors.@set f.model = model
9589
end
9690

97-
# HACK: heavy usage of `AbstractSampler` for, well, _everything_, is being phased out. In the mean time
98-
# we need to define these annoying methods to ensure that we stay compatible with everything.
99-
getsampler(f::LogDensityFunction) = getsampler(getcontext(f))
100-
hassampler(f::LogDensityFunction) = hassampler(getcontext(f))
101-
10291
"""
10392
getparams(f::LogDensityFunction)
10493
@@ -122,7 +111,26 @@ end
122111
LogDensityProblems.dimension(f::LogDensityFunction) = length(getparams(f))
123112

124113
# LogDensityProblems interface: gradient (1st order)
125-
struct LogDensityFunctionWithGrad{V,M,C,TAD}
114+
"""
115+
LogDensityFunctionWithGrad(ldf::DynamicPPL.LogDensityFunction, adtype::ADTypes.AbstractADType)
116+
117+
A callable representing a log density function of a `model`.
118+
`DynamicPPL.LogDensityFunctionWithGrad` implements the LogDensityProblems.jl
119+
interface to 1st-order, meaning that you can both calculate the log density
120+
using
121+
122+
LogDensityProblems.logdensity(f, x)
123+
124+
and its gradient using
125+
126+
LogDensityProblems.logdensity_and_gradient(f, x)
127+
128+
where `f` is a `LogDensityFunctionWithGrad` object and `x` is a vector of parameters.
129+
130+
# Fields
131+
$(FIELDS)
132+
"""
133+
struct LogDensityFunctionWithGrad{V,M,C,TAD<:ADTypes.AbstractADType}
126134
ldf::LogDensityFunction{V,M,C}
127135
adtype::TAD
128136
prep::DI.GradientPrep

test/ad.jl

+3-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@
2727
@test_broken 1 == 0
2828
else
2929
ldf_with_grad = DynamicPPL.LogDensityFunctionWithGrad(f, adtype)
30-
logp, grad = LogDensityProblems.logdensity_and_gradient(f, x)
30+
logp, grad = LogDensityProblems.logdensity_and_gradient(
31+
ldf_with_grad, x
32+
)
3133
@test grad ref_grad
3234
@test logp ref_logp
3335
end

test/ext/DynamicPPLForwardDiffExt.jl

-14
This file was deleted.

0 commit comments

Comments
 (0)