PyTorch support
Would it be feasible at all to add PyTorch support as a substrate?
Besides the obvious use case, this would enable researchers to implement framework-agnostic probabilistic algorithms that use a Numpy-like API to execute on TF2, JAX, PyTorch.
It's probably feasible, but not on our roadmap. You'd essentially replicate how the numpy substrate is implemented: add a new directory in tfp/python/internal/backends and then add a bunch of bazel rules to generate the rewritten sources. I'd start with trying to get gradients (tfp/python/math/gradient.py and tfp/python/internal/custom_gradient.py) working, because my limited understanding of PyTorch is that its autodiff implementation is very different from JAX/TF's.
Cool, it does sound feasible. Custom gradients in PyTorch are quite simple and computing gradients fits into the value_and_grad() API that JAX uses. For example, another project called eagerpy implements this for PyTorch. I don't know TFP internals enough to help with the implementation, but it seems this could increase the user base of TFP quite substantially and I'd definitely use it :)