1
- struct PriorExtractorContext{D<: OrderedDict{VarName,Any} ,Ctx<: AbstractContext } < :
2
- AbstractContext
1
+ struct PriorDistributionAccumulator{D<: OrderedDict{VarName,Any} } <: AbstractAccumulator
3
2
priors:: D
4
- context:: Ctx
5
3
end
6
4
7
- PriorExtractorContext (context ) = PriorExtractorContext (OrderedDict {VarName,Any} (), context )
5
+ PriorDistributionAccumulator ( ) = PriorDistributionAccumulator (OrderedDict {VarName,Any} ())
8
6
9
- NodeTrait (:: PriorExtractorContext ) = IsParent ()
10
- childcontext (context:: PriorExtractorContext ) = context. context
11
- function setchildcontext (parent:: PriorExtractorContext , child)
12
- return PriorExtractorContext (parent. priors, child)
7
+ accumulator_name (:: PriorDistributionAccumulator ) = :PriorDistributionAccumulator
8
+
9
+ split (acc:: PriorDistributionAccumulator ) = PriorDistributionAccumulator (empty (acc. priors))
10
+ function combine (acc1:: PriorDistributionAccumulator , acc2:: PriorDistributionAccumulator )
11
+ return PriorDistributionAccumulator (merge (acc1. priors, acc2. priors))
13
12
end
14
13
15
- function setprior! (context:: PriorExtractorContext , vn:: VarName , dist:: Distribution )
16
- return context. priors[vn] = dist
14
+ function setprior! (acc:: PriorDistributionAccumulator , vn:: VarName , dist:: Distribution )
15
+ acc. priors[vn] = dist
16
+ return acc
17
17
end
18
18
19
19
function setprior! (
20
- context :: PriorExtractorContext , vns:: AbstractArray{<:VarName} , dist:: Distribution
20
+ acc :: PriorDistributionAccumulator , vns:: AbstractArray{<:VarName} , dist:: Distribution
21
21
)
22
22
for vn in vns
23
- context . priors[vn] = dist
23
+ acc . priors[vn] = dist
24
24
end
25
+ return acc
25
26
end
26
27
27
28
function setprior! (
28
- context :: PriorExtractorContext ,
29
+ acc :: PriorDistributionAccumulator ,
29
30
vns:: AbstractArray{<:VarName} ,
30
31
dists:: AbstractArray{<:Distribution} ,
31
32
)
32
33
for (vn, dist) in zip (vns, dists)
33
- context . priors[vn] = dist
34
+ acc . priors[vn] = dist
34
35
end
36
+ return acc
35
37
end
36
38
37
- function DynamicPPL. tilde_assume (context:: PriorExtractorContext , right, vn, vi)
38
- setprior! (context, vn, right)
39
- return DynamicPPL. tilde_assume (childcontext (context), right, vn, vi)
39
+ function accumulate_assume!! (acc:: PriorDistributionAccumulator , val, logjac, vn, right)
40
+ return setprior! (acc, vn, right)
40
41
end
41
42
43
+ accumulate_observe!! (acc:: PriorDistributionAccumulator , right, left, vn) = acc
44
+
42
45
"""
43
46
extract_priors([rng::Random.AbstractRNG, ]model::Model)
44
47
@@ -108,9 +111,13 @@ julia> length(extract_priors(rng, model)[@varname(x)])
108
111
extract_priors (args:: Union{Model,AbstractVarInfo} ...) =
109
112
extract_priors (Random. default_rng (), args... )
110
113
function extract_priors (rng:: Random.AbstractRNG , model:: Model )
111
- context = PriorExtractorContext (SamplingContext (rng))
112
- evaluate!! (model, VarInfo (), context)
113
- return context. priors
114
+ varinfo = VarInfo ()
115
+ # TODO (mhauru) This doesn't actually need the NumProduceAccumulator, it's only a
116
+ # workaround for the fact that `order` is still hardcoded in VarInfo, and hence you
117
+ # can't push new variables without knowing the num_produce. Remove this when possible.
118
+ varinfo = setaccs!! (varinfo, (PriorDistributionAccumulator (), NumProduceAccumulator ()))
119
+ varinfo = last (evaluate!! (model, varinfo, SamplingContext (rng)))
120
+ return getacc (varinfo, Val (:PriorDistributionAccumulator )). priors
114
121
end
115
122
116
123
"""
@@ -122,7 +129,12 @@ This is done by evaluating the model at the values present in `varinfo`
122
129
and recording the distributions that are present at each tilde statement.
123
130
"""
124
131
function extract_priors (model:: Model , varinfo:: AbstractVarInfo )
125
- context = PriorExtractorContext (DefaultContext ())
126
- evaluate!! (model, deepcopy (varinfo), context)
127
- return context. priors
132
+ # TODO (mhauru) This doesn't actually need the NumProduceAccumulator, it's only a
133
+ # workaround for the fact that `order` is still hardcoded in VarInfo, and hence you
134
+ # can't push new variables without knowing the num_produce. Remove this when possible.
135
+ varinfo = setaccs!! (
136
+ deepcopy (varinfo), (PriorDistributionAccumulator (), NumProduceAccumulator ())
137
+ )
138
+ varinfo = last (evaluate!! (model, varinfo, DefaultContext ()))
139
+ return getacc (varinfo, Val (:PriorDistributionAccumulator )). priors
128
140
end
0 commit comments