``scipy.linalg.solve_banded`` and ``scipy.linalg.ldl``
These two functions, for solving banded matrices and performing symmetric indefinite factorizations, show up a lot in optimization and would be great to have in JAX.
The scipy routines are basically just wrappers around LAPACK, so I don't think it would be too difficult. I'd be happy to take a stab at it though I'm not too familiar with how to register jvps for them and make sure it works on GPU etc.
I'm also interested in this, particularly solveh_banded for convex optimization. Is there a general guide to helping code extensions to jax?
We don't have anything centered on adding new linalg functions but here are a few pointers that might be helpful:
- Here is how SVD is implemented.
- In particular you can see the definition of a new JAX primitive and a bunch of rule registrations. Some of those are described in the Defining new JAX primitives section in the docs.
- For general guidelines regarding contributions to JAX, see here.
I'm going to (slowly) have a crack at this.
Out of curiosity — any update on this? I'd be interested in helping if possible!
I was looking at this issue but it appears like solve_banded has been implemented by tridiagonal_solve in jax/_src/lax/linalg.py
Still suffering from lack of native L D L^T in 2024 :( Would be great to have this!