pytorch-lightning icon indicating copy to clipboard operation
pytorch-lightning copied to clipboard

OnExceptionCheckpoint callback suppresses exceptions and results in NCCL timeout

Open jackdent opened this issue 1 year ago • 0 comments

Bug description

When running Lightning with a multi-device training strategy (e.g. with DDP), using the OnExceptionCheckpoint callback:

  • silently swallows exceptions, which makes it challenging to identify the cause of errors
  • results in a NCCL timeout

This is due to the following:

  • When we catch an exception, it gets handled by _call_and_handle_interrupt, which calls into _interrupt: https://github.com/Lightning-AI/pytorch-lightning/blob/1551a16b94f5234a4a78801098f64d0732ef5cb5/src/lightning/pytorch/trainer/call.py#L67
  • We are supposed to re-raise the original exception at the end of this function, but we never get there because...
  • In _interrupt, we call _call_callback_hooks, which calls the on_exception callbacks: https://github.com/Lightning-AI/pytorch-lightning/blob/1551a16b94f5234a4a78801098f64d0732ef5cb5/src/lightning/pytorch/trainer/call.py#L76
  • If the OnExceptionCheckpoint is enabled, we then call that callback. However, we never finish executing this callback, because in that callback, we call trainer.save_checkpoint: https://github.com/Lightning-AI/pytorch-lightning/blob/master/src/lightning/pytorch/callbacks/on_exception_checkpoint.py#L67
  • The trainer.save_checkpoint method saves the checkpoint, and then calls self.strategy.barrier("Trainer.save_checkpoint"), which waits for the other processes to get reach that barrier. However, if those processes haven't had an exception, they will never hit this codepath, which means we never advance beyond that barrier (until it times out).

As described in the docstring for Trainer.save_checkpoint:

This method needs to be called on all processes in case the selected strategy is handling distributed checkpointing.

In practice, this means that our jobs eventually time out with a NCCL error, and don't print the original exception.

What version are you seeing the problem on?

v2.4

How to reproduce the bug

No response

Error messages and logs

# Error messages and logs here please

Environment

Current environment
#- PyTorch Lightning Version (e.g., 2.4.0):
#- PyTorch Version (e.g., 2.4):
#- Python version (e.g., 3.12):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(`conda`, `pip`, source):

More info

No response

jackdent avatar Aug 10 '24 05:08 jackdent