Skip to content

Commit 326d7ed

Browse files
authored
Replace PriorExtractorContext with PriorDistributionAccumulator (#907)
1 parent 299e17b commit 326d7ed

File tree

1 file changed

+35
-23
lines changed

1 file changed

+35
-23
lines changed

src/extract_priors.jl

+35-23
Original file line numberDiff line numberDiff line change
@@ -1,44 +1,47 @@
1-
struct PriorExtractorContext{D<:OrderedDict{VarName,Any},Ctx<:AbstractContext} <:
2-
AbstractContext
1+
struct PriorDistributionAccumulator{D<:OrderedDict{VarName,Any}} <: AbstractAccumulator
32
priors::D
4-
context::Ctx
53
end
64

7-
PriorExtractorContext(context) = PriorExtractorContext(OrderedDict{VarName,Any}(), context)
5+
PriorDistributionAccumulator() = PriorDistributionAccumulator(OrderedDict{VarName,Any}())
86

9-
NodeTrait(::PriorExtractorContext) = IsParent()
10-
childcontext(context::PriorExtractorContext) = context.context
11-
function setchildcontext(parent::PriorExtractorContext, child)
12-
return PriorExtractorContext(parent.priors, child)
7+
accumulator_name(::PriorDistributionAccumulator) = :PriorDistributionAccumulator
8+
9+
split(acc::PriorDistributionAccumulator) = PriorDistributionAccumulator(empty(acc.priors))
10+
function combine(acc1::PriorDistributionAccumulator, acc2::PriorDistributionAccumulator)
11+
return PriorDistributionAccumulator(merge(acc1.priors, acc2.priors))
1312
end
1413

15-
function setprior!(context::PriorExtractorContext, vn::VarName, dist::Distribution)
16-
return context.priors[vn] = dist
14+
function setprior!(acc::PriorDistributionAccumulator, vn::VarName, dist::Distribution)
15+
acc.priors[vn] = dist
16+
return acc
1717
end
1818

1919
function setprior!(
20-
context::PriorExtractorContext, vns::AbstractArray{<:VarName}, dist::Distribution
20+
acc::PriorDistributionAccumulator, vns::AbstractArray{<:VarName}, dist::Distribution
2121
)
2222
for vn in vns
23-
context.priors[vn] = dist
23+
acc.priors[vn] = dist
2424
end
25+
return acc
2526
end
2627

2728
function setprior!(
28-
context::PriorExtractorContext,
29+
acc::PriorDistributionAccumulator,
2930
vns::AbstractArray{<:VarName},
3031
dists::AbstractArray{<:Distribution},
3132
)
3233
for (vn, dist) in zip(vns, dists)
33-
context.priors[vn] = dist
34+
acc.priors[vn] = dist
3435
end
36+
return acc
3537
end
3638

37-
function DynamicPPL.tilde_assume(context::PriorExtractorContext, right, vn, vi)
38-
setprior!(context, vn, right)
39-
return DynamicPPL.tilde_assume(childcontext(context), right, vn, vi)
39+
function accumulate_assume!!(acc::PriorDistributionAccumulator, val, logjac, vn, right)
40+
return setprior!(acc, vn, right)
4041
end
4142

43+
accumulate_observe!!(acc::PriorDistributionAccumulator, right, left, vn) = acc
44+
4245
"""
4346
extract_priors([rng::Random.AbstractRNG, ]model::Model)
4447
@@ -108,9 +111,13 @@ julia> length(extract_priors(rng, model)[@varname(x)])
108111
extract_priors(args::Union{Model,AbstractVarInfo}...) =
109112
extract_priors(Random.default_rng(), args...)
110113
function extract_priors(rng::Random.AbstractRNG, model::Model)
111-
context = PriorExtractorContext(SamplingContext(rng))
112-
evaluate!!(model, VarInfo(), context)
113-
return context.priors
114+
varinfo = VarInfo()
115+
# TODO(mhauru) This doesn't actually need the NumProduceAccumulator, it's only a
116+
# workaround for the fact that `order` is still hardcoded in VarInfo, and hence you
117+
# can't push new variables without knowing the num_produce. Remove this when possible.
118+
varinfo = setaccs!!(varinfo, (PriorDistributionAccumulator(), NumProduceAccumulator()))
119+
varinfo = last(evaluate!!(model, varinfo, SamplingContext(rng)))
120+
return getacc(varinfo, Val(:PriorDistributionAccumulator)).priors
114121
end
115122

116123
"""
@@ -122,7 +129,12 @@ This is done by evaluating the model at the values present in `varinfo`
122129
and recording the distributions that are present at each tilde statement.
123130
"""
124131
function extract_priors(model::Model, varinfo::AbstractVarInfo)
125-
context = PriorExtractorContext(DefaultContext())
126-
evaluate!!(model, deepcopy(varinfo), context)
127-
return context.priors
132+
# TODO(mhauru) This doesn't actually need the NumProduceAccumulator, it's only a
133+
# workaround for the fact that `order` is still hardcoded in VarInfo, and hence you
134+
# can't push new variables without knowing the num_produce. Remove this when possible.
135+
varinfo = setaccs!!(
136+
deepcopy(varinfo), (PriorDistributionAccumulator(), NumProduceAccumulator())
137+
)
138+
varinfo = last(evaluate!!(model, varinfo, DefaultContext()))
139+
return getacc(varinfo, Val(:PriorDistributionAccumulator)).priors
128140
end

0 commit comments

Comments
 (0)