SuperNormal icon indicating copy to clipboard operation
SuperNormal copied to clipboard

i think there's a bug in loading the checkpoint

Open therealron opened this issue 1 year ago • 1 comments

the code is not loading the checkpoint correctly. since its only loading the state_dict, attributes like the the bandwidth are not correctly loaded.

therealron avatar Jul 12 '24 04:07 therealron

The following code may help load/save the ckpt.

`

def load_checkpoint(self, checkpoint_name):
    checkpoint = torch.load(os.path.join(self.base_exp_dir, 'checkpoints', checkpoint_name), map_location=self.device)
    self.sdf_network.load_state_dict(checkpoint['sdf_network_fine'])
    self.deviation_network.load_state_dict(checkpoint['variance_network_fine'])
    self.optimizer.load_state_dict(checkpoint['optimizer'])
    self.iter_step = checkpoint['iter_step']
    self.renderer.occupancy_grid.load_state_dict(checkpoint['occupancy'])
    self.sdf_network.bindwidth = checkpoint['wb']

def save_checkpoint(self):
    checkpoint = {
        'sdf_network_fine': self.sdf_network.state_dict(),
        'variance_network_fine': self.deviation_network.state_dict(),
        'optimizer': self.optimizer.state_dict(),
        'iter_step': self.iter_step,
        'occupancy': self.renderer.occupancy_grid.state_dict(),
        'wb': self.sdf_network.bindwidth
    }

    os.makedirs(os.path.join(self.base_exp_dir, 'checkpoints'), exist_ok=True)
    torch.save(checkpoint, os.path.join(self.base_exp_dir, 'checkpoints', 'ckpt_{:0>6d}.pth'.format(self.iter_step)))

`

Terry10086 avatar Oct 16 '25 12:10 Terry10086