WebGLM icon indicating copy to clipboard operation
WebGLM copied to clipboard

关于train_retriever.py中的loss

Open llllooong opened this issue 2 years ago • 2 comments

麻烦问一下train_retriever.py文件中第44行求loss的函数中,cross_entropy的训练target为什么是是torch.arange(0, len(l_pos)呀? image

llllooong avatar Jul 20 '23 10:07 llllooong

  1. 每一条训练数据包含一条强关联(作为 positive sample)与弱关联(作为 hard negative sample)。
  2. 训练过程中,若 batchsize 为 $n$,则同一个 batch 内将包含 $n$ 条 positive sample 和 $n$ 条 hard negative sample,对于每一条数据而言,只有它的 positive sample 是正例,其余 $2n - 1$ 条全都是负例。
  3. 将这 2n 条数据按 $(pos_1, \cdots, pos_n, neg_1, \cdots, neg_n)$ 的方式拼接起来后,第 $i$ 条数据的正样本 index 即为 $i$。

Longin-Yu avatar Jul 24 '23 08:07 Longin-Yu

那是不是得确保一个batch内,尽量少有相似问题?

llllooong avatar Jul 27 '23 02:07 llllooong