WIP: GumbelSoftmax / RelaxedOneHotCategoricalStraightThrough
Hi all, since it's been a while I thought I should maybe give a sign of life and continue from here. This tries to implement https://github.com/pyro-ppl/numpyro/issues/559.
There are still some things I haven't figured out myself yet, so I was planning to only request the review when I'm more ready, but of course feel free to take a look if you want already.
- [ ] write tests
- [x] sampling GumbelSoftmax with low temperatures
- [x] sampling GumbelSoftmax with high temperatures
- [ ] more tests
- [ ] make working GumbelSoftmaxProbs
- [ ] pass test_log_prob_gradient
- [x] log_prob in general
- [ ] log_prob with prepended shapes (i.e. the failing test_distribution_constraints test
- [x] sampling in general
- [ ] discretize at the forward pass, not the backward pass.
- [ ] mean, variance?
- [ ] documentation
- [ ] distributions.rst
- [x] every test
- [ ] every docstring
- [ ] figure out consistent and proper interface
- [ ] big cleanup before review
Hi @daydreamt , thanks for the PR! I think the main blocker of your work would be to define custom derivative rules for some of your operators. I'll update the repo to the latest JAX version today to unblock your work.
@daydreamt FYI, I think @tbsexton only needs RelaxedOneHotCategorical (or GumbelSoftmax) in his feature request because he wanted to use MCMC (instead of SVI) to draw samples from the relaxed distribution. @tbsexton could you confirm that StraightThrough is not required?
@fehiepsi I was originally only using HMC, though planning to test out SVI as well. As long as not having access to a backward pass doesn't preclude using NUTS for inference of latent variables, should work!
This is in practice a work-around for not having discrete latent variables; see my original example problem here.
Thanks, @tbsexton! In your model, you want to infer each ϕ for each cascade, so I guess you can replace
ϕ = ny.sample("ϕ", dist.Dirichlet(np.ones(n_nodes)))
x0 = ny.sample("x0", dist.Categorical(ϕ))
infectious, hist = spread_jax(s_ij, x0, 5)
by
ϕ = ny.sample("ϕ", dist.Dirichlet(np.ones(n_nodes)))
infectious, hist = spread_jax(s_ij, ϕ, 5)
Or if you want the prior for ϕ to be more like discrete, you can choose (or define a prior) a suitable temperature variable and use RelaxedOneHotCategorical
ϕ = ny.sample("ϕ", dist.RelaxedOneHotCategorical(temporature, logits=np.ones(n_nodes))))
infectious, hist = spread_jax(s_ij, ϕ, 5)
The reason is with RelaxedOneHotCategorical, the support is "simplex", and there is a transform which transforms a simplex to an "unconstrained" value, which is required for HMC/NUTS. The support of RelaxedOneHotCategoricalStraightThrough is discrete, hence there is no such transform.
If you want something like straight through, you can simply use
ϕ = ny.sample("ϕ", dist.RelaxedOneHotCategorical(temporature, logits=np.ones(n_nodes))))
ϕ_quantize = quantize(ϕ)
by defining "straight-through" quantize operator as in Pyro
def quantize(x):
return x + jax.lax.stop_gradient((x == np.max(x, -1, keepdims=True)) - x)
. You can use numpyro.deterministic(...) to record those quantized values. I am happy to add new helpers to NumPyro for your convenience when you start using SVI.
@fehiepsi much appreciated! I think I should update the model there to reflect som local changes, but primarily I think it makes more sense to pull the dirichlet out of the plates:
def diff_kg(infections):
n_cascades, n_nodes = infections.shape
n_edges = n_nodes*(n_nodes-1)//2 # complete graph
# node initial infection, relative probability
ϕ = ny.sample("ϕ", dist.Dirichlet(np.ones(n_nodes)))
# beta hyperpriors
u = ny.sample("u", dist.Uniform(np.zeros(n_edges),
np.ones(n_edges)))
v = ny.sample("v", dist.Gamma(np.ones(n_edges),
20*np.ones(n_edges)))
Λ = ny.sample("Λ", dist.Beta(u*v, (1-u)*v))
s_ij = jax_squareform(Λ) # adjacency matrix to recover via inference
with ny.plate("n_cascades", n_cascades):
# infer infection source node
x0 = ny.sample("x0", dist.Categorical(ϕ))
# simulate ode and realize
infectious, hist = spread_jax(s_ij, x0, 5)
numpyro.sample("obs", dist.Bernoulli(probs=infectious),
obs=infections)
The main idea being that certain nodes in general have a tendency to be "sources", represented by the dirichlet prior, and those manifest as conditional probabilities that each node was the source (given any individual observed infection cascade). That should be realized as one node for the spread_jax sim, or at least, very close to one node (therefore the [relaxed]categorical).
Maybe that dirichlet prior is unnecessary partial pooling? I will definitely give the new relaxed categorical a try. @daydreamt would it be helpful if I tested things out before the PR gets merged?
pull the dirichlet out of the plates
Agree that this makes more sense. With this model, you can define RelaxedOneHotCategorical for x0. (FYI, in PyTorch, Categorical samples are 0, 1, 2, 3. If you want OneHot version, you can use RelaxedOneHotCat... or OneHotCat...)
Hey @daydreamt , any progress on this?
Hi @dirmeier, not really, please feel free to take over or supersede with another MR.