SKNet icon indicating copy to clipboard operation
SKNet copied to clipboard

error: attention_vectors = torch.cat([attention_vectors, vector], dim=1)

Open dangxusheng opened this issue 5 years ago • 0 comments

我认为这个操作和paper里不符合,应该是如下: batch_size, ch = 5, 3 feat_a = torch.randn(batch_size, ch) feat_a.unsqueeze_(-1) feat_b = torch.randn(batch_size, ch) feat_b.unsqueeze_(-1) feat_tmp = torch.cat([feat_a, feat_b], dim=-1) feat_softmax = torch.softmax(feat_tmp, dim=-1) weight_a = feat_softmax[:, :, 0].squeeze() weight_b = feat_softmax[:, :, 1].squeeze() feat_c = weight_a * feat_U1 + weight_b * feat_U2

dangxusheng avatar Dec 14 '20 11:12 dangxusheng