Is reparameterization not available for Mixture MultivariateNormalDiag using MixtureSameFamily?
Trying to sample from a mixture multivariate normal distribution using reparameterization trick, however errors return.
import tensorflow_probability as tfp
tfd = tfp.distributions
gm = tfd.MixtureSameFamily(
mixture_distribution=tfd.Categorical(
probs=[0.3, 0.7]),
components_distribution=tfd.MultivariateNormalDiag(
loc=[[-1., 1], # component 1
[1, -1]], # component 2
scale_identity_multiplier=[.3, .6]), reparameterize=True )
sample = gm.sample()
Errors return after executing the code above:
InvalidArgumentError: `univariate_components` must have scalar event
Condition x == y did not hold.
First 1 elements of x:
[False]
First 1 elements of y:
[ True]
Is there any way to work around that? Or is reparameterization just not available for such distribution? Thank you
The underlying machinery wants components_distribution to be manifestly "factorized" (or, at least "factorizable"). MVNDiag actually should qualify, but isn't implemented as such. A somewhat hacky workaround would be to replace it with an Independent(Normal(...)) (the implementation is happy to recognize Independent as factorizable):
import tensorflow_probability as tfp
tfd = tfp.distributions
gm = tfd.MixtureSameFamily(
mixture_distribution=tfd.Categorical(
probs=[0.3, 0.7]),
components_distribution=tfd.Independent(
tfd.Normal(
loc=[[-1., 1], # component 1
[1, -1]], # component 2
scale=[.3, .6]),
reinterpreted_batch_ndims=1),
reparameterize=True)
sample = gm.sample()
More on how it's implemented here: https://github.com/tensorflow/probability/blob/v0.16.0/tensorflow_probability/python/distributions/mixture_same_family.py#L446-L474
This feature was kindly contributed by one of the authors of the implicit reparameterization work, and it looks like there's some nice documentation in the code. It might be possible to generalize it a bit so that easy cases like MVNDiag work, too.