algorithmic-efficiency icon indicating copy to clipboard operation
algorithmic-efficiency copied to clipboard

Shampoo conformer workload hangs

Open priyakasimbeg opened this issue 2 years ago • 4 comments

The conformer workload hangs when run with shampoo training algorithm.

Description

Traceback

I0505 23:26:00.158526 139795269302080 submission_runner.py:319] Starting training loop.
I0505 23:26:00.373614 139795269302080 input_pipeline.py:20] Loading split = train-clean-100
I0505 23:26:00.410641 139795269302080 input_pipeline.py:20] Loading split = train-clean-360
I0505 23:26:00.817146 139795269302080 input_pipeline.py:20] Loading split = train-other-500
2023-05-05 23:32:21.267134: E external/org_tensorflow/tensorflow/compiler/xla/service/rendezvous.cc:31] This thread has been waiting for 10 seconds and may be stuck:

Steps to Reproduce

Pull the docker image:

$ docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/base_image:timing

Run the container and entrypoint script which will launch a submission runner:

$ docker run -t -d -v /home/kasimbeg/data/:/data/ -v /home/kasimbeg/experiment_runs/:/experiment_runs -v /home/kasimbeg/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/base_image:timing -d librispeech -f jax -s baselines/shampoo/jax/submission.py -w librispeech_conformer -t baselines/shampoo/tuning_search_space_conformer.json -e timing_fancy_2_redo/timing_shampoo -m 20000 -c False -o True -r False 

To see output of submission_runner.py monitor the logs of the container: $ docker logs -f <container_id printed by previous command>

Source or Possible Fix

I think this may be an XLA memory issue. On a different VM the runs got a little further along and errored out with a seemingly memory related issue. I restarted all the VMs and they don't get any further along then the above message. I may have changed some environment flags on the VM that got further along. I tried setting XLA_PYTHON_CLIENT_PREALLOCATE=false which didn't do anything and setting XLA_PYTHON_CLIENT_MEM_FRACTION=.80 which made it error out sooner.

For reference the output of the run that got further:

I0504 04:13:55.604297 139669110556480 submission_runner.py:415] Time since start: 5507.98s, 	Step: 3192, 	{'train/ctc_loss': DeviceArray(1.8280892, dtype=float32), 'train/wer': 0.44060410729030713, 'validation/ctc_loss': DeviceArray(2.347627, dtype=float32), 'validation/wer': 0.4757016469044564, 'validation/num_examples': 5348, 'test/ctc_loss': DeviceArray(1.9786975, dtype=float32), 'test/wer': 0.4258119553957711, 'test/num_examples': 2472, 'score': 5199.0200300216675, 'total_duration': 5507.9753386974335, 'accumulated_submission_time': 5199.0200300216675, 'accumulated_eval_time': 308.8330419063568, 'accumulated_logging_time': 0.07393193244934082}
I0504 04:13:55.626670 139496276358912 logging_writer.py:48] [3192] accumulated_eval_time=308.833042, accumulated_logging_time=0.073932, accumulated_submission_time=5199.020030, global_step=3192, preemption_count=0, score=5199.020030, test/ctc_loss=1.9786975383758545, test/num_examples=2472, test/wer=0.425812, total_duration=5507.975339, train/ctc_loss=1.8280892372131348, train/wer=0.440604, validation/ctc_loss=2.3476269245147705, validation/num_examples=5348, validation/wer=0.475702
I0504 04:14:11.395573 139496267966208 logging_writer.py:48] [3200] global_step=3200, grad_norm=0.9491458535194397, loss=1.8900938034057617
I0504 04:16:31.686281 139496276358912 logging_writer.py:48] [3300] global_step=3300, grad_norm=0.8079001307487488, loss=1.9073154926300049
I0504 04:18:51.079895 139496267966208 logging_writer.py:48] [3400] global_step=3400, grad_norm=0.7481346726417542, loss=1.8942415714263916
I0504 04:21:10.230679 139496276358912 logging_writer.py:48] [3500] global_step=3500, grad_norm=0.8360145092010498, loss=1.8996608257293701
I0504 04:23:28.858766 139496267966208 logging_writer.py:48] [3600] global_step=3600, grad_norm=1.0013821125030518, loss=1.9014889001846313
I0504 04:25:48.481894 139496276358912 logging_writer.py:48] [3700] global_step=3700, grad_norm=0.9406089186668396, loss=1.8606724739074707
I0504 04:28:07.921166 139496267966208 logging_writer.py:48] [3800] global_step=3800, grad_norm=0.8744378089904785, loss=1.8769983053207397
2023-05-04 04:30:00.136705: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2130] Execution of replica 5 failed: INTERNAL: Failed to launch CUDA kernel: fusion_59 with block dimensions: 32x1x1 and grid dimensions: 1x1x1: CUDA_ERROR_LAUNCH_FAILED: unspecified launch failure
2023-05-04 04:30:09.848574: E external/org_tensorflow/tensorflow/compiler/xla/service/rendezvous.cc:31] This thread has been waiting for 10 seconds and may be stuck:
2023-05-04 04:30:09.848806: E external/org_tensorflow/tensorflow/compiler/xla/service/rendezvous.cc:31] This thread has been waiting for 10 seconds and may be stuck:
2023-05-04 04:30:09.850079: E external/org_tensorflow/tensorflow/compiler/xla/service/rendezvous.cc:31] This thread has been waiting for 10 seconds and may be stuck:
2023-05-04 04:30:09.852182: E external/org_tensorflow/tensorflow/compiler/xla/service/rendezvous.cc:31] This thread has been waiting for 10 seconds and may be stuck:
2023-05-04 04:30:09.855277: E external/org_tensorflow/tensorflow/compiler/xla/service/rendezvous.cc:31] This thread has been waiting for 10 seconds and may be stuck:
2023-05-04 04:30:09.855440: E external/org_tensorflow/tensorflow/compiler/xla/service/rendezvous.cc:31] This thread has been waiting for 10 seconds and may be stuck:
2023-05-04 04:30:09.862346: E external/org_tensorflow/tensorflow/compiler/xla/service/rendezvous.cc:31] This thread has been waiting for 10 seconds and may be stuck:
2023-05-04 04:30:10.145779: F external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2275] Replicated computation launch failed, but not all replicas terminated. Aborting process to work around deadlock. Failure message (there may have been multiple failures, see the error log for all failures): 

