JSL icon indicating copy to clipboard operation
JSL copied to clipboard

figure out why tfp MVN works but distrax does not

Open murphyk opened this issue 3 years ago • 0 comments

In https://github.com/probml/JSL/blob/main/jsl/demos/hmm_lillypad.py we use

  hmm = HMM(trans_dist=distrax.Categorical(probs=A),
              init_dist=distrax.Categorical(probs=initial_probs),
              obs_dist=distrax.as_distribution(
                  tfp.substrates.jax.distributions.MultivariateNormalFullCovariance(loc=mu_collection,
                                                                                    covariance_matrix=cov_collection)))

but it fails when I switch to


    hmm = HMM(trans_dist=distrax.Categorical(probs=A),
            init_dist=distrax.Categorical(probs=initial_probs),
            obs_dist=distrax.MultivariateNormalFullCovariance(
                loc=mu_collection, covariance_matrix=cov_collection))

Why?

murphyk avatar May 07 '22 23:05 murphyk