torchdrug
torchdrug copied to clipboard
NBFNet Loading Issue.
Hi! There is currently some bug in loading state dicts that contains data.Graph. For NBFNet, since these two attributes are static and can be computed from the dataset, you may just pop them from the state dict and load the rest of the state dict like
state = torch.load(checkpoint_file, map_location=solver.device)
state["model"].pop("graph")
state["model"].pop("fact_graph")
solver.model.load_state_dict(state["model"], strict=False)
Originally posted by @KiddoZhu in https://github.com/DeepGraphLearning/torchdrug/issues/89#issuecomment-1094501984
The error persists in the current iterations as well. I think this hack should be added to the documentation.
We updated this in the newest version of A*Net, which also includes a new implementation of NBFNet.