Failed to launch CUDA kernel: fusion_59 with block dimensions: 32x1x1 and grid dimensions: 1x1x1: CUDA_ERROR_LAUNCH_FAILED: unspecified launch failure
Fatal Python error: Aborted

To debug in container:

Run the container without starting the submission runner (not passing in a value for the -s flag):

$ docker run -t -d -v /home/kasimbeg/data/:/data/ -v /home/kasimbeg/experiment_runs/:/experiment_runs -v /home/kasimbeg/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/base_image:timing  -r false -b true

Start an interactive bash session in the running container:

$ docker exec --it <container_id> /bin/bash

Run submission_runner.py in the container:

$ python3 submission_runner.py --framework=jax --workload=librispeech_conformer --submission_path=baselines/shampoo/jax/submission.py --tuning_search_space=baselines/shampoo/tuning_search_space.json --data_dir=/data/librispeech --num_tuning_trials=1 --experiment_dir=/experiment_runs --experiment_name=timing_fancy_2_redo/timing_shampoo --overwrite=True --save_checkpoints=False --max_global_steps=20000 --librispeech_tokenizer_vocab_path=/data/librispeech/spm_model.vocab 2>&1 | tee -a /logs/librispeech_conformer_jax

You can also pull the code to the host VM and mount the local repo so that you can make changes to the code without losing them.

  • Pull the repo:
$ cd $HOME
$ git clone https://github.com/priyakasimbeg/algorithmic-efficiency.git 
$ git fetch origin && git pull && git checkout shampoo_debugging

Run the container w the mounted dir:

docker run -t -d -v /home/kasimbeg/data/:/data/ -v /home/kasimbeg/experiment_runs/:/experiment_runs -v /home/kasimbeg/experiment_runs/logs:/logs -v $HOME/algorithmic-efficiency:/algorithmic-efficiency --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/base_image:timing  -r False -b

priyakasimbeg avatar May 06 '23 01:05 priyakasimbeg

For completeness the conformer run that did not get stuck printed these warnings:

I0504 02:42:07.626892 139669110556480 submission_runner.py:318] Starting training loop.
I0504 02:42:07.827529 139669110556480 input_pipeline.py:20] Loading split = train-clean-100
I0504 02:42:07.855990 139669110556480 input_pipeline.py:20] Loading split = train-clean-360
I0504 02:42:08.184822 139669110556480 input_pipeline.py:20] Loading split = train-other-500
/algorithmic-efficiency/baselines/shampoo/jax/distributed_shampoo.py:592: UserWarning: Explicitly requested dtype <class 'jax.numpy.float64'> requested in astype is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
  matrix = matrix.astype(_MAT_INV_PTH_ROOT_DTYPE)
/algorithmic-efficiency/baselines/shampoo/jax/distributed_shampoo.py:593: UserWarning: Explicitly requested dtype <class 'jax.numpy.float64'> requested in asarray is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
  alpha = jnp.asarray(-1.0 / p, _MAT_INV_PTH_ROOT_DTYPE)
