research-contributions icon indicating copy to clipboard operation
research-contributions copied to clipboard

While fine tuning the Swin UNETR, Training loss is not decreasing and training gets crashed after 10 epohs

Open Mgithus opened this issue 2 years ago • 10 comments

Describe the bug I am trying to reproduce the Swin UNETR. I am doing finetuning, using the code and model.pt file given at: https://github.com/Project-MONAI/research-contributions/tree/main/SwinUNETR/BRATS21

That model was trained using BraTS 2021 data. I am using Brats 2023 data provided on request from synapse: https://www.synapse.org/#!Synapse:syn27046444/wiki/616992

I am using an A100 GPU provided by Colab Pro, running following command line for finetuning, using 1 GPU:

!python '/content/drive/MyDrive/Mgithus/SWIN/SwinUNETR/BRATS21/main.py' --json_list='/content/drive/MyDrive/data/whole_2023/ASNR-MICCAI-BraTS2023-GLI-Challenge-TrainingData.json' --data_dir='/content/drive/MyDrive/data/whole_2023/train_ds/ASNR-MICCAI-BraTS2023-GLI-Challenge-TrainingData' --val_every=10 --noamp --pretrained_model_name='Swin UNETR'
--pretrained_dir='/content/drive/MyDrive/fold1_f48_ep300_4gpu_dice0_9059/fold1_f48_ep300_4gpu_dice0_9059/model.pt' --fold=1 --roi_x=128 --roi_y=128 --roi_z=128 --in_channels=4 --spatial_dims=3 --use_checkpoint --feature_size=48 --max_epochs=80 --batch_size=3 --workers=12

In paper, Swin UNETR: Swin Transformers for Semantic Segmentation of Brain Tumors in MRI Images: https://arxiv.org/abs/2201.01266 they have achieved Avg dice score of 0.913. but I train it on BraTS 2024, it did not show any inpactful reduction in loss, also gives error after 1st val, as follows :

0th epoch

5th epoch

6th epoch After 5th epoch loss increased from 0.9465 to 0.97 instead of decreasing....

val

10 epochs

Am I not using the pretrained model in correct way? What hyperparameter values can help in increasing the dice score, as ET is 0 till the end of third epoch. Although the input data have 1251 sample folders , even if I am using 4 fold cross validation, model gives 936 iterations instead of 939 at batch size of 1, using T4 GPU. Is it related to runtime type, the cuda out of memory problem?

T4 GPU

Mgithus avatar Sep 10 '23 03:09 Mgithus

same issue

FengheTan9 avatar Sep 21 '23 07:09 FengheTan9

Hi, have you figured it out?

Luffy03 avatar Oct 10 '23 13:10 Luffy03

Cuda out of memory ... Crashing of training problem was solved by reducing no.of workers and batch size to 1... But problem with increasing loss after 5th or 6th epoch is still there...

Mgithus avatar Oct 11 '23 02:10 Mgithus

Cuda out of memory ... Crashing of training problem was solved by reducing no.of workers and batch size to 1... But problem with increasing loss after 5th or 6th epoch is still there...

Thx for sharing! Would you please share your Monai and pytorch version? I also meet the same problem and I have found a solution here (https://github.com/Project-MONAI/model-zoo/issues/180). But it does not work for me .....

Luffy03 avatar Oct 11 '23 02:10 Luffy03

My pleasure... I will try it .... Thnx....

Mgithus avatar Oct 11 '23 03:10 Mgithus

I used Google Colab Pro Plus and it automatically installed the latest versions of Monai and Pytorch directly without specifying a specific version. However, when I tried to run this model in the virtual environment in VS code using the latest versions (then they were Monai 1.2 and Pytorch 2.0.1), it did not work.

Mgithus avatar Oct 11 '23 03:10 Mgithus

I used Google Colab Pro Plus and it automatically installed the latest versions of Monai and Pytorch directly without specifying a specific version. However, when I tried to run this model in the virtual environment in VS code using the latest versions (then they were Monai 1.2 and Pytorch 2.0.1), it did not work.

I still struggle to implement it ......

Luffy03 avatar Oct 17 '23 17:10 Luffy03

@Mgithus it shows ET as 0 because the segmentation labels are different. BRATS labelling "ConvertMultiChannel..." takes 0, 1, 2, 4 based on previous datasets segmentation labels (..., 2021). The labelling changed to 0, 1, 2, 3 in 2023. You have to fix that by either modifying the labels on transforms."ConvertMultiChannel..." with a custom code & the 2021 seg labels, then retrain the model, or use nibabel to convert label 3 to 4 on the 2023 & 2024 data. It's a label mismatch error.

NkwamPhilip avatar Aug 07 '24 17:08 NkwamPhilip

Hi, we reproduce the results at https://github.com/Luffy03/Large-Scale-Medical. You can find our implementation at https://github.com/Luffy03/Large-Scale-Medical/tree/main/Downstream/monai/BRATS21.

Luffy03 avatar Oct 14 '24 11:10 Luffy03

@Mgithus it shows ET as 0 because the segmentation labels are different. BRATS labelling "ConvertMultiChannel..." takes 0, 1, 2, 4 based on previous datasets segmentation labels (..., 2021). The labelling changed to 0, 1, 2, 3 in 2023. You have to fix that by either modifying the labels on transforms."ConvertMultiChannel..." with a custom code & the 2021 seg labels, then retrain the model, or use nibabel to convert label 3 to 4 on the 2023 & 2024 data. It's a label mismatch error.

hi,do you know BRATS2018's labels?please tell me, I really appreciate you!

sekioo avatar Apr 03 '25 09:04 sekioo