Lena Martens
Lena Martens
hm, I think this means we need a `switch_error_check` which will do this functionalizing/merging for us. Thanks for reporting! (and for bearing with us while we make checkify feature complete!)
aha, actually the issue is that the `lax.switch` doesn't have any inputs: if you add an operand, the check is succesfully eliminated (so you don't get the first functionalization error)....
I'm pretty sure the issue is that jit `unflatten`s the `Boxes` when passing the args into the checkified function, so the `check` in the `__init__` is called outside of the...
In this case you can `checkify` again! The issue here is that you can't throw an `error` without functionalizing that error effect, but you can do that "unboxing" (`checkify`) and...
as discussed offline, only raise in case of effectful branches.
We plan on starting this migration once `PRNGKeyArray` is no longer a PyTree, ie. when we can register it as a JAX-type through the upcoming typeclass mechanism (like `vmappable` https://github.com/google/jax/pull/8451)
Whoops, I seem to have missed this! Yes, you can use a partial or lambda to get around this, or in the case of methods, transform the _bound method_ in...
Nice, +1 on the `danger of weak_type=True for triggering recompilation`, we've had people ask for more documentation on that. I might try and add that.
Sorry, I took a short-cut in my previous answer for brevity, but I omitted some important details. This is partly why better documentation would be great! The fundamental issue here...
Is this for logging purposes or would you use the value itself in a subsequent computation? Would a way to print these counts through eg. a JAX `print` be enough?...