Turing.jl icon indicating copy to clipboard operation
Turing.jl copied to clipboard

Progress reporting in parallel sampling

Open SamuelBrand1 opened this issue 1 year ago • 4 comments

Hi everyone,

One thing I've noticed is that progress reporting when doing chains in parallel (for example using MCMCThreads()) is not informative, the progress meter only updates when a chain is finished rather than reporting within chain progress (as per serial sampling).

Is there any movement towards chain-by-chain progress reporting as per stan?

SamuelBrand1 avatar Jun 12 '24 09:06 SamuelBrand1

This is an issue that has come up fairly often but AFAIK no perfect solution exists.

Ref: https://github.com/TuringLang/AbstractMCMC.jl/issues/82 https://github.com/TuringLang/AbstractMCMC.jl/issues/105

There's a discourse thread where someone seems to have come up with a "solution" (https://discourse.julialang.org/t/displaying-parallel-progress-bars/4148/8), but that's ages ago and not sure if that solution still works.

Note that you can provide an arbitrary callback to sample which is executed after every step where you could so custom progress-keeping, but atm there's no good built-in solution unfortunately :confused:

torfjelde avatar Jun 12 '24 11:06 torfjelde

Thanks for flagging this up @torfjelde ! I guess this will keep circling around :-(.

SamuelBrand1 avatar Jun 12 '24 11:06 SamuelBrand1

Might be possible to do something with this: https://github.com/timholy/ProgressMeter.jl/pull/157

In fact, if I use that branch + some minor changes to AbstractMCMC.jl, the following

using ProgressMeter
using Turing

struct ProgressCallback{P}
    p::P
    index::Int
end

function (callback::ProgressCallback)(rng, model, sampler, sample, state, iteration; kwargs...)
    # Can do more stuff here if you want.
    next!(callback.p[callback.index])
end

@model demo() = x ~ Normal()
model = demo()

num_samples = 100_000
num_chains = 10
p = MultipleProgress(
    [Progress(num_samples; desc="Chain $i ") for i in 1:num_chains],
    Progress(num_samples * num_chains; desc="Total ")
)
callbacks = map(1:num_chains) do i
    ProgressCallback(p, i)
end
chain = sample(
    model,
    HMC(0.1, 32),
    MCMCThreads(),
    num_samples,
    num_chains,
    callback=callbacks,
    progress=false,
    thinning=10
)

results in

image

It's small so not sure if you can see it, but it creates one bar for each thread + a global progress bar.

(note that this relies on minor changes to abstractmcmc + that experimental branch of progressmeter, which only supports the REPL, not, say, IJulia)

Miiight be worth adopting this in TuringCallbacks.jl as a bridge until there's good solution.

torfjelde avatar Jun 12 '24 11:06 torfjelde

Maybe something like this would be nice to support?

nsiccha avatar May 08 '25 06:05 nsiccha