Replace GMRES with a wrapper to the Jax version
Jax proper now supports GMRES, and the version there should be better than ours in several ways. If we are willing to enforce that one have the latest Jax version we should replace our code with a wrapper.
I was wondering if I could contribute to this? But, this is my first time contributing to an open repository. Can anyone walk me through what needs to be done?
Hey there, Sounds great! But first we need to clarify if we are ready to drop support for earlier Jax versions. @mganahl what do you think
SGTM
Ok, then. @alewis if you don't mind, would you be interested in walking me through what needs to be done. It would be really helpful if you could point me to a place where I could start.
Hi,
Ok so the relevant code is in backends/jax/jax_backend.py. Basically you need to change lines 693-694 so that they call the gmres code in base Jax instead of the one we wrote in backends/jax/jitted_functions.py.