mlx icon indicating copy to clipboard operation
mlx copied to clipboard

Feature-Request: Matrix Exponentiation

Open N8python opened this issue 1 year ago • 7 comments

Useful for orthogonal optimization and other algorithms where calculating unitary matrices or dealing with complex numbers is necessary.

Pytorch impl: https://discuss.pytorch.org/t/what-implementation-is-used-for-matrix-exp/159608/5

https://pytorch.org/docs/stable/generated/torch.linalg.matrix_exp.html

N8python avatar Jan 19 '25 18:01 N8python

In-python implementation, yoinked from torch and ported w/ Claude - appears to work in training, though:


def _compute_T1(A):
    """I + A"""
    return mx.eye(A.shape[-1]) + A

def _compute_T2(A):
    """I + A + A^2/2"""
    A2 = A @ A
    return mx.eye(A.shape[-1]) + A + A2/2

def _compute_T4(A):
    """I + A + A^2 * (I/2 + A/6 + A^2/24)"""
    A2 = A @ A
    inner_term = (mx.eye(A.shape[-1])/2 + A/6 + A2/24)
    return mx.eye(A.shape[-1]) + A + (A2 @ inner_term)

def _compute_T8(A):
    sqrt_177 = 0.1330413469565007072504e+2
    x3 = 2/3
    x1 = x3 * ((1 + sqrt_177) / 88)
    x2 = x3 * ((1 + sqrt_177) / 352)
    x4 = (-271 + 29 * sqrt_177) / (315 * x3)
    x5 = (-11 + 11 * sqrt_177) / (1260 * x3)
    x6 = (-99 + 11 * sqrt_177) / (5040 * x3)
    x7 = (89 - sqrt_177) / (5040 * x3)
    y2 = (857 - 58 * sqrt_177) / 630

    A2 = A @ A
    A4 = A2 @ (x1*A + x2*A2)
    A8 = (x3*A2 + A4) @ (x4*mx.eye(A.shape[-1]) + x5*A + x6*A2 + x7*A4)
    
    return mx.eye(A.shape[-1]) + A + y2*A2 + A8

def matrix_exp(A):
    """
    Computes matrix exponential using optimized Taylor series.
    Based on PyTorch's implementation from the paper:
    Bader, P.; Blanes, S.; Casas, F.
    Computing the Matrix Exponential with an Optimized Taylor Polynomial Approximation.
    """
    if A.shape[-2:] == (0, 0):
        return A.clone()
    elif A.shape[-2:] == (1, 1):
        return mx.exp(A)

    # Compute the matrix norm to choose degree
    matrix_norm = mx.max(mx.sum(mx.abs(A), axis=-2), axis=-1)
    
    # These thresholds are from PyTorch's implementation
    # They're carefully chosen based on the paper
    if A.dtype == mx.float32:
        thresholds = [
            1.192092800768788e-07,  # deg 1
            5.978858893805233e-04,  # deg 2
            5.116619363445086e-02,  # deg 4
            5.800524627688768e-01,  # deg 8
            1.461661507209034e+00,  # deg 12
            3.010066362817634e+00   # deg 18
        ]
    else:  # float64
        thresholds = [
            2.220446049250313e-16,  # deg 1
            2.580956802971767e-08,  # deg 2
            3.397168839976962e-04,  # deg 4
            4.991228871115323e-02,  # deg 8
            2.996158913811580e-01,  # deg 12
            1.090863719290036e+00   # deg 18
        ]

    # For small norms use lower degree approximations
    if matrix_norm <= thresholds[0]:
        return _compute_T1(A)
    elif matrix_norm <= thresholds[1]:
        return _compute_T2(A)
    elif matrix_norm <= thresholds[2]:
        return _compute_T4(A)
    elif matrix_norm <= thresholds[3]:
        return _compute_T8(A)

    # For larger norms use scaling and squaring with T8
    s = mx.maximum(
        mx.zeros_like(matrix_norm),
        mx.ceil(mx.log2(matrix_norm / thresholds[3]))
    )
    s = s.astype(mx.int32)
    A_scaled = A / mx.expand_dims(mx.expand_dims(2.0**s, -1), -1)
    
    # Compute exponential of scaled matrix
    X = _compute_T8(A_scaled)
    
    # Square back up
    max_s = int(mx.max(s).item())
    for _ in range(max_s):
        X = mx.where(s > 0, X @ X, X)
        s = s - 1
        
    return X

N8python avatar Jan 19 '25 18:01 N8python

Full implementation w/ optimization and custom vjp for smol kernels, all, again, yoinked from pytorch: https://gist.github.com/N8python/b3e24a4f88efa52bdd81a8762b7a7238.

For two 1024x1024 matrices initialized with randn(0, 0.1), the provided matrix_exp implementation diverges from that of pytorch by a maximum absolute different of 0.000975.

N8python avatar Jan 19 '25 22:01 N8python

For two 1024x1024 matrices initialized with randn(0, 1) - intentionally made to have diverging eigenvalues - the average difference, as a percent of the maxmimum element in the torch computation, is 0.03124% in absolute terms - unnormalized, it is roughly 1311244288.0 due to diverging eigenvalues.

N8python avatar Jan 19 '25 22:01 N8python

Fundamental algorithm from:

https://raw.githubusercontent.com/pytorch/pytorch/7f18ef14c1fed4e4376a75d626d98ba3c074809c/aten/src/ATen/native/LinearAlgebra.cpp

N8python avatar Jan 19 '25 22:01 N8python

This is very nicely done :-). I especially love the gist with the custom function...

I think it is not quite ready to be made into an op yet, but it definitely raises some nice issues for us to solve. The main problem I see is the implicit graph evaluation in a couple of places (matrix norm conditionals and looping).

I think this may be a good example for an if and while that do not cause a graph evaluation.

angeloskath avatar Jan 21 '25 18:01 angeloskath

Agreed. But how would I go about doing that - given that the if on the norm is neccessary to choose the correct polynomial, and the scale factor - which also depends on the norm.

Or are you saying this needs to be implemented: I think this may be a good example for an if and while that do not cause a graph evaluation.

N8python avatar Jan 23 '25 00:01 N8python

+1. I'm doing quantum computing simulation on my Mac with PyTorch, but PyTorch has incomplete and slow MPS support. I'm considering implementing quantum simulation algorithms in MLX, but I need matrix exponentiation and other ops. Maybe a workaround besides this gist is to bridge the forward and backward passes between MLX and PyTorch?

ifsheldon avatar Apr 20 '25 14:04 ifsheldon