probability icon indicating copy to clipboard operation
probability copied to clipboard

PyTorch support

Open danijar opened this issue 3 years ago • 2 comments

Would it be feasible at all to add PyTorch support as a substrate?

Besides the obvious use case, this would enable researchers to implement framework-agnostic probabilistic algorithms that use a Numpy-like API to execute on TF2, JAX, PyTorch.

danijar avatar Jul 01 '22 17:07 danijar

It's probably feasible, but not on our roadmap. You'd essentially replicate how the numpy substrate is implemented: add a new directory in tfp/python/internal/backends and then add a bunch of bazel rules to generate the rewritten sources. I'd start with trying to get gradients (tfp/python/math/gradient.py and tfp/python/internal/custom_gradient.py) working, because my limited understanding of PyTorch is that its autodiff implementation is very different from JAX/TF's.

SiegeLordEx avatar Jul 01 '22 18:07 SiegeLordEx

Cool, it does sound feasible. Custom gradients in PyTorch are quite simple and computing gradients fits into the value_and_grad() API that JAX uses. For example, another project called eagerpy implements this for PyTorch. I don't know TFP internals enough to help with the implementation, but it seems this could increase the user base of TFP quite substantially and I'd definitely use it :)

danijar avatar Jul 02 '22 02:07 danijar