jax icon indicating copy to clipboard operation
jax copied to clipboard

visualize jaxpr graph?

Open ilemhadri opened this issue 3 years ago • 0 comments

i was wondering if there are any tools to visualize jaxprs.

Certainly, one can visualize the entire HLO graph using jax.xla_computation. But that tends to produce overly complex graphs, especially when jitting.

It would be nice to be able to work on jaxprs and visualize jaxpr dependency graphs. Does jax have any tools in this regard (besides the toy example in https://gist.github.com/mattjj/a60e0991455965ae960a8d2dcddc3407)? Ideally, the tool would also allow annotation of functions or arrays.

My motivating example is that I am going through a fairly complex custom likelihood function on time series observations, that involves many lu_solves and eigh operations and the HLO graph is immense.

ilemhadri avatar Sep 22 '22 11:09 ilemhadri