Skip to content

Commit 50d493c

Browse files
New libtask interface
1 parent 8b4b558 commit 50d493c

File tree

2 files changed

+19
-25
lines changed

2 files changed

+19
-25
lines changed

ext/AdvancedPSLibtaskExt.jl

+17-23
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,7 @@ State wrapper to hold `Libtask.CTask` model initiated from `f`.
2424
function AdvancedPS.LibtaskModel(
2525
f::AdvancedPS.AbstractGenericModel, rng::Random.AbstractRNG, args...
2626
) # Changed the API, need to take care of the RNG properly
27-
return AdvancedPS.LibtaskModel(
28-
f,
29-
Libtask.TapedTask(
30-
f, rng, args...; deepcopy_types=Union{AdvancedPS.TracedRNG,typeof(f)}
31-
),
32-
)
27+
return AdvancedPS.LibtaskModel(f, Libtask.TapedTask(rng, f, args...))
3328
end
3429

3530
"""
@@ -51,12 +46,16 @@ end
5146

5247
# step to the next observe statement and
5348
# return the log probability of the transition (or nothing if done)
54-
function AdvancedPS.advance!(t::LibtaskTrace, isref::Bool=false)
55-
isref ? AdvancedPS.load_state!(t.rng) : AdvancedPS.save_state!(t.rng)
56-
AdvancedPS.inc_counter!(t.rng)
49+
function AdvancedPS.advance!(trace::LibtaskTrace, isref::Bool=false)
50+
# Where is the RNG ?
51+
# isref ? AdvancedPS.load_state!(trace.rng) : AdvancedPS.save_state!(trace.model.ctask.dynamic_scope) # Nasty
52+
isref ? AdvancedPS.load_state!(trace.rng) : AdvancedPS.save_state!(trace.rng)
53+
AdvancedPS.inc_counter!(trace.rng)
54+
55+
Libtask.set_dynamic_scope!(trace.model.ctask, trace.rng)
5756

5857
# Move to next step
59-
return Libtask.consume(t.model.ctask)
58+
return Libtask.consume(trace.model.ctask)
6059
end
6160

6261
# create a backward reference in task_local_storage
@@ -70,8 +69,9 @@ function AdvancedPS.addreference!(task::Task, trace::LibtaskTrace)
7069
end
7170

7271
function AdvancedPS.update_rng!(trace::LibtaskTrace)
73-
rng, = trace.model.ctask.args
74-
trace.rng = rng
72+
new_rng = deepcopy(trace.rng)
73+
trace.rng = new_rng
74+
Libtask.set_dynamic_scope!(trace.model.ctask, trace.rng)
7575
return trace
7676
end
7777

@@ -81,27 +81,23 @@ function AdvancedPS.fork(trace::LibtaskTrace, isref::Bool=false)
8181
AdvancedPS.update_rng!(newtrace)
8282
isref && AdvancedPS.delete_retained!(newtrace.model.f)
8383
isref && delete_seeds!(newtrace)
84-
85-
# add backward reference
86-
AdvancedPS.addreference!(newtrace.model.ctask.task, newtrace)
8784
return newtrace
8885
end
8986

9087
# PG requires keeping all randomness for the reference particle
9188
# Create new task and copy randomness
9289
function AdvancedPS.forkr(trace::LibtaskTrace)
93-
newf = AdvancedPS.reset_model(trace.model.f)
90+
newf = AdvancedPS.reset_model(trace.model.ctask.fargs[1])
9491
Random123.set_counter!(trace.rng, 1)
9592

96-
ctask = Libtask.TapedTask(
97-
newf, trace.rng; deepcopy_types=Union{AdvancedPS.TracedRNG,typeof(trace.model.f)}
98-
)
93+
ctask = Libtask.TapedTask(trace.rng, newf)
9994
new_tapedmodel = AdvancedPS.LibtaskModel(newf, ctask)
10095

10196
# add backward reference
10297
newtrace = AdvancedPS.Trace(new_tapedmodel, trace.rng)
103-
AdvancedPS.addreference!(ctask.task, newtrace)
10498
AdvancedPS.gen_refseed!(newtrace)
99+
100+
Libtask.set_dynamic_scope!(ctask, trace.rng) # Sync trace and rng
105101
return newtrace
106102
end
107103

@@ -117,7 +113,7 @@ function AdvancedPS.observe(dist::Distributions.Distribution, x)
117113
end
118114

119115
"""
120-
AbstractMCMC interface. We need libtask to sample from arbitrary callable AbstractModel
116+
AbstractMCMC interface. We need libtask to sample from arbitrary callable AbstractModelext
121117
"""
122118

123119
function AbstractMCMC.step(
@@ -138,7 +134,6 @@ function AbstractMCMC.step(
138134
else
139135
trng = AdvancedPS.TracedRNG()
140136
trace = AdvancedPS.Trace(deepcopy(model), trng)
141-
AdvancedPS.addreference!(trace.model.ctask.task, trace) # TODO: Do we need it here ?
142137
trace
143138
end
144139
end
@@ -176,7 +171,6 @@ function AbstractMCMC.sample(
176171
traces = map(1:(sampler.nparticles)) do i
177172
trng = AdvancedPS.TracedRNG()
178173
trace = AdvancedPS.Trace(deepcopy(model), trng)
179-
AdvancedPS.addreference!(trace.model.ctask.task, trace) # Do we need it here ?
180174
trace
181175
end
182176

src/container.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -206,12 +206,12 @@ function resample_propagate!(
206206

207207
Random.seed!(p.rng, seeds[1])
208208

209-
children[j += 1] = p
209+
children[j+=1] = p
210210
# fork additional children
211211
for k in 2:ni
212212
part = fork(p, isref)
213213
Random.seed!(part.rng, seeds[k])
214-
children[j += 1] = part
214+
children[j+=1] = part
215215
end
216216
end
217217
end

0 commit comments

Comments
 (0)