SKNet
SKNet copied to clipboard
error: attention_vectors = torch.cat([attention_vectors, vector], dim=1)
我认为这个操作和paper里不符合,应该是如下: batch_size, ch = 5, 3 feat_a = torch.randn(batch_size, ch) feat_a.unsqueeze_(-1) feat_b = torch.randn(batch_size, ch) feat_b.unsqueeze_(-1) feat_tmp = torch.cat([feat_a, feat_b], dim=-1) feat_softmax = torch.softmax(feat_tmp, dim=-1) weight_a = feat_softmax[:, :, 0].squeeze() weight_b = feat_softmax[:, :, 1].squeeze() feat_c = weight_a * feat_U1 + weight_b * feat_U2