TextGAN-PyTorch
TextGAN-PyTorch copied to clipboard
question about seqgan rollout code
Hi @williamSYSU,
Thanks for the work.
I found a potential bug in your seqgan rollout code.
In the below code snippet in function get_reward in utils/rollout.py,
rewards = torch.zeros([rollout_num * self.max_seq_len, batch_size]).float()
......
rewards = torch.mean(rewards.view(batch_size, self.max_seq_len, rollout_num), dim=-1)
the reward tensor is reshaped from [rollout_num, max_seq_len, batch_size] to [batch_size, max_seq_len, rollout_num] and then (is expected to be) reduced at rollout_num. However, the tensor would have a different layout after the view as expected, which means the reduce would be performed erroneously.
To correct this error, I think there needs to be a transpose operation before view.
Looking forward to your reply.