libai icon indicating copy to clipboard operation
libai copied to clipboard

Rwkv v4 merge main

Open CPFLAME opened this issue 3 years ago • 2 comments

rwkv 代码修正版:

跑loss对齐时的需要手动修改的代码:

  • 注释libai/engine/default.py 的 729~730行:
        # Global scheduler cfg
        # cfg.train.scheduler.warmup_iter = cfg.train.warmup_iter
        # cfg.train.scheduler.max_iter = cfg.train.train_iter
  • libai/data/build.py 下所有的 shuffle改为false, 并注释掉所有的persistent_workers=True if num_workers > 0 else False,
train_sampler=LazyCall(CyclicSampler)(shuffle=False)
...
# persistent_workers=True if num_workers > 0 else False,
  • 修改projects/RWKV_v4/dataset/dataset.py 的57行
#i = np.random.randint(0, len(self.data) - (self.ctx_len + 1))
i = 1
  • libai/engine/trainer.py213行添加写loss txt的代码
            total_losses_reduced = sum(metrics_dict.values())
            if dist.is_main_process():
                txt = open("/home/zhangxiaoyu/libai_bfp16.txt", "a")
                txt.write(str(total_losses_reduced.item())+"\n")

跑loss txt的运行指令:

bash tools/train.sh projects/RWKV_v4/train_net.py projects/RWKV_v4/configs/config_test.py 1

CPFLAME avatar Aug 26 '22 02:08 CPFLAME

和 rwkv4 分支的loss做对比,结果如下:

图片

正确性和rwkv4分支一致。

BBuf avatar Aug 26 '22 08:08 BBuf

这个分支 根据用户反馈还有些吞吐和显存的问题.
在解决问题的过程中可能需要修改一下这个分支的代码. 等问题解决了以后, 再让大家review吧

CPFLAME avatar Aug 29 '22 07:08 CPFLAME