SoftMaskedBert
SoftMaskedBert copied to clipboard
detector_correct和corrector_correct计算错误
corrector_correct = corrector_correct + sum( [(output*batch_mask).reshape(-1)[j].equal((batch_out_ids*batch_mask).reshape(-1)[j]) for j in range(len(output.reshape(-1)))])
这里会把pad的0相同都计入正确,明显不合理。
何不改成:
detector_correct += (prob.squeeze().equal(batch_labels) * batch_mask).sum()