CodeGen icon indicating copy to clipboard operation
CodeGen copied to clipboard

How did you train the large-sized models without out-of-memory?

Open jiang719 opened this issue 3 years ago • 1 comments

I would like to fine-tune the 2B model, but I got the out-of-memory issue even with the batch size setting to 1 (on a single GPU with 24G memory).

I wonder what devices you used to pre-train the 2B and 16B models? How did you address the memory issue? Did you parallel the model by layers on different GPUs? Thank you.

Nan

jiang719 avatar Aug 06 '22 09:08 jiang719

The models were pre-trained in JAX and TPU-v4 hardware and then later converted to PyTorch for sampling.

The training code in JAX will be released soon.

You may try to fine-tune the models in PyTorch using DeepSpeed:

https://news.ycombinator.com/item?id=32331764

enijkamp avatar Aug 07 '22 20:08 enijkamp

Training code in JAX has been released: https://github.com/salesforce/CodeGen/issues/16#issuecomment-1262799121

xanderdunn avatar Oct 04 '22 02:10 xanderdunn

@jiang719 Here is DeepSpeed fine-tuning code with CPU parameter offloading, so that you should be able to avoid OOM:

https://github.com/salesforce/jaxformer/blob/main/jaxformer/hf/train.py

enijkamp avatar Oct 04 '22 03:10 enijkamp