Add some interface functions to support the new Gibbs sampler in Turing
The recent new Gibbs sampler provides a way forward for the Turing inference stack.
A near-to-medium-range goal has been to further reduce the glue code between Turing and inference packages (ref https://github.com/TuringLang/Turing.jl/issues/2281). The new Gibbs implementation laid a great plan to achieve this goal.
This PR is modeled after the interface of @torfjelde's recent PR. And in some aspects, it is a rehash of https://github.com/TuringLang/AbstractMCMC.jl/pull/86.
(the explanation here is outdated, please refer to https://github.com/TuringLang/AbstractMCMC.jl/pull/144#issuecomment-2337681868)
~~The goal of this PR is to determine and implement some necessary interface improvements, so that, when we update the inference packages up to the interface, they will more or less "just work" with the new Gibbs implementation.~~
~~As a first step, we test-flight two new functions recompute_logprob!!(rng, model, sampler, state) and getparams(state):~~
- ~~
recompute_logprob!!(rng, model, sampler, state)recomputes the logprob given thestate~~ - ~~
getparams(state)extract the parameter values~~
~~Some considerations:~~
- ~~This assumes a
stateis implemented withAbstractMCMCcompatible inference packages. And astateat least stores values of parameters from the current iteration (traditionally, this is in the form of aTransition) and logprob.~~ - ~~
recompute_logprob!!(rng, model, sampler, state)~~- ~~do we need
rng?~~ - ~~should we make
modelintoAbstractMCMC.LogDensityModelor justLogDensityProblem(and make inference packages depend onLogDensityProblemsin the latter case)? This should allow inference packages to be independent from DynamicPPL, we can usegetparamsto construct avarinfoin Turing~~
- ~~do we need
- ~~
getparams(state)~~- ~~What does this function return? A vector, a
transition?~~ - ~~Do we need
setparams?~~
- ~~What does this function return? A vector, a
- ~~Do we also need some interface functions for
statelikegetstats?~~
~~Tor also says (in a Slack conversation) that the a condition(model, params) is needed, but better to be implemented by packages that defines the model, which I agree.~~
@yebai @devmotion @cpfiffer
How is #86 related to this PR?
Hmm, it's unclear to me whether it's worth adding these methods when they have "no use" unless some notion of conditioning is also added π
How is https://github.com/TuringLang/AbstractMCMC.jl/pull/86 related to this PR?
getparams is probably overlapping between the two PRs, but the recompute_logprob!! method is not
I am for adding a condition interface, should we upstream this from AbstractPPL?
I think AbstractPPL imports AbstractMCMC, so it is also a good idea to define condition here and then reexport from AbstractPPL.
Okay, now condition and decondition are moved to AbstractMCMC from AbstractPPL.
Do we want fix here?
@devmotion @yebai @torfjelde @mhauru a penny for your thoughts?
I'm still a bit uncertain about all of this tbh. I feel like right now we're just shoving condition and decondition (which I don't think we need for Gibbs?) into AbstractMCMC.jl to motivate the inclusion of recompute_logprob!! without much thought about whether it makes sense or not π
I think if this is the case, then I'm preferential to ignoring my original comment of "needing condition to motivate recompute_logprob!!", i.e. just leave it as you did originally (without condition and decondition).
I removed condition (and decondition) and use the public keyword for the new interface functions. The latter will technically change the interface, so I bumped the minor version.
I also think we should add something like AbstractState to normalize the design of state. This will introduce types for state everywhere, I am unsure of the impact. What's your thoughts on this?
I also think we should add something like AbstractState to normalize the design of state. This will introduce types for state everywhere, I am unsure of the impact. What's your thoughts on this?
Not for this PR at least:) If we want to discuss this, then we should open an issue and move discussion there.
The latter will technically change the interface, so I bumped the minor version.
It seems you've bumped the major version, not the minor version?
Also, if we're making this part of the interface, we should probably document this?
Oops, you're right.
we should probably document this?
By using the public keyword, maybe we can say "this is not official yet"? ~~I am a little hesitate to add official documentation right now, because we don't yet have a crystal clear idea of what the interface should behave.~~
Will add docs.
Some high-level comments:
- Let's introduce a
setparamsfunction to complete thegetparamsfunction. - Let's introduce some tests to test the interface and get a more grounded view of the design.
- Think of an alternative name to
recompute_logprob!!, which is a bit unintuitive in terms of what it means.
@sunxd3 please also take a careful look at
- https://github.com/TuringLang/AbstractMCMC.jl/pull/117
- https://github.com/TuringLang/AbstractMCMC.jl/pull/86
- and the
DynamicHMCsampling interface design, specifically thewarmup_stagesandreporterarguments,
we want to push for merging these PRs and incorporate some nice ideas elsewhere in the ecosystem.
current concerns:
- We need
recompute_logprobfor Gibbs sampling, the reason is: the values of the parameters on which a parameter block is conditioned will change in the general case between two Gibbs steps, so recompute the logp is generally needed. But storing the recomputed logprob into a state is not necessary, thusset_logprob!!is not necessary for Gibbs. Although it will aid other "meta samplers". - Adopting
recompute_logprob!!as anAbstractMCMCis an issue for me for two not fundamental reasons- the function is for Gibbs only and the re part is overly specific in my eye, this is very philosophical and concerns personal taste
-
recompute_logprob!!need to take anmodelargument, do we need to make clear what thismodelis?- options are
AbstractMCMC.LogDensityModelorAbstractPPL.AbstractModel, but neither is perfect
- options are
What I am proposing in the current state of this PR is:
-
AbstractMCMCdefines interfaces ofstate -
Gibbsdefine an interfacerecompute_logprob!!, which returns astate - model definition packages (e.g.
DynamicPPL) implementsrecompute_logprob!!using the interface functions ofstate, defined inAbstractMCMC
Some new updates, based on an offline discussion with @yebai
(edited Sep 17th)
-
gibbs.jlis moved into/src- it no longer uses
OrderedDict, this is to keep the dependency light (OrderedDictrequiresOrderedCollections). - now the
sampler_mapargument only accept a NamedTuple, it is less flexible than Turing's Gibbs sampler, but should be fine for a first step - the
vifield is renamed totrace, it only holds the values, not logp
- it no longer uses
- the four proposed interface functions are all removed, and each replaced with
-
get_logpandset_logp!!are combined into a single function calledlogdensity_and_state(logdensity_f, state; recompute_logp); ifrecompute_logpis true, the logdensity should be reevaluated, otherwise return the logdensity in state -
get_params-->Base.vechttps://github.com/TuringLang/AbstractMCMC.jl/blob/1382054b3be11cf28f231b636ff1c34be187d881/src/gibbs.jl#L106-L111 -
set_params!!is removed
-
- because
gibbs.jlis in/src, aconditionfunction is added toAbstractMCMC, we should find a way to combine withAsbtractPPL, also do we want to definecondition(logdensityproblem, ..)? - a sampler package now need to implement
- the
logdensity_and_state(logdensityfunction, state)interface -
Base.vecreturns a flattened representation of the parameters- Another option is to use
Iteratorinterface.
- Another option is to use
- the
and then we can add explicit impls of higher-order samplers later once we've tested out the interface + arrived at really good implementations of these, e.g. Gibbs.
I'd encourage @torfjelde and @sunxd3 to have more discussions to converge on a sensible Gibbs design; we will never get a good implementation if we don't start working on it.
I'd encourage @torfjelde and @sunxd3 to have more discussions to converge on a sensible Gibbs design; we will never get a good implementation if we don't start working on it.
Definitively! But I think it's also partially a thing that we need to try out a bit, and we're not going to get it right on the first try, hence my reluctance to putting an impl in AbstractMCMC.jl as soon as we enable the functionality.
@torfjelde thanks for the very nice review. The code and markdown doc are out-of-sync, a major difference is the one in the markdown doc uses OrderedDict, but the one in current src/gibbs.jl only uses NamedTuple (this is to not introduce dep on OrderedCollection. (The latter, then restricts that one gibbs block can only have one variable name.) Sorry for the confusion.
But if we don't end up introducing gibbs.jl into AbstractMCMC, then we should probably enable using OrderedDict as well as NamedTuple.
I'll update the code and documentation once we decide where do we want to keep the gibbs code.
Codecov Report
All modified and coverable lines are covered by tests :white_check_mark:
Project coverage is 97.54%. Comparing base (
2a77f53) to head (3ed5cb3). Report is 5 commits behind head on master.
Additional details and impacted files
@@ Coverage Diff @@
## master #144 +/- ##
==========================================
+ Coverage 97.19% 97.54% +0.34%
==========================================
Files 8 8
Lines 321 326 +5
==========================================
+ Hits 312 318 +6
+ Misses 9 8 -1
| Flag | Coverage Ξ | |
|---|---|---|
97.54% <ΓΈ> (?) |
Flags with carried forward coverage won't be shown. Click here to find out more.
:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.
@yebai @torfjelde this has been updated, many issues in the previous version has been corrected (thanks for the discussion and code review). I also added another notes in the folder design_notes, comments are welcomed. Can you give it another read?
In its current form, no interface change is made to AbstractMCMC, all the interface functions are from other packages.
the test error seems to be Julia 1.6-only related
related reply from @devmotion https://github.com/TuringLang/Turing.jl/pull/2304#issuecomment-2291097709
I do think the entire process of this would be quite a bit less painful if we did the following (I believe I've mentioned this before; if not, I apologize):
- Improve #86 to a finalized form . This is useful, not just for Gibbs sampling.
- Make a separate package, e.g.
AbstractMCMCGibbs.jl, which implements the Gibbs-only stuff, e.g.recompute_logprob!!and the sampler mapping stuff.
This is how we're doing it with MCMCTempering.jl, i.e. keep it as a separate package and slowly move pieces to AbstractMCMC.jl if it seems suitable. My problem, as stated before, is that the current Gibbs impls we're working with are really not good enough as I think is evident by a) issues that we've encountered with my Turing.jl-impl in https://github.com/TuringLang/Turing.jl/pull/2328#issuecomment-2378971734, and b) the amount of iterating you've done in this PR. This shit is complicated :grimacing: And I imagine it's really annoying iterating on this back and forth but without actually getting stuff merged..
So, I think a separate package would just make this entire process much easier @sunxd3 ; then we can iterate much faster on ideas (just make breaking releases), and then we can just upstream changes as we finalize things there + we can even inform about this in the official AbstractMCMC.jl docs and then people can easily support this via extensions.
https://github.com/TuringLang/AbstractMCMC.jl/issues/85#issuecomment-2061300622