feat(autogram): Add `ModuleBasedGramianComputer`.
This is big for architectures with very big linear layers, like AlexNet. For AlexNet, on cuda, with batch_dim=0, this leads to:
- Double max batch size (from batch_size=19 to batch_size=38 on my gpu). Sadly, this is still very far from the max batch size of 1268 of SGD with autograd). Do you think there might be a theoretical way to bridge this gap even more? EDIT: I have no idea why, but re-running the same tests (or maybe I did a mistake before) yields completely different results. Max batch size is now 18 for main and 468 for this PR - much, much closer to the 1268 of autograd). EDIT2: the big memory improvement (and a small speed improvement) comes from just installing opt_einsum (without even changing the code).
uv pip install opt_einsum. - x2 to x4 speed (depending on the batch size) of the whole autogram_forward_backward function (so this includes not only the gramian computation of the linear layers, but also of all other layers + the forward passes).
For other architectures, the differences are not very noticeable though. But this is very promising. Let's fully focus on this direction IMO.
Also pretty big for Transformers (with the change i suggested to handle higher order tensors).
Times for forward + backward on WithTransformerLarge with BS=256, A=Mean on cuda:0. Reduced from 3.13 sec (main) to 2.20 (this PR)
Memory is however increased. Max batch size went from 273 (main) to 256 (this PR).
This seems to break NoFreeParam (tiny errors) and ModuleReuse (large errors). Need to investigate that.
For ModuleReuse, my guess is that it simply doesn't consider cross terms anymore, so it's normal that it fails the test.
Of interest: https://optimized-einsum.readthedocs.io/en/stable/reusing_paths.html If we can try to explore what contraction is the optimal and if it is essentially always the same, then we may want to use a self forged contraction. It would be very helpful to know the optimal contraction order.
@PierreQuinton I found a way to compute the gramian with autograd with no cross terms from module reuse / inter-module param reuse: 30fdc0078be5. Basically, the idea is to have a module pre-hook that clones each parameter before using them, and a module post-hook that restores to the original params. This way, each module usage corresponds to a different clone, and you can compute a gradient wrt each clone. The implementation is with a context manager so it's quite clean IMO. Current limitations:
- Does not work on WithMultiheadAttention, WithTransformer and WithFreeParam, because they all involve some indirect parameters for the hooked module. Need to investigate and fix that (it's probably doable).
- Still counts cross-terms from intra-module parameter reuse: we'd need a node-based algo (rather than module-based) to fix that. But since autogram is still module based, it doesn't matter yet.