@@ -108,86 +108,14 @@ function DynamicPPL.generated_quantities(
108
108
varinfo = DynamicPPL. VarInfo (model)
109
109
iters = Iterators. product (1 : size (chain, 1 ), 1 : size (chain, 3 ))
110
110
return map (iters) do (sample_idx, chain_idx)
111
- if DynamicPPL. supports_varname_indexing (chain)
112
- varname_pairs = _varname_pairs_with_varname_indexing (
113
- chain, varinfo, sample_idx, chain_idx
114
- )
115
- else
116
- varname_pairs = _varname_pairs_without_varname_indexing (
117
- chain, varinfo, sample_idx, chain_idx
118
- )
119
- end
120
- fixed_model = DynamicPPL. fix (model, Dict (varname_pairs))
121
- return fixed_model ()
111
+ # TODO : Use `fix` once we've addressed https://github.com/TuringLang/DynamicPPL.jl/issues/702.
112
+ # Update the varinfo with the current sample and make variables not present in `chain`
113
+ # to be sampled.
114
+ DynamicPPL. setval_and_resample! (varinfo, chain, sample_idx, chain_idx)
115
+ # NOTE: Some of the varialbes can be a view into the `varinfo`, so we need to
116
+ # `deepcopy` the `varinfo` before passing it to the `model`.
117
+ model (deepcopy (varinfo))
122
118
end
123
119
end
124
120
125
- """
126
- _varname_pairs_with_varname_indexing(
127
- chain::MCMCChains.Chains, varinfo, sample_idx, chain_idx
128
- )
129
-
130
- Get pairs of `VarName => value` for all the variables in the `varinfo`, picking the values
131
- from the chain.
132
-
133
- This implementation assumes `chain` can be indexed using variable names, and is the
134
- preffered implementation.
135
- """
136
- function _varname_pairs_with_varname_indexing (
137
- chain:: MCMCChains.Chains , varinfo, sample_idx, chain_idx
138
- )
139
- vns = DynamicPPL. varnames (chain)
140
- vn_parents = Iterators. map (vns) do vn
141
- # The call nested_setindex_maybe! is used to handle cases where vn is not
142
- # the variable name used in the model, but rather subsumed by one. Except
143
- # for the subsumption part, this could be
144
- # vn => getindex_varname(chain, sample_idx, vn, chain_idx)
145
- # TODO (mhauru) This call to nested_setindex_maybe! is unintuitive.
146
- DynamicPPL. nested_setindex_maybe! (
147
- varinfo, DynamicPPL. getindex_varname (chain, sample_idx, vn, chain_idx), vn
148
- )
149
- end
150
- varname_pairs = Iterators. map (Iterators. filter (! isnothing, vn_parents)) do vn_parent
151
- vn_parent => varinfo[vn_parent]
152
- end
153
- return varname_pairs
154
- end
155
-
156
- """
157
- Check which keys in `key_strings` are subsumed by `vn_string` and return the their values.
158
-
159
- The subsumption check is done with `DynamicPPL.subsumes_string`, which is quite weak, and
160
- won't catch all cases. We should get rid of this if we can.
161
- """
162
- # TODO (mhauru) See docstring above.
163
- function _vcat_subsumed_values (vn_string, values, key_strings)
164
- indices = findall (Base. Fix1 (DynamicPPL. subsumes_string, vn_string), key_strings)
165
- return ! isempty (indices) ? reduce (vcat, values[indices]) : nothing
166
- end
167
-
168
- """
169
- _varname_pairs_without_varname_indexing(
170
- chain::MCMCChains.Chains, varinfo, sample_idx, chain_idx
171
- )
172
-
173
- Get pairs of `VarName => value` for all the variables in the `varinfo`, picking the values
174
- from the chain.
175
-
176
- This implementation does not assume that `chain` can be indexed using variable names. It is
177
- thus not guaranteed to work in cases where the variable names have complex subsumption
178
- patterns, such as if the model has a variable `x` but the chain stores `x.a[1]`.
179
- """
180
- function _varname_pairs_without_varname_indexing (
181
- chain:: MCMCChains.Chains , varinfo, sample_idx, chain_idx
182
- )
183
- values = chain. value[sample_idx, :, chain_idx]
184
- keys = Base. keys (chain)
185
- keys_strings = map (string, keys)
186
- varname_pairs = [
187
- vn => _vcat_subsumed_values (string (vn), values, keys_strings) for
188
- vn in Base. keys (varinfo)
189
- ]
190
- return varname_pairs
191
- end
192
-
193
121
end
0 commit comments