RiNALMo icon indicating copy to clipboard operation
RiNALMo copied to clipboard

Are the downloadable finetuned weights for secondary structure prediction intra- or inter-familty trained?

Open shanry opened this issue 11 months ago • 9 comments

I noticed that the the split in the archiveII dataset (fam-fold) is intra-family based. However, the paper claims that the fine-tuning on archiveII used an inter-family split.

Could the author clarify a bit about how the downloadable weights are fine-tuned?

shanry avatar Feb 24 '25 01:02 shanry

Hello 😄, As stated in the paper: "The dataset of 3865 RNAs from nine families was split nine times, and in each split, a different family was held out for evaluation while the other eight families were used for training and validation."

I just tried out weights that were used for SRP family evaluation and results are as expected. Are you sure you are using the right weights? File rinalmo_giga_ss_archiveII-srp_ft.pt contains weights of the model trained (and validated) with all families except SRP.

RJPenic avatar Feb 24 '25 10:02 RJPenic

Thank you for your clarification! I’m now able to obtain the expected results for SRP using rinalmo_giga_ss_archiveII-srp_ft.pt.

If I want to replicate the finetuning process as well, do I need to reorganize the folder hierarchy so that it follows a "one family vs. all other families" structure? Could you confirm if this is the correct approach?

shanry avatar Feb 26 '25 07:02 shanry

I did use the default data splits to fine-tune the pre-trained model (it seems the files are already organized in an inter-family format). However, the test results are far from expected. Could you please see if the hyperparameters are properly set?

I was using the cmd: /nfs/hpc/share/zhoutian/repos/RiNALMo/train_sec_struct_prediction.py ./ss_data2 --pretrained_rinalmo_weights ./weights/rinalmo_giga_pretrained.pt --output_dir myft_5s --dataset archiveII_5s --accelerator gpu --max_epochs 15 --wandb --ft_schedule ft_schedules/giga_sec_struct_ft.yaml

The summary metrics are: { "_runtime": 12287.808371543884, "_step": 6562, "_timestamp": 1740607260.1589224, "_wandb.runtime": 12289, "epoch": 15, "lr-Adam": 0.00008306000000000176, "lr-Adam/pg1": 0.000050000000000001425, "lr-Adam/pg2": 0.000050000000000001425, "lr-Adam/pg3": 0.000050000000000001425, "lr-Adam/pg4": 0.000050000000000001425, "lr-Adam/pg5": 0.000050000000000001425, "test/f1": 0.35909613966941833, "test/loss": 0.029532011243427175, "test/precision": 0.563310444355011, "test/recall": 0.2690052092075348, "train/loss": 0.0020109469678422976, "trainer/global_step": 34860, "val/f1": 0.9452812671661376, "val/loss": 0.00011648952716895472, "val/threshold": 0.12999999523162842 }

Notice the validation f1 is very high but test f1 is very low.

shanry avatar Feb 27 '25 09:02 shanry

I found a few discrepancies in the code compared to our internal version that caused the learning rate to be higher than expected. Thanks for pointing this out. High learning rates during fine-tuning tend to overwrite pre-trained knowledge of the LM. I pushed new changes to the main branch. Could you please pull the latest commit and try repeating the experiments? Results should now align with what we reported in the paper.

You can use this command to run the experiment: python train_sec_struct_prediction.py ./ss_data/ --pretrained_rinalmo_weights ./weights/rinalmo_giga_pretrained.pt --output_dir ./tmp_out --dataset archiveII_5s --accelerator gpu --devices 1 --max_epochs 15 --tune_threshold_every_n_epoch 15 --ft_schedule ft_schedules/giga_sec_struct_ft.yaml --precision bf16-mixed --num_workers 4

RJPenic avatar Feb 28 '25 09:02 RJPenic

Thank you for double-checking the hyperparameters and updating the code—I really appreciate it. I pulled the latest version and fine-tuned the model on three datasets: archiveII_5s, archiveII_srp, and bpRNA. Here are the test metrics for reference:

5s: F1 = 0.860, Precision = 0.969, Recall = 0.780 srp: F1 = 0.695, Precision = 0.781, Recall = 0.649 bpRNA: F1 = 0.742, Precision = 0.777, Recall = 0.726 The results for srp and bpRNA align well with the paper, though there is a gap in the F1 score for the 5s family.

Before proceeding with fine-tuning on other archiveII families, I wanted to check if there’s anything I might have overlooked. For instance, I noticed that the default batch size is set to 1, which is uncommon in machine learning. Was this the batch size used in the paper?

Looking forward to your thoughts!

shanry avatar Mar 01 '25 08:03 shanry

Yes, batch size was set to one in the paper as well. While it is true that such batch size is uncommon in machine learning, it isn't that unusual when it comes to SS prediction based on DL (for example MXfold2 and Ufold also had their training batch sizes set to one). Main problem is that for SS prediction you usually need to featurize/model all possible nucleotide pairings which leads to a quadratic memory complexity. As we conducted fine-tuning experiments on a bit "weaker" GPUs (12GB - 16GB) compared to the GPUs we used for pre-training (A100, ~80GB), we decided to set the batch size to one to be able to process a bit longer sequences during training.

RJPenic avatar Mar 01 '25 14:03 RJPenic

Hey Rafael - Thanks for the amazing worked, I enjoyed reading your manuscript. I was wondering if it would be possible to make available the secondary structure prediction finetuned weights for the 150M parameter model. I only see the pre-trained ones. Thanks!

fhidalgor avatar Mar 09 '25 01:03 fhidalgor

In the process of fine-tuning with the ArchiveII dataset, the F1 score on the validation set remains consistently around 97% for every training split. However, as reported in the paper, the F1 scores on the test sets vary significantly, ranging from 0.12 to 0.93. This discrepancy suggests that the data splits may not be well-balanced or representative, potentially affecting the reliability of the evaluation.

shanry avatar Mar 15 '25 05:03 shanry

As we stated in the paper, we used the data splits (training, validation and test) proposed in this study/benchmark. For the test sets, they excluded specific families from the rest of the dataset and then randomly split the remaining RNAs into training and validation. This training-validation split approach obviously isn't the best and probably contains a certain level of data leak. Additionally, certain RNA families make up a larger portion of the dataset (e.g. 5S rRNAs), leading to them having a "stronger influence" on the validation metrics. Still, we decided to go with the proposed validation set to ensure fair comparison with other tools, as the training procedure for certain models proved difficult to replicate (most notably UFold). It must be noted, however, that these issues are specific for the validation set and that the test sets are not affected (they are family-wise non-redundant to training/validation sets and contain only one family), making the evaluation reliable.

RJPenic avatar Mar 15 '25 08:03 RJPenic