flax icon indicating copy to clipboard operation
flax copied to clipboard

Support for exporting models in the ONNX format

Open Artoriuz opened this issue 1 year ago • 4 comments

Hi!

I know that this is something that have been asked before, but I just wanted to ask again: Is there any plan to eventually support exporting Flax models in the ONNX format?

ONNX has become a very popular way of distributing models as some kind of lingua franca that is supported by various inference engines like TensorRT, NCNN, OpenVINO, ONNX Runtime, et cetera...

The current workaround is to convert your Flax model to TensorFlow first using jax2tf, and then converting to ONNX from that using tf2onnx. While this works, the resulting ONNX models often contain various unnecessary steps and even simple things aren't mapped to the expected corresponding operations. I've also only been able to get this to work with enable_xla=False in the jax2tf conversion, which has been deprecated.

I understand that there's an argument to be made that perhaps it would make more sense to have this at the JAX level instead, but honestly I don't think ONNX is very popular outside of ML and doing it at the Flax level would maybe make it easier for the operations to be mapped 1:1.

FWIW, Equinox supports this by bridging through TF first as well, so that seems to be the status quo everywhere.

Thanks in advance!

Artoriuz avatar Dec 12 '24 11:12 Artoriuz

Hi @Artoriuz, is there a notion of a "Module" in ONNX or what you mean by this is that we should provide a helper function to easily map to ONNX? If its the latter I agree we could add that, I've been wanting to add a saved_model helper as well.

cgarciae avatar Dec 12 '24 17:12 cgarciae

Should is probably a strong word, I don't think I'm qualified to tell you what you should do. I was just looking forward to adopting Flax as my main ML library going forward and found some rough edges along the way (but I really liked everything else!).

And yes, I just think it would be very convenient to have a "native" way of mapping JAX operations (and Flax modules by extension) into the corresponding ONNX operations without having to use TF as a bridge. Most things should have a 1:1 counterpart anyway.

Since this would be more oriented towards inference, "all" we need is to export the final forward step. The notion of a "Module" would be lost, but that's fine (the entire upper level nnx.Module would be a single ONNX model).

Just as reference, PyTorch has 2 distinct exporters: https://pytorch.org/docs/stable/onnx_dynamo.html https://pytorch.org/docs/stable/onnx.html

And this is what is generally used to convert from either TF or Keras: https://github.com/onnx/tensorflow-onnx

Artoriuz avatar Dec 12 '24 18:12 Artoriuz

Hi @Artoriuz, I’m also interested in this functionality and couldn’t find a direct solution in the JAX ecosystem. So, I gave it a try and started working on a project for this: https://github.com/enpasos/jax2onnx

Would love to hear your thoughts or any feedback on how it could be improved!

enpasos avatar Feb 22 '25 11:02 enpasos

Netron has just released version 8.2.6, enhancing its support for hierarchical function-based graph views — and it genuinely feels a bit like discovering Google Maps for the first time. You can now zoom in and out of function scopes, with full shape information visible inside each nested view.

To see it in action, check out this compact Vision Encoder (sized for MNIST classification). It was generated from a JAX/Flax model using jax2onnx, which follows a robust jaxpr-based conversion approach with a user-friendly API.

This design aligns with the ideas proposed in
🔹 Flax Feature Request #4430
🔹 JAX Feature Request #26430

And of course — there's still plenty of room for improvement in the implementation... but we're getting there! 🚀

enpasos avatar Apr 11 '25 16:04 enpasos