Complex support for implicit solver
Hi, I am trying to solve a stiff differential equation using diffrax. I need to use complex number and would like to use an implicit solver. I would like to try to implement complex type for implicit solver if it is not too difficult. Is it possible and do you have some global advice that would help me start ?
This is a difficult thing to try and do, and comes with some questions we need to think pretty carefully about (most notably what backpropagation does in this scenario).
FWIW you should be able to make this work today by just treating the real and imaginary parts separately.
Hi @patrick-kidger, we'd be pretty interested in complex support for implicit solvers in the context of dynamiqs. I understand it's a non-trivial task, but if you're interested, we can definitely put in the effort and make the required PRs for diffrax/lineax in the coming months.
What is your concern regarding backpropagation? I'm guessing regular autodiff should work out of the box, but the recursive checkpoint method might be more subtle? Or is it something else?
Regarding the trick to separate real/imaginary parts, would it not be overall slower due to repeated memory accesses? Or is this optimized by the JIT?
Thanks :)
So the main thing that's needed here is just loads of tests for Diffrax! As it wasn't originally written with complex support in mind then it's entirely possible we have places where we write something like x**2 intending to compute a norm, but which with complex numbers will silently misbehave.
I imagine the actual change in lines of code for adding this feature should be fairly small:
- In Diffrax itself, fixing up any examples like the above.
- In Optimistix, being sure that
optx.implicit_jvpdoes the right thing. - In Lineax, this should already basically be done thanks to the hard efforts of @Randl! The main blocker here was that they uncovered an XLA bug (https://github.com/openxla/xla/issues/8471), but that actually got fixed just a few days ago. Once the new jaxlib release is out then I imagine things should be good-to-go there.
On backpropagation, this is in large part about making sure that optx.implicit_jvp does the right thing. JAX follows a pretty quirky convention when it comes to complex backpropagation, which is that the VJP is given by the transpose of the JVP, not the conjugate transpose of the JVP. This is unlike what you normally do when computing the adjoint of a complex matrix, and is also different to PyTorch. I highly recommend reading that PyTorch doc btw, it's very informative.
(In practice what this usually means is that when doing autodiff in complex numbers in JAX, you should compute the conjugate of your gradients before performing SGD.)
I suspect (but am not sure) that we're already doing the right thing for backpropagation, so if you like you can imagine putting this under the 'it needs to be tested' banner.