practical-pytorch icon indicating copy to clipboard operation
practical-pytorch copied to clipboard

Error in Masked Cross Entropy

Open cheekala opened this issue 8 years ago • 2 comments

I have built a seq to seq code in batch mode. I am facing run time issues in training.

Code: from masked_cross_entropy import * (downloaded from here https://github.com/spro/practical-pytorch/blob/master/seq2seq-translation/masked_cross_entropy.py )

Loss calculation and backpropagation

print(all_decoder_outputs.transpose(0, 1).contiguous().size() , target_batch.transpose(0, 1).contiguous().size())
loss = masked_cross_entropy(
    all_decoder_outputs.transpose(0, 1).contiguous(), # -> batch x seq
    target_batch.transpose(0, 1).contiguous(), # -> batch x seq
    target_batch_length
)
loss.backward()

The error is following : (128 is batch size , 6 is number of words in the sentences in the batch , 42005 is the vocabulary , 15 is the maximum length of words in a sentence allowed)

torch.Size([128, 6, 42005]) torch.Size([128, 15])

/home/ubuntu/masked_cross_entropy.py in masked_cross_entropy(logits, target, length) 41 target_flat = target.view(-1, 1) 42 # losses_flat: (batch * max_len, 1) ---> 43 losses_flat = -torch.gather(log_probs_flat, dim=1, index=target_flat) 44 # losses: (batch, max_len) 45 losses = losses_flat.view(*target.size())

/home/ubuntu/anaconda3/envs/tensorflow/lib/python3.6/site-packages/torch/autograd/variable.py in gather(self, dim, index) 621 622 def gather(self, dim, index): --> 623 return Gather(dim)(self, index) 624 625 def scatter(self, dim, index, source):

/home/ubuntu/anaconda3/envs/tensorflow/lib/python3.6/site-packages/torch/autograd/_functions/tensor.py in forward(self, input, index) 539 self.input_size = input.size() 540 self.save_for_backward(index) --> 541 return input.gather(self.dim, index) 542 543 def backward(self, grad_output):

RuntimeError: Input tensor must have same size as output tensor apart from the specified dimension at /py/conda-bld/pytorch_1493681908901/work/torch/lib/THC/generic/THCTensorScatterGather.cu:29

cheekala avatar Jul 19 '17 11:07 cheekala

i have the same issue, did you resolve?

kdrivas avatar Jan 06 '18 01:01 kdrivas

Make sure all_decoder_outputs.size(1) is the same as target_batch.size(1), which both mean max_length.

zhongpeixiang avatar Jan 24 '18 07:01 zhongpeixiang