flax
flax copied to clipboard
More memory consume compared with Pytorch
Hello, flax team. When I tried to transfer Pytorch's models to the flax framework, I find flax will consume more memory than Pytorch's. For example, a ResNet50 model in Pytorch will consume 4G GPU memory, while it rises to 6G in flax. I wonder what causes the difference in memory consumption between Pytorch and Flax? Or What can I do to reduce memory usage in Flax? Thanks!
@chiamp Hello, Is it working as expected, or did I make a mistake? Thanks.
Hi @Sun-Xiaohui, how are you transferring over the model? Could you provide an example code snippet?