TextGAN-PyTorch icon indicating copy to clipboard operation
TextGAN-PyTorch copied to clipboard

question about seqgan rollout code

Open dldaisy opened this issue 4 years ago • 0 comments

Hi @williamSYSU, Thanks for the work. I found a potential bug in your seqgan rollout code. In the below code snippet in function get_reward in utils/rollout.py,

rewards = torch.zeros([rollout_num * self.max_seq_len, batch_size]).float()
......
rewards = torch.mean(rewards.view(batch_size, self.max_seq_len, rollout_num), dim=-1)

the reward tensor is reshaped from [rollout_num, max_seq_len, batch_size] to [batch_size, max_seq_len, rollout_num] and then (is expected to be) reduced at rollout_num. However, the tensor would have a different layout after the view as expected, which means the reduce would be performed erroneously.
To correct this error, I think there needs to be a transpose operation before view. Looking forward to your reply.

dldaisy avatar Jan 05 '22 09:01 dldaisy