deepspeed训练m3模型,OOM
四卡40gH卡,batch开32,开了gradient_checkpointing,query_max_len 512 ,passage_max_len 8192 训着训着会突然OOM,求问稳定的解决方案,(尝试了batch开4,16,32都会炸)
需要开启sub_batch_size,https://github.com/FlagOpen/FlagEmbedding/blob/master/FlagEmbedding/BGE_M3/modeling.py#L156 在比较长的长度下,能明显提高batch size。
感谢您的提醒,不知道能否获得一些参数,如果我想在bgm-m3-unsupervised上接着训练,长度为512+8192(使用bge-m3数据集),使用deepspeed,需要什么样的显存和配置(或者推荐batch_size多大呢,我看到之前是建议>64)。
如果您想在 bge-m3-unsupervised 上接着训练,且您的训练数据中既有长数据又有短数据,那推荐您使用 efficient batching 策略:
- 把训练数据按照长度划分成不同的文件,使用这里的脚本:https://github.com/FlagOpen/FlagEmbedding/blob/master/FlagEmbedding/BGE_M3/split_data_by_length.py
- 根据您的机器的显存情况在这里设置各个长度下的 batch_size:https://github.com/FlagOpen/FlagEmbedding/blob/master/FlagEmbedding/BGE_M3/data.py#L122
我们训练时使用的是 80GB 显存的 A800,用的各个长度下的 batch size 在这里也有给出,您可以参考下。
关于设置各个长度下的 batch size 的方式,我们也没有特别好的办法,只能手动去测试,测试方式大概是:
- 准备一份测试数据,数据量不用很大,文件名中不要包含长度后缀 (即用于在 get_file_batch_size 中判断 batch_size 的标志,如 len-0-500, len-500-1000 等);
- 修改这里的 padding 方式:https://github.com/FlagOpen/FlagEmbedding/blob/master/FlagEmbedding/BGE_M3/data.py#L293;
- 在脚本中设置相应的 max_length 和 batch_size 进行测试观察显存使用情况,大概跑 3 个 step 没有报 OOM 就可以停止测试了。
关于其他的一些参数,可以参考这里的样例脚本:https://github.com/FlagOpen/FlagEmbedding/blob/master/examples/unified_finetune/unified_finetune_bge-m3_exmaple.sh
Efficient Batching 在m3 的pretrain 中有脚本吗?类似https://github.com/FlagOpen/FlagEmbedding/blob/master/FlagEmbedding/BGE_M3/split_data_by_length.py 这样的脚本