esmjax
esmjax copied to clipboard
Have you managed to train the model
The training time for the model is very slow - 400 hours per epoch for 3B parameter as compared to 1.5 hours per epoch for my Keras data-parallel implementation of 650M model