Boris Yangel

Results 12 issues of Boris Yangel

### Description I'm trying to do a switch over functions that have a checkify check inside. It goes somewhat like this: ```python def make_branch(i): def branch(): result = jnp.full((1,), i)...

bug

### Description The most obvious way to run into this would be to apply `checkify` to a class method of a class that is not a valid JAX type. ```python...

bug

Hello! I've noticed that the gradient is being scaled by 0.5 after every model step: https://github.com/Hwhitetooth/jax_muzero/blob/b8ab36251d22e0246d61514841ba17a22f4b2a36/algorithms/agents.py#L116-L120 Can you clarify the motivation for that? Are you aware of any experimental results...

### Description I have a transformer model where each transformer block is rematerialized. The model is distributed over multiple devices using jit. Each transformer block has dropout enabled. To prevent...

bug

### Description Consider the following code snippet: ```python import jax import flax.linen as nn from jax.sharding import Mesh import functools class Model(nn.Module): output_dim = 32768 * 8 @nn.compact def __call__(self,...

bug

### Description This is a duplicate of [this issue](https://github.com/google/jax/issues/19893), but I think it might be flax-related, so also posting here. Consider the following code snippet: ```python import jax import flax.linen...

Which can be problematic, for instance, when it's used in conjunction with pjit.

### Description I've encountered an interesting situation that I've described in more details here: https://github.com/google/jax/discussions/20284#discussioncomment-8815174 Basically, the problem is as follows: * I've written some code for running inference on...

bug
NVIDIA GPU

JAX recently added [a transformation](https://jax.readthedocs.io/en/latest/debugging/checkify_guide.html) that allows to validate user-defined assertions, as well as check for some standard issues, such as division by zero or NaNs arising in computations. Unfortunately,...

`orbax/checkpoint/pytree_checkpoint_handler.py:661` has the following check: `if not item` It most likely should be `if item is None`, as otherwise this check will raise an error when item is an array...