Consistent convergence checks and options
Related to #470, I have been wondering if we should make the convergence handling consistent and reusable across methods. Currently, methods handle convergence checks in slightly different ways—some have an absolute threshold (PARAFAC2), or a relative threshold set according to the absolute difference in error (CP) or fractional difference in error (PARAFAC2). Finally, there is repeated code for storing a history of the error, checking these thresholds, printing the error at each iteration, etc.
One possibility would be to have an error tracker class, that is instantiated by any method. This class takes the tolerances on instantiation and handles storing a history of the errors. It has one method that is called on each iteration of a method, and returns a boolean as to whether the method should continue. Printing the per-iteration error is handled by this method (with a verbose flag).
My impression is that this would clean up a lot of code and force methods to have consistent convergence criteria. Note that the method itself would still have to handle calculating the absolute error on each iteration (so efficiencies like reuse of MTTKRP in CP can still occur).
Thanks @aarmey. I am a little conflicted by such refactoring. Overall I like to have each algorithm being clearly readable from source and not spread out important bits in different files. Convergence can be pretty method specific as you mentioned, so I think it makes sense to do it in each file. Plus, the code to store error, and check threshold is pretty trivial so I don't think there is a need to centralize that. I often use scikit-learn as reference, they manage really well the balance features/complexity despite the growing codebase.
I almost find the callback option already overkill in some ways but if we really want to refactor I guess we could use a callback to let users store the error history and not do that ourselves in the also.
In any case, we could probably use some harmonization of the error metrics across the codebase as you're saying.
Indeed, I was worried this might be too clever and impact readability. Agreed that we can do some work to unify the convergence options, but leave it at that.
Wanted to check in - what do you think @aarmey ?
I would still like to go back through methods and make the tolerance checks defined in a consistent way. This shouldn't be too difficult, so I'll hopefully have a PR in a couple weeks.
@JeanKossaifi @cohenjer what is the purpose of cvg_criterion? If the error is strictly decreasing, shouldn't both options lead to the same behavior?
https://github.com/tensorly/tensorly/blob/59c1126251b248d85d5333477d91c50eae3b2864/tensorly/decomposition/_cp.py#L285
In some algorithms indeed they should be the same, but for others, the error might increase. For instance when using line search. It is not a big deal, we might only support absolute relative error check, what do you think @JeanKossaifi ?
Also, after #550 goes through, I plan to push another PR where I add the callback API to a few decomposition functions. Is this also the kind of push you had in mind ?
For parafac(), the line search step is only accepted if it decreases the fitting error. Consequently, this option should have no effect. I'll open a PR removing the option, then solicit more feedback there.