Results 14 comments of Erik Nijkamp

The models were pre-trained in JAX and TPU-v4 hardware and then later converted to PyTorch for sampling. The training code in JAX will be released soon. You may try to...

@jiang719 Here is DeepSpeed fine-tuning code with CPU parameter offloading, so that you should be able to avoid OOM: https://github.com/salesforce/jaxformer/blob/main/jaxformer/hf/train.py

The converted PyTorch models can be fine-tuned similarly to other causal LMs in HuggingFace. See tutorials like http://reyfarhan.com/posts/easy-gpt2-finetuning-huggingface/.

@TheodoreGalanos Working on a release for the JAX coding. I trained the models on TPU-v4 and have to resolve a blocker for v3.

@thisisanshgupta @Ontopic Yes, I'm working on the release of my training library for TPU-v3/v4 and will keep you posted.

@smith-co @thisisanshgupta @tlkh For torch, I wrote up a minimal example in deepspeed, which can train the 16B on a ~24 GB gpu. You would need to sanity test this,...

@smith-co @thisisanshgupta @tlkh @Ontopic @TheodoreGalanos @shmuelhizmi A first release of the training code for TPU-v3/v4 is here: https://github.com/salesforce/jaxformer

Your requirements.txt states tensorflow==2.0.0, see https://github.com/ParthaEth/Regularized_autoencoders-RAE-/blob/master/requirements.txt#L21 Please fix.

Hi @Aryagm, The log-likelihood and sampling can be run on TPUs using JAX. I'm currently verifying the training code on TPU-v3 and then will release the implementation. I can include...

@xanderdunn @Aryagm I'm working on adding the pjit'ed sampling code as we speak, sorry for those delays.