dynamax icon indicating copy to clipboard operation
dynamax copied to clipboard

Off-by-one error in non-stationary HMM transitions

Open umeshksingla opened this issue 1 month ago • 2 comments

There was a recent PR #436 on fixing the issues of off-by-1 error in non-stationary HMM transitions code by @colecitrenbaum. However, the code seems to have another bug and if anyone has encountered it before, please let me know.

In abstractions.py, at line L341:

lp = jnp.sum(expected_transitions * log_trans_matrix)

It gives an error:

     lp = jnp.sum(expected_transitions * log_trans_matrix)
                 ~~~~~~~~~~~~~~~~~~~~~^~~~~~~~~~~~~~~~~~
    return lax.mul(x, y) if x.dtype != bool else lax.bitwise_and(x, y)
           ~~~~~~~^^^^^^
TypeError: mul got incompatible shapes for broadcasting: (9, 3, 3), (8, 3, 3).

on a setup of T=10 timesteps.

I have attached below the screenshot of where I think the bug could be (highlighted in pink, L276 and L298), but I can also provide a minimal example to reproduce the error if needed. To me, it seems like there is no need to do pytree_slice at L298, as that is taken care of at L276 when computing transition matrices.

Image

umeshksingla avatar Dec 10 '25 01:12 umeshksingla

The code to reproduce the above error is here. It is a minimal HMM implementation with input-driven state transitions and Gaussian emissions.

umeshksingla avatar Dec 10 '25 01:12 umeshksingla

Hi, Thanks for your patience on this. I agree with your assessment -- I just sent a PR that should fix this and your code runs. Please let me know if there's any further issue once it's merged!

colecitrenbaum avatar Jan 05 '26 23:01 colecitrenbaum

Hi @colecitrenbaum , thanks so much for taking this on!

I just noticed a minor inconsistency with HMMInitialState as well. In abstractions.py, at line L144, the code currently returns pytree_slice(inputs, 0).

However, at L124, _compute_initial_probs (and by extension distribution) expects to receive the full inputs array. So, maybe, either pytree_slice can be removed and that will make it consistent with the above HMMTransitions change. OR the definition of the method _compute_initial_probs (and distribution) can be updated to reflect that it only takes in a single t=0 input.

umeshksingla avatar Jan 16 '26 07:01 umeshksingla