Skip to content

Commit 3d18cfc

Browse files
penelopeysmsunxd3
andauthored
Don't include lhs of := in results of predict() (#766)
* Don't include lhs of := in results of predict() * Bump minor version * Remove unused constructor * Add a test for `values_as_in_model(rng, model, ...)` --------- Co-authored-by: Xianda Sun <[email protected]>
1 parent b7fd9ea commit 3d18cfc

File tree

5 files changed

+102
-58
lines changed

5 files changed

+102
-58
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.32.2"
3+
version = "0.33.0"
44

55
[deps]
66
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

ext/DynamicPPLMCMCChainsExt.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ function DynamicPPL.predict(
117117
DynamicPPL.setval_and_resample!(varinfo, parameter_only_chain, sample_idx, chain_idx)
118118
model(rng, varinfo, DynamicPPL.SampleFromPrior())
119119

120-
vals = DynamicPPL.values_as_in_model(model, varinfo)
120+
vals = DynamicPPL.values_as_in_model(model, false, varinfo)
121121
varname_vals = mapreduce(
122122
collect,
123123
vcat,

src/values_as_in_model.jl

+14-11
Original file line numberDiff line numberDiff line change
@@ -22,22 +22,22 @@ $(TYPEDFIELDS)
2222
struct ValuesAsInModelContext{T,C<:AbstractContext} <: AbstractContext
2323
"values that are extracted from the model"
2424
values::T
25+
"whether to extract variables on the LHS of :="
26+
include_colon_eq::Bool
2527
"child context"
2628
context::C
2729
end
28-
29-
ValuesAsInModelContext(values) = ValuesAsInModelContext(values, DefaultContext())
30-
function ValuesAsInModelContext(context::AbstractContext)
31-
return ValuesAsInModelContext(OrderedDict(), context)
30+
function ValuesAsInModelContext(include_colon_eq, context::AbstractContext)
31+
return ValuesAsInModelContext(OrderedDict(), include_colon_eq, context)
3232
end
3333

3434
NodeTrait(::ValuesAsInModelContext) = IsParent()
3535
childcontext(context::ValuesAsInModelContext) = context.context
3636
function setchildcontext(context::ValuesAsInModelContext, child)
37-
return ValuesAsInModelContext(context.values, child)
37+
return ValuesAsInModelContext(context.values, context.include_colon_eq, child)
3838
end
3939

40-
is_extracting_values(context::ValuesAsInModelContext) = true
40+
is_extracting_values(context::ValuesAsInModelContext) = context.include_colon_eq
4141
function is_extracting_values(context::AbstractContext)
4242
return is_extracting_values(NodeTrait(context), context)
4343
end
@@ -114,8 +114,8 @@ function dot_tilde_assume(
114114
end
115115

116116
"""
117-
values_as_in_model(model::Model[, varinfo::AbstractVarInfo, context::AbstractContext])
118-
values_as_in_model(rng::Random.AbstractRNG, model::Model[, varinfo::AbstractVarInfo, context::AbstractContext])
117+
values_as_in_model(model::Model, include_colon_eq::Bool[, varinfo::AbstractVarInfo, context::AbstractContext])
118+
values_as_in_model(rng::Random.AbstractRNG, model::Model, include_colon_eq::Bool[, varinfo::AbstractVarInfo, context::AbstractContext])
119119
120120
Get the values of `varinfo` as they would be seen in the model.
121121
@@ -132,6 +132,7 @@ of additional model evaluations.
132132
133133
# Arguments
134134
- `model::Model`: model to extract realizations from.
135+
- `include_colon_eq::Bool`: whether to also include variables on the LHS of `:=`.
135136
- `varinfo::AbstractVarInfo`: variable information to use for the extraction.
136137
- `context::AbstractContext`: context to use for the extraction. If `rng` is specified, then `context`
137138
will be wrapped in a [`SamplingContext`](@ref) with the provided `rng`.
@@ -183,24 +184,26 @@ false
183184
julia> # Approach 2: Extract realizations using `values_as_in_model`.
184185
# (✓) `values_as_in_model` will re-run the model and extract
185186
# the correct realization of `y` given the new values of `x`.
186-
lb ≤ values_as_in_model(model, varinfo_linked)[@varname(y)] ≤ ub
187+
lb ≤ values_as_in_model(model, true, varinfo_linked)[@varname(y)] ≤ ub
187188
true
188189
```
189190
"""
190191
function values_as_in_model(
191192
model::Model,
193+
include_colon_eq::Bool,
192194
varinfo::AbstractVarInfo=VarInfo(),
193195
context::AbstractContext=DefaultContext(),
194196
)
195-
context = ValuesAsInModelContext(context)
197+
context = ValuesAsInModelContext(include_colon_eq, context)
196198
evaluate!!(model, varinfo, context)
197199
return context.values
198200
end
199201
function values_as_in_model(
200202
rng::Random.AbstractRNG,
201203
model::Model,
204+
include_colon_eq::Bool,
202205
varinfo::AbstractVarInfo=VarInfo(),
203206
context::AbstractContext=DefaultContext(),
204207
)
205-
return values_as_in_model(model, varinfo, SamplingContext(rng, context))
208+
return values_as_in_model(model, true, varinfo, SamplingContext(rng, context))
206209
end

test/compiler.jl

+9-2
Original file line numberDiff line numberDiff line change
@@ -702,10 +702,17 @@ module Issue537 end
702702
@test haskey(varinfo, @varname(x))
703703
@test !haskey(varinfo, @varname(y))
704704

705-
# While `values_as_in_model` should contain both `x` and `y`.
706-
values = values_as_in_model(model, deepcopy(varinfo))
705+
# While `values_as_in_model` should contain both `x` and `y`, if
706+
# include_colon_eq is set to `true`.
707+
values = values_as_in_model(model, true, deepcopy(varinfo))
707708
@test haskey(values, @varname(x))
708709
@test haskey(values, @varname(y))
710+
711+
# And if include_colon_eq is set to `false`, then `values` should
712+
# only contain `x`.
713+
values = values_as_in_model(model, false, deepcopy(varinfo))
714+
@test haskey(values, @varname(x))
715+
@test !haskey(values, @varname(y))
709716
end
710717
end
711718

test/model.jl

+77-43
Original file line numberDiff line numberDiff line change
@@ -383,7 +383,10 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal()
383383
example_values = DynamicPPL.TestUtils.rand_prior_true(model)
384384
varinfos = DynamicPPL.TestUtils.setup_varinfos(model, example_values, vns)
385385
@testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos
386-
realizations = values_as_in_model(model, varinfo)
386+
# We can set the include_colon_eq arg to false because none of
387+
# the demo models contain :=. The behaviour when
388+
# include_colon_eq is true is tested in test/compiler.jl
389+
realizations = values_as_in_model(model, false, varinfo)
387390
# Ensure that all variables are found.
388391
vns_found = collect(keys(realizations))
389392
@test vns vns_found == vns vns_found
@@ -393,6 +396,22 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal()
393396
end
394397
end
395398
end
399+
400+
@testset "check that sampling obeys rng if passed" begin
401+
@model function f()
402+
x ~ Normal(0)
403+
return y ~ Normal(x)
404+
end
405+
model = f()
406+
# Call values_as_in_model with the rng
407+
values = values_as_in_model(Random.Xoshiro(43), model, false)
408+
# Check that they match the values that would be used if vi was seeded
409+
# with that seed instead
410+
expected_vi = VarInfo(Random.Xoshiro(43), model)
411+
for vn in keys(values)
412+
@test values[vn] == expected_vi[vn]
413+
end
414+
end
396415
end
397416

398417
@testset "Erroneous model call" begin
@@ -432,72 +451,87 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal()
432451

433452
@testset "predict" begin
434453
@testset "with MCMCChains.Chains" begin
435-
DynamicPPL.Random.seed!(100)
436-
437454
@model function linear_reg(x, y, σ=0.1)
438455
β ~ Normal(0, 1)
439456
for i in eachindex(y)
440457
y[i] ~ Normal* x[i], σ)
441458
end
459+
# Insert a := block to test that it is not included in predictions
460+
return σ2 := σ^2
442461
end
443462

444-
@model function linear_reg_vec(x, y, σ=0.1)
445-
β ~ Normal(0, 1)
446-
return y ~ MvNormal.* x, σ^2 * I)
447-
end
448-
463+
# Construct a chain with 'sampled values' of β
449464
ground_truth_β = 2
450465
β_chain = MCMCChains.Chains(rand(Normal(ground_truth_β, 0.002), 1000), [])
451466

467+
# Generate predictions from that chain
452468
xs_test = [10 + 0.1, 10 + 2 * 0.1]
453469
m_lin_reg_test = linear_reg(xs_test, fill(missing, length(xs_test)))
454470
predictions = DynamicPPL.predict(m_lin_reg_test, β_chain)
455471

456-
ys_pred = vec(mean(Array(group(predictions, :y)); dims=1))
457-
@test ys_pred[1] ground_truth_β * xs_test[1] atol = 0.01
458-
@test ys_pred[2] ground_truth_β * xs_test[2] atol = 0.01
459-
460-
# Ensure that `rng` is respected
461-
rng = MersenneTwister(42)
462-
predictions1 = DynamicPPL.predict(rng, m_lin_reg_test, β_chain[1:2])
463-
predictions2 = DynamicPPL.predict(
464-
MersenneTwister(42), m_lin_reg_test, β_chain[1:2]
465-
)
466-
@test all(Array(predictions1) .== Array(predictions2))
467-
468-
# Predict on two last indices for vectorized
469-
m_lin_reg_test = linear_reg_vec(xs_test, missing)
470-
predictions_vec = DynamicPPL.predict(m_lin_reg_test, β_chain)
471-
ys_pred_vec = vec(mean(Array(group(predictions_vec, :y)); dims=1))
472-
473-
@test ys_pred_vec[1] ground_truth_β * xs_test[1] atol = 0.01
474-
@test ys_pred_vec[2] ground_truth_β * xs_test[2] atol = 0.01
472+
# Also test a vectorized model
473+
@model function linear_reg_vec(x, y, σ=0.1)
474+
β ~ Normal(0, 1)
475+
return y ~ MvNormal.* x, σ^2 * I)
476+
end
477+
m_lin_reg_test_vec = linear_reg_vec(xs_test, missing)
475478

