DMPfold2
DMPfold2 copied to clipboard
Update network.py
As of torch 2.x torch.symeigh() is no longer available. torch.linalg.eigh can be used to replace this instead
Just leaving this note here for anyone who might consider this:
This change will indeed need to be pulled if we ever intend to support newer GPUs, for example.
However, I did test torch.linalg.eigh() when the deprecation was first announced; the output coordinates are not identical to those produced when symeig() is used, and there is a small drop in model accuracy w.r.t. symeig().
Ideally we would need to retrain/finetune the model with eigh as part of the compute graph in order to do this "properly".