Skip to content

Rethinking Threaded/Multichain Callbacks #2568

@joelkandiah

Description

@joelkandiah

TL;DR Callbacks when passed to a sample call work on every iteration but cannot in every case tell which chain they are on. This means that when chain id is required callbacks aren't suitable.

I had need of a Timing callback, which would tell me the current cpu time for each of the steps in the sample call. To do this I wrote a callback (similar to some inference tests) as:

# Define a custom callback
Base.@kwdef struct Timings{A}
    times::A = Vector{UInt64}()
end

function (callback::Timings)(
    rng,
    model,
    sampler,
    sample,
    state,
    iteration;
    kwargs...,
    )
    time = time_ns()  # Get the current time in nanoseconds
    if iteration == 1
        # Initialize the times vector on the first iteration
        empty!(callback.times)
    end
    
    # Store the time for this iteration
    push!(callback.times, time)
    return nothing
end

callback = Timings([])

However when writing sample(..., ..., MCMCSerial, ..., 2; callback =callback) only one vector is saved of the timings for the second chain. Ideally multiple callback objects could be passed (the same length as n_chains) or a chain id would be passed to the callback which would allow a matrix to be used to store the results. (Ofc for the latter in multithreaded scenarios pre-allocation would be required rather than a copy of the callback object, however at the moment neither can be done).

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions