axlearn icon indicating copy to clipboard operation
axlearn copied to clipboard

`check_numerics` doesn't work inside `repeat.py`

Open ds-hwang opened this issue 1 year ago • 0 comments

assert bool(jnp.isfinite(x).all()), f"Check numerics {msg_fmt.format(**msg_kwargs)}: {x}" doesn't work with traced x.

assert bool(jnp.isfinite(x).all()), f"Check numerics {msg_fmt.format(**msg_kwargs)}: {x}"
Traceback (most recent call last):
  File "/Users/dongseong/miniforge3/envs/ajax/lib/python3.10/site-packages/jax/_src/core.py", line 782, in __bool__
    return self.aval._bool(self)
  File "/Users/dongseong/miniforge3/envs/ajax/lib/python3.10/site-packages/jax/_src/core.py", line 1538, in error
    raise TracerBoolConversionError(arg)
jax.errors.TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].
The error occurred while tracing the function fn at /Users/dongseong/Workspaces/axlearn/axlearn/common/base_layer.py:329 for checkpoint. This concrete value was not available in Python because it depends on the value of the argument kwargs['inputs'].
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError

So check_numerics doesn't work inside jit, pmap, and scan.

def check_numerics(x: Tensor, msg_fmt: str = "", **msg_kwargs):
    """Checks that all elements in `x` are finite."""
    global _enable_numeric_checks  # pylint: disable=global-statement,global-variable-not-assigned
    if _enable_numeric_checks:
        assert bool(jnp.isfinite(x).all()), f"Check numerics {msg_fmt.format(**msg_kwargs)}: {x}"
    return x

There is jax checkify, but it requires wrapped by check.checkify(main). It's not trivial to use it in axlearn. https://jax.readthedocs.io/en/latest/debugging/checkify_guide.html

ds-hwang avatar Nov 05 '24 23:11 ds-hwang