Neural_Topic_Models icon indicating copy to clipboard operation
Neural_Topic_Models copied to clipboard

Inference from Checkpoints

Open ghost opened this issue 4 years ago • 5 comments

Hi and thank you very much for your really helpful code! I am trying to test my trained model and have problems with the inference.py file. I specified a checkpoint stored in the ckpt folder, but I get a "KeyError: 'param'". Could you please elaborate on how to use the --model_path flag? (And in general, it would be useful to have a quick overview on how to use the inference.py file.) Thank you very much in advance and best regards.

ghost avatar Jun 14 '21 14:06 ghost

I browsed a little bit through the history of the project and found my problem to be a result of the refactoring of the inference. I will update the code and prepare a pull request, if you don't mind

ghost avatar Jun 22 '21 19:06 ghost

yeah, I have the same problem with " KeyErrror: 'param'", and can you update how to use inference in README pls!

minhkids avatar Jun 01 '22 12:06 minhkids

@bowphs could you show me how to fix this issue!

minhkids avatar Jun 02 '22 06:06 minhkids

If I remember correctly, the problem are the *_run scripts, which do not save the model properly: During inference, you try to load the model, but the keys do not exist. A quick fix would be to save the model manually in your *_run script, for example, in the WTM_run.py script, you could add something like:

save_name = f'./ckpt/WTM_{taskname}_tp{n_topic}_{dist}_{time.strftime("%Y-%m-%d-%H-%M", time.localtime())}.ckpt'

    checkpoint = {
        "net": model.wae.state_dict(),
        "optimizer": model.optimizer.state_dict(),
        "epoch": num_epochs,
        "param": {
            "bow_dim": voc_size,
            "n_topic": n_topic,
            "taskname": taskname,
            "dist": dist,
            "dropout": dropout
        }
    }
    torch.save(checkpoint, save_name)
    print("Succesfully saved model. Model name: {save_name}.")

ghost avatar Jun 02 '22 06:06 ghost

Thanks for your reply, I just try your code and it still not working for me, for other *_run file doesn't have checkpoint so idk how to handle those file in order to work. I would be really thankful if you can give me your repo which you fix in this issue. @bowphs

minhkids avatar Jun 07 '22 03:06 minhkids