BGE-M3 unify微调时,forward函数中self.use_inbatch_neg的实际含义?
@hotchpotch @yzhliu @zh217 @neofung 首先,特别感谢您的优秀工作。在学习您的工作时遇到了理解上的问题,想请教您。如题所述,对于这个含义有点不清楚,请帮忙解答一下。我的理解如下,请帮忙看看是否正确:use_inbatch_neg:同一个batch中query对应的neg数据是否参与loss计算。
具体分析其中的代码(FlagEmbedding/baai_general_embedding/finetune/modeling.py forward函数),又有些不理解的地方:
- if self.use_inbatch_neg: 为每个查询创建了一个查询索引,将查询索引乘以 group_size,确保了每个查询都指向其对应的文档组的第一个文档,这可以视为正样本。但是后续的loss计算为self.compute_loss(scores, target),我的理解是只计算了正样本与query之间的loss,这里并没有体现出use_inbatch_neg
- else:(也即 not self.use_inbatch_neg) 为每个查询创建了一个查询索引0, 表示每个查询只考虑第一个文档。第一个文档作为batch中的第一个文档,它只与第一个query对应,且为第一个query的正样本。将第一个query的正样本与所有的query求loss,这样是为了区分不同的query吗?那为什么不用第二个query的第一个文档与所有的query求loss呢?
非常期待您的回复。
if self.use_inbatch_neg: 为每个查询创建了一个查询索引,将查询索引乘以 group_size,确保了每个查询都指向其对应的文档组的第一个文档,这可以视为正样本。但是后续的loss计算为self.compute_loss(scores, target),我的理解是只计算了正样本与query之间的loss,这里并没有体现出use_inbatch_neg
这里的scores 里是每个query对所有passage(包括in-batch的passage)的分数,
else:(也即 not self.use_inbatch_neg) 为每个查询创建了一个查询索引0, 表示每个查询只考虑第一个文档。第一个文档作为batch中的第一个文档,它只与第一个query对应,且为第一个query的正样本。将第一个query的正样本与所有的query求loss,这样是为了区分不同的query吗?那为什么不用第二个query的第一个文档与所有的query求loss呢?
这里的scores里只有每个query和其对应group里样本的分数,
@staoxiao 感谢回复。我再跟您确认一下我的理解是不是对的:
- if self.use_inbatch_neg 中的loss是所有passage和query对应的正样本的loss
- else中的loss是query对应的group中的passage和query对应的正样本的loss