jax
jax copied to clipboard
Adding `JAX_LOGGING_LEVEL` configuration option
For example, setting JAX_LOGGING_LEVEL=DEBUG for
jax.jit(lambda x: x)(jnp.ones((10,)))
gives
DEBUG:2024-09-16 10:11:49,929:jax._src.path:45: etils.epath was not found. Using pathlib for file I/O.
DEBUG:2024-09-16 10:11:50,045:jax._src.dispatch:178: Finished tracing + transforming convert_element_type for pjit in 0.000222683 sec
DEBUG:2024-09-16 10:11:50,048:jax._src.xla_bridge:579: Discovered path based JAX plugin: jax_plugins.cuda
DEBUG:2024-09-16 10:11:50,048:jax._src.xla_bridge:579: Discovered path based JAX plugin: jax_plugins.rocm
DEBUG:2024-09-16 10:11:50,054:jax._src.xla_bridge:594: Loading plugin module jax_plugins.rocm
DEBUG:2024-09-16 10:11:50,055:jax._src.xla_bridge:594: Loading plugin module jax_plugins.cuda
DEBUG:2024-09-16 10:11:50,055:jax._src.xla_bridge:970: Initializing backend 'cpu'
DEBUG:2024-09-16 10:11:50,057:jax._src.xla_bridge:982: Backend 'cpu' initialized
DEBUG:2024-09-16 10:11:50,057:jax._src.xla_bridge:970: Initializing backend 'rocm'
INFO:2024-09-16 10:11:50,057:jax._src.xla_bridge:895: Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
DEBUG:2024-09-16 10:11:50,057:jax._src.xla_bridge:970: Initializing backend 'tpu'
INFO:2024-09-16 10:11:50,058:jax._src.xla_bridge:895: Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object
file: No such file or directory
DEBUG:2024-09-16 10:11:50,059:jax._src.interpreters.pxla:1903: Compiling convert_element_type with global shapes and types [ShapedArray(float32[])]. Argument mapp
ing: (UnspecifiedValue,).
DEBUG:2024-09-16 10:11:50,066:jax._src.dispatch:178: Finished jaxpr to MLIR module conversion jit(convert_element_type) in 0.007055998 sec
DEBUG:2024-09-16 10:11:50,066:jax._src.compiler:168: get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:2024-09-16 10:11:50,067:jax._src.compiler:227: get_compile_options XLA-AutoFDO profile: using XLA-AutoFDO profile version -1
DEBUG:2024-09-16 10:11:50,067:jax._src.cache_key:120: get_cache_key hash of serialized computation: 96f6c2c849611cb65ca14596725417384a55335a17aaf471c2f18031b69bb7
36
DEBUG:2024-09-16 10:11:50,067:jax._src.cache_key:126: get_cache_key hash after serializing computation: 96f6c2c849611cb65ca14596725417384a55335a17aaf471c2f18031b69bb736
DEBUG:2024-09-16 10:11:50,067:jax._src.cache_key:120: get_cache_key hash of serialized jax_lib version: 5f45bbcdc1f605c0aa3cce55583a208fba43af4faefecfdf157d982dac286c25
DEBUG:2024-09-16 10:11:50,067:jax._src.cache_key:126: get_cache_key hash after serializing jax_lib version: 321fe1b73386c60fa57bb696231a28dbda25a62a329efc2c26d72fee523ea5fd
...
in the formatting style of already existing JAX module debugging: jax_debug_log_modules option.
Also can you provide example command + output in the PR description?