jax icon indicating copy to clipboard operation
jax copied to clipboard

support jax.numpy.linalg.eig eigenvector derivatives

Open froystig opened this issue 5 years ago • 54 comments

Note that eigh is already taken care of.

froystig avatar Apr 17 '20 00:04 froystig

I'm pretty sure we could achieve this with a minor variation of our existing JVP rule for eigh, replacing U.T.conj() -> inv(U) (of course it should really use an LU solve rather than computing the inverse directly).

shoyer avatar Apr 17 '20 16:04 shoyer

Just wanted to throw in a +1 for wanting this to be implemented.

ianwilliamson avatar May 08 '20 19:05 ianwilliamson

@shoyer do you have a reference for it? I've just been working through the math by hand and it seems what you said is correct, except that you have to do a slightly awkward correction to ensure that dU.T @ U has ones down the diagonal (which I think is required - this comes from the constraint that the eigenvectors are normalized). Anyway I think I will draft an implementation today.

Edit: It's implemented in Autograd https://github.com/HIPS/autograd/blob/master/autograd/numpy/linalg.py#L152-L173, with a reference to https://arxiv.org/pdf/1701.00392.pdf, eq 4.77.

Edit 2: The jvp equations in that paper are 4.60 and 4.63, but I think 4.63 (the jvp for the eigenvectors) is wrong. The statement above 4.63 ("...can not influence the amplitude of the eigenvectors...") is correct but I don't think they translated that constraint correctly into math. I've tried implementing their version, and my own, neither are working yet so not 100% sure whether I'm right about this.

j-towns avatar May 11 '20 11:05 j-towns

Also @shoyer how should I do a solve (inv(a) @ b for square matrices matrices a and b)? I think I can't use jax.numpy.linalg.solve from jax.lax_linalg because of circular dependency.

For now I'll use inv as you suggested above.

j-towns avatar May 11 '20 12:05 j-towns

Section 3.1 from this reference in a comment under eigh_jvp_rule (in lax_linalg.py) works through the general case of how to calculate eigenvector derivative: https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf

shoyer avatar May 11 '20 17:05 shoyer

Also @shoyer how should I do a solve (inv(a) @ b for square matrices matrices a and b)? I think I can't use jax.numpy.linalg.solve from jax.lax_linalg because of circular dependency.

My suggestion would be to use a local import, in the JVP rule, e.g.,

def eig_jvp_rule(...):
  from jax.numpy.linalg import solve
  ...

You could also try refactoring, but this the usual hack for circular dependency challenges.

shoyer avatar May 11 '20 17:05 shoyer

I'd be tempted to at least try the refactoring of moving the guts of solve into lax_linalg.py.

hawkinsp avatar May 11 '20 18:05 hawkinsp

Cool, I've moved it in the draft pr, wasn't too bad to do. Still getting incorrect values for the eig derivatives though 😔.

j-towns avatar May 12 '20 11:05 j-towns

The JVP seems to be correct now that I've relaxed the test tolerance slightly, but the VJPs are way out, I'm not sure why that is yet.

j-towns avatar May 12 '20 13:05 j-towns

I notice that testEighGrad is currently skipped because 'Test fails with numeric errors', I wonder if the problems I'm seeing are related, since the eig jvp is mostly copied from the eigh jvp.

j-towns avatar May 12 '20 13:05 j-towns

