jax icon indicating copy to clipboard operation
jax copied to clipboard

jax.scipy.linalg.eigh_tridiagonal() doesn't implement calculation of eigenvectors

Open crowsonkb opened this issue 3 years ago • 11 comments

I need this for finding the eigenvectors of the Hessian after tridiagonalizing it with Lanczos iteration. Right now the function looks like:

def eigh_tridiagonal(d: ArrayLike, e: ArrayLike, *, eigvals_only: bool = False,
                     select: str = 'a', select_range: Optional[Tuple[float, float]] = None,
                     tol: Optional[float] = None) -> Array:
  if not eigvals_only:
    raise NotImplementedError("Calculation of eigenvectors is not implemented")

and it is not in the documentation that it is not implemented either, it just raises a NotImplementedError when you try to use it. If I shell out to scipy on cpu I then can't JIT the function. Are there any plans to implement calculation of eigenvectors for tridiagonal matrices?

Thank you, Katherine Crowson

crowsonkb avatar Jan 15 '23 03:01 crowsonkb

Thanks for raising this!

Are there any plans to implement calculation of eigenvectors for tridiagonal matrices?

I don't know of any plans, but we don't really plan things like this; instead we wait for users to ask for them!

If I shell out to scipy on cpu I then can't JIT the function.

As a temporary suboptimal workaround, how about using jax.pure_callback? Then you should be able to call into scipy, and jitting will still work. If you want autodiff, you could put a custom_jvp around it, like in this example.

