controlled-text-generation icon indicating copy to clipboard operation
controlled-text-generation copied to clipboard

Reproducing Hu, et. al., ICML 2017's "Toward Controlled Generation of Text"

Results 8 controlled-text-generation issues
Sort by recently updated
recently updated
newest added

I notice there are two backpropagations for the generator and encoder. https://github.com/wiseodd/controlled-text-generation/blob/master/train_discriminator.py#L120-L122 https://github.com/wiseodd/controlled-text-generation/blob/master/train_discriminator.py#L130-L132 After the back-propagation of loss G, it runs zero_grad to clear all the grads of the generator...

Hi @wiseodd, thanks for your open-source code. I find the `forward_encoder_embed` function in the `model.py` module cannot support inputs with variable length. In your code, it seems that you assume...

There is a small problem with the model.py file. Python reported an error when it goes to line 297, where there is a missing argument for the multinomial function.

Traceback (most recent call last): File "train_discriminator.py", line 172, in main() File "train_discriminator.py", line 140, in main loss_G.backward() File "/remote-home/yrchen/anaconda3/envs/py37_cuda8/lib/python3.7/site-packages/torch/tensor.py", line 102, in backward torch.autograd.backward(self, gradient, retain_graph, create_graph) File "/remote-home/yrchen/anaconda3/envs/py37_cuda8/lib/python3.7/site-packages/torch/autograd/__init__.py",...

Fix TypeError: multinomial() missing 1 required positional arguments:"num_samples" Fix RuntimeError: cudnn RNN backward can only be called in training mode

can you give some explanations of the KL divergence term? I am a little bit confused kl_loss = torch.mean(0.5 * torch.sum(torch.exp(logvar) + mu**2 - 1 - logvar, 1)) Thank you...

I tried to run train_vae.py ``` Traceback (most recent call last): File "train_vae.py", line 40, in dataset = SST_Dataset() File "/n/w1-bjayakumar/Others_Models/controlled-text-generation/ctextgen/dataset.py", line 8, in __init__ self.TEXT = data.Field(init_token='', eos_token='', lower=True,...