476-
# Multiple chains
477-
multiple_β_chain = MCMCChains.Chains(
478-
reshape(rand(Normal(ground_truth_β, 0.002), 1000, 2), 1000, 1, 2), []
479-
)
480-
m_lin_reg_test = linear_reg(xs_test, fill(missing, length(xs_test)))
481-
predictions = DynamicPPL.predict(m_lin_reg_test, multiple_β_chain)
482-
@test size(multiple_β_chain, 3) == size(predictions, 3)
479+
@testset "variables in chain" begin
480+
# Note that this also checks that variables on the lhs of :=,
481+
# such as σ2, are not included in the resulting chain
482+
@test Set(keys(predictions)) == Set([Symbol("y[1]"), Symbol("y[2]")])
483+
end
483484

484-
for chain_idx in MCMCChains.chains(multiple_β_chain)
485-
ys_pred = vec(mean(Array(group(predictions[:, :, chain_idx], :y)); dims=1))
485+
@testset "accuracy" begin
486+
ys_pred = vec(mean(Array(group(predictions, :y)); dims=1))
486487
@test ys_pred[1] ground_truth_β * xs_test[1] atol = 0.01
487488
@test ys_pred[2] ground_truth_β * xs_test[2] atol = 0.01
488489
end
489490

490-
# Predict on two last indices for vectorized
491-
m_lin_reg_test = linear_reg_vec(xs_test, missing)
492-
predictions_vec = DynamicPPL.predict(m_lin_reg_test, multiple_β_chain)
493-
494-
for chain_idx in MCMCChains.chains(multiple_β_chain)
495-
ys_pred_vec = vec(
496-
mean(Array(group(predictions_vec[:, :, chain_idx], :y)); dims=1)
491+
@testset "ensure that rng is respected" begin
492+
rng = MersenneTwister(42)
493+
predictions1 = DynamicPPL.predict(rng, m_lin_reg_test, β_chain[1:2])
494+
predictions2 = DynamicPPL.predict(
495+
MersenneTwister(42), m_lin_reg_test, β_chain[1:2]
497496
)
497+
@test all(Array(predictions1) .== Array(predictions2))
498+
end
499+
500+
@testset "accuracy on vectorized model" begin
501+
predictions_vec = DynamicPPL.predict(m_lin_reg_test_vec, β_chain)
502+
ys_pred_vec = vec(mean(Array(group(predictions_vec, :y)); dims=1))
503+
498504
@test ys_pred_vec[1] ground_truth_β * xs_test[1] atol = 0.01
499505
@test ys_pred_vec[2] ground_truth_β * xs_test[2] atol = 0.01
500506
end
507+
508+
@testset "prediction from multiple chains" begin
509+
# Normal linreg model
510+
multiple_β_chain = MCMCChains.Chains(
511+
reshape(rand(Normal(ground_truth_β, 0.002), 1000, 2), 1000, 1, 2), []
512+
)
513+
predictions = DynamicPPL.predict(m_lin_reg_test, multiple_β_chain)
514+
@test size(multiple_β_chain, 3) == size(predictions, 3)
515+
516+
for chain_idx in MCMCChains.chains(multiple_β_chain)
517+
ys_pred = vec(
518+
mean(Array(group(predictions[:, :, chain_idx], :y)); dims=1)
519+
)
520+
@test ys_pred[1] ground_truth_β * xs_test[1] atol = 0.01
521+
@test ys_pred[2] ground_truth_β * xs_test[2] atol = 0.01
522+
end
523+
524+
# Vectorized linreg model
525+
predictions_vec = DynamicPPL.predict(m_lin_reg_test_vec, multiple_β_chain)
526+
527+
for chain_idx in MCMCChains.chains(multiple_β_chain)
528+
ys_pred_vec = vec(
529+
mean(Array(group(predictions_vec[:, :, chain_idx], :y)); dims=1)
530+
)
531+
@test ys_pred_vec[1] ground_truth_β * xs_test[1] atol = 0.01
532+
@test ys_pred_vec[2] ground_truth_β * xs_test[2] atol = 0.01
533+
end
534+
end
501535
end
502536

503537
@testset "with AbstractVector{<:AbstractVarInfo}" begin

0 commit comments

Comments
 (0)