@@ -24,12 +24,7 @@ State wrapper to hold `Libtask.CTask` model initiated from `f`.
24
24
function AdvancedPS. LibtaskModel (
25
25
f:: AdvancedPS.AbstractGenericModel , rng:: Random.AbstractRNG , args...
26
26
) # Changed the API, need to take care of the RNG properly
27
- return AdvancedPS. LibtaskModel (
28
- f,
29
- Libtask. TapedTask (
30
- f, rng, args... ; deepcopy_types= Union{AdvancedPS. TracedRNG,typeof (f)}
31
- ),
32
- )
27
+ return AdvancedPS. LibtaskModel (f, Libtask. TapedTask (rng, f, args... ))
33
28
end
34
29
35
30
"""
51
46
52
47
# step to the next observe statement and
53
48
# return the log probability of the transition (or nothing if done)
54
- function AdvancedPS. advance! (t:: LibtaskTrace , isref:: Bool = false )
55
- isref ? AdvancedPS. load_state! (t. rng) : AdvancedPS. save_state! (t. rng)
56
- AdvancedPS. inc_counter! (t. rng)
49
+ function AdvancedPS. advance! (trace:: LibtaskTrace , isref:: Bool = false )
50
+ # Where is the RNG ?
51
+ # isref ? AdvancedPS.load_state!(trace.rng) : AdvancedPS.save_state!(trace.model.ctask.dynamic_scope) # Nasty
52
+ isref ? AdvancedPS. load_state! (trace. rng) : AdvancedPS. save_state! (trace. rng)
53
+ AdvancedPS. inc_counter! (trace. rng)
54
+
55
+ Libtask. set_dynamic_scope! (trace. model. ctask, trace. rng)
57
56
58
57
# Move to next step
59
- return Libtask. consume (t . model. ctask)
58
+ return Libtask. consume (trace . model. ctask)
60
59
end
61
60
62
61
# create a backward reference in task_local_storage
@@ -70,8 +69,9 @@ function AdvancedPS.addreference!(task::Task, trace::LibtaskTrace)
70
69
end
71
70
72
71
function AdvancedPS. update_rng! (trace:: LibtaskTrace )
73
- rng, = trace. model. ctask. args
74
- trace. rng = rng
72
+ new_rng = deepcopy (trace. rng)
73
+ trace. rng = new_rng
74
+ Libtask. set_dynamic_scope! (trace. model. ctask, trace. rng)
75
75
return trace
76
76
end
77
77
@@ -81,27 +81,23 @@ function AdvancedPS.fork(trace::LibtaskTrace, isref::Bool=false)
81
81
AdvancedPS. update_rng! (newtrace)
82
82
isref && AdvancedPS. delete_retained! (newtrace. model. f)
83
83
isref && delete_seeds! (newtrace)
84
-
85
- # add backward reference
86
- AdvancedPS. addreference! (newtrace. model. ctask. task, newtrace)
87
84
return newtrace
88
85
end
89
86
90
87
# PG requires keeping all randomness for the reference particle
91
88
# Create new task and copy randomness
92
89
function AdvancedPS. forkr (trace:: LibtaskTrace )
93
- newf = AdvancedPS. reset_model (trace. model. f )
90
+ newf = AdvancedPS. reset_model (trace. model. ctask . fargs[ 1 ] )
94
91
Random123. set_counter! (trace. rng, 1 )
95
92
96
- ctask = Libtask. TapedTask (
97
- newf, trace. rng; deepcopy_types= Union{AdvancedPS. TracedRNG,typeof (trace. model. f)}
98
- )
93
+ ctask = Libtask. TapedTask (trace. rng, newf)
99
94
new_tapedmodel = AdvancedPS. LibtaskModel (newf, ctask)
100
95
101
96
# add backward reference
102
97
newtrace = AdvancedPS. Trace (new_tapedmodel, trace. rng)
103
- AdvancedPS. addreference! (ctask. task, newtrace)
104
98
AdvancedPS. gen_refseed! (newtrace)
99
+
100
+ Libtask. set_dynamic_scope! (ctask, trace. rng) # Sync trace and rng
105
101
return newtrace
106
102
end
107
103
@@ -117,7 +113,7 @@ function AdvancedPS.observe(dist::Distributions.Distribution, x)
117
113
end
118
114
119
115
"""
120
- AbstractMCMC interface. We need libtask to sample from arbitrary callable AbstractModel
116
+ AbstractMCMC interface. We need libtask to sample from arbitrary callable AbstractModelext
121
117
"""
122
118
123
119
function AbstractMCMC. step (
@@ -138,7 +134,6 @@ function AbstractMCMC.step(
138
134
else
139
135
trng = AdvancedPS. TracedRNG ()
140
136
trace = AdvancedPS. Trace (deepcopy (model), trng)
141
- AdvancedPS. addreference! (trace. model. ctask. task, trace) # TODO : Do we need it here ?
142
137
trace
143
138
end
144
139
end
@@ -176,7 +171,6 @@ function AbstractMCMC.sample(
176
171
traces = map (1 : (sampler. nparticles)) do i
177
172
trng = AdvancedPS. TracedRNG ()
178
173
trace = AdvancedPS. Trace (deepcopy (model), trng)
179
- AdvancedPS. addreference! (trace. model. ctask. task, trace) # Do we need it here ?
180
174
trace
181
175
end
182
176
0 commit comments