torchtitan icon indicating copy to clipboard operation
torchtitan copied to clipboard

Process got stuck when trying to optimize different groups of parameters using different types of data

Open Yangyi-Chen opened this issue 1 year ago • 3 comments

Hi,

I'm adding a new linear projection layer (nn.Linear) to the original Llama3 architecture to process a new type of data. During training, I use two types of data (language-only and multimodal data). When using language-only data, the whole Llama-3 parameters will be finetuned. When using multimodal data, the whole Llama-3 parameters and the parameters in the added linear layer will be finetuned. Both of them can function well independently.

However, when I combined these two types of data to do multi-task learning, the process just got stuck without any further information. Doesn't the current torchtitan support this kind of function? Thanks.

### Tasks

Yangyi-Chen avatar Sep 18 '24 22:09 Yangyi-Chen

For some further information, I use a single node, multi-GPU distributed training. When waiting for a long time, I received the following messages:

[rank0]: return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass [rank0]:[rank0]:[E918 17:48:37.892017038 ProcessGroupNCCL.cpp:1423] [PG ID 0 PG GUID 0(default_pg) Rank 0] Observed flight recorder dump signal from another rank via TCPStore. [rank0]:[rank0]:[E918 17:48:37.892143284 ProcessGroupNCCL.cpp:1484] [PG ID 0 PG GUID 0(default_pg) Rank 0] Received a dump signal due to a collective timeout from rank 3 and we will try our best to dump the debug info. Last enqueued NCCL work: 108, last completed NCCL work: 107.This is most likely caused by incorrect usages of collectives, e.g., wrong sizes used across ranks, the order of collectives is not same for all ranks or the scheduled collective, for some reason, didn't run. Additionally, this can be caused by GIL deadlock or other reasons such as network errors or bugs in the communications library (e.g. NCCL), etc. [rank0]:[rank0]:[E918 17:48:37.892317119 ProcessGroupNCCL.cpp:1288] [PG ID 0 PG GUID 0(default_pg) Rank 0] ProcessGroupNCCL preparing to dump debug info. [rank0]:[rank0]:[E918 17:48:37.935023931 ProcessGroupNCCL.cpp:616] [Rank 0] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=108, OpType=ALLREDUCE, NumelIn=1, NumelOut=1, Timeout(ms)=300000) ran for 300032 milliseconds before timing out. [rank0]:[rank0]:[E918 17:48:37.938135753 ProcessGroupNCCL.cpp:1785] [PG ID 0 PG GUID 0(default_pg) Rank 0] Exception (either an error or timeout) detected by watchdog at work: 108, last enqueued NCCL work: 108, last completed NCCL work: 107.

Yangyi-Chen avatar Sep 18 '24 22:09 Yangyi-Chen

It may help if you can provide a repro of some kind and/or give some more information about what parallelism you are using.

awgu avatar Sep 19 '24 07:09 awgu

Hi, Thanks for the follow-up question. I basically use the default setting as in the ./train_configs/llama3_8b.toml file.

[training] batch_size = 1 seq_len = 8192 # 8192 # 16384 warmup_steps = 200 # lr scheduler warm up max_norm = 1.0 # grad norm clipping steps = 3000 data_parallel_degree = -1 tensor_parallel_degree = 1 enable_fp8_linear = false compile = true dataset = "imagenet+dclm"

[experimental] pipeline_parallel_degree = 1

Yangyi-Chen avatar Sep 19 '24 07:09 Yangyi-Chen

Is it possible that you have any kind of conditional computation? For example, one data parallel rank does not receive the multimodal data, so the linear layer did not get used? It also depends a bit on how you applied FSDP to the modules.

Is it difficult to provide a way to repro the issue so that we can help debug? (I understand it might be very hard but just wanted to ask.)

awgu avatar Sep 20 '24 15:09 awgu

Yes. It can happen (one data parallel rank uses the linear layer and the others do not). SO it seems like the current implementation doesn't support such function, right?

Yes 😂 it is still an ongoing project so we do not opensource the code yet.

Yangyi-Chen avatar Sep 24 '24 20:09 Yangyi-Chen

SO it seems like the current implementation doesn't support such function, right?

yea... you might need to feed some dummy data through since this is breaking SPMD semantics, there is no way for the rank not using the linear to know that it should still participate in collectives in our case

awgu avatar Sep 24 '24 20:09 awgu

I see. Thanks for your help!

Yangyi-Chen avatar Sep 24 '24 20:09 Yangyi-Chen

Just one quick question. When we run the dummy input through the added linear layer, do we need to compute the gradient for the linear layer regarding this dummy part? Or just runing the dummy input through the entire model (the added linear layer and the whole Transformer) be enough?

Yangyi-Chen avatar Oct 01 '24 21:10 Yangyi-Chen

I think there is a bit of nuance depending on how you apply FSDP to the model. If you are not directly calling fully_shard on that linear but rather some parent module, then I think it should be ok to not compute gradient wrt. that linear with the dummy input 🤔

awgu avatar Oct 01 '24 22:10 awgu

Thanks for the clarification!

Yangyi-Chen avatar Oct 02 '24 17:10 Yangyi-Chen

hi guys, i accidentally find this issue and leave comment, even though it may not be helpful :) maybe this deepspeed issue is related to this issue? TL;DR, if process-1 consume text-only minibatches and process-2 take vision-text minibatchs, their forward graph are different, so during backpropagation process-2 will request all-reduce for vision module but process-1 will not and process-2 will wait forever. so llava project make dummy image and forward them for text-only batches, and finally make them empty right before the llm. this trick works for zero-3 (microsoft's FSDP).

SeunghyunSEO avatar Nov 26 '24 06:11 SeunghyunSEO

@SeunghyunSEO this is useful. Got a question though. LLava feeds empty images for all text batches, will this make the vit get updated with the dummy data & text input?

lucasjinreal avatar May 17 '25 00:05 lucasjinreal