Adding support for complex dtypes
Hi,
I finally have the time to actually finish this PR. I've been trying to start from scratch instead of going over all of the changes, and currently the solvers are able to return a solution, it's just the wrong solution. I'm not sure how to debug this, any suggestions?
Some suggestions:
- Go through the guts of the integration code, adding
jax.experimental.host_callback.id_printstatements to see what value each array takes, and where things go wrong. - Try using e.g.
Tsit5().stepdirectly instead of thediffeqsolveinterface. Check if that works as you expect. - Try using only constant step sizes, and not adaptive step sizing, if you aren't already. Just trying to get the simplest possible thing working first.
It seems that jax.experimental.host_callback.id_print is doing nothing when I run in debug mode. Should I use it directly in a print statement?
The default tolerances for allclose are rtol=1e-05 and atol=1e-08. With stepsize dt0=0.1 this kind of precision is too high I think, in test_basic for a similar test dt0=0.01 is used and for the allclose rtol=1e-02 and atol=1e-02 are chosen. Your approach fixes the issue for the Euler solver, however there is another place in the Runge-Kutta solver that needs to be assigned a correct dtype. I'll check and test if that is all and probably will submit a PR if I can manage to find and tweak all the lines that are causing issues with complex numbers.
Closing in favour of #330.