jax.scipy.linalg.eigh_tridiagonal() doesn't implement calculation of eigenvectors
I need this for finding the eigenvectors of the Hessian after tridiagonalizing it with Lanczos iteration. Right now the function looks like:
def eigh_tridiagonal(d: ArrayLike, e: ArrayLike, *, eigvals_only: bool = False,
select: str = 'a', select_range: Optional[Tuple[float, float]] = None,
tol: Optional[float] = None) -> Array:
if not eigvals_only:
raise NotImplementedError("Calculation of eigenvectors is not implemented")
and it is not in the documentation that it is not implemented either, it just raises a NotImplementedError when you try to use it. If I shell out to scipy on cpu I then can't JIT the function. Are there any plans to implement calculation of eigenvectors for tridiagonal matrices?
Thank you, Katherine Crowson
Thanks for raising this!
Are there any plans to implement calculation of eigenvectors for tridiagonal matrices?
I don't know of any plans, but we don't really plan things like this; instead we wait for users to ask for them!
If I shell out to scipy on cpu I then can't JIT the function.
As a temporary suboptimal workaround, how about using jax.pure_callback? Then you should be able to call into scipy, and jitting will still work. If you want autodiff, you could put a custom_jvp around it, like in this example.
@hawkinsp do you happen to know/remember how hard it is to add eigenvector calculations (context: #6622)? Maybe we should ask Rasmus...
Looks like Rasmus implemented it here for TensorFlow, with great comments, so maybe we can port it?
This would be great to have!
The missing piece for porting Rasmus's implementation is a batched tridiagonal solve, I believe.
That sounds tricky. Shame, it makes it difficult to implement any of the popular GP inference engines that use eigendecomp of tridiagonal matrices as a way to compute log-determinants cheaply.
a way to compute log-determinants cheaply.
Perhaps I'm misunderstanding. Isn't log det(M) = sum_i lg( eigh[i] ) so eigenvalues are sufficient?
a way to compute log-determinants cheaply.
Perhaps I'm misunderstanding. Isn't
log det(M) = sum_i lg( eigh[i] )so eigenvalues are sufficient?
eigh is still expensive no? it doesnt exploit the tridiagonal structure for cheaper compute.
Sorry. I meant to say eigh_tridiagonal(M, eigvals_only=True) instead of eigh; that is, only compute eigenvalues using tridiagonal structure. For a tridiagonal matrix you can compute the log determinant as
from jax.scipy.linalg import eigh_tridiagonal
def log_det_trid(M): return jnp.sum(jnp.log(eigh_tridiagonal(M, eigvals_only=True))) # you'll need use diagonal/off-diagonal instead of M
This is true because log det(M) = log prod(eigh_tridiagonal(M, eigvals_only=True)) = sum log eigh_tridiagonal(M, eigvals_only=True)
Even jax.grad(lg(det(M)) doesn't require eigenvectors only inverse (see section 2.1.4 or eq 57 in the matrix cookbook).
Here's a proof-of-concept implementation modeled after tensorflow's tf.linalg.eigh_tridiagonal, along with some light testing
import jax
import jax.numpy as jnp
from jax.lax.linalg import tridiagonal_solve
from jax.lax import while_loop
from jax.random import normal
from jax.scipy.linalg import eigh_tridiagonal
import scipy.linalg as la
def apply_perturbation_if_needed(alpha, beta, eps):
"""
Apply a perturbation to eigenvalues and off-diagonals if eigenvalue gaps are too small.
"""
n = len(alpha)
# Compute gaps between consecutive eigenvalues
gaps = jnp.abs(alpha[1:] - alpha[:-1])
# Define a gap threshold (TOL * |W(I)|, where TOL = 10 * EPS)
tol = 10 * eps
gap_threshold = tol * jnp.abs(alpha[:-1])
# Identify indices where gaps are too small
small_gap_indices = jnp.where(gaps < gap_threshold, size=len(gaps), fill_value=-1)[0]
# Apply a small perturbation at these indices
perturbation = eps * jnp.arange(1, n)
perturbed_alpha = alpha
perturbed_alpha = perturbed_alpha.at[small_gap_indices + 1].add(perturbation[small_gap_indices])
perturbed_beta = beta + eps * jnp.arange(1, n)
return perturbed_alpha, perturbed_beta
def compute_eigenvectors_jax(key, alpha, beta, eigvals):
"""Implements inverse iteration to compute eigenvectors in JAX."""
k = eigvals.shape[0]
n = alpha.shape[0]
eps = jnp.finfo(eigvals.dtype).eps
alpha, beta = apply_perturbation_if_needed(alpha, beta, eps)
t_norm = jnp.maximum(jnp.abs(eigvals[0]), jnp.abs(eigvals[-1]))
gaptol = jnp.sqrt(eps) * t_norm
# Identify clusters of close eigenvalues
gap = eigvals[1:] - eigvals[:-1]
close = gap < gaptol
left_neighbor_close = jnp.concatenate([jnp.array([False]), close])
right_neighbor_close = jnp.concatenate([close, jnp.array([False])])
max_clusters = n # Maximum possible clusters
ortho_interval_start = jnp.where(~left_neighbor_close & right_neighbor_close, size=max_clusters, fill_value=-1)[0]
ortho_interval_end = jnp.where(left_neighbor_close & ~right_neighbor_close, size=max_clusters, fill_value=-1)[0] + 1
num_clusters = jnp.sum(ortho_interval_start != -1)
# Initialize random starting vectors
v0 = normal(key, (k, n), dtype=alpha.dtype)
v0 = v0 / jnp.linalg.norm(v0, axis=1, keepdims=True)
alpha_shifted = alpha[None, :] - eigvals[:, None]
beta_tiled = jnp.tile(beta[None, :], (k, 1))
# Pad beta_tiled to create dl and du with the required shape
dl = jnp.pad(beta_tiled, [(0, 0), (1, 0)]) # Add leading zero
du = jnp.pad(beta_tiled, [(0, 0), (0, 1)]) # Add trailing zero
def orthogonalize_cluster(vectors, start, end):
# Create a mask for the range [start:end]
indices = jnp.arange(vectors.shape[0])
mask = (indices >= start) & (indices < end)
# Apply the mask to extract the cluster
cluster = vectors * mask[:, None] # Mask along the rows
cluster = jnp.where(mask[:, None], cluster, 0) # Ensure masked rows are zero
# QR decomposition on the selected rows
q, _ = jnp.linalg.qr(cluster.T)
# Align q to the appropriate rows of the original array
aligned_q = jnp.where(mask[:, None], q.T, 0)
# Update the original vectors using jnp.where
updated_vectors = jnp.where(mask[:, None], aligned_q, vectors)
return updated_vectors
def orthogonalize_close_eigenvectors(v):
def body(i, v):
start = ortho_interval_start[i]
end = ortho_interval_end[i]
return jax.lax.cond(
(start != -1) & (end != -1),
lambda args: orthogonalize_cluster(args[0], args[1], args[2]),
lambda args: args[0],
(v, start, end)
)
return jax.lax.fori_loop(0, max_clusters, body, v)
def continue_iteration(state):
i, _, nrm_v, nrm_v_old = state
max_it = 5
min_norm_growth = 0.1
norm_growth_factor = 1 + min_norm_growth
return jnp.logical_and(
i < max_it, jnp.any(nrm_v >= norm_growth_factor * nrm_v_old)
)
def iteration_step(state):
i, v, nrm_v, nrm_v_old = state
# Use vmap to handle batched tridiagonal solve
def solve_tridiagonal_single(diags, rhs):
return tridiagonal_solve(diags[0], diags[1], diags[2], rhs[:, None])[:, 0]
dl_batched = dl # Lower diagonal
d_batched = alpha_shifted # Middle diagonal
du_batched = du # Upper diagonal
# Stack diagonals to match the batch structure
batched_diags = (dl_batched, d_batched, du_batched)
# Apply batched tridiagonal solve
v = jax.vmap(solve_tridiagonal_single, in_axes=(0, 0))(batched_diags, v.T).T
nrm_v_old = nrm_v
nrm_v = jnp.linalg.norm(v, axis=0)
v = v / nrm_v[None, :]
v = orthogonalize_close_eigenvectors(v)
return i + 1, v, nrm_v, nrm_v_old
nrm_v = jnp.linalg.norm(v0, axis=1)
zero_nrm = jnp.zeros_like(nrm_v)
state = (0, v0, nrm_v, zero_nrm)
_, v, _, _ = while_loop(continue_iteration, iteration_step, state)
return v
def test_basic():
# Example usage:
alpha = jnp.array([1.0, 2.0, 3.0])
beta = jnp.array([0.5, 0.5])
# Compute eigenvalues with JAX's eigh_tridiagonal
jax_eigenvalues = eigh_tridiagonal(alpha, beta, eigvals_only=True)
# Compute eigenvalues and eigenvectors with JAX implementation
eigvals = jax_eigenvalues
eigenvectors = compute_eigenvectors_jax(jax.random.PRNGKey(42), alpha, beta, eigvals)
# Compare against SciPy implementation
scipy_eigenvalues, scipy_eigenvectors = la.eigh_tridiagonal(alpha, beta)
print("JAX Eigenvalues:", jax_eigenvalues)
print("SciPy Eigenvalues:", scipy_eigenvalues)
print("Difference in Eigenvalues:", jnp.abs(jax_eigenvalues - scipy_eigenvalues))
# Compare eigenvectors (up to sign)
alignment = jnp.sign(jnp.sum(eigenvectors * scipy_eigenvectors, axis=0))
adjusted_jax_eigenvectors = eigenvectors * alignment
print("Difference in Eigenvectors:", jnp.linalg.norm(adjusted_jax_eigenvectors - scipy_eigenvectors))
# Check normalization after computation
norms = jnp.linalg.norm(eigenvectors, axis=1)
print("Final norms of our eigenvectors:", norms)
# Check orthogonality
orthogonality = jnp.dot(eigenvectors, eigenvectors.T)
print("Orthogonality check (should be close to identity):", orthogonality)
# Check alignment with scipy eigenvectors
alignment_check = jnp.abs(jnp.sum(eigenvectors * scipy_eigenvectors, axis=0))
print("Alignment with SciPy eigenvectors:", alignment_check)
print("Our eigenvectors", eigenvectors)
print("Norms of our eigenvectors", jnp.linalg.norm(eigenvectors, axis=1))
print("Scipy eigenvectors", scipy_eigenvectors)
print("Norms of Scipy eigenvectors", jnp.linalg.norm(scipy_eigenvectors, axis=1))
def test_clustered_eigenvalues_reconstruction():
# Create a tridiagonal matrix with clustered eigenvalues
n = 8
eps = jnp.finfo(jnp.float32).eps
alpha = jnp.ones(n) # Main diagonal
beta = 0.01 * jnp.sqrt(eps) * jnp.ones(n - 1) # Small off-diagonal values
# Compute eigenvalues and eigenvectors using JAX
jax_eigenvalues = eigh_tridiagonal(alpha, beta, eigvals_only=True)
eigenvectors = compute_eigenvectors_jax(jax.random.PRNGKey(42), alpha, beta, jax_eigenvalues)
# Reconstruct the tridiagonal matrix A
A = jnp.diag(alpha) + jnp.diag(beta, k=1) + jnp.diag(beta, k=-1)
# Verify reconstruction: A @ V ≈ V @ diag(E)
reconstructed_A = eigenvectors @ jnp.diag(jax_eigenvalues) @ eigenvectors.T
reconstruction_error = jnp.linalg.norm(A - reconstructed_A)
# Compare with SciPy
scipy_eigenvalues, scipy_eigenvectors = la.eigh_tridiagonal(alpha, beta)
print("JAX Eigenvalues:", jax_eigenvalues)
print("SciPy Eigenvalues:", scipy_eigenvalues)
print("Difference in Eigenvalues:", jnp.abs(jax_eigenvalues - scipy_eigenvalues))
print("Reconstruction Error:", reconstruction_error)
# Assertions
assert reconstruction_error < 1e-4, "Reconstruction failed"
assert jnp.allclose(jax_eigenvalues, scipy_eigenvalues, atol=1e-6), "Eigenvalues mismatch"
def test_complete_degeneracy():
n = 8
alpha = jnp.ones(n) # Completely degenerate eigenvalues
beta = jnp.zeros(n - 1) # Zero off-diagonal values
# Compute eigenvalues with JAX's eigh_tridiagonal
jax_eigenvalues = eigh_tridiagonal(alpha, beta, eigvals_only=True)
# Compute eigenvectors with JAX implementation
eigenvectors = compute_eigenvectors_jax(jax.random.PRNGKey(42), alpha, beta, jax_eigenvalues)
# Reconstruct the tridiagonal matrix A
A = jnp.diag(alpha) + jnp.diag(beta, k=1) + jnp.diag(beta, k=-1)
# Verify reconstruction: A @ V ≈ V @ diag(E)
reconstructed_A = eigenvectors @ jnp.diag(jax_eigenvalues) @ eigenvectors.T
reconstruction_error = jnp.linalg.norm(A - reconstructed_A)
print("Reconstruction Error:", reconstruction_error)
print("Eigenvalues:", jax_eigenvalues)
print("Eigenvectors (orthogonal check):", jnp.dot(eigenvectors, eigenvectors.T))
# Assertions
assert reconstruction_error < 1e-4, "Reconstruction failed"
assert jnp.allclose(jax_eigenvalues, alpha, atol=1e-6), "Eigenvalues mismatch"
def benchmark(n):
import time
a = jnp.ones(n)
b = jnp.ones(n-1)
for i in range(3):
start = time.time()
eigvals = eigh_tridiagonal(a, b, eigvals_only=True)
eigvals.block_until_ready()
end = time.time()
print(f"eigvals {end-start}s")
start = time.time()
v = compute_eigenvectors_jax(jax.random.PRNGKey(42), a, b, eigvals)
v.block_until_ready()
end = time.time()
print(f"eigvecs {end-start}s")
test_basic()
test_clustered_eigenvalues_reconstruction()
test_complete_degeneracy()
benchmark(10000)
If it is useful to others and more thoroughly tested, it could be converted into a PR.
The missing piece for porting Rasmus's implementation is a batched tridiagonal solve, I believe.
I think batched tridiagonal solves are now working: https://github.com/jax-ml/jax/pull/28696
Any chance eigenvectors will get implemented soon so? Would be very useful for optimising some Gaussian process applications!