Recommendation for real-time publishing of vector field
As far as I understand, there are two main ways of inspecting the progress of a diffeqsolve:
-
progress_meterthis is called every timestep with the ability to use jax conditionals to do everyntimesteps instead but uses thesolver_staterather than the vector field as an argument and therefore, unless a custom solver is defined that saves a copy of vector field in its internal state then we are limited with the information we can save. -
saveathere diffrax stores the relevant data in memory and returns asolutionvariable containing all this data after the solve completes. What I would like to do is instead push (and potentially postprocess) data at eachsaveattimestep to a file or external server (such as mlflow) immediately (or potentially asynchronously) so that I can monitor how the simulation is progressing in real-time.
Do you have any recommendations on how to achieve this with the current API? Or a laundry list of required PR's to make something like this possible?
Thanks again! 🙏🏻
In the general case, this is basically live saving of data in a JAX while loop, I've never tried it but I assume the go to would be to have an io_callback (https://docs.jax.dev/en/latest/_autosummary/jax.experimental.io_callback.html#jax.experimental.io_callback). For diffrax, it might be possible to just put these callbacks inside the SaveAt (never tried, but that is what I would first try).
I think you could probably do this with SaveAt(fn=...), where the fun wraps a jax.pure_callback that performs the save.