flax
flax copied to clipboard
flax is significantly slower than pytorch
minist classification using convolutional networks link: https://flax.readthedocs.io/en/latest/quick_start.html
same approach on pytorch is significantly faster, (almost 5 times faster) flax version takes 50 seconds per epoch!!!
Can you point me to or provide the pytorch code so I can do some benchmarking?