Charlie Brummitt

Results 2 comments of Charlie Brummitt

Based on https://github.com/google/jax/issues/763 I see I can pass a tuple to `xs`, like so: ```python _, (states, observations) = lax.scan( f=_step, init=params.initial_state, xs=(rngs, u) ) ``` Then the `_step` function...