No Flash-attn variant
Hi RiNALMo team,
First of all, thank you for your excellent work on RiNALMo — it's a valuable contribution to the RNA research!
I'm trying to use the model for inference in environments where FlashAttention isn't available (e.g., on CPU, or on GPU without FlashAttention support). Since FlashAttention is a runtime optimization and does not affect the underlying model weights, I expected that the weights would be compatible with a standard attention implementation (e.g., using torch.nn.MultiheadAttention or similar).
However, I noticed some discrepancies between the FlashAttention-based model and a standard Transformer variant. For instance, there are: Missing keys like: "transformer.blocks.0.mh_attn.mh_attn.to_q.weight" Unexpected keys like: "transformer.blocks.0.mh_attn.Wqkv.weight"
This suggests that the FlashAttention version and the standard attention version may use different internal structures or naming conventions, making it non-trivial to load the same checkpoint into a CPU-compatible version of the model.
Would it be possible to provide a model variant (or configuration) that uses standard attention and is fully compatible with the released checkpoints? Or alternatively, release a version of the pretrained weights that can be used directly with a FlashAttention-free implementation?
Guidance on converting or adapting the model would also be greatly appreciated.
Thanks again for your work, and looking forward to your response!
Best regards, Marek Justyna
Hello Marek! Hope you are doing well! Please check the non_flash branch in the repository. There we replaced FA mechanism with "ordinary" attention, which should be runnable on CPUs and older GPUs. One thing to note though, this branch hasn't been thoroughly tested yet but outputs do seem to match the FA-based version. Let me know if you encounter any issues with it.
Thank you so much! I really appreciate your help. I will try it out soon!
Best wishes!
Ok, so I'm testing it and I encountered a couple of problems.
(1) In some places (e.g. attention.py) you import from flash-attn directly, which should be replace with if-else statement (if flash-attn is not installed then do not use it).
(2) I still have a problem to load giga-v1 weights into the model. If I understand correctly, you have trained an alternative model using the standard Transformer architecture, but where can I download it from? The README file and downloads in the code keep downloading the same giga-v1 version.
(1) Ah, this is true. Currently, we use the rotary positional embedding implementation from the flash attention library. To resolve this issue, you need to either install flash-attention (which should be doable even on older non-Ampere GPUs, rotary embedding implementation should work regardless of the hardware since it isn't exactly tied to the flash attention mechanism) or add your own rotary embedding implementation in the RiNALMo code. I'll try to add the custom rotary embedding implementation to completely remove the need for FA installation. (2) We did not train an alternative model. Weights "giga-v1" should be loadable in the non-FA version as well. We did a few tests on the T4 GPU (which is not supported by FA), and inference is working fine (at least on that GPU). Can you share the error you get when you try to load the weights?
Ok, So I was using the regular model, meaning in the config file I changed the field use_flash_attn to False. Then it loads the regular model without any flash-attn, but the model still raises the same errors: Missing keys and Unexpected keys. Now this makes sense, since I thought it's fully flash-attn free and that it should be changed in the config file to load the other variant of the model.
Non-FA code on the main branch is legacy code that is not compatible with FA weights. As I said, please check the non_flash branch (git checkout non_flash) which contains a compatible version. You don't need to make any changes to the config file. I managed to replace FA RoPE implementation with our own, so you should be able to run the inference without any flash-attention-related installations.
To sum up, use the following commands to install the non-FA version of RiNALMo:
git clone https://github.com/lbcb-sci/RiNALMo
cd RiNALMo
git checkout non_flash
pip install .
Let me know if this works (or if it doesn't 😅)!
@mjustynaPhD i was actually wondering the same thing. I was trying to submit Kaggle's 3D RNA competiton with GraphaRNA, but turns out P100 or T2 GPU couldn't process Flash-Attention. So like @RJPenic said, I changed to RiNALMo non-flash, and trained on A100 SXM x 8 in RUNPOD for $15 /hr for about 25~30 hrs and had 60,70,80 epoch weights. And I also did inference with sample_rna_pdb.py in kaggle notebooks and it worked. 😍😍
Although I didn't really tried infering in kaggle notebooks with original RiNALMo with flash-attention. @mjustynaPhD if u need some weights i can send u one. Thank u both!!
Hi all,
@RJPenic Thank you so much for your support!! I'm sorry, I still did not test it - busy time... @Beck-Pro I'm glad to hear you could run GraphaRNA without problems and that you found it somewhat useful! Good luck!
@RJPenic thank you once again for your support and developing RiNALMo. In my opinion, this is a very powerful model, and it has a significant impact on the RNA community. Congrats!
@Beck-Pro Glad to hear you got it working! Good luck with the competition! 😄
Flash attention (neither training nor inference) does not work in the Kaggle notebook environment, as Kaggle only offers non-Ampere GPUs, which are pretty old. Definitely use the non-FA version of RiNALMo if you are running anything in their environment.
@mjustynaPhD Thank you for the kind words! 😄
@Beck-Pro Sir, have you succeded to predict RNAs over 700nt?(For example, R1138). I'm trying to use the non_flash Version, but it seems not performing well, and also very low TM-Score.
@doheon114 if u give me ur email, i can send u 80epoch weights. Since I trained GraphaRNA with Coarse-grained mode, which doesn't include C1' coordinates, I have to use models such as DRFOLD2 to post process C1' coordinates.
@mjustynaPhD @RJPenic I successfully infered with RiNALMo NON-flash attn rotary embeding and GraphaRNA's model_800epoch.h5 weight even though it(model_800epoch.h5) was pretrained with RiNALMo's flash-attn rotary embedding.
I did some editing in rope.py and sample_rna_pdb.py(GraphaRNA code).
<rope.py>
class RotaryPositionEmbedding(nn.Module):
def __init__(self, dim: int, base: int = 10000, inv_freq: torch.Tensor = None): **# inv_freq added**
super().__init__()
self.dim = dim **# added (no needed tho)**
self.base = base **# added**
if inv_freq is None: **# added**
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq)
self.seq_len_cached = None
self.cos_cached = None
self.sin_cached = None
and
<samle_rna_pdb.py>
model.load_state_dict(torch.load(model_path),strict=False) **# added strict=False**
thank you for your amazing codes both of you!