Feature-Request: Matrix Exponentiation
Useful for orthogonal optimization and other algorithms where calculating unitary matrices or dealing with complex numbers is necessary.
Pytorch impl: https://discuss.pytorch.org/t/what-implementation-is-used-for-matrix-exp/159608/5
https://pytorch.org/docs/stable/generated/torch.linalg.matrix_exp.html
In-python implementation, yoinked from torch and ported w/ Claude - appears to work in training, though:
def _compute_T1(A):
"""I + A"""
return mx.eye(A.shape[-1]) + A
def _compute_T2(A):
"""I + A + A^2/2"""
A2 = A @ A
return mx.eye(A.shape[-1]) + A + A2/2
def _compute_T4(A):
"""I + A + A^2 * (I/2 + A/6 + A^2/24)"""
A2 = A @ A
inner_term = (mx.eye(A.shape[-1])/2 + A/6 + A2/24)
return mx.eye(A.shape[-1]) + A + (A2 @ inner_term)
def _compute_T8(A):
sqrt_177 = 0.1330413469565007072504e+2
x3 = 2/3
x1 = x3 * ((1 + sqrt_177) / 88)
x2 = x3 * ((1 + sqrt_177) / 352)
x4 = (-271 + 29 * sqrt_177) / (315 * x3)
x5 = (-11 + 11 * sqrt_177) / (1260 * x3)
x6 = (-99 + 11 * sqrt_177) / (5040 * x3)
x7 = (89 - sqrt_177) / (5040 * x3)
y2 = (857 - 58 * sqrt_177) / 630
A2 = A @ A
A4 = A2 @ (x1*A + x2*A2)
A8 = (x3*A2 + A4) @ (x4*mx.eye(A.shape[-1]) + x5*A + x6*A2 + x7*A4)
return mx.eye(A.shape[-1]) + A + y2*A2 + A8
def matrix_exp(A):
"""
Computes matrix exponential using optimized Taylor series.
Based on PyTorch's implementation from the paper:
Bader, P.; Blanes, S.; Casas, F.
Computing the Matrix Exponential with an Optimized Taylor Polynomial Approximation.
"""
if A.shape[-2:] == (0, 0):
return A.clone()
elif A.shape[-2:] == (1, 1):
return mx.exp(A)
# Compute the matrix norm to choose degree
matrix_norm = mx.max(mx.sum(mx.abs(A), axis=-2), axis=-1)
# These thresholds are from PyTorch's implementation
# They're carefully chosen based on the paper
if A.dtype == mx.float32:
thresholds = [
1.192092800768788e-07, # deg 1
5.978858893805233e-04, # deg 2
5.116619363445086e-02, # deg 4
5.800524627688768e-01, # deg 8
1.461661507209034e+00, # deg 12
3.010066362817634e+00 # deg 18
]
else: # float64
thresholds = [
2.220446049250313e-16, # deg 1
2.580956802971767e-08, # deg 2
3.397168839976962e-04, # deg 4
4.991228871115323e-02, # deg 8
2.996158913811580e-01, # deg 12
1.090863719290036e+00 # deg 18
]
# For small norms use lower degree approximations
if matrix_norm <= thresholds[0]:
return _compute_T1(A)
elif matrix_norm <= thresholds[1]:
return _compute_T2(A)
elif matrix_norm <= thresholds[2]:
return _compute_T4(A)
elif matrix_norm <= thresholds[3]:
return _compute_T8(A)
# For larger norms use scaling and squaring with T8
s = mx.maximum(
mx.zeros_like(matrix_norm),
mx.ceil(mx.log2(matrix_norm / thresholds[3]))
)
s = s.astype(mx.int32)
A_scaled = A / mx.expand_dims(mx.expand_dims(2.0**s, -1), -1)
# Compute exponential of scaled matrix
X = _compute_T8(A_scaled)
# Square back up
max_s = int(mx.max(s).item())
for _ in range(max_s):
X = mx.where(s > 0, X @ X, X)
s = s - 1
return X
Full implementation w/ optimization and custom vjp for smol kernels, all, again, yoinked from pytorch: https://gist.github.com/N8python/b3e24a4f88efa52bdd81a8762b7a7238.
For two 1024x1024 matrices initialized with randn(0, 0.1), the provided matrix_exp implementation diverges from that of pytorch by a maximum absolute different of 0.000975.
For two 1024x1024 matrices initialized with randn(0, 1) - intentionally made to have diverging eigenvalues - the average difference, as a percent of the maxmimum element in the torch computation, is 0.03124% in absolute terms - unnormalized, it is roughly 1311244288.0 due to diverging eigenvalues.
Fundamental algorithm from:
https://raw.githubusercontent.com/pytorch/pytorch/7f18ef14c1fed4e4376a75d626d98ba3c074809c/aten/src/ATen/native/LinearAlgebra.cpp
This is very nicely done :-). I especially love the gist with the custom function...
I think it is not quite ready to be made into an op yet, but it definitely raises some nice issues for us to solve. The main problem I see is the implicit graph evaluation in a couple of places (matrix norm conditionals and looping).
I think this may be a good example for an if and while that do not cause a graph evaluation.
Agreed. But how would I go about doing that - given that the if on the norm is neccessary to choose the correct polynomial, and the scale factor - which also depends on the norm.
Or are you saying this needs to be implemented: I think this may be a good example for an if and while that do not cause a graph evaluation.
+1. I'm doing quantum computing simulation on my Mac with PyTorch, but PyTorch has incomplete and slow MPS support. I'm considering implementing quantum simulation algorithms in MLX, but I need matrix exponentiation and other ops. Maybe a workaround besides this gist is to bridge the forward and backward passes between MLX and PyTorch?