@@ -383,7 +383,10 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal()
383
383
example_values = DynamicPPL. TestUtils. rand_prior_true (model)
384
384
varinfos = DynamicPPL. TestUtils. setup_varinfos (model, example_values, vns)
385
385
@testset " $(short_varinfo_name (varinfo)) " for varinfo in varinfos
386
- realizations = values_as_in_model (model, varinfo)
386
+ # We can set the include_colon_eq arg to false because none of
387
+ # the demo models contain :=. The behaviour when
388
+ # include_colon_eq is true is tested in test/compiler.jl
389
+ realizations = values_as_in_model (model, false , varinfo)
387
390
# Ensure that all variables are found.
388
391
vns_found = collect (keys (realizations))
389
392
@test vns ∩ vns_found == vns ∪ vns_found
@@ -393,6 +396,22 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal()
393
396
end
394
397
end
395
398
end
399
+
400
+ @testset " check that sampling obeys rng if passed" begin
401
+ @model function f ()
402
+ x ~ Normal (0 )
403
+ return y ~ Normal (x)
404
+ end
405
+ model = f ()
406
+ # Call values_as_in_model with the rng
407
+ values = values_as_in_model (Random. Xoshiro (43 ), model, false )
408
+ # Check that they match the values that would be used if vi was seeded
409
+ # with that seed instead
410
+ expected_vi = VarInfo (Random. Xoshiro (43 ), model)
411
+ for vn in keys (values)
412
+ @test values[vn] == expected_vi[vn]
413
+ end
414
+ end
396
415
end
397
416
398
417
@testset " Erroneous model call" begin
@@ -432,72 +451,87 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal()
432
451
433
452
@testset " predict" begin
434
453
@testset " with MCMCChains.Chains" begin
435
- DynamicPPL. Random. seed! (100 )
436
-
437
454
@model function linear_reg (x, y, σ= 0.1 )
438
455
β ~ Normal (0 , 1 )
439
456
for i in eachindex (y)
440
457
y[i] ~ Normal (β * x[i], σ)
441
458
end
459
+ # Insert a := block to test that it is not included in predictions
460
+ return σ2 := σ^ 2
442
461
end
443
462
444
- @model function linear_reg_vec (x, y, σ= 0.1 )
445
- β ~ Normal (0 , 1 )
446
- return y ~ MvNormal (β .* x, σ^ 2 * I)
447
- end
448
-
463
+ # Construct a chain with 'sampled values' of β
449
464
ground_truth_β = 2
450
465
β_chain = MCMCChains. Chains (rand (Normal (ground_truth_β, 0.002 ), 1000 ), [:β ])
451
466
467
+ # Generate predictions from that chain
452
468
xs_test = [10 + 0.1 , 10 + 2 * 0.1 ]
453
469
m_lin_reg_test = linear_reg (xs_test, fill (missing , length (xs_test)))
454
470
predictions = DynamicPPL. predict (m_lin_reg_test, β_chain)
455
471
456
- ys_pred = vec (mean (Array (group (predictions, :y )); dims= 1 ))
457
- @test ys_pred[1 ] ≈ ground_truth_β * xs_test[1 ] atol = 0.01
458
- @test ys_pred[2 ] ≈ ground_truth_β * xs_test[2 ] atol = 0.01
459
-
460
- # Ensure that `rng` is respected
461
- rng = MersenneTwister (42 )
462
- predictions1 = DynamicPPL. predict (rng, m_lin_reg_test, β_chain[1 : 2 ])
463
- predictions2 = DynamicPPL. predict (
464
- MersenneTwister (42 ), m_lin_reg_test, β_chain[1 : 2 ]
465
- )
466
- @test all (Array (predictions1) .== Array (predictions2))
467
-
468
- # Predict on two last indices for vectorized
469
- m_lin_reg_test = linear_reg_vec (xs_test, missing )
470
- predictions_vec = DynamicPPL. predict (m_lin_reg_test, β_chain)
471
- ys_pred_vec = vec (mean (Array (group (predictions_vec, :y )); dims= 1 ))
472
-
473
- @test ys_pred_vec[1 ] ≈ ground_truth_β * xs_test[1 ] atol = 0.01
474
- @test ys_pred_vec[2 ] ≈ ground_truth_β * xs_test[2 ] atol = 0.01
472
+ # Also test a vectorized model
473
+ @model function linear_reg_vec (x, y, σ= 0.1 )
474
+ β ~ Normal (0 , 1 )
475
+ return y ~ MvNormal (β .* x, σ^ 2 * I)
476
+ end
477
+ m_lin_reg_test_vec = linear_reg_vec (xs_test, missing )
475
478
476
- # Multiple chains
477
- multiple_β_chain = MCMCChains. Chains (
478
- reshape (rand (Normal (ground_truth_β, 0.002 ), 1000 , 2 ), 1000 , 1 , 2 ), [:β ]
479
- )
480
- m_lin_reg_test = linear_reg (xs_test, fill (missing , length (xs_test)))
481
- predictions = DynamicPPL. predict (m_lin_reg_test, multiple_β_chain)
482
- @test size (multiple_β_chain, 3 ) == size (predictions, 3 )
479
+ @testset " variables in chain" begin
480
+ # Note that this also checks that variables on the lhs of :=,
481
+ # such as σ2, are not included in the resulting chain
482
+ @test Set (keys (predictions)) == Set ([Symbol (" y[1]" ), Symbol (" y[2]" )])
483
+ end
483
484
484
- for chain_idx in MCMCChains . chains (multiple_β_chain)
485
- ys_pred = vec (mean (Array (group (predictions[:, :, chain_idx] , :y )); dims= 1 ))
485
+ @testset " accuracy " begin
486
+ ys_pred = vec (mean (Array (group (predictions, :y )); dims= 1 ))
486
487
@test ys_pred[1 ] ≈ ground_truth_β * xs_test[1 ] atol = 0.01
487
488
@test ys_pred[2 ] ≈ ground_truth_β * xs_test[2 ] atol = 0.01
488
489
end
489
490
490
- # Predict on two last indices for vectorized
491
- m_lin_reg_test = linear_reg_vec (xs_test, missing )
492
- predictions_vec = DynamicPPL. predict (m_lin_reg_test, multiple_β_chain)
493
-
494
- for chain_idx in MCMCChains. chains (multiple_β_chain)
495
- ys_pred_vec = vec (
496
- mean (Array (group (predictions_vec[:, :, chain_idx], :y )); dims= 1 )
491
+ @testset " ensure that rng is respected" begin
492
+ rng = MersenneTwister (42 )
493
+ predictions1 = DynamicPPL. predict (rng, m_lin_reg_test, β_chain[1 : 2 ])
494
+ predictions2 = DynamicPPL. predict (
495
+ MersenneTwister (42 ), m_lin_reg_test, β_chain[1 : 2 ]
497
496
)
497
+ @test all (Array (predictions1) .== Array (predictions2))
498
+ end
499
+
500
+ @testset " accuracy on vectorized model" begin
501
+ predictions_vec = DynamicPPL. predict (m_lin_reg_test_vec, β_chain)
502
+ ys_pred_vec = vec (mean (Array (group (predictions_vec, :y )); dims= 1 ))
503
+
498
504
@test ys_pred_vec[1 ] ≈ ground_truth_β * xs_test[1 ] atol = 0.01
499
505
@test ys_pred_vec[2 ] ≈ ground_truth_β * xs_test[2 ] atol = 0.01
500
506
end
507
+
508
+ @testset " prediction from multiple chains" begin
509
+ # Normal linreg model
510
+ multiple_β_chain = MCMCChains. Chains (
511
+ reshape (rand (Normal (ground_truth_β, 0.002 ), 1000 , 2 ), 1000 , 1 , 2 ), [:β ]
512
+ )
513
+ predictions = DynamicPPL. predict (m_lin_reg_test, multiple_β_chain)
514
+ @test size (multiple_β_chain, 3 ) == size (predictions, 3 )
515
+
516
+ for chain_idx in MCMCChains. chains (multiple_β_chain)
517
+ ys_pred = vec (
518
+ mean (Array (group (predictions[:, :, chain_idx], :y )); dims= 1 )
519
+ )
520
+ @test ys_pred[1 ] ≈ ground_truth_β * xs_test[1 ] atol = 0.01
521
+ @test ys_pred[2 ] ≈ ground_truth_β * xs_test[2 ] atol = 0.01
522
+ end
523
+
524
+ # Vectorized linreg model
525
+ predictions_vec = DynamicPPL. predict (m_lin_reg_test_vec, multiple_β_chain)
526
+
527
+ for chain_idx in MCMCChains. chains (multiple_β_chain)
528
+ ys_pred_vec = vec (
529
+ mean (Array (group (predictions_vec[:, :, chain_idx], :y )); dims= 1 )
530
+ )
531
+ @test ys_pred_vec[1 ] ≈ ground_truth_β * xs_test[1 ] atol = 0.01
532
+ @test ys_pred_vec[2 ] ≈ ground_truth_β * xs_test[2 ] atol = 0.01
533
+ end
534
+ end
501
535
end
502
536
503
537
@testset " with AbstractVector{<:AbstractVarInfo}" begin
0 commit comments