/algorithmic-efficiency/baselines/shampoo/jax/distributed_shampoo.py:594: UserWarning: Explicitly requested dtype <class 'jax.numpy.float64'> requested in eye is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
  identity = jnp.eye(matrix_size, dtype=_MAT_INV_PTH_ROOT_DTYPE)
/algorithmic-efficiency/baselines/shampoo/jax/distributed_shampoo.py:477: UserWarning: Explicitly requested dtype <class 'jax.numpy.float64'> requested in eye is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
  power = jnp.eye(mat_m.shape[0], dtype=_MAT_INV_PTH_ROOT_DTYPE)
/usr/local/lib/python3.8/dist-packages/jax/interpreters/mlir.py:592: UserWarning: Some donated buffers were not usable: ShapedArray(float32[512]), ShapedArray(float32[512]), ShapedArray(float32[512]), ShapedArray(float32[512]), ShapedArray(float32[512]), ShapedArray(float32[512]), ShapedArray(float32[512]), ShapedArray(float32[512]).
See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer-donation.
  warnings.warn(f"Some donated buffers were not usable: {', '.join(unused_donations)}.\n{msg}")
I0504 02:48:44.909264 139494363756288 logging_writer.py:48] [0] global_step=0, grad_norm=46.149715423583984, loss=31.95041275024414
I0504 02:48:44.985919 139669110556480 spec.py:298] Evaluating on the training split.
I0504 02:48:45.101763 139669110556480 input_pipeline.py:20] Loading split = train-clean-100
I0504 02:48:45.130307 139669110556480 input_pipeline.py:20] Loading split = train-clean-360
I0504 02:48:45.215455 139669110556480 input_pipeline.py:20] Loading split = train-other-500
/usr/local/lib/python3.8/dist-packages/jax/_src/ops/scatter.py:87: FutureWarning: scatter inputs have incompatible types: cannot safely cast value from dtype=float32 to dtype=int32. In future JAX releases this will result in an error.
  warnings.warn("scatter inputs have incompatible types: cannot safely cast "
I0504 02:49:43.929699 139669110556480 spec.py:310] Evaluating on the validation split.
I0504 02:49:44.000639 139669110556480 input_pipeline.py:20] Loading split = dev-clean
I0504 02:49:44.005420 139669110556480 input_pipeline.py:20] Loading split = dev-other
I0504 02:50:26.653083 139669110556480 spec.py:326] Evaluating on the test split.
I0504 02:50:26.725322 139669110556480 input_pipeline.py:20] Loading split = test-clean
I0504 02:50:57.282124 139669110556480 submission_runner.py:415] Time since start: 529.65s, 	Step: 1, 	{'train/ctc_loss': DeviceArray(31.262894, dtype=float32), 'train/wer': 1.5752139599775108, 'validation/ctc_loss': DeviceArray(30.076275, dtype=float32), 'validation/wer': 1.145539271965962, 'validation/num_examples': 5348, 'test/ctc_loss': DeviceArray(30.172346, dtype=float32), 'test/wer': 1.2194259947596124, 'test/num_examples': 2472, 'score': 397.35881662368774, 'total_duration': 529.6535410881042, 'accumulated_submission_time': 397.35881662368774, 'accumulated_eval_time': 132.29455065727234, 'accumulated_logging_time': 0}
I0504 02:50:57.306291 139491301910272 logging_writer.py:48] [1] accumulated_eval_time=132.294551, accumulated_logging_time=0, accumulated_submission_time=397.358817, global_step=1, preemption_count=0, score=397.358817, test/ctc_loss=30.172346115112305, test/num_examples=2472, test/wer=1.219426, total_duration=529.653541, train/ctc_loss=31.262893676757812, train/wer=1.575214, validation/ctc_loss=30.076274871826172, validation/num_examples=5348, validation/wer=1.145539
I0504 02:59:56.601535 139496018323200 logging_writer.py:48] [100] global_step=100, grad_norm=3.0943682193756104, loss=5.815123081207275
I0504 03:02:11.086658 139496026715904 logging_writer.py:48] [200] global_step=200, grad_norm=1.56673264503479, loss=5.7747578620910645

priyakasimbeg avatar May 06 '23 02:05 priyakasimbeg

Strongly believe this is a memory issue. The workload runs fine w smaller model or smaller batch size.

priyakasimbeg avatar May 09 '23 17:05 priyakasimbeg

Marking as obsolete

priyakasimbeg avatar Dec 01 '23 02:12 priyakasimbeg

Reopening to explore feasibility of submission.

priyakasimbeg avatar Feb 06 '24 21:02 priyakasimbeg