libai
libai copied to clipboard
Rwkv v4 merge main
rwkv 代码修正版:
跑loss对齐时的需要手动修改的代码:
- 注释
libai/engine/default.py的 729~730行:
# Global scheduler cfg
# cfg.train.scheduler.warmup_iter = cfg.train.warmup_iter
# cfg.train.scheduler.max_iter = cfg.train.train_iter
- 把
libai/data/build.py下所有的 shuffle改为false, 并注释掉所有的persistent_workers=True if num_workers > 0 else False,
train_sampler=LazyCall(CyclicSampler)(shuffle=False)
...
# persistent_workers=True if num_workers > 0 else False,
- 修改
projects/RWKV_v4/dataset/dataset.py的57行
#i = np.random.randint(0, len(self.data) - (self.ctx_len + 1))
i = 1
- 在
libai/engine/trainer.py213行添加写loss txt的代码
total_losses_reduced = sum(metrics_dict.values())
if dist.is_main_process():
txt = open("/home/zhangxiaoyu/libai_bfp16.txt", "a")
txt.write(str(total_losses_reduced.item())+"\n")
跑loss txt的运行指令:
bash tools/train.sh projects/RWKV_v4/train_net.py projects/RWKV_v4/configs/config_test.py 1
和 rwkv4 分支的loss做对比,结果如下:

正确性和rwkv4分支一致。
这个分支 根据用户反馈还有些吞吐和显存的问题.
在解决问题的过程中可能需要修改一下这个分支的代码.
等问题解决了以后, 再让大家review吧