@hawkinsp do you happen to know/remember how hard it is to add eigenvector calculations (context: #6622)? Maybe we should ask Rasmus...

mattjj avatar Jan 15 '23 04:01 mattjj

Looks like Rasmus implemented it here for TensorFlow, with great comments, so maybe we can port it?

mattjj avatar Jan 15 '23 04:01 mattjj

This would be great to have!

HHalva avatar Jan 20 '23 17:01 HHalva

The missing piece for porting Rasmus's implementation is a batched tridiagonal solve, I believe.

hawkinsp avatar Jan 20 '23 17:01 hawkinsp

That sounds tricky. Shame, it makes it difficult to implement any of the popular GP inference engines that use eigendecomp of tridiagonal matrices as a way to compute log-determinants cheaply.

HHalva avatar Jan 20 '23 21:01 HHalva

a way to compute log-determinants cheaply.

Perhaps I'm misunderstanding. Isn't log det(M) = sum_i lg( eigh[i] ) so eigenvalues are sufficient?

AlexanderMath avatar Oct 01 '23 11:10 AlexanderMath

a way to compute log-determinants cheaply.

Perhaps I'm misunderstanding. Isn't log det(M) = sum_i lg( eigh[i] ) so eigenvalues are sufficient?

eigh is still expensive no? it doesnt exploit the tridiagonal structure for cheaper compute.

HHalva avatar Oct 02 '23 15:10 HHalva

Sorry. I meant to say eigh_tridiagonal(M, eigvals_only=True) instead of eigh; that is, only compute eigenvalues using tridiagonal structure. For a tridiagonal matrix you can compute the log determinant as

from jax.scipy.linalg import eigh_tridiagonal
def log_det_trid(M): return jnp.sum(jnp.log(eigh_tridiagonal(M, eigvals_only=True))) # you'll need use diagonal/off-diagonal instead of M

This is true because log det(M) = log prod(eigh_tridiagonal(M, eigvals_only=True)) = sum log eigh_tridiagonal(M, eigvals_only=True)

AlexanderMath avatar Oct 09 '23 15:10 AlexanderMath

Even jax.grad(lg(det(M)) doesn't require eigenvectors only inverse (see section 2.1.4 or eq 57 in the matrix cookbook).

AlexanderMath avatar Oct 09 '23 15:10 AlexanderMath

Here's a proof-of-concept implementation modeled after tensorflow's tf.linalg.eigh_tridiagonal, along with some light testing

import jax
import jax.numpy as jnp
from jax.lax.linalg import tridiagonal_solve
from jax.lax import while_loop
from jax.random import normal
from jax.scipy.linalg import eigh_tridiagonal
import scipy.linalg as la

def apply_perturbation_if_needed(alpha, beta, eps):
    """
    Apply a perturbation to eigenvalues and off-diagonals if eigenvalue gaps are too small.
    """
    n = len(alpha)

    # Compute gaps between consecutive eigenvalues
    gaps = jnp.abs(alpha[1:] - alpha[:-1])

    # Define a gap threshold (TOL * |W(I)|, where TOL = 10 * EPS)
    tol = 10 * eps
    gap_threshold = tol * jnp.abs(alpha[:-1])

    # Identify indices where gaps are too small
    small_gap_indices = jnp.where(gaps < gap_threshold, size=len(gaps), fill_value=-1)[0]

    # Apply a small perturbation at these indices
    perturbation = eps * jnp.arange(1, n)
    perturbed_alpha = alpha
    perturbed_alpha = perturbed_alpha.at[small_gap_indices + 1].add(perturbation[small_gap_indices])
    perturbed_beta = beta + eps * jnp.arange(1, n)

    return perturbed_alpha, perturbed_beta

def compute_eigenvectors_jax(key, alpha, beta, eigvals):
    """Implements inverse iteration to compute eigenvectors in JAX."""
    k = eigvals.shape[0]
    n = alpha.shape[0]

    eps = jnp.finfo(eigvals.dtype).eps
    alpha, beta = apply_perturbation_if_needed(alpha, beta, eps)

    t_norm = jnp.maximum(jnp.abs(eigvals[0]), jnp.abs(eigvals[-1]))
    gaptol = jnp.sqrt(eps) * t_norm

    # Identify clusters of close eigenvalues
    gap = eigvals[1:] - eigvals[:-1]
    close = gap < gaptol
    left_neighbor_close = jnp.concatenate([jnp.array([False]), close])
    right_neighbor_close = jnp.concatenate([close, jnp.array([False])])

    max_clusters = n  # Maximum possible clusters
    ortho_interval_start = jnp.where(~left_neighbor_close & right_neighbor_close, size=max_clusters, fill_value=-1)[0]
    ortho_interval_end = jnp.where(left_neighbor_close & ~right_neighbor_close, size=max_clusters, fill_value=-1)[0] + 1
    num_clusters = jnp.sum(ortho_interval_start != -1)

    # Initialize random starting vectors
    v0 = normal(key, (k, n), dtype=alpha.dtype)
    v0 = v0 / jnp.linalg.norm(v0, axis=1, keepdims=True)

    alpha_shifted = alpha[None, :] - eigvals[:, None]
    beta_tiled = jnp.tile(beta[None, :], (k, 1))

    # Pad beta_tiled to create dl and du with the required shape
    dl = jnp.pad(beta_tiled, [(0, 0), (1, 0)])  # Add leading zero
    du = jnp.pad(beta_tiled, [(0, 0), (0, 1)])  # Add trailing zero

    def orthogonalize_cluster(vectors, start, end):
        # Create a mask for the range [start:end]
        indices = jnp.arange(vectors.shape[0])
        mask = (indices >= start) & (indices < end)

        # Apply the mask to extract the cluster
        cluster = vectors * mask[:, None]  # Mask along the rows
        cluster = jnp.where(mask[:, None], cluster, 0)  # Ensure masked rows are zero

        # QR decomposition on the selected rows
        q, _ = jnp.linalg.qr(cluster.T)

        # Align q to the appropriate rows of the original array
        aligned_q = jnp.where(mask[:, None], q.T, 0)

        # Update the original vectors using jnp.where
        updated_vectors = jnp.where(mask[:, None], aligned_q, vectors)
        return updated_vectors

    def orthogonalize_close_eigenvectors(v):
        def body(i, v):
            start = ortho_interval_start[i]
            end = ortho_interval_end[i]
            return jax.lax.cond(
                (start != -1) & (end != -1),
                lambda args: orthogonalize_cluster(args[0], args[1], args[2]),
                lambda args: args[0],
                (v, start, end)
            )

        return jax.lax.fori_loop(0, max_clusters, body, v)

    def continue_iteration(state):
        i, _, nrm_v, nrm_v_old = state
        max_it = 5
        min_norm_growth = 0.1
        norm_growth_factor = 1 + min_norm_growth
        return jnp.logical_and(
            i < max_it, jnp.any(nrm_v >= norm_growth_factor * nrm_v_old)
        )

    def iteration_step(state):
        i, v, nrm_v, nrm_v_old = state

        # Use vmap to handle batched tridiagonal solve
        def solve_tridiagonal_single(diags, rhs):
            return tridiagonal_solve(diags[0], diags[1], diags[2], rhs[:, None])[:, 0]

        dl_batched = dl  # Lower diagonal
        d_batched = alpha_shifted  # Middle diagonal
        du_batched = du  # Upper diagonal

        # Stack diagonals to match the batch structure
        batched_diags = (dl_batched, d_batched, du_batched)

        # Apply batched tridiagonal solve
        v = jax.vmap(solve_tridiagonal_single, in_axes=(0, 0))(batched_diags, v.T).T

        nrm_v_old = nrm_v
        nrm_v = jnp.linalg.norm(v, axis=0)
        v = v / nrm_v[None, :]

        v = orthogonalize_close_eigenvectors(v)
        return i + 1, v, nrm_v, nrm_v_old

    nrm_v = jnp.linalg.norm(v0, axis=1)
    zero_nrm = jnp.zeros_like(nrm_v)
    state = (0, v0, nrm_v, zero_nrm)

    _, v, _, _ = while_loop(continue_iteration, iteration_step, state)
    return v
def test_basic():
    # Example usage:
    alpha = jnp.array([1.0, 2.0, 3.0])
    beta = jnp.array([0.5, 0.5])

    # Compute eigenvalues with JAX's eigh_tridiagonal
    jax_eigenvalues = eigh_tridiagonal(alpha, beta, eigvals_only=True)

    # Compute eigenvalues and eigenvectors with JAX implementation
    eigvals = jax_eigenvalues
    eigenvectors = compute_eigenvectors_jax(jax.random.PRNGKey(42), alpha, beta, eigvals)

    # Compare against SciPy implementation
    scipy_eigenvalues, scipy_eigenvectors = la.eigh_tridiagonal(alpha, beta)

    print("JAX Eigenvalues:", jax_eigenvalues)
    print("SciPy Eigenvalues:", scipy_eigenvalues)
    print("Difference in Eigenvalues:", jnp.abs(jax_eigenvalues - scipy_eigenvalues))

    # Compare eigenvectors (up to sign)
    alignment = jnp.sign(jnp.sum(eigenvectors * scipy_eigenvectors, axis=0))
    adjusted_jax_eigenvectors = eigenvectors * alignment
    print("Difference in Eigenvectors:", jnp.linalg.norm(adjusted_jax_eigenvectors - scipy_eigenvectors))

    # Check normalization after computation
    norms = jnp.linalg.norm(eigenvectors, axis=1)
    print("Final norms of our eigenvectors:", norms)

    # Check orthogonality
    orthogonality = jnp.dot(eigenvectors, eigenvectors.T)
    print("Orthogonality check (should be close to identity):", orthogonality)

    # Check alignment with scipy eigenvectors
    alignment_check = jnp.abs(jnp.sum(eigenvectors * scipy_eigenvectors, axis=0))
    print("Alignment with SciPy eigenvectors:", alignment_check)
    
    print("Our eigenvectors", eigenvectors)
    print("Norms of our eigenvectors", jnp.linalg.norm(eigenvectors, axis=1))
    print("Scipy eigenvectors", scipy_eigenvectors)
    print("Norms of Scipy eigenvectors", jnp.linalg.norm(scipy_eigenvectors, axis=1))

def test_clustered_eigenvalues_reconstruction():
    # Create a tridiagonal matrix with clustered eigenvalues
    n = 8
    eps = jnp.finfo(jnp.float32).eps
    alpha = jnp.ones(n)  # Main diagonal
    beta = 0.01 * jnp.sqrt(eps) * jnp.ones(n - 1)  # Small off-diagonal values

    # Compute eigenvalues and eigenvectors using JAX
    jax_eigenvalues = eigh_tridiagonal(alpha, beta, eigvals_only=True)
    eigenvectors = compute_eigenvectors_jax(jax.random.PRNGKey(42), alpha, beta, jax_eigenvalues)

    # Reconstruct the tridiagonal matrix A
    A = jnp.diag(alpha) + jnp.diag(beta, k=1) + jnp.diag(beta, k=-1)

    # Verify reconstruction: A @ V ≈ V @ diag(E)
    reconstructed_A = eigenvectors @ jnp.diag(jax_eigenvalues) @ eigenvectors.T
    reconstruction_error = jnp.linalg.norm(A - reconstructed_A)

    # Compare with SciPy
    scipy_eigenvalues, scipy_eigenvectors = la.eigh_tridiagonal(alpha, beta)

    print("JAX Eigenvalues:", jax_eigenvalues)
    print("SciPy Eigenvalues:", scipy_eigenvalues)
    print("Difference in Eigenvalues:", jnp.abs(jax_eigenvalues - scipy_eigenvalues))
    print("Reconstruction Error:", reconstruction_error)

    # Assertions
    assert reconstruction_error < 1e-4, "Reconstruction failed"
    assert jnp.allclose(jax_eigenvalues, scipy_eigenvalues, atol=1e-6), "Eigenvalues mismatch"

def test_complete_degeneracy():
    n = 8
    alpha = jnp.ones(n)  # Completely degenerate eigenvalues
    beta = jnp.zeros(n - 1)  # Zero off-diagonal values

    # Compute eigenvalues with JAX's eigh_tridiagonal
    jax_eigenvalues = eigh_tridiagonal(alpha, beta, eigvals_only=True)

    # Compute eigenvectors with JAX implementation
    eigenvectors = compute_eigenvectors_jax(jax.random.PRNGKey(42), alpha, beta, jax_eigenvalues)

    # Reconstruct the tridiagonal matrix A
    A = jnp.diag(alpha) + jnp.diag(beta, k=1) + jnp.diag(beta, k=-1)

    # Verify reconstruction: A @ V ≈ V @ diag(E)
    reconstructed_A = eigenvectors @ jnp.diag(jax_eigenvalues) @ eigenvectors.T
    reconstruction_error = jnp.linalg.norm(A - reconstructed_A)

    print("Reconstruction Error:", reconstruction_error)
    print("Eigenvalues:", jax_eigenvalues)
    print("Eigenvectors (orthogonal check):", jnp.dot(eigenvectors, eigenvectors.T))

    # Assertions
    assert reconstruction_error < 1e-4, "Reconstruction failed"
    assert jnp.allclose(jax_eigenvalues, alpha, atol=1e-6), "Eigenvalues mismatch"

def benchmark(n):
    import time
    a = jnp.ones(n)
    b = jnp.ones(n-1)

    for i in range(3):
        start = time.time()
        eigvals = eigh_tridiagonal(a, b, eigvals_only=True)
        eigvals.block_until_ready()
        end = time.time()
        print(f"eigvals {end-start}s")
        start = time.time()
        v = compute_eigenvectors_jax(jax.random.PRNGKey(42), a, b, eigvals)
        v.block_until_ready()
        end = time.time()
        print(f"eigvecs {end-start}s")

test_basic()
test_clustered_eigenvalues_reconstruction()
test_complete_degeneracy()
benchmark(10000)

If it is useful to others and more thoroughly tested, it could be converted into a PR.

jglaser avatar Jan 09 '25 00:01 jglaser

The missing piece for porting Rasmus's implementation is a batched tridiagonal solve, I believe.

I think batched tridiagonal solves are now working: https://github.com/jax-ml/jax/pull/28696

f0uriest avatar Jun 16 '25 23:06 f0uriest

Any chance eigenvectors will get implemented soon so? Would be very useful for optimising some Gaussian process applications!

markfortune avatar Jul 26 '25 17:07 markfortune