probability icon indicating copy to clipboard operation
probability copied to clipboard

Is reparameterization not available for Mixture MultivariateNormalDiag using MixtureSameFamily?

Open moonmoondog opened this issue 3 years ago • 1 comments

Trying to sample from a mixture multivariate normal distribution using reparameterization trick, however errors return.

import tensorflow_probability as tfp
tfd = tfp.distributions

gm = tfd.MixtureSameFamily(
      mixture_distribution=tfd.Categorical(
          probs=[0.3, 0.7]),
      components_distribution=tfd.MultivariateNormalDiag(
          loc=[[-1., 1],  # component 1
               [1, -1]],  # component 2
          scale_identity_multiplier=[.3, .6]), reparameterize=True )
sample = gm.sample()

Errors return after executing the code above:

InvalidArgumentError: `univariate_components` must have scalar event
Condition x == y did not hold.
First 1 elements of x:
[False]
First 1 elements of y:
[ True]

Is there any way to work around that? Or is reparameterization just not available for such distribution? Thank you

moonmoondog avatar Apr 11 '22 20:04 moonmoondog

The underlying machinery wants components_distribution to be manifestly "factorized" (or, at least "factorizable"). MVNDiag actually should qualify, but isn't implemented as such. A somewhat hacky workaround would be to replace it with an Independent(Normal(...)) (the implementation is happy to recognize Independent as factorizable):

import tensorflow_probability as tfp
tfd = tfp.distributions

gm = tfd.MixtureSameFamily(
      mixture_distribution=tfd.Categorical(
          probs=[0.3, 0.7]),
      components_distribution=tfd.Independent(
          tfd.Normal(
              loc=[[-1., 1],  # component 1
                   [1, -1]],  # component 2
              scale=[.3, .6]),
          reinterpreted_batch_ndims=1),
    reparameterize=True)
sample = gm.sample()

More on how it's implemented here: https://github.com/tensorflow/probability/blob/v0.16.0/tensorflow_probability/python/distributions/mixture_same_family.py#L446-L474

This feature was kindly contributed by one of the authors of the implicit reparameterization work, and it looks like there's some nice documentation in the code. It might be possible to generalize it a bit so that easy cases like MVNDiag work, too.

csuter avatar Apr 12 '22 18:04 csuter