probability icon indicating copy to clipboard operation
probability copied to clipboard

`JointDistributionCoroutineAutoBatched.sample_distributions` errors using jax substrate

Open jeffpollock9 opened this issue 1 year ago • 0 comments

Hi, I found that sample_distributions can error when using the jax substrate. sample seems to be ok, using tf appears to be ok, and using the non auto batched joint distribution is also ok.

Here is a small example:

from functools import partial

import jax
from tensorflow_probability.substrates import jax as tfp

tfd = tfp.distributions

tfp.__version__
# 0.25.0

@partial(tfd.JointDistributionCoroutine, batch_ndims=0)
def joint_dist():
    x = yield tfd.Gamma(2.0, 10.0, name="x")
    y = yield tfd.Gamma(x, 10.0, name="y")


seed = jax.random.key(123)

# ok
dists, samples = joint_dist.sample_distributions(x=[1.0, 2.0], seed=seed)

# samples:
# StructTuple(
#   x=Array([1., 2.], dtype=float32),
#   y=Array([0.10665689, 0.21802416], dtype=float32)
# )


@tfd.JointDistributionCoroutineAutoBatched
def joint_dist():
    x = yield tfd.Gamma(2.0, 10.0, name="x")
    y = yield tfd.Gamma(x, 10.0, name="y")


# ok
samples = joint_dist.sample(x=[1.0, 2.0], seed=seed)

# samples:
# StructTuple(
#   x=Array([1., 2.], dtype=float32),
#   y=Array([0.05508393, 0.14792603], dtype=float32)
# )

# ValueError: Attempt to convert a value (<object object at 0x717aa57398a0>) with an unsupported type (<class 'object'>) to a Tensor.
dists, samples = joint_dist.sample_distributions(x=[1.0, 2.0], seed=seed)

Thanks!

jeffpollock9 avatar Dec 05 '24 16:12 jeffpollock9