Easy repetition of distribution parameters
Hi all,
I'm trying to figure out if there is an implemented, general way of taking a tfp.distribution, repeating certain batch members, and making a new distribution to sample from. The code below works, but gets cumbersome quickly when I need to do the repetition on many different types of distributions.
import tensorflow as tf
import tensorflow_probability as tfp
norm = tfp.distributions.Normal([[0,1],[2,3]], [[1,2],[3,4]])
### what if you want to repeat the first distribution multiple times?
norm_repeat = tfp.distributions.Normal(tf.gather(norm.loc, [0,0,0,1]), tf.gather(norm.scale, [0,0,0,1]))
norm_repeat.sample()
Distribution batch slicing (e.g., norm[:1, :]) is close to what you want, but doesn't quite work because there's there's no way to specify repeated indices AFAIK. But I think you could use the same underlying machinery to call gather:
import tensorflow as tf
import tensorflow_probability as tfp
from tensorflow_probability.python.internal import batch_shape_lib
norm = tfp.distributions.Normal([[0,1],[2,3]], [[1,2],[3,4]])
norm_repeat = type(norm)(
**batch_shape_lib.map_fn_over_parameters_with_event_ndims(
norm,
lambda p, event_ndims: tf.gather(p, [0, 0, 0, 1])))
appears to work for me. (with the usual caveat that internal APIs may change without notice).
I also would like this functionality, because I want to be able to efficiently sample from the marginal of an independent distribution, but it seems the only way to do this is to write an ad-hoc implementation for every distribution using tf.gather on the parameters. It seems like it would be a nightmare to try to do it in a generalised way, having to worry about batch_shape vs event_shape and whether the distribution is a tfp.distributions.Independent or tfp.distributions.BatchBroadcast, etc.