-
Notifications
You must be signed in to change notification settings - Fork 228
Description
TL;DR Callbacks when passed to a sample call work on every iteration but cannot in every case tell which chain they are on. This means that when chain id is required callbacks aren't suitable.
I had need of a Timing callback, which would tell me the current cpu time for each of the steps in the sample call. To do this I wrote a callback (similar to some inference tests) as:
# Define a custom callback
Base.@kwdef struct Timings{A}
times::A = Vector{UInt64}()
end
function (callback::Timings)(
rng,
model,
sampler,
sample,
state,
iteration;
kwargs...,
)
time = time_ns() # Get the current time in nanoseconds
if iteration == 1
# Initialize the times vector on the first iteration
empty!(callback.times)
end
# Store the time for this iteration
push!(callback.times, time)
return nothing
end
callback = Timings([])
However when writing sample(..., ..., MCMCSerial, ..., 2; callback =callback)
only one vector is saved of the timings for the second chain. Ideally multiple callback objects could be passed (the same length as n_chains) or a chain id would be passed to the callback which would allow a matrix to be used to store the results. (Ofc for the latter in multithreaded scenarios pre-allocation would be required rather than a copy of the callback object, however at the moment neither can be done).