AndrewY

Results 3 comments of AndrewY

Thanks @ChrisAGBlake. I encountered a problem. My GPU memory only has 12 G, and OOM occurs when calculating loss and gradient. Is there any solution?

Maybe is the version of cuda and cudnn not match the jaxlib version? You can find right version [here](https://storage.googleapis.com/jax-releases/jax_cuda_releases.html). By the way, you need to uninstall jax&jaxlib of old version...

Have you figured out the reason? I want to know too.