MPI+SparseTimeFunction: More efficient _dist_gather
In the context of checkpointing, it's a significant overhead that upon returning from C-land we redistribute the entire SparseTimeFunction while potentially only a relatively small number of time iterations have been computed
This should be easily fixable by plumbing the args down to _arg_apply and then to _dist_gather so that we only retain the written region of the data array
Probably a very similar fix if so, but does this potentially apply to _dist_scatter before going to C also?
Should we drop the "gather" completely and only scatter once and be done?
I think there might be side effects if one expects certain data to be on a certain rank while it actually is somewhere else.
Perhaps we can:
- scatter upon startup (ie very first SparseFunction initialization)
- gather on-the-fly each time
sf.datais accessed (ie we return a view)
this way would be basically 0 overhead during forward and backward propagation