SuperNormal
SuperNormal copied to clipboard
i think there's a bug in loading the checkpoint
the code is not loading the checkpoint correctly. since its only loading the state_dict, attributes like the the bandwidth are not correctly loaded.
The following code may help load/save the ckpt.
`
def load_checkpoint(self, checkpoint_name):
checkpoint = torch.load(os.path.join(self.base_exp_dir, 'checkpoints', checkpoint_name), map_location=self.device)
self.sdf_network.load_state_dict(checkpoint['sdf_network_fine'])
self.deviation_network.load_state_dict(checkpoint['variance_network_fine'])
self.optimizer.load_state_dict(checkpoint['optimizer'])
self.iter_step = checkpoint['iter_step']
self.renderer.occupancy_grid.load_state_dict(checkpoint['occupancy'])
self.sdf_network.bindwidth = checkpoint['wb']
def save_checkpoint(self):
checkpoint = {
'sdf_network_fine': self.sdf_network.state_dict(),
'variance_network_fine': self.deviation_network.state_dict(),
'optimizer': self.optimizer.state_dict(),
'iter_step': self.iter_step,
'occupancy': self.renderer.occupancy_grid.state_dict(),
'wb': self.sdf_network.bindwidth
}
os.makedirs(os.path.join(self.base_exp_dir, 'checkpoints'), exist_ok=True)
torch.save(checkpoint, os.path.join(self.base_exp_dir, 'checkpoints', 'ckpt_{:0>6d}.pth'.format(self.iter_step)))
`