flax icon indicating copy to clipboard operation
flax copied to clipboard

More memory consume compared with Pytorch

Open Sun-Xiaohui opened this issue 1 year ago • 2 comments

Hello, flax team. When I tried to transfer Pytorch's models to the flax framework, I find flax will consume more memory than Pytorch's. For example, a ResNet50 model in Pytorch will consume 4G GPU memory, while it rises to 6G in flax. I wonder what causes the difference in memory consumption between Pytorch and Flax? Or What can I do to reduce memory usage in Flax? Thanks!

Sun-Xiaohui avatar Mar 18 '24 03:03 Sun-Xiaohui

@chiamp Hello, Is it working as expected, or did I make a mistake? Thanks.

Sun-Xiaohui avatar Apr 01 '24 05:04 Sun-Xiaohui

Hi @Sun-Xiaohui, how are you transferring over the model? Could you provide an example code snippet?

chiamp avatar May 08 '24 23:05 chiamp