probability icon indicating copy to clipboard operation
probability copied to clipboard

STS components with batching on the initial states break

Open jeffpollock9 opened this issue 1 year ago • 0 comments

I am trying to run many time series models with different initial states via batching, it seems to be supported in some parts of the code although not explicit in the documentation, so I am wondering if this is a bug or just not supported? Either way I think it would be useful to have working.

The following example shows some weird results you can get currently whereby the joint_distribution function seems to give the wrong answer (it works as expected for batch_shape=[]) where it seems to sum 3*3 log likelihoods instead of 3:

import tensorflow_probability as tfp

print(tf.__version__)
# 2.18.0

print(tfp.__version__)
# 0.25.0

sts = tfp.sts
tfd = tfp.distributions

batch_shape = [3]
num_timesteps = 10
param_vals = [1.0, 2.0]
observations = tf.ones(batch_shape + [num_timesteps, 1])

local_level = sts.LocalLevel(
    initial_level_prior=tfd.Normal(loc=tf.zeros(batch_shape), scale=1.0)
)

print(local_level.batch_shape)
# ()
print(local_level.initial_state_prior.batch_shape)
# (3,)

model = sts.Sum(components=[local_level])

# joint dist
joint_dist = model.joint_distribution(observed_time_series=observations)
joint_dist_log_prob = joint_dist.log_prob(param_vals)

print(joint_dist_log_prob)
# tf.Tensor(-167.13658, shape=(), dtype=float32)

# ssm
ssm = model.make_state_space_model(num_timesteps=num_timesteps, param_vals=param_vals)
log_likelihood = sum(
    ssm.forward_filter(observations, final_step_only=True).log_likelihoods
)
prior = sum(p.prior.log_prob(x) for p, x in zip(model.parameters, param_vals))

print(log_likelihood + prior)
# tf.Tensor(-60.865345, shape=(), dtype=float32)

print(log_likelihood * 3 + prior)
# tf.Tensor(-167.13658, shape=(), dtype=float32)

At a guess this is due to the components only considering their parameters and not the initial state for their batch shape. I am happy to work on a fix or feature addition if that is helpful. Thanks!

jeffpollock9 avatar Nov 28 '24 13:11 jeffpollock9