numpyro icon indicating copy to clipboard operation
numpyro copied to clipboard

WIP: GumbelSoftmax / RelaxedOneHotCategoricalStraightThrough

Open daydreamt opened this issue 5 years ago • 8 comments

Hi all, since it's been a while I thought I should maybe give a sign of life and continue from here. This tries to implement https://github.com/pyro-ppl/numpyro/issues/559.

There are still some things I haven't figured out myself yet, so I was planning to only request the review when I'm more ready, but of course feel free to take a look if you want already.

  • [ ] write tests
    • [x] sampling GumbelSoftmax with low temperatures
    • [x] sampling GumbelSoftmax with high temperatures
    • [ ] more tests
  • [ ] make working GumbelSoftmaxProbs
    • [ ] pass test_log_prob_gradient
    • [x] log_prob in general
    • [ ] log_prob with prepended shapes (i.e. the failing test_distribution_constraints test
    • [x] sampling in general
    • [ ] discretize at the forward pass, not the backward pass.
  • [ ] mean, variance?
  • [ ] documentation
    • [ ] distributions.rst
    • [x] every test
    • [ ] every docstring
  • [ ] figure out consistent and proper interface
  • [ ] big cleanup before review

daydreamt avatar Apr 06 '20 20:04 daydreamt

Hi @daydreamt , thanks for the PR! I think the main blocker of your work would be to define custom derivative rules for some of your operators. I'll update the repo to the latest JAX version today to unblock your work.

fehiepsi avatar Apr 07 '20 17:04 fehiepsi

@daydreamt FYI, I think @tbsexton only needs RelaxedOneHotCategorical (or GumbelSoftmax) in his feature request because he wanted to use MCMC (instead of SVI) to draw samples from the relaxed distribution. @tbsexton could you confirm that StraightThrough is not required?

fehiepsi avatar Apr 08 '20 02:04 fehiepsi

@fehiepsi I was originally only using HMC, though planning to test out SVI as well. As long as not having access to a backward pass doesn't preclude using NUTS for inference of latent variables, should work!

This is in practice a work-around for not having discrete latent variables; see my original example problem here.

rtbs-dev avatar Apr 08 '20 13:04 rtbs-dev

Thanks, @tbsexton! In your model, you want to infer each ϕ for each cascade, so I guess you can replace

ϕ = ny.sample("ϕ", dist.Dirichlet(np.ones(n_nodes)))  
x0 = ny.sample("x0", dist.Categorical(ϕ))
infectious, hist = spread_jax(s_ij, x0, 5)

by

ϕ = ny.sample("ϕ", dist.Dirichlet(np.ones(n_nodes)))  
infectious, hist = spread_jax(s_ij, ϕ, 5)

Or if you want the prior for ϕ to be more like discrete, you can choose (or define a prior) a suitable temperature variable and use RelaxedOneHotCategorical

ϕ = ny.sample("ϕ", dist.RelaxedOneHotCategorical(temporature, logits=np.ones(n_nodes))))
infectious, hist = spread_jax(s_ij, ϕ, 5)

The reason is with RelaxedOneHotCategorical, the support is "simplex", and there is a transform which transforms a simplex to an "unconstrained" value, which is required for HMC/NUTS. The support of RelaxedOneHotCategoricalStraightThrough is discrete, hence there is no such transform.

If you want something like straight through, you can simply use

ϕ = ny.sample("ϕ", dist.RelaxedOneHotCategorical(temporature, logits=np.ones(n_nodes))))
ϕ_quantize = quantize(ϕ)

by defining "straight-through" quantize operator as in Pyro

def quantize(x):
    return x + jax.lax.stop_gradient((x == np.max(x, -1, keepdims=True)) - x)

. You can use numpyro.deterministic(...) to record those quantized values. I am happy to add new helpers to NumPyro for your convenience when you start using SVI.

fehiepsi avatar Apr 12 '20 17:04 fehiepsi

@fehiepsi much appreciated! I think I should update the model there to reflect som local changes, but primarily I think it makes more sense to pull the dirichlet out of the plates:

def diff_kg(infections):
    n_cascades, n_nodes  = infections.shape
    n_edges = n_nodes*(n_nodes-1)//2 # complete graph
        
    # node initial infection, relative probability
    ϕ = ny.sample("ϕ", dist.Dirichlet(np.ones(n_nodes))) 
    
    # beta hyperpriors
    u = ny.sample("u", dist.Uniform(np.zeros(n_edges), 
                                         np.ones(n_edges)))
    v = ny.sample("v", dist.Gamma(np.ones(n_edges),
                                       20*np.ones(n_edges)))
    Λ = ny.sample("Λ", dist.Beta(u*v, (1-u)*v))
    s_ij = jax_squareform(Λ)  # adjacency matrix to recover via inference
    
    with ny.plate("n_cascades", n_cascades):
        # infer infection source node
        x0 = ny.sample("x0", dist.Categorical(ϕ))
        # simulate ode and realize
        infectious, hist = spread_jax(s_ij, x0, 5)
        numpyro.sample("obs", dist.Bernoulli(probs=infectious), 
                       obs=infections)

The main idea being that certain nodes in general have a tendency to be "sources", represented by the dirichlet prior, and those manifest as conditional probabilities that each node was the source (given any individual observed infection cascade). That should be realized as one node for the spread_jax sim, or at least, very close to one node (therefore the [relaxed]categorical).

Maybe that dirichlet prior is unnecessary partial pooling? I will definitely give the new relaxed categorical a try. @daydreamt would it be helpful if I tested things out before the PR gets merged?

rtbs-dev avatar Apr 13 '20 14:04 rtbs-dev

pull the dirichlet out of the plates

Agree that this makes more sense. With this model, you can define RelaxedOneHotCategorical for x0. (FYI, in PyTorch, Categorical samples are 0, 1, 2, 3. If you want OneHot version, you can use RelaxedOneHotCat... or OneHotCat...)

fehiepsi avatar Apr 13 '20 15:04 fehiepsi

Hey @daydreamt , any progress on this?

dirmeier avatar Jun 15 '22 09:06 dirmeier

Hi @dirmeier, not really, please feel free to take over or supersede with another MR.

daydreamt avatar Jun 15 '22 11:06 daydreamt