OK I think I know why what I have is incorrect. Eigenvectors are only unique up to (complex) scalar multiple. The eigenvectors returned by numpy.linalg.eig are normalized so that they have length 1 (I already knew this), and also so that their largest component is real (see http://www.netlib.org/lapack/lapack-3.1.1/html/dgeev.f.html). That constraint I was not previously aware of and I think it might take some work to correct the derivations + implementation that I have.

j-towns avatar May 12 '20 15:05 j-towns

A similar issue might also explain why the derivative tests for eigh are failing - the eigenvectors are normalized so they have length 1, but are still only unique up to multiplication by a complex scalar whose absolute value is 1 (i.e. there is one degree of freedom per eigenvector). It's not clear from the low level eigh docs (http://www.netlib.org/lapack/lapack-3.1.1/html/zheevd.f.html) how this non-uniqueness is addressed.

Edit: just running np.linalg.eigh on a couple of inputs it looks like the eigenvectors are normalized so that the first component of each is real. It seems a bit strange that eigh uses a different convention to eig, and this means that you'll get np.linalg.eigh(x) != np.linalg.eig(x) for complex, hermitian x. The eigh convention should be easier to differentiate, and maybe we should change our eig_p primitive to match the eigh convention, so that lax_linalg.eig(x) == lax_linalg.eigh(x) for all hermitian x.

j-towns avatar May 12 '20 15:05 j-towns

We could certainly pick a new convention for normalizing vectors from eig if that makes it easier to differentiate. The downside is that this would probably require a bit more computation. If it's only O(n^2) time, I would say definitely go for it, maybe more questionable if we need dense matrix/matrix multiplication which O(n^3). In the later case we might add an optional argument for making eig differentiate.

For what it's worth, I have feeling that the right answer for how to differentiate eig/eigh in most cases is don't, precisely because eigen-decomposition is often not a well defined function. The right function to differentiate is something downstream of eigen-decomposition where the outputs of the numerical method become a well defined function, e.g., the result of a matrix power series calculation. If we can characterize the full set of such "well defined functions of eigen-decompositions" then perhaps those are the right primitives for which to define auto-diff rules.

shoyer avatar May 12 '20 16:05 shoyer

Yeah I agree. It would be very weird and likely a bug if a user implemented a function that depended on the length of an eigenvector, since the normalization is essentially an implementation detail. Catering for these design decisions with correct derivatives is also really awkward, so maybe we should indeed look for another level at which to provide derivatives.

j-towns avatar May 12 '20 16:05 j-towns

@ianwilliamson do you have a use case for eig derivatives? Would be useful to know a bit about it.

j-towns avatar May 12 '20 16:05 j-towns

Also I think it is reasonable to support eigh of a real symmetric matrix, where there is a pretty obvious and straightforward unique value and derivative.

j-towns avatar May 12 '20 16:05 j-towns

Also I think it is reasonable to support eigh of a real symmetric matrix, where there is a pretty obvious and straightforward unique value and derivative.

Even eigh is only uniquely defined if the eigenvalues are unique. If degeneracies are valid (common in many applications) it isn't a well defined function.

shoyer avatar May 12 '20 16:05 shoyer

We had a related discussion when I was fixing eigh in autograd: https://github.com/HIPS/autograd/pull/527

Essentially, the vjp there works for objective functions that do not depend on the arbitrary phase of the eigenvectors (the "gauge"), and the tests are written for such functions. This is because in a general solver this phase is just arbitrary, so even finite-difference derivatives won't work, i.e. eig(X) and eig(X + dX) can spit out eigenvectors with arbitrary phase difference. It sounds like in jax you are actually setting the gauge (largest element to be real), so you could try to make the vjp account for that and match the finite-difference derivative under that gauge, but I think you can't really expect the user to know that you're doing that. Meaning that if I'm a user and a have a function that depends on the phase of an eigenvector, the correct way to do it is to manually set the gauge to whatever I want it to be, in a way tracked by jax. Or in other words: you can first get the vjp to work for gauge-independent functions, and then add the normalization on top of that.

The problem with degeneracies is harder. In one of my packages, I purposefully add small noise to avoid symmetries that result in degeneracies, but that's obviously a workaround. Here's a paper that could provide some indication on how this could be handled ("generalized gradient"), but I don't really understand it well: https://oaktrust.library.tamu.edu/bitstream/handle/1969.1/184232/document-1.pdf?sequence=1

momchilmm avatar May 12 '20 17:05 momchilmm

@ianwilliamson do you have a use case for eig derivatives? Would be useful to know a bit about it.

Just in case this is helpful, I've never actually had a use case for eig derivatives per se since I've always had Hermitian matrices available, but the last time I reached for an eigh derivative it was because I needed to find a representative set of inputs yielding a singular Jacobian (a determinant derivative would have worked fine, but that was a bit slower and more unstable iirc -- I stopped searching when eigh derivatives were good enough). The scipy.linalg package was more helpful to me than the numpy wrapper because of its ability to single out a range of eigenvalues.

Most natural uses of an eig derivative I think would follow a similar pattern of having a deterministic scheme for choosing a particular eigenvalue (smallest magnitude, largest real part, etc) that relates to the problem being studied, or perhaps as inputs to a symmetric function.

I know you asked this in the context of eigenvector normalization, and fwiw I've always had to normalize them myself in whichever way suites the current problem and have never needed their derivatives except to compute higher-order derivatives of eigenvalues. Sorry I can't be more help there.

The problem with degeneracies is harder. In one of my packages, I purposefully add small noise to avoid symmetries that result in degeneracies, but that's obviously a workaround. Here's a paper that could provide some indication on how this could be handled ("generalized gradient"), but I don't really understand it well: https://oaktrust.library.tamu.edu/bitstream/handle/1969.1/184232/document-1.pdf?sequence=1

Jax doesn't support subgradients at all does it? E.g., grad(abs)(0.)==1 even though the subdifferential there is the entire closed interval [-1, 1].

hmusgrave avatar May 30 '20 17:05 hmusgrave

Jax doesn't support subgradients at all does it? E.g., grad(abs)(0.)==1 even though the subdifferential there is the entire closed interval [-1, 1].

Ohh I see what this is about. Yeah I wouldn't expect this to be something that will be supported in jax. By the way, #3112 and #3114 might also be of interest to you.

momchilmm avatar Jun 01 '20 16:06 momchilmm

Hello, just wondering what the status is of the implementation of the np.linalg.eig function in JAX? I am working as quantum physicist and really like the JAX library, I successfully used it for an optimization problem involving the the eigh function in a previous project. For a new project however I am dealing with non-hermitian matrices so I require the eig function.

LuukCoopmans avatar Jun 05 '20 13:06 LuukCoopmans

@LuukCoopmans np.linalg.eig is implemented but its derivatives are not. Do you need to be able to differentiate eig?

j-towns avatar Jun 08 '20 11:06 j-towns

@j-towns yes I need to be able differentiate it.

LuukCoopmans avatar Jun 09 '20 09:06 LuukCoopmans

Cool, as you can see in the comments above, the derivative for the eigen-vectors is quite awkward to get right because they’re only defined up to ‘gauge’ (that is up to multiplication by a complex scalar with absolute value 1).

@LuukCoopmans sorry to keep quizzing you, but does your objective function depend on the whole output of eig or just on the eigenvalues? The latter might be easier to support.

In the short term you might be interested in using JAX’s custom_jvp and custom_vjp for implementing your own workarounds where we haven’t managed to implement derivatives, like in this case.

j-towns avatar Jun 14 '20 20:06 j-towns

@j-towns actually I find this an interesting problem, in physics the quantum wavefunction (an eigenvector) is always defined up to a 'gauge' the same way as you describe. However, this gauge is usually not important for the calculation of interested quantities (expectation values) because it gets multiplied out, like say O is a matrix and v is the eigenvector of some other matrix O' then we are interested in quantities v.T.conj()Ov. Also in my case I eventually take the absolute value squared of the eigenvector so the phase is not important. I can however see that for the derivative this might give a problem, because the gauge on the eigenvector and the derivative can come back different if I am correct?

LuukCoopmans avatar Jun 16 '20 09:06 LuukCoopmans

However, this gauge is usually not important for the calculation of interested quantities (expectation values) because it gets multiplied out

This is my experience as well. I've only ever needed eigenvector derivatives in scenarios where the gauge didn't matter to the final calculation. I usually did need a particular magnitude, e.g. normalizing to |v|=1; jax does however easily support differentiating that normalization step, so I'm not sure that really matters for an eig() derivative.

Also in my case I eventually take the absolute value squared of the eigenvector so the phase is not important.

This quantity isn't uniquely defined either, and it's similar to the gauge problem. Eigenvectors are only unique up to a non-zero constant multiple from the relevant field.

hmusgrave avatar Jun 16 '20 15:06 hmusgrave

As I noted above in https://github.com/google/jax/issues/2748#issuecomment-627444706, I think np.linalg.eig is rarely the right level at which to calculate derivatives. We have conventions for how to pick the gauge for calculations, but those aren't necessarily consistent with derivatives. I think the problem of calculating reverse mode derivatives of eig may be fundamentally undefined from a mathematical perspective -- there does not necessarily exist a single choice of gauge for which the eig function is entirely continuous.

Instead, we need higher level auto-diff primitives, corresponding to well defined functions that are invariant of gauge. For example, we can calculate derivatives for any matrix-valued function of a Hermitian matrix (see https://github.com/FluxML/Zygote.jl/pull/355). We could add helper functions for calculating these sorts of things, ideally with support for calculating the underlying functions in flexible ways (e.g., using eig internally).

shoyer avatar Jun 16 '20 16:06 shoyer

Instead, we need higher level auto-diff primitives, corresponding to well defined functions that are invariant of gauge.

That makes sense. Do you think it's still reasonable to support eigenvalue derivatives except on the measure-zero sets where they don't exist (either raising an error or providing a default value in such cases, sort of like how abs is handled)?

hmusgrave avatar Jun 16 '20 16:06 hmusgrave

Do you think it's still reasonable to support eigenvalue derivatives except on the measure-zero sets where they don't exist (either raising an error or providing a default value in such cases, sort of like how abs is handled)?

Yes, absolutely!

This is basically what we do currently for eigh. If there are degeneracies, then the derivative with respect to the eigenvectors will be all NaN.

shoyer avatar Jun 16 '20 17:06 shoyer

Is there a straightforward way for us to provide eigenvalue derivatives without providing eigenvector derivatives (since this gauge issue only affects evectors afaict)? Do you think we ought to have a primitive which only returns eigenvalues?

j-towns avatar Jun 17 '20 09:06 j-towns