TPU operations don't appear if profile is too long
Description
Hi! I noticed unexpected behavior with the JAX profiler (on TPU v3). No TPU operations appear on the profile, and the overview shows "Host: CPU" in some cases even when operations are running on the TPU. I believe this is happening when the profile is too long (~30-40s). Here's what I observed with my train loop:
- (worked) Profiling 20 steps of my train loop + a dummy dataset: 3 seconds total time on the profile, showed TPU activity as expected.
- (failed) 20 steps of same train loop + expensive data pipeline which loaded images from GCS instead of dummy data + large batch size: 35s, only showed CPU activity.
- (worked) Expensive data pipeline + smaller batch size: 5s, showed TPU
- (worked) Expensive data pipeline + profile for 2 steps: 2s, showed TPU
I managed to reproduce this with a minimal example and uploaded the profiles below of the same script showing TPU ops when profiling for 2 steps but not for 25.
Minimal mocked example
def main():
import jax
import numpy as np
import granular
from big_vision.datasets.interleaved.interleaved import make_interleaved_mixture
from big_vision import utils as u
P = jax.sharding.PartitionSpec
jax.distributed.initialize()
mesh = jax.sharding.Mesh(jax.devices(), "d")
d = jax.sharding.NamedSharding(mesh, P("d"))
n = jax.sharding.NamedSharding(mesh, P())
# Define dummy update function.
def fn(x, y):
y = y["image"][:, :, 0, 0, 0].repeat(256 // 16, 1)
res = x @ y.repeat(len(x) // len(y), 0)
loss = res.sum()
return res, loss
fn = jax.jit(fn, in_shardings=d, out_shardings=(d, n))
x = jax.device_put(np.ones((256, 256)), d)
dataset = ... # Expensive dataset that loads images from GCS
loader = ... # Multiprocessing dataloader with large batch size
it = iter(loader)
start_step = 5
end_step = 7 # Setting this to 25 results in no TPU ops on profile
for i, y in zip(range(30), it):
if i == start_step:
jax.profiler.start_trace('profiles')
if i > 0:
prev_loss = loss
with jax.profiler.StepTraceAnnotation('train', step_num=i):
res, loss = fn(x, y)
# Runahead max one batch.
if i > 0:
jax.block_until_ready(prev_loss)
if i == 7:
jax.profiler.stop_trace()
if __name__ == '__main__':
main()
Profiles here: https://github.com/jlin816/jax-profiling-bug
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.4.31 jaxlib: 0.4.31 numpy: 1.26.4 python: 3.10.12 (main, Jul 29 2024, 16:56:48) [GCC 11.4.0] jax.devices (8 total, 8 local): [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0) TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1) ... TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0) TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)] process_count: 1 platform: uname_result(system='Linux', node='t1v-n-61c88ef7-w-0', release='5.19.0-1022-gcp', version='#24~22.04.1-Ubuntu SMP Sun Apr 23 09:51:08 UTC 2023', machine='x86_64')