From 2176a0768b0a5045d7ba91bff8dfc3e91e06c31a Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 2 May 2025 14:15:52 +0100 Subject: [PATCH] Replace PriorExtractorContext with PriorDistributionAccumulator --- src/extract_priors.jl | 58 ++++++++++++++++++++++++++----------------- 1 file changed, 35 insertions(+), 23 deletions(-) diff --git a/src/extract_priors.jl b/src/extract_priors.jl index 0f312fa2c..9047c9f0a 100644 --- a/src/extract_priors.jl +++ b/src/extract_priors.jl @@ -1,44 +1,47 @@ -struct PriorExtractorContext{D<:OrderedDict{VarName,Any},Ctx<:AbstractContext} <: - AbstractContext +struct PriorDistributionAccumulator{D<:OrderedDict{VarName,Any}} <: AbstractAccumulator priors::D - context::Ctx end -PriorExtractorContext(context) = PriorExtractorContext(OrderedDict{VarName,Any}(), context) +PriorDistributionAccumulator() = PriorDistributionAccumulator(OrderedDict{VarName,Any}()) -NodeTrait(::PriorExtractorContext) = IsParent() -childcontext(context::PriorExtractorContext) = context.context -function setchildcontext(parent::PriorExtractorContext, child) - return PriorExtractorContext(parent.priors, child) +accumulator_name(::PriorDistributionAccumulator) = :PriorDistributionAccumulator + +split(acc::PriorDistributionAccumulator) = PriorDistributionAccumulator(empty(acc.priors)) +function combine(acc1::PriorDistributionAccumulator, acc2::PriorDistributionAccumulator) + return PriorDistributionAccumulator(merge(acc1.priors, acc2.priors)) end -function setprior!(context::PriorExtractorContext, vn::VarName, dist::Distribution) - return context.priors[vn] = dist +function setprior!(acc::PriorDistributionAccumulator, vn::VarName, dist::Distribution) + acc.priors[vn] = dist + return acc end function setprior!( - context::PriorExtractorContext, vns::AbstractArray{<:VarName}, dist::Distribution + acc::PriorDistributionAccumulator, vns::AbstractArray{<:VarName}, dist::Distribution ) for vn in vns - context.priors[vn] = dist + acc.priors[vn] = dist end + return acc end function setprior!( - context::PriorExtractorContext, + acc::PriorDistributionAccumulator, vns::AbstractArray{<:VarName}, dists::AbstractArray{<:Distribution}, ) for (vn, dist) in zip(vns, dists) - context.priors[vn] = dist + acc.priors[vn] = dist end + return acc end -function DynamicPPL.tilde_assume(context::PriorExtractorContext, right, vn, vi) - setprior!(context, vn, right) - return DynamicPPL.tilde_assume(childcontext(context), right, vn, vi) +function accumulate_assume!!(acc::PriorDistributionAccumulator, val, logjac, vn, right) + return setprior!(acc, vn, right) end +accumulate_observe!!(acc::PriorDistributionAccumulator, right, left, vn) = acc + """ extract_priors([rng::Random.AbstractRNG, ]model::Model) @@ -108,9 +111,13 @@ julia> length(extract_priors(rng, model)[@varname(x)]) extract_priors(args::Union{Model,AbstractVarInfo}...) = extract_priors(Random.default_rng(), args...) function extract_priors(rng::Random.AbstractRNG, model::Model) - context = PriorExtractorContext(SamplingContext(rng)) - evaluate!!(model, VarInfo(), context) - return context.priors + varinfo = VarInfo() + # TODO(mhauru) This doesn't actually need the NumProduceAccumulator, it's only a + # workaround for the fact that `order` is still hardcoded in VarInfo, and hence you + # can't push new variables without knowing the num_produce. Remove this when possible. + varinfo = setaccs!!(varinfo, (PriorDistributionAccumulator(), NumProduceAccumulator())) + varinfo = last(evaluate!!(model, varinfo, SamplingContext(rng))) + return getacc(varinfo, Val(:PriorDistributionAccumulator)).priors end """ @@ -122,7 +129,12 @@ This is done by evaluating the model at the values present in `varinfo` and recording the distributions that are present at each tilde statement. """ function extract_priors(model::Model, varinfo::AbstractVarInfo) - context = PriorExtractorContext(DefaultContext()) - evaluate!!(model, deepcopy(varinfo), context) - return context.priors + # TODO(mhauru) This doesn't actually need the NumProduceAccumulator, it's only a + # workaround for the fact that `order` is still hardcoded in VarInfo, and hence you + # can't push new variables without knowing the num_produce. Remove this when possible. + varinfo = setaccs!!( + deepcopy(varinfo), (PriorDistributionAccumulator(), NumProduceAccumulator()) + ) + varinfo = last(evaluate!!(model, varinfo, DefaultContext())) + return getacc(varinfo, Val(:PriorDistributionAccumulator)).priors end