implicit-hyper-opt icon indicating copy to clipboard operation
implicit-hyper-opt copied to clipboard

How to calculate d_train_loss_d_w

Open Shancong-Mou opened this issue 2 years ago • 6 comments

Hi Thanks for sharing the code!

I have a question on how to calculate the d l_train/ d w. Shall we use all trianing samples or at leat a few traning batches and then take the average for d l_train/ d w? I notice that in your code, it is just calculated for one batch. Could you please explain more about it?

        for batch_idx, (x, y) in enumerate(train_loader):
            train_loss, _ = train_loss_func(x, y)
            optimizer.zero_grad()
            d_train_loss_d_w += gather_flat_grad(grad(train_loss, model.parameters(), create_graph=True))
            break
        optimizer.zero_grad()

Shancong-Mou avatar Aug 29 '23 01:08 Shancong-Mou

We estimate the terms over mini-batches, which are are the same batch size we do for optimizing the training loss. Perhaps other sized mini-batches - ex., bigger - could be useful, but we did not thoroughly investigate this axis. Or maybe using gradient accumulation could help. Or, maybe different mini batch sizes could use useful for the training vs validation loss evaluations. It might also be interesting to investigate how batch size affects the benefit of using more Neumann terms?

But, we probably want to avoid evaluating terms over the whole dataset, else the method won’t scale to large datasets. I think just setting the validation and training batch sizes the same, and maxing them out to fill memory of your compute with 1-2 Neumann terms is simple and likely gets most of the benefit.

I hope this helps!

lorraine2 avatar Aug 29 '23 01:08 lorraine2

Thanks for your timely reply!

Could you pleaes let me know if the following understand is correct?

  1. In the paper, $L_v$ and $L_T$ should be the validation loss for the whole validation set and training loss for the whole traning set;
  2. When implementing the algorithm, minibatch is used to approximate those terms (as well).

However, for optimizing the training loss, small batch size seems ok, but if the batch size is too small (like 2) for apprximating the d_train_loss_dw and d_val_loss_d_theta, this minibatch estimation can be unstable. Do you have any recommendation on this?

Thank you!

Shancong-Mou avatar Aug 29 '23 02:08 Shancong-Mou

In the paper we were not explicit about when to use mini batches with that notation. But, in practice, every time we evaluate L_t and L_v for updates we use mini batches else the method would not work on any reasonable sized dataset.

I suspect that the batch sizes used when evaluating these terms should be at least a size which is stable for updating the inner neural network weights with just L_t. But, if this doesn’t fit in memory for you, you could use gradient accumulation, or try reducing the batch size until it fits in memory. So, if you use batch size 32 to update the network weights I’d try this first, before something super small like 2. I expect the hyper gradient to be higher variance than the inner parameter gradient, which means you want a larger batch size or smaller learning rate/optimizer parameters to account for noise.

We were also not explicit about when you should re-sample the mini-batches. But, I don’t remember there being a significant difference from re-using the same L_t for each term in the Neumann series, and your inner parameter updates. If you re-sample it will be lower variance but cost more memory.

lorraine2 avatar Aug 29 '23 03:08 lorraine2

I see. Thanks for the detaield explaination! The reason I am using a small batch size is because each image in my application is 256x800, it will easily be out of memory if I use more than 2 (maybe I have to do some optimziation of the memory usage). For the accumulation of gradient, since the computing graph of "d_train_loss_d_w" of each batch need to be retained (retain_graph == True) for the caculation of Hessian, it still requires lots of memory to do so. I am wondering if you have any idea on this.

Thanks!

Shancong-Mou avatar Aug 30 '23 01:08 Shancong-Mou

For gradient accumulation I meant: first evaluate the hypergradient (i.e., the hyperparameter update) using whatever batch size fits in memory (ex., bs=1 for Lv and Lt). Then store this hypergradient and repeat the computation N times. Then make the final update you apply equal to the average of all your hypergradients.

I think this should work for batch size = 1 on Lt and Lv given you average enough hypergradients together. This also avoids maintaining the compute graph between hypergradient evaluations, so it shouldn’t cost any extra memory. But, it will be slower since you aren’t batching over examples anymore.

I think this could be implemented by taking the hyper_step function and doing a loop over the evaluation of hypergrad, storing them in a list (or just a running avg), and only applying the update with average over N hypergrads. N would be a hyperparameter so I’d make sure N=1 is same result as before and N>1 does strictly better. Maybe you need to empty a cache but it should do it automatically and definitely doesn’t require anything stored between hypergrad evaluations.

lorraine2 avatar Aug 30 '23 02:08 lorraine2

Thanks! That makes sense! I will try the accumulation of gradient in this way! 😊

Shancong-Mou avatar Aug 30 '23 17:08 Shancong-Mou