jax icon indicating copy to clipboard operation
jax copied to clipboard

``scipy.linalg.solve_banded`` and ``scipy.linalg.ldl``

Open f0uriest opened this issue 3 years ago • 6 comments

These two functions, for solving banded matrices and performing symmetric indefinite factorizations, show up a lot in optimization and would be great to have in JAX.

The scipy routines are basically just wrappers around LAPACK, so I don't think it would be too difficult. I'd be happy to take a stab at it though I'm not too familiar with how to register jvps for them and make sure it works on GPU etc.

f0uriest avatar Oct 12 '22 20:10 f0uriest

I'm also interested in this, particularly solveh_banded for convex optimization. Is there a general guide to helping code extensions to jax?

deasmhumhna avatar Oct 24 '22 01:10 deasmhumhna

We don't have anything centered on adding new linalg functions but here are a few pointers that might be helpful:

  • Here is how SVD is implemented.
  • In particular you can see the definition of a new JAX primitive and a bunch of rule registrations. Some of those are described in the Defining new JAX primitives section in the docs.
  • For general guidelines regarding contributions to JAX, see here.

apaszke avatar Oct 27 '22 13:10 apaszke

I'm going to (slowly) have a crack at this.

harryjulian avatar Nov 13 '22 20:11 harryjulian

Out of curiosity — any update on this? I'd be interested in helping if possible!

arjunsavel avatar Mar 16 '23 17:03 arjunsavel

I was looking at this issue but it appears like solve_banded has been implemented by tridiagonal_solve in jax/_src/lax/linalg.py

NDOWAH avatar Jul 26 '24 18:07 NDOWAH

Still suffering from lack of native L D L^T in 2024 :( Would be great to have this!

joaospinto avatar Sep 11 '24 13:09 joaospinto