Alex Havrilla
Alex Havrilla
### System Info ```Shell - `Accelerate` version: 0.11.0 - Platform: Linux-5.10.112-108.499.amzn2.x86_64-x86_64-with-glibc2.2.5 - Python version: 3.8.5 - Numpy version: 1.23.1 - PyTorch version (GPU?): 1.12.0+cu113 (True) - `Accelerate` default config: -...
Ppo z3
Work in progress integrating zero3 with hydra models for ppo. Current implementation works for models < 6B but OOMs on 6B.
### 🚀 The feature, motivation, and pitch Add jax support for RLHF on TPUs. ### Alternatives _No response_ ### Additional context _No response_
Carp config requires a device which needs to be changed for multi-gpu training
Implementation of multi-generation RL in trlX Suggested (but optional) external inference pipeline wrapper can be found[ here](https://github.com/CarperAI/autocrit/pull/16)
Implementing `ref_model` as an additional reward component
### 🚀 The feature, motivation, and pitch Implementing an asynchronous PPO mitigates model rollout/exploration as the largest bottleneck in the training process. ### Alternatives _No response_ ### Additional context _No...