jax icon indicating copy to clipboard operation
jax copied to clipboard

TPU operations don't appear if profile is too long

Open jlin816 opened this issue 1 year ago • 0 comments

Description

Hi! I noticed unexpected behavior with the JAX profiler (on TPU v3). No TPU operations appear on the profile, and the overview shows "Host: CPU" in some cases even when operations are running on the TPU. I believe this is happening when the profile is too long (~30-40s). Here's what I observed with my train loop:

  • (worked) Profiling 20 steps of my train loop + a dummy dataset: 3 seconds total time on the profile, showed TPU activity as expected.
  • (failed) 20 steps of same train loop + expensive data pipeline which loaded images from GCS instead of dummy data + large batch size: 35s, only showed CPU activity.
  • (worked) Expensive data pipeline + smaller batch size: 5s, showed TPU
  • (worked) Expensive data pipeline + profile for 2 steps: 2s, showed TPU

I managed to reproduce this with a minimal example and uploaded the profiles below of the same script showing TPU ops when profiling for 2 steps but not for 25.

Minimal mocked example
def main():
    import jax
    import numpy as np
    import granular
    from big_vision.datasets.interleaved.interleaved import make_interleaved_mixture
    from big_vision import utils as u
    P = jax.sharding.PartitionSpec

    jax.distributed.initialize()

    mesh = jax.sharding.Mesh(jax.devices(), "d")
    d = jax.sharding.NamedSharding(mesh, P("d"))
    n = jax.sharding.NamedSharding(mesh, P())

    # Define dummy update function.
    def fn(x, y):
        y = y["image"][:, :, 0, 0, 0].repeat(256 // 16, 1)
        res = x @ y.repeat(len(x) // len(y), 0)
        loss = res.sum()
        return res, loss

    fn = jax.jit(fn, in_shardings=d, out_shardings=(d, n))
    x = jax.device_put(np.ones((256, 256)), d)

    dataset = ... # Expensive dataset that loads images from GCS
    loader = ... # Multiprocessing dataloader with large batch size
    it = iter(loader)

    start_step = 5
    end_step = 7  # Setting this to 25 results in no TPU ops on profile

    for i, y in zip(range(30), it):
      if i == start_step:
        jax.profiler.start_trace('profiles')

      if i > 0:
        prev_loss = loss
      with jax.profiler.StepTraceAnnotation('train', step_num=i):
        res, loss = fn(x, y)
      # Runahead max one batch.
      if i > 0:
        jax.block_until_ready(prev_loss)

      if i == 7:
        jax.profiler.stop_trace()


if __name__ == '__main__':
    main()

Profiles here: https://github.com/jlin816/jax-profiling-bug

System info (python version, jaxlib version, accelerator, etc.)

jax: 0.4.31 jaxlib: 0.4.31 numpy: 1.26.4 python: 3.10.12 (main, Jul 29 2024, 16:56:48) [GCC 11.4.0] jax.devices (8 total, 8 local): [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0) TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1) ... TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0) TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)] process_count: 1 platform: uname_result(system='Linux', node='t1v-n-61c88ef7-w-0', release='5.19.0-1022-gcp', version='#24~22.04.1-Ubuntu SMP Sun Apr 23 09:51:08 UTC 2023', machine='x86_64')

jlin816 avatar Sep 05 '24 08:09 jlin816