Expose the acceptance ratio in Metropolis-Hastings
When debugging MH inference (e.g. "why is my proposal never accepted"), it would be useful to have a variant of mh that exposes the acceptance probability. We could make something like mh_with_diagnostics which contains the main implementation and then mh becomes a wrapper around that which (for backwards compatibility) doesn't return the acceptance probability.
Yes, this would be great. I also think breaking the ratio down into the four components (forward and backward model and proposal scores) would be useful in a diagnostic-MH function.
Instrumenting a local build of Gen with @debug statements printing the relevant quantities has been very useful. In theory we could add @debug statements to master too, but I think this feature is best placed in the non-debug code path as another return value of mh that can be ignored if desired.
Straw-man proposal:
(new_trace::Trace, accepted::Bool, diagnostics) = Gen.mh(::Trace, ::Selection)
(new_trace::Trace, accepted::Bool, diagnostics) = Gen.mh(::Trace, ::GenerativeFunction, proposal_args::Tuple)
(new_trace::Trace, accepted::Bool, diagnostics) = Gen.mh(::Trace, ::GenerativeFunction, proposal_args::Tuple, involution)
Each diagnostics can have a different type which depends on the MH variant used. For example the involution MH diagnostics can include the log-det-Jacobian of the diffeomorphism, and the non-involution variants could have something like
struct MHDiagnostics
"""log(prior density of new trace) - log(prior density of old trace)"""
log_ratio_prior_prob::Real
"""
log(proposal density of old trace -> new trace) -
log(proposal density of new trace -> old trace)
"""
log_ratio_proposal_prob::Real
acceptance_logprob::Real
accepted::Bool
end
@marcoct @alex-lew @ztangent wdyt?
I think this is a good idea. I think it makes sense to design inference library functions such that debug and dynamic checks can be enabled and disabled as needed. Dynamic checks are currently handled this way in many cases in Gen. Even if the overhead is low in this case, I think that we will want to uniformly support diagnostic information in the inference library, and this information could become complicated; I think that having an affordance for determining whether diagnostics should be applied makes sense.
What about adding an optional flag to inference library functions that changes the type signature. Something like:
const DIAGNOSTICS_ON = Val{true}()
const DIAGNOSTICS_OFF = Val{false}()
...
function mh(..., diagnostics::Val{true})
...
return (new_trace, diagnostics::MHDiagnostics)
end
function mh(..., diagnostics::Val{false})
...
return (new_trace, nothing)
end
mh(...) = mh(..., DIAGNOSTICS_ON)
Also, I think the acceptance Boolean is probably not necessary separately from this. That is, mh would just return a tuple of two elements, the trace and the debug info.
Do you prefer having separate methods for the different values of the flag, as opposed to having if statements inside the body? If so, is it because you'd hope for more optimized compiled code in the version without diagnostics? My sense is that we don't currently have cases where computing the diagnostics incurs significant additional computational cost, and if we did, we would probably want to distinguish between easy-to-compute diagnostics and hard-to-compute diagnostics in the API anyway.
I like the idea of having the function always return the same number of things, so that the user can more easily flip the diagnostics flag in their code. In fact, we could have something like
@enum DiagnosticsOption begin
DIAGNOSTICS_OFF
DIAGNOSTICS_LITE
DIAGNOSTICS_FULL
end
each of which can return a different type for the diagnostics -- and Nothing happens to be the diagnostics type corresponding to DIAGNOSTICS_OFF.
Also, I think the acceptance Boolean is probably not necessary separately from this. That is, mh would just return a tuple of two elements, the trace and the debug info.
Would have said the same but wanted to allow it to be a non-breaking(ish) change.
We also need to return the proposed trace, and the Jacobian correction when there is one.