JSL
JSL copied to clipboard
figure out why tfp MVN works but distrax does not
In https://github.com/probml/JSL/blob/main/jsl/demos/hmm_lillypad.py we use
hmm = HMM(trans_dist=distrax.Categorical(probs=A),
init_dist=distrax.Categorical(probs=initial_probs),
obs_dist=distrax.as_distribution(
tfp.substrates.jax.distributions.MultivariateNormalFullCovariance(loc=mu_collection,
covariance_matrix=cov_collection)))
but it fails when I switch to
hmm = HMM(trans_dist=distrax.Categorical(probs=A),
init_dist=distrax.Categorical(probs=initial_probs),
obs_dist=distrax.MultivariateNormalFullCovariance(
loc=mu_collection, covariance_matrix=cov_collection))
Why?