Inference from Checkpoints
Hi and thank you very much for your really helpful code! I am trying to test my trained model and have problems with the inference.py file. I specified a checkpoint stored in the ckpt folder, but I get a "KeyError: 'param'". Could you please elaborate on how to use the --model_path flag? (And in general, it would be useful to have a quick overview on how to use the inference.py file.) Thank you very much in advance and best regards.
I browsed a little bit through the history of the project and found my problem to be a result of the refactoring of the inference. I will update the code and prepare a pull request, if you don't mind
yeah, I have the same problem with " KeyErrror: 'param'", and can you update how to use inference in README pls!
@bowphs could you show me how to fix this issue!
If I remember correctly, the problem are the *_run scripts, which do not save the model properly: During inference, you try to load the model, but the keys do not exist. A quick fix would be to save the model manually in your *_run script, for example, in the WTM_run.py script, you could add something like:
save_name = f'./ckpt/WTM_{taskname}_tp{n_topic}_{dist}_{time.strftime("%Y-%m-%d-%H-%M", time.localtime())}.ckpt'
checkpoint = {
"net": model.wae.state_dict(),
"optimizer": model.optimizer.state_dict(),
"epoch": num_epochs,
"param": {
"bow_dim": voc_size,
"n_topic": n_topic,
"taskname": taskname,
"dist": dist,
"dropout": dropout
}
}
torch.save(checkpoint, save_name)
print("Succesfully saved model. Model name: {save_name}.")
Thanks for your reply, I just try your code and it still not working for me, for other *_run file doesn't have checkpoint so idk how to handle those file in order to work. I would be really thankful if you can give me your repo which you fix in this issue. @bowphs