probability icon indicating copy to clipboard operation
probability copied to clipboard

Easy repetition of distribution parameters

Open bryorsnef opened this issue 4 years ago • 2 comments

Hi all,

I'm trying to figure out if there is an implemented, general way of taking a tfp.distribution, repeating certain batch members, and making a new distribution to sample from. The code below works, but gets cumbersome quickly when I need to do the repetition on many different types of distributions.

import tensorflow as tf
import tensorflow_probability as tfp

norm = tfp.distributions.Normal([[0,1],[2,3]], [[1,2],[3,4]])
### what if you want to repeat the first distribution multiple times?
norm_repeat = tfp.distributions.Normal(tf.gather(norm.loc, [0,0,0,1]), tf.gather(norm.scale, [0,0,0,1]))
norm_repeat.sample()

bryorsnef avatar Dec 09 '21 00:12 bryorsnef

Distribution batch slicing (e.g., norm[:1, :]) is close to what you want, but doesn't quite work because there's there's no way to specify repeated indices AFAIK. But I think you could use the same underlying machinery to call gather:

import tensorflow as tf
import tensorflow_probability as tfp
from tensorflow_probability.python.internal import batch_shape_lib

norm = tfp.distributions.Normal([[0,1],[2,3]], [[1,2],[3,4]])
norm_repeat = type(norm)( 
  **batch_shape_lib.map_fn_over_parameters_with_event_ndims(
      norm,
      lambda p, event_ndims: tf.gather(p, [0, 0, 0, 1])))

appears to work for me. (with the usual caveat that internal APIs may change without notice).

davmre avatar Dec 09 '21 00:12 davmre

I also would like this functionality, because I want to be able to efficiently sample from the marginal of an independent distribution, but it seems the only way to do this is to write an ad-hoc implementation for every distribution using tf.gather on the parameters. It seems like it would be a nightmare to try to do it in a generalised way, having to worry about batch_shape vs event_shape and whether the distribution is a tfp.distributions.Independent or tfp.distributions.BatchBroadcast, etc.

meowcakes avatar Dec 16 '21 11:12 meowcakes