restore test_apply_paddings_check runtime_checks test
The main idea is that we need to call jax.effects_barrier(), because the error may be raised in an XLA computation that is asynchronous with the main Python thread and therefore we need to block. (There may have been a recent change in behavior, where JAX runs more computations asynchronously on the CPU backend.) We could put that call to jax.effects_barrier() in the test code (and corresponding user code), or we could bulid it into the runtime_checks context manager. Currently this commit does the latter.
I also tweaked the runtime_checks logic to use a try/finally pattern to restore the state when the context is exited, even when it's exited via exception. We may want to do the same to context managers like numeric_checks.
While the test now passes, there is a gross warning printed about "Exception ignored in atexit callback". That may be a JAX internal bug, or it may be some quirk of CPython 3.10; I haven't investigated further. Let me know if that seems like a problem.
What do you think?
My intrepid teammates @yashk2810 and @hawkinsp noticed that in the most recent release of JAX we no longer raise jaxlib.xla_extension.XlaRuntimeError but rather jax.errors.JaxRuntimeError (EDIT: or maybe that's just a public-facing alias for the same object...). See https://github.com/jax-ml/jax/pull/23943. I'll try to update that in this PR (under a version switch), or send a follow-up PR if this PR gets merged before I make the fix.
Thank you!
Hi @mattjj @matthew-e-hopkins is this PR still relevant? Thanks
Well, the test is still present and commented out at HEAD: https://github.com/apple/axlearn/blob/df7ed095656599675e75df2451c505e107c988f3/axlearn/common/transducer_test.py#L302-L313
It's @matthew-e-hopkins's TODO. How about that Matt decides!
(I personally am fine with closing it.)
I'll just close it for now but you guys / @matthew-e-hopkins should re-open and merge if you care!