AbstractMCMC.jl icon indicating copy to clipboard operation
AbstractMCMC.jl copied to clipboard

Add some interface functions to support the new Gibbs sampler in Turing

Open sunxd3 opened this issue 1 year ago β€’ 20 comments

The recent new Gibbs sampler provides a way forward for the Turing inference stack.

A near-to-medium-range goal has been to further reduce the glue code between Turing and inference packages (ref https://github.com/TuringLang/Turing.jl/issues/2281). The new Gibbs implementation laid a great plan to achieve this goal.

This PR is modeled after the interface of @torfjelde's recent PR. And in some aspects, it is a rehash of https://github.com/TuringLang/AbstractMCMC.jl/pull/86.

(the explanation here is outdated, please refer to https://github.com/TuringLang/AbstractMCMC.jl/pull/144#issuecomment-2337681868)

~~The goal of this PR is to determine and implement some necessary interface improvements, so that, when we update the inference packages up to the interface, they will more or less "just work" with the new Gibbs implementation.~~

~~As a first step, we test-flight two new functions recompute_logprob!!(rng, model, sampler, state) and getparams(state):~~

  • ~~recompute_logprob!!(rng, model, sampler, state) recomputes the logprob given the state~~
  • ~~getparams(state) extract the parameter values~~

~~Some considerations:~~

  • ~~This assumes a state is implemented with AbstractMCMC compatible inference packages. And a state at least stores values of parameters from the current iteration (traditionally, this is in the form of a Transition) and logprob.~~
  • ~~recompute_logprob!!(rng, model, sampler, state)~~
    • ~~do we need rng?~~
    • ~~should we make model into AbstractMCMC.LogDensityModel or just LogDensityProblem (and make inference packages depend on LogDensityProblems in the latter case)? This should allow inference packages to be independent from DynamicPPL, we can use getparams to construct a varinfo in Turing~~
  • ~~getparams(state) ~~
    • ~~What does this function return? A vector, a transition?~~
    • ~~Do we need setparams?~~
  • ~~Do we also need some interface functions for state like getstats?~~

~~Tor also says (in a Slack conversation) that the a condition(model, params) is needed, but better to be implemented by packages that defines the model, which I agree.~~

sunxd3 avatar Jul 12 '24 15:07 sunxd3

@yebai @devmotion @cpfiffer

sunxd3 avatar Jul 12 '24 15:07 sunxd3

How is #86 related to this PR?

devmotion avatar Jul 14 '24 21:07 devmotion

Hmm, it's unclear to me whether it's worth adding these methods when they have "no use" unless some notion of conditioning is also added πŸ˜•

How is https://github.com/TuringLang/AbstractMCMC.jl/pull/86 related to this PR?

getparams is probably overlapping between the two PRs, but the recompute_logprob!! method is not

torfjelde avatar Jul 15 '24 07:07 torfjelde

I am for adding a condition interface, should we upstream this from AbstractPPL?

sunxd3 avatar Jul 16 '24 07:07 sunxd3

I think AbstractPPL imports AbstractMCMC, so it is also a good idea to define condition here and then reexport from AbstractPPL.

yebai avatar Jul 16 '24 14:07 yebai

Okay, now condition and decondition are moved to AbstractMCMC from AbstractPPL.

Do we want fix here?

sunxd3 avatar Jul 18 '24 10:07 sunxd3

@devmotion @yebai @torfjelde @mhauru a penny for your thoughts?

sunxd3 avatar Jul 19 '24 10:07 sunxd3

Do we want fix here?

I'd keep it in DynamicPPL / AbstractPPL unless there is a reason to move here.

yebai avatar Jul 19 '24 11:07 yebai

I'm still a bit uncertain about all of this tbh. I feel like right now we're just shoving condition and decondition (which I don't think we need for Gibbs?) into AbstractMCMC.jl to motivate the inclusion of recompute_logprob!! without much thought about whether it makes sense or not πŸ˜…

I think if this is the case, then I'm preferential to ignoring my original comment of "needing condition to motivate recompute_logprob!!", i.e. just leave it as you did originally (without condition and decondition).

torfjelde avatar Jul 19 '24 19:07 torfjelde

I removed condition (and decondition) and use the public keyword for the new interface functions. The latter will technically change the interface, so I bumped the minor version.

I also think we should add something like AbstractState to normalize the design of state. This will introduce types for state everywhere, I am unsure of the impact. What's your thoughts on this?

sunxd3 avatar Jul 22 '24 07:07 sunxd3

I also think we should add something like AbstractState to normalize the design of state. This will introduce types for state everywhere, I am unsure of the impact. What's your thoughts on this?

Not for this PR at least:) If we want to discuss this, then we should open an issue and move discussion there.

torfjelde avatar Jul 22 '24 08:07 torfjelde

The latter will technically change the interface, so I bumped the minor version.

It seems you've bumped the major version, not the minor version?

Also, if we're making this part of the interface, we should probably document this?

torfjelde avatar Jul 22 '24 08:07 torfjelde

Oops, you're right.

we should probably document this?

By using the public keyword, maybe we can say "this is not official yet"? ~~I am a little hesitate to add official documentation right now, because we don't yet have a crystal clear idea of what the interface should behave.~~ Will add docs.

sunxd3 avatar Jul 22 '24 09:07 sunxd3

Some high-level comments:

  • Let's introduce a setparams function to complete the getparams function.
  • Let's introduce some tests to test the interface and get a more grounded view of the design.
  • Think of an alternative name to recompute_logprob!!, which is a bit unintuitive in terms of what it means.

@sunxd3 please also take a careful look at

  • https://github.com/TuringLang/AbstractMCMC.jl/pull/117
  • https://github.com/TuringLang/AbstractMCMC.jl/pull/86
  • and the DynamicHMC sampling interface design, specifically the warmup_stages and reporter arguments,

we want to push for merging these PRs and incorporate some nice ideas elsewhere in the ecosystem.

yebai avatar Jul 23 '24 19:07 yebai

current concerns:

  • We need recompute_logprob for Gibbs sampling, the reason is: the values of the parameters on which a parameter block is conditioned will change in the general case between two Gibbs steps, so recompute the logp is generally needed. But storing the recomputed logprob into a state is not necessary, thus set_logprob!! is not necessary for Gibbs. Although it will aid other "meta samplers".
  • Adopting recompute_logprob!! as an AbstractMCMC is an issue for me for two not fundamental reasons
    • the function is for Gibbs only and the re part is overly specific in my eye, this is very philosophical and concerns personal taste
    • recompute_logprob!! need to take an model argument, do we need to make clear what this model is?
      • options are AbstractMCMC.LogDensityModel or AbstractPPL.AbstractModel, but neither is perfect

What I am proposing in the current state of this PR is:

  • AbstractMCMC defines interfaces of state
  • Gibbs define an interface recompute_logprob!!, which returns a state
  • model definition packages (e.g. DynamicPPL) implements recompute_logprob!! using the interface functions of state, defined in AbstractMCMC

sunxd3 avatar Aug 23 '24 14:08 sunxd3

Some new updates, based on an offline discussion with @yebai

(edited Sep 17th)

  • gibbs.jl is moved into /src
    • it no longer uses OrderedDict, this is to keep the dependency light (OrderedDict requires OrderedCollections).
    • now the sampler_map argument only accept a NamedTuple, it is less flexible than Turing's Gibbs sampler, but should be fine for a first step
    • the vi field is renamed to trace, it only holds the values, not logp
  • the four proposed interface functions are all removed, and each replaced with
    • get_logp and set_logp!! are combined into a single function called logdensity_and_state(logdensity_f, state; recompute_logp); if recompute_logp is true, the logdensity should be reevaluated, otherwise return the logdensity in state
    • get_params --> Base.vec https://github.com/TuringLang/AbstractMCMC.jl/blob/1382054b3be11cf28f231b636ff1c34be187d881/src/gibbs.jl#L106-L111
    • set_params!! is removed
  • because gibbs.jl is in /src, a condition function is added to AbstractMCMC, we should find a way to combine with AsbtractPPL, also do we want to define condition(logdensityproblem, ..)?
  • a sampler package now need to implement
    • the logdensity_and_state(logdensityfunction, state) interface
    • Base.vec returns a flattened representation of the parameters
      • Another option is to use Iterator interface.

sunxd3 avatar Sep 09 '24 10:09 sunxd3

and then we can add explicit impls of higher-order samplers later once we've tested out the interface + arrived at really good implementations of these, e.g. Gibbs.

I'd encourage @torfjelde and @sunxd3 to have more discussions to converge on a sensible Gibbs design; we will never get a good implementation if we don't start working on it.

yebai avatar Sep 18 '24 12:09 yebai

I'd encourage @torfjelde and @sunxd3 to have more discussions to converge on a sensible Gibbs design; we will never get a good implementation if we don't start working on it.

Definitively! But I think it's also partially a thing that we need to try out a bit, and we're not going to get it right on the first try, hence my reluctance to putting an impl in AbstractMCMC.jl as soon as we enable the functionality.

torfjelde avatar Sep 18 '24 12:09 torfjelde

@torfjelde thanks for the very nice review. The code and markdown doc are out-of-sync, a major difference is the one in the markdown doc uses OrderedDict, but the one in current src/gibbs.jl only uses NamedTuple (this is to not introduce dep on OrderedCollection. (The latter, then restricts that one gibbs block can only have one variable name.) Sorry for the confusion.

But if we don't end up introducing gibbs.jl into AbstractMCMC, then we should probably enable using OrderedDict as well as NamedTuple.

I'll update the code and documentation once we decide where do we want to keep the gibbs code.

sunxd3 avatar Sep 18 '24 14:09 sunxd3

Codecov Report

All modified and coverable lines are covered by tests :white_check_mark:

Project coverage is 97.54%. Comparing base (2a77f53) to head (3ed5cb3). Report is 5 commits behind head on master.

Additional details and impacted files
@@            Coverage Diff             @@
##           master     #144      +/-   ##
==========================================
+ Coverage   97.19%   97.54%   +0.34%     
==========================================
  Files           8        8              
  Lines         321      326       +5     
==========================================
+ Hits          312      318       +6     
+ Misses          9        8       -1     
Flag Coverage Ξ”
97.54% <ΓΈ> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.

codecov[bot] avatar Sep 22 '24 20:09 codecov[bot]

@yebai @torfjelde this has been updated, many issues in the previous version has been corrected (thanks for the discussion and code review). I also added another notes in the folder design_notes, comments are welcomed. Can you give it another read?

sunxd3 avatar Oct 01 '24 20:10 sunxd3

In its current form, no interface change is made to AbstractMCMC, all the interface functions are from other packages.

sunxd3 avatar Oct 01 '24 20:10 sunxd3

the test error seems to be Julia 1.6-only related

sunxd3 avatar Oct 03 '24 12:10 sunxd3

related reply from @devmotion https://github.com/TuringLang/Turing.jl/pull/2304#issuecomment-2291097709

sunxd3 avatar Oct 03 '24 15:10 sunxd3

I do think the entire process of this would be quite a bit less painful if we did the following (I believe I've mentioned this before; if not, I apologize):

  1. Improve #86 to a finalized form . This is useful, not just for Gibbs sampling.
  2. Make a separate package, e.g. AbstractMCMCGibbs.jl, which implements the Gibbs-only stuff, e.g. recompute_logprob!! and the sampler mapping stuff.

This is how we're doing it with MCMCTempering.jl, i.e. keep it as a separate package and slowly move pieces to AbstractMCMC.jl if it seems suitable. My problem, as stated before, is that the current Gibbs impls we're working with are really not good enough as I think is evident by a) issues that we've encountered with my Turing.jl-impl in https://github.com/TuringLang/Turing.jl/pull/2328#issuecomment-2378971734, and b) the amount of iterating you've done in this PR. This shit is complicated :grimacing: And I imagine it's really annoying iterating on this back and forth but without actually getting stuff merged..

So, I think a separate package would just make this entire process much easier @sunxd3 ; then we can iterate much faster on ideas (just make breaking releases), and then we can just upstream changes as we finalize things there + we can even inform about this in the official AbstractMCMC.jl docs and then people can easily support this via extensions.

torfjelde avatar Oct 04 '24 10:10 torfjelde

https://github.com/TuringLang/AbstractMCMC.jl/issues/85#issuecomment-2061300622

torfjelde avatar Oct 04 '24 15:10 torfjelde