Feature Request: Complex-Valued Integration With ZVODE - CVODE in Jax (autodiff)
Hi,
Unfortunately, all the available (S)ODE integration subroutines in auto-differentiable Python frameworks (RK45, Dopri, etc.) behave very poorly with complex-valued functions [*]. In the Python ecosystem, only Scipy's Fortran wrappers titled ode (ZVODE) and complex_ode (using CVODE) seem to be working fine, but obviously, they are not differentiable and not applicable to the modern applications we love.
I was wondering if anybody wants to adapt these features to Diffrax, and make them auto-differentiable.
[*] https://arxiv.org/abs/2406.06361
We have some limited support for complex numbers in Diffrax. In particular I think all of the explicit solvers (Tsit5, Dopri etc.) should behave correctly. Glancing at the paper I can see they briefly mention Diffrax, but apparently indicate they had some difficulty getting reverse-mode working. I've not seen a bug report from them though so there's not much I can do until then. 🤷
More importantly though, I believe this whole thing is essentially a non-issue. It's trivial to make any real integrator work with complex numbers: just split into real and imaginary parts before passing your initial condition into the solver, and then combine them back together inside your vector field. Job done.
Hi. Thanks to @onurdanaci for asking the question and to @lockwo for pointing this question out to me. Apologies to @patrick-kidger for not posting the issue earlier (I am the author of the document mentioned above).
I have an MWE below. I would be happy to be told that this issue is minor or that I am using Diffrax incorrectly.
import diffrax
from diffrax import diffeqsolve, ODETerm, Tsit5
import jax
import jax.numpy as jnp
jax.config.update("jax_enable_x64", True)
def solver(y0, t_f, A, use_direct):
def ode_fn(t, y, B):
return jnp.matmul(B[0],y)
term = ODETerm(ode_fn)
ODEsolver = Tsit5()
solver_args = dict(t0=0.0, t1=t_f.real, dt0=0.2, y0=y0, args=(A,))
#Required for forward mode only
if use_direct == True:
solver_args |= dict(adjoint=diffrax.DirectAdjoint())
solution = diffrax.diffeqsolve(term, ODEsolver, **solver_args)
return solution.ys[0]
def driver(params, use_direct):
#Create y0 from params
time = params[0] * jnp.pi
cos = jnp.cos(time / 2)
sin = jnp.sin(time / 2)
axis_angle = params[1] * jnp.pi
KET_0 = jnp.array([1, 0], dtype=jnp.complex128) # |0>, spin up
KET_1 = jnp.array([0, 1], dtype=jnp.complex128) # |1>, spin down
y0 = cos * KET_1 - 1j * jnp.exp(-1j * axis_angle) * sin * KET_0
A = jnp.array([[0-1j, 1.0+2j],
[- 100.0+3j, 0+4j]], dtype=jnp.complex128)
#Evolve y0. Time is influenced by params
y = solver(y0, time, A, use_direct)
return y
params_f = jnp.array([0.5,0.4], dtype=jnp.float64)
jacfwd_fun = jax.jacfwd(driver, argnums=(0))
jac_f = jacfwd_fun(params_f, True)
#Must be complex for reverse mode
params_b = jnp.array([0.5+0j,0.4+0j], dtype=jnp.complex128)
#Must set holomorphic=True for reverse mode
jacrev_fun = jax.jacrev(driver, argnums=(0), holomorphic=True)
jac_b = jacrev_fun(params_b, False)
print(jac_f-jac_b)
This generates
/usr/local/lib/python3.10/dist-packages/equinox/_jit.py:51: UserWarning: Complex dtype support is work in progress, please read https://github.com/patrick-kidger/diffrax/pull/197 and proceed carefully.
out = fun(*args, **kwargs)
/usr/local/lib/python3.10/dist-packages/equinox/_jit.py:51: UserWarning: Complex dtype support is work in progress, please read https://github.com/patrick-kidger/diffrax/pull/197 and proceed carefully.
out = fun(*args, **kwargs)
[[7.91624188e-09+1.36585934e+06j 1.16415322e-09-2.06637196e-09j]
[4.65661287e-08+6.69479738e+06j 1.54832378e-08-1.45519152e-11j]]
The problem might be that params influences the initial state of the solver y0 as well the time t1.
Thanks for your attention and help!
Dear Patrick @patrick-kidger ,
Thank you for your answer. Indeed I can transform my complex valued system of equations into:
`dvdt = M @ v v = vreal + 1j* vimag M = Mreal + 1j*Mimag
d([vreal; vimag]) = [[Mreal, - Mimag];[Mimag, Mreal]] @ [vreal;vimag]
`
Then combine these two vector fields in post-processing. Of course it would have been much more convenient for the Quantum Technologies communities to have these features are pre-defined in libraries. But, I agree that this part is a non-issue. However, I am still suspicious.
Because the Scipy's VODE subroutine, which was inherited from Fortran libraries, use multi-step implicit Adams methods such as Adams-Moulton method for non-stiff problems, and BDF for stiff problems. I couldn't parse all the archaic Fortran code but my suspicion is that Scipy's ZVODE just use this VODE library by implementing your vector-field trick.
I have doubts, based on some small (but not systematic, elaborate or conclusive at any metric) numerical experiments and the paper that I shared before, that the cream de la cream explicit Runge-Kutta methods y'all provide such as Tsit5 and Dopri5 would be as good for the said non-stiff quantum problems as implicit Adams. Or, KenCarp4 would be as good as BDF for stiff problems. Maybe I am wrong. I will need to use them on some important unit tests to make sure that I do not get non-physical results. I will get back to you.
Thank you @sriharikrishna for the MWE! That's really useful. I'm going to tag @randl as our resident complex autodiff expert. Any thoughts?
Other than that, thank you @onurdanaci for your write-up above! I'd like it if Diffrax could be useful to you regardless :)
@sriharikrishna
Isn't the mismatch since, in the first case, you calculate the gradient with respect to a real parameter, which is automatically real, and in the second case, the gradient is with respect to a complex parameter, thus it also has an imaginary part? I've tried running check_grads for a function equivalent to yours:
@pytest.mark.parametrize(
"solver",
[
diffrax.Tsit5(),
],
)
def test_grad_complex(solver):
def ode_fn(t, y, B):
return jnp.matmul(B[0], y)
term = ODETerm(ode_fn)
@partial(jax.jit)
def driver(pt, ang):
# Create y0 from params
time = pt * jnp.pi
cos = jnp.cos(time / 2)
sin = jnp.sin(time / 2)
axis_angle = ang * jnp.pi
KET_0 = jnp.array([1, 0], dtype=jnp.complex128) # |0>, spin up
KET_1 = jnp.array([0, 1], dtype=jnp.complex128) # |1>, spin down
y0 = cos * KET_1 - 1j * jnp.exp(-1j * axis_angle) * sin * KET_0
jax.debug.print("{y0}",y0=y0)
A = jnp.array([[0 - 1j, 1.0 + 2j],
[- 100.0 + 3j, 0 + 4j]], dtype=jnp.complex128)
solver_args = dict(t0=0.0, t1=time.real, dt0=0.2, y0=y0, args=(A,))
# # Required for forward mode only
# if use_direct == True:
solver_args |= dict(adjoint=diffrax.DirectAdjoint())
# Evolve y0. Time is influenced by params
solution = diffrax.diffeqsolve(term, solver, **solver_args)
return solution.ys[0]
# check_grads(driver, (0.5,0.4), order=2, modes=["fwd"])
check_grads(driver, (0.5+0.j,0.4+0.j), order=2, modes=["rev"], atol=1e15)
Up to the fact that absolute differences are huge in rev case, I couldn't see a fail. If you could point out the mismatch vs numerical gradients (alternatively, there may be bug in the solver itself, which makes both analytic and numeric gradients wrong), that'd